divimund95 commited on
Commit
d911050
1 Parent(s): cc8944c

use original PyTorch model for inference

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ lama/
2
+ big-lama/
3
+
LaMa.mlpackage/Data/com.apple.CoreML/model.mlmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:289f2c611bd3e52805ee3e686e290981d96d3b9674db93fe6bf30962f7e60d87
3
- size 1166404
 
 
 
 
LaMa.mlpackage/Data/com.apple.CoreML/weights/weight.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:aae26da8deca02ead81120f1d683b6c38361cd593c5a685e543c4b84726500e1
3
- size 204086656
 
 
 
 
LaMa.mlpackage/Manifest.json DELETED
@@ -1,18 +0,0 @@
1
- {
2
- "fileFormatVersion": "1.0.0",
3
- "itemInfoEntries": {
4
- "058403EC-D454-47EC-9C08-D1149DC8311C": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Specification",
7
- "name": "model.mlmodel",
8
- "path": "com.apple.CoreML/model.mlmodel"
9
- },
10
- "BCCB46DC-D6B9-4B28-8D24-B59CF8160E49": {
11
- "author": "com.apple.CoreML",
12
- "description": "CoreML Model Weights",
13
- "name": "weights",
14
- "path": "com.apple.CoreML/weights"
15
- }
16
- },
17
- "rootModelIdentifier": "058403EC-D454-47EC-9C08-D1149DC8311C"
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,30 +1,101 @@
1
  import gradio as gr
2
- import coremltools as ct
3
  import numpy as np
 
4
  from PIL import Image
5
  import io
 
 
 
 
 
 
 
 
 
6
 
7
  # Load the model
8
- coreml_model_file_name = "LaMa.mlpackage"
9
- loaded_model = ct.models.MLModel(coreml_model_file_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def inpaint(input_dict):
12
- # Resize input image and mask to 800x800
13
- input_image = input_dict["background"].convert("RGB").resize((800, 800), Image.LANCZOS)
14
- input_mask = pil_to_binary_mask(input_dict['layers'][0].resize((800, 800), Image.NEAREST))
 
 
 
 
15
 
16
- # Convert mask to grayscale
17
- input_mask = input_mask.convert("L")
 
 
 
 
18
 
19
- # Run inference
20
- prediction = loaded_model.predict({"image": input_image, "mask": input_mask})
21
 
22
- # Access the output
23
- output_image = prediction["output"]
24
 
25
- return output_image, input_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def pil_to_binary_mask(pil_image, threshold=0):
 
 
 
 
 
 
 
 
 
 
28
  np_image = np.array(pil_image)
29
  grayscale_image = Image.fromarray(np_image).convert("L")
30
  binary_mask = np.array(grayscale_image) > threshold
@@ -35,7 +106,8 @@ def pil_to_binary_mask(pil_image, threshold=0):
35
  mask[i,j] = 1
36
  mask = (mask*255).astype(np.uint8)
37
  output_mask = Image.fromarray(mask)
38
- return output_mask
 
39
 
40
  # Create Gradio interface
41
  with gr.Blocks() as demo:
@@ -43,13 +115,13 @@ with gr.Blocks() as demo:
43
  gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
44
 
45
  with gr.Row():
46
- input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True)
47
  output_image = gr.Image(type="pil", label="Output Image")
48
- with gr.Column():
49
- masked_image = gr.Image(label="Masked image", type="pil")
50
 
51
  inpaint_button = gr.Button("Inpaint")
52
- inpaint_button.click(fn=inpaint, inputs=[input_image], outputs=[output_image, masked_image])
53
 
54
  # Launch the interface
55
  if __name__ == "__main__":
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ import torch
4
  from PIL import Image
5
  import io
6
+ from omegaconf import OmegaConf
7
+
8
+ import sys
9
+ import os
10
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama'))
11
+
12
+ from lama.saicinpainting.evaluation.refinement import refine_predict
13
+ from lama.saicinpainting.training.trainers import load_checkpoint
14
+
15
 
16
  # Load the model
17
+ def get_inpaint_model():
18
+ """
19
+ Loads and initializes the inpainting model.
20
+ Returns: Tuple of (model, predict_config)
21
+ """
22
+ predict_config = OmegaConf.load('./default.yaml')
23
+ predict_config.model.path = './big-lama/models/'
24
+ predict_config.refiner.gpu_ids = '0'
25
+
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ # Instead of setting device directly, we'll use it when loading the model
28
+ predict_config.device = str(device) # Store as string in config
29
+ train_config_path = './big-lama/config.yaml'
30
+
31
+ train_config = OmegaConf.load(train_config_path)
32
+ train_config.training_model.predict_only = True
33
+ train_config.visualizer.kind = 'noop'
34
+
35
+ checkpoint_path = os.path.join(predict_config.model.path,
36
+ predict_config.model.checkpoint)
37
+
38
+ model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device)
39
+ model.freeze()
40
+ model.to(device)
41
+ return model, predict_config
42
 
43
  def inpaint(input_dict):
44
+ """
45
+ Performs image inpainting on the input image using the provided mask.
46
+ Args: input_dict containing 'background' (image) and 'layers' (mask)
47
+ Returns: Tuple of (output_image, input_mask)
48
+ """
49
+ input_image = input_dict["background"].convert("RGB")
50
+ input_mask = pil_to_binary_mask(input_dict['layers'][0])
51
 
