PictureCleanUp / app.py
divimund95's picture
disable tensorflow library
6f3f66a
import gradio as gr
import numpy as np
import torch
from PIL import Image
import spaces
from omegaconf import OmegaConf
import subprocess
rc = subprocess.call("./setup.sh")
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama'))
from lama.saicinpainting.evaluation.refinement import refine_predict
from lama.saicinpainting.training.trainers import load_checkpoint
from lama.saicinpainting.evaluation.utils import move_to_device
# Load the model
def get_inpaint_model():
"""
Loads and initializes the inpainting model.
Returns: Tuple of (model, predict_config)
"""
predict_config = OmegaConf.load('./default.yaml')
predict_config.model.path = './big-lama/models/'
predict_config.refiner.gpu_ids = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Instead of setting device directly, we'll use it when loading the model
predict_config.device = str(device) # Store as string in config
train_config_path = './big-lama/config.yaml'
train_config = OmegaConf.load(train_config_path)
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
checkpoint_path = os.path.join(predict_config.model.path,
predict_config.model.checkpoint)
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device)
model.freeze()
model.to(device)
return model, predict_config
@spaces.GPU
def inpaint(input_dict, refinement_enabled=False):
"""
Performs image inpainting on the input image using the provided mask.
Args: input_dict containing 'background' (image) and 'layers' (mask)
Returns: Tuple of (output_image, input_mask)
"""
input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255
input_mask = pil_to_binary_mask(input_dict['layers'][0])
np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images
batch = dict(image=np_input_image, mask=np_input_mask)
inpaint_model, predict_config = get_inpaint_model()
device = torch.device(predict_config.device)
batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])]
batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)
if refinement_enabled is True:
cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = inpaint_model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
output_image = Image.fromarray(cur_res)
return output_image
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
def pil_to_binary_mask(pil_image, threshold=0, max_scale=1):
"""
Converts a PIL image to a binary mask.
Args:
pil_image (PIL.Image): The input PIL image.
threshold (int, optional): The threshold value for binarization. Defaults to 0.
Returns:
PIL.Image: A grayscale PIL image representing the binary mask.
"""
np_image = np.array(pil_image)
grayscale_image = Image.fromarray(np_image).convert("L")
binary_mask = np.array(grayscale_image) > threshold
mask = np.zeros(binary_mask.shape, dtype=np.uint8)
for i in range(binary_mask.shape[0]):
for j in range(binary_mask.shape[1]):
if binary_mask[i,j] == True :
mask[i,j] = 1
mask = (mask*max_scale).astype(np.uint8)
output_mask = Image.fromarray(mask)
# Convert mask to grayscale
return output_mask.convert("L")
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("# Image Inpainting")
gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
with gr.Row():
input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25))
output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto")
with gr.Row():
refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False)
inpaint_button = gr.Button("Inpaint")
def inpaint_with_refinement(image, enable_refinement):
return inpaint(image, refinement_enabled=enable_refinement)
inpaint_button.click(
fn=inpaint_with_refinement,
inputs=[input_image, refine_checkbox],
outputs=[output_image]
)
# Launch the interface
if __name__ == "__main__":
demo.launch()