Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import spaces | |
from omegaconf import OmegaConf | |
import subprocess | |
rc = subprocess.call("./setup.sh") | |
import sys | |
import os | |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama')) | |
from lama.saicinpainting.evaluation.refinement import refine_predict | |
from lama.saicinpainting.training.trainers import load_checkpoint | |
from lama.saicinpainting.evaluation.utils import move_to_device | |
# Load the model | |
def get_inpaint_model(): | |
""" | |
Loads and initializes the inpainting model. | |
Returns: Tuple of (model, predict_config) | |
""" | |
predict_config = OmegaConf.load('./default.yaml') | |
predict_config.model.path = './big-lama/models/' | |
predict_config.refiner.gpu_ids = '0' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Instead of setting device directly, we'll use it when loading the model | |
predict_config.device = str(device) # Store as string in config | |
train_config_path = './big-lama/config.yaml' | |
train_config = OmegaConf.load(train_config_path) | |
train_config.training_model.predict_only = True | |
train_config.visualizer.kind = 'noop' | |
checkpoint_path = os.path.join(predict_config.model.path, | |
predict_config.model.checkpoint) | |
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device) | |
model.freeze() | |
model.to(device) | |
return model, predict_config | |
def inpaint(input_dict, refinement_enabled=False): | |
""" | |
Performs image inpainting on the input image using the provided mask. | |
Args: input_dict containing 'background' (image) and 'layers' (mask) | |
Returns: Tuple of (output_image, input_mask) | |
""" | |
input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255 | |
input_mask = pil_to_binary_mask(input_dict['layers'][0]) | |
np_input_image = np.transpose(np.array(input_image), (2, 0, 1)) | |
np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images | |
batch = dict(image=np_input_image, mask=np_input_mask) | |
inpaint_model, predict_config = get_inpaint_model() | |
device = torch.device(predict_config.device) | |
batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])] | |
batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device) | |
batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device) | |
if refinement_enabled is True: | |
cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner) | |
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy() | |
else: | |
with torch.no_grad(): | |
batch = move_to_device(batch, device) | |
batch['mask'] = (batch['mask'] > 0) * 1 | |
batch = inpaint_model(batch) | |
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() | |
unpad_to_size = batch.get('unpad_to_size', None) | |
if unpad_to_size is not None: | |
orig_height, orig_width = unpad_to_size | |
cur_res = cur_res[:orig_height, :orig_width] | |
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') | |
output_image = Image.fromarray(cur_res) | |
return output_image | |
def ceil_modulo(x, mod): | |
if x % mod == 0: | |
return x | |
return (x // mod + 1) * mod | |
def pad_img_to_modulo(img, mod): | |
channels, height, width = img.shape | |
out_height = ceil_modulo(height, mod) | |
out_width = ceil_modulo(width, mod) | |
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') | |
def pil_to_binary_mask(pil_image, threshold=0, max_scale=1): | |
""" | |
Converts a PIL image to a binary mask. | |
Args: | |
pil_image (PIL.Image): The input PIL image. | |
threshold (int, optional): The threshold value for binarization. Defaults to 0. | |
Returns: | |
PIL.Image: A grayscale PIL image representing the binary mask. | |
""" | |
np_image = np.array(pil_image) | |
grayscale_image = Image.fromarray(np_image).convert("L") | |
binary_mask = np.array(grayscale_image) > threshold | |
mask = np.zeros(binary_mask.shape, dtype=np.uint8) | |
for i in range(binary_mask.shape[0]): | |
for j in range(binary_mask.shape[1]): | |
if binary_mask[i,j] == True : | |
mask[i,j] = 1 | |
mask = (mask*max_scale).astype(np.uint8) | |
output_mask = Image.fromarray(mask) | |
# Convert mask to grayscale | |
return output_mask.convert("L") | |
css = ".output-image, .input-image, .image-preview {height: 600px !important}" | |
# Create Gradio interface | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Image Inpainting") | |
gr.Markdown("Upload an image and draw a mask to remove unwanted objects.") | |
with gr.Row(): | |
input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25)) | |
output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto") | |
with gr.Row(): | |
refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False) | |
inpaint_button = gr.Button("Inpaint") | |
def inpaint_with_refinement(image, enable_refinement): | |
return inpaint(image, refinement_enabled=enable_refinement) | |
inpaint_button.click( | |
fn=inpaint_with_refinement, | |
inputs=[input_image, refine_checkbox], | |
outputs=[output_image] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch() | |