52
+ # TODO: check if this is correct; (C,H,W) or (H,W,C)
53
+
54
+ # batch = dict(image=input_image, mask=input_mask[None, ...])
55
+ np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
56
+ np_input_mask = np.array(input_mask)[None, :, :] # Add channel dimension for grayscale images
57
+ batch = dict(image=np_input_image, mask=np_input_mask)
58
 
59
+ print('lol', batch['image'].shape)
60
+ print('lol', batch['mask'].shape)
61
 
62
+ inpaint_model, predict_config = get_inpaint_model()
63
+ device = torch.device(predict_config.device)
64
 
65
+ batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])]
66
+ batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
67
+ batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)
68
+
69
+ cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
70
+ cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
71
+
72
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
73
+ output_image = Image.fromarray(cur_res)
74
+
75
+ return output_image
76
+
77
+ def ceil_modulo(x, mod):
78
+ if x % mod == 0:
79
+ return x
80
+ return (x // mod + 1) * mod
81
+
82
+ def pad_img_to_modulo(img, mod):
83
+ channels, height, width = img.shape
84
+ out_height = ceil_modulo(height, mod)
85
+ out_width = ceil_modulo(width, mod)
86
+ return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
87
 
88
  def pil_to_binary_mask(pil_image, threshold=0):
89
+ """
90
+ Converts a PIL image to a binary mask.
91
+
92
+ Args:
93
+ pil_image (PIL.Image): The input PIL image.
94
+ threshold (int, optional): The threshold value for binarization. Defaults to 0.
95
+
96
+ Returns:
97
+ PIL.Image: A grayscale PIL image representing the binary mask.
98
+ """
99
  np_image = np.array(pil_image)
100
  grayscale_image = Image.fromarray(np_image).convert("L")
101
  binary_mask = np.array(grayscale_image) > threshold
 
106
  mask[i,j] = 1
107
  mask = (mask*255).astype(np.uint8)
108
  output_mask = Image.fromarray(mask)
109
+ # Convert mask to grayscale
110
+ return output_mask.convert("L")
111
 
112
  # Create Gradio interface
113
  with gr.Blocks() as demo:
 
115
  gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
116
 
117
  with gr.Row():
118
+ input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto")
119
  output_image = gr.Image(type="pil", label="Output Image")
120
+ # with gr.Column():
121
+ # masked_image = gr.Image(label="Masked image", type="pil")
122
 
123
  inpaint_button = gr.Button("Inpaint")
124
+ inpaint_button.click(fn=inpaint, inputs=[input_image], outputs=[output_image])
125
 
126
  # Launch the interface
127
  if __name__ == "__main__":
default.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ indir: no # to be overriden in CLI
2
+ outdir: no # to be overriden in CLI
3
+
4
+ model:
5
+ path: no # to be overriden in CLI
6
+ checkpoint: best.ckpt
7
+
8
+ dataset:
9
+ kind: default
10
+ img_suffix: .png
11
+ pad_out_to_modulo: 8
12
+
13
+ device: cuda
14
+ out_key: inpainted
15
+
16
+ refine: False # refiner will only run if this is True
17
+ refiner:
18
+ gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0,"
19
+ modulo: ${dataset.pad_out_to_modulo}
20
+ n_iters: 15 # number of iterations of refinement for each scale
21
+ lr: 0.002 # learning rate
22
+ min_side: 512 # all sides of image on all scales should be >= min_side / sqrt(2)
23
+ max_scales: 3 # max number of downscaling scales for the image-mask pyramid
24
+ px_budget: 1800000 # pixels budget. Any image will be resized to satisfy height*width <= px_budget
enter_env.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Initialize conda
4
+ eval "$(conda shell.bash hook)"
5
+
6
+ # Activate the cleanup environment
7
+ conda activate cleanup
8
+
9
+ # Additional commands or environment setup can be added here
10
+
11
+ export TORCH_HOME=$(pwd) && export PYTHONPATH=$(pwd)
requirements.txt CHANGED
@@ -1,4 +1,21 @@
1
  gradio
2
- coremltools
3
  numpy
4
- pillow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
 
2
  numpy
3
+ pillow
4
+ pyyaml
5
+ tqdm
6
+ easydict==1.9.0
7
+ scikit-image
8
+ scikit-learn
9
+ opencv-python
10
+ tensorflow
11
+ joblib
12
+ matplotlib
13
+ pandas
14
+ albumentations==0.5.2
15
+ hydra-core==1.1.0
16
+ pytorch-lightning==1.2.9
17
+ tabulate
18
+ kornia==0.5.0
19
+ webdataset
20
+ packaging
21
+ wldhx.yadisk-direct
setup.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conda create -n cleanup python=3.10 -y
2
+ conda activate cleanup
3
+ # conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia
4
+ conda install pytorch torchvision -c pytorch -y
5
+
6
+ pip install -r requirements.txt
7
+
8
+
9
+ # Clone dependency repos
10
+ git clone https://github.com/advimman/lama.git
11
+
12
+ # Download big-lama model
13
+ curl -LJO https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
14
+ unzip big-lama.zip