import io import os import sys from typing import Union import cv2 import torch import numpy as np from PIL import Image sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") # 判断sys.argv 的长度,如果小于2,说明没有传入参数,设置初始参数 # if len(sys.argv) < 2: # sys.argv = [ # "C:/Users/27698/Desktop/LAITool/resources/scripts/lama/lama_inpaint.exe", # "-l", # "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\temp\\1717508661218.png", # "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\mask_temp_1717508662659.png", # "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\temp\\1717508564042.png", # ] print(sys.argv) if getattr(sys, "frozen", False): cript_directory = os.path.dirname(sys.executable) elif __file__: cript_directory = os.path.dirname(__file__) link_name = os.path.join(os.path.expanduser("~"), "big_lama.pt") cu_name = os.path.join(cript_directory, "model\\big-lama.pt") mode_pa = link_name if len(sys.argv) < 2: # # 判断model_path是否存在,如果不存在,设置默认值 if not os.path.exists(link_name): os.system(f'mklink "{link_name}" "{cu_name}"') print("Params: ") sys.exit(0) def get_image(image): if isinstance(image, Image.Image): img = np.array(image) elif isinstance(image, np.ndarray): img = image.copy() else: raise Exception("Input image should be either PIL Image or numpy array!") if img.ndim == 3: img = np.transpose(img, (2, 0, 1)) # chw elif img.ndim == 2: img = img[np.newaxis, ...] assert img.ndim == 3 img = img.astype(np.float32) / 255 return img def ceil_modulo(x, mod): if x % mod == 0: return x return (x // mod + 1) * mod def scale_image(img, factor, interpolation=cv2.INTER_AREA): if img.shape[0] == 1: img = img[0] else: img = np.transpose(img, (1, 2, 0)) img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation) if img.ndim == 2: img = img[None, ...] else: img = np.transpose(img, (2, 0, 1)) return img def pad_img_to_modulo(img, mod): channels, height, width = img.shape out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) return np.pad( img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric", ) def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None): out_image = get_image(image) out_mask = get_image(mask) if scale_factor is not None: out_image = scale_image(out_image, 1) out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST) if pad_out_to_modulo is not None and pad_out_to_modulo > 1: out_image = pad_img_to_modulo(out_image, pad_out_to_modulo) out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo) out_image = torch.from_numpy(out_image).unsqueeze(0).to(device) out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device) out_mask = (out_mask > 0) * 1 return out_image, out_mask class LamaInpaint: def __init__( self, device, model_path=None, ) -> None: if model_path is None: model_path = os.path.join(cript_directory, "model\\big-lama.pt") self.model = torch.jit.load(model_path, map_location=device) self.model.eval() self.model.to(device) self.device = device def run( self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray], ): if isinstance(image, np.ndarray): orig_height, orig_width = image.shape[:2] else: orig_height, orig_width = np.array(image).shape[:2] # image_width = image.shape[1] # mask_width = mask.shape[1] scale = image.width / mask.width image, mask = prepare_img_and_mask(image, mask, self.device, 8, scale) with torch.inference_mode(): inpainted = self.model(image, mask) cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") cur_res = cur_res[:orig_height, :orig_width] return cur_res try: de = "cpu" if torch.cuda.is_available(): de = "cuda" lama = LamaInpaint(de, mode_pa) image_path = sys.argv[2] mask_path = sys.argv[3] output_path = sys.argv[4] # 若是没有传递mask_path,需要自己计算mask区域 # 使用Image.open打开图片 image = Image.open(image_path).convert("RGB") mask = Image.open(mask_path).convert("L") res = lama.run(image, mask) # 将修复后的图片保存到本地 img = Image.fromarray(res) # 使用 save 方法将图像保存到文件 img.save(output_path) sys.exit(0) except Exception as e: print(e) sys.exit(str(e))