import gradio as gr import numpy as np import torch from PIL import Image import spaces from omegaconf import OmegaConf import subprocess rc = subprocess.call("./setup.sh") import sys import os sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama')) from lama.saicinpainting.evaluation.refinement import refine_predict from lama.saicinpainting.training.trainers import load_checkpoint from lama.saicinpainting.evaluation.utils import move_to_device # Load the model def get_inpaint_model(): """ Loads and initializes the inpainting model. Returns: Tuple of (model, predict_config) """ predict_config = OmegaConf.load('./default.yaml') predict_config.model.path = './big-lama/models/' predict_config.refiner.gpu_ids = '0' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Instead of setting device directly, we'll use it when loading the model predict_config.device = str(device) # Store as string in config train_config_path = './big-lama/config.yaml' train_config = OmegaConf.load(train_config_path) train_config.training_model.predict_only = True train_config.visualizer.kind = 'noop' checkpoint_path = os.path.join(predict_config.model.path, predict_config.model.checkpoint) model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device) model.freeze() model.to(device) return model, predict_config @spaces.GPU def inpaint(input_dict, refinement_enabled=False): """ Performs image inpainting on the input image using the provided mask. Args: input_dict containing 'background' (image) and 'layers' (mask) Returns: Tuple of (output_image, input_mask) """ input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255 input_mask = pil_to_binary_mask(input_dict['layers'][0]) np_input_image = np.transpose(np.array(input_image), (2, 0, 1)) np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images batch = dict(image=np_input_image, mask=np_input_mask) inpaint_model, predict_config = get_inpaint_model() device = torch.device(predict_config.device) batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])] batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device) batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device) if refinement_enabled is True: cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner) cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy() else: with torch.no_grad(): batch = move_to_device(batch, device) batch['mask'] = (batch['mask'] > 0) * 1 batch = inpaint_model(batch) cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() unpad_to_size = batch.get('unpad_to_size', None) if unpad_to_size is not None: orig_height, orig_width = unpad_to_size cur_res = cur_res[:orig_height, :orig_width] cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') output_image = Image.fromarray(cur_res) return output_image def ceil_modulo(x, mod): if x % mod == 0: return x return (x // mod + 1) * mod def pad_img_to_modulo(img, mod): channels, height, width = img.shape out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') def pil_to_binary_mask(pil_image, threshold=0, max_scale=1): """ Converts a PIL image to a binary mask. Args: pil_image (PIL.Image): The input PIL image. threshold (int, optional): The threshold value for binarization. Defaults to 0. Returns: PIL.Image: A grayscale PIL image representing the binary mask. """ np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) for i in range(binary_mask.shape[0]): for j in range(binary_mask.shape[1]): if binary_mask[i,j] == True : mask[i,j] = 1 mask = (mask*max_scale).astype(np.uint8) output_mask = Image.fromarray(mask) # Convert mask to grayscale return output_mask.convert("L") css = ".output-image, .input-image, .image-preview {height: 600px !important}" # Create Gradio interface with gr.Blocks(css=css) as demo: gr.Markdown("# Image Inpainting") gr.Markdown("Upload an image and draw a mask to remove unwanted objects.") with gr.Row(): input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25)) output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto") with gr.Row(): refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False) inpaint_button = gr.Button("Inpaint") def inpaint_with_refinement(image, enable_refinement): return inpaint(image, refinement_enabled=enable_refinement) inpaint_button.click( fn=inpaint_with_refinement, inputs=[input_image, refine_checkbox], outputs=[output_image] ) # Launch the interface if __name__ == "__main__": demo.launch()