radames commited on
Commit
bb5540a
1 Parent(s): 4bf9d3c

bg removal

Browse files
Files changed (3) hide show
  1. app.py +136 -172
  2. briarmbg.py +460 -0
  3. utils.py +114 -0
app.py CHANGED
@@ -1,11 +1,18 @@
1
  import sys
2
  import os
3
  import torch
4
- from pathlib import Path
5
- from huggingface_hub import hf_hub_download
6
- from PIL import Image, ImageSequence, ImageOps
7
  from typing import List
8
  import numpy as np
 
 
 
 
 
 
 
 
9
 
10
  sys.path.append(os.path.dirname("./ComfyUI/"))
11
  from ComfyUI.nodes import (
@@ -27,20 +34,11 @@ from ComfyUI.custom_nodes.layerdiffuse.layered_diffusion import (
27
  LayeredDiffusionCond,
28
  )
29
  import gradio as gr
 
30
 
 
31
 
32
- MODEL_PATH = hf_hub_download(
33
- repo_id="lllyasviel/fav_models",
34
- subfolder="fav",
35
- filename="juggernautXL_v8Rundiffusion.safetensors",
36
- )
37
- try:
38
- os.symlink(
39
- MODEL_PATH,
40
- Path("./ComfyUI/models/checkpoints/juggernautXL_v8Rundiffusion.safetensors"),
41
- )
42
- except FileExistsError:
43
- pass
44
 
45
  with torch.inference_mode():
46
  ckpt_load_checkpoint = CheckpointLoaderSimple().load_checkpoint
@@ -58,73 +56,14 @@ ld_decode = LayeredDiffusionDecode().decode
58
  mask_to_image = MaskToImage().mask_to_image
59
  invert_mask = InvertMask().invert
60
  join_image_with_alpha = JoinImageWithAlpha().join_image_with_alpha
61
-
62
-
63
- def tensor_to_pil(images: torch.Tensor | List[torch.Tensor]) -> List[Image.Image]:
64
- if not isinstance(images, list):
65
- images = [images]
66
- imgs = []
67
- for image in images:
68
- i = 255.0 * image.cpu().numpy()
69
- img = Image.fromarray(np.clip(np.squeeze(i), 0, 255).astype(np.uint8))
70
- imgs.append(img)
71
- return imgs
72
-
73
-
74
- def pad_image(input_image):
75
- pad_w, pad_h = (
76
- np.max(((2, 2), np.ceil(np.array(input_image.size) / 64).astype(int)), axis=0)
77
- * 64
78
- - input_image.size
79
- )
80
- im_padded = Image.fromarray(
81
- np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode="edge")
82
- )
83
- w, h = im_padded.size
84
- if w == h:
85
- return im_padded
86
- elif w > h:
87
- new_image = Image.new(im_padded.mode, (w, w), (0, 0, 0))
88
- new_image.paste(im_padded, (0, (w - h) // 2))
89
- return new_image
90
- else:
91
- new_image = Image.new(im_padded.mode, (h, h), (0, 0, 0))
92
- new_image.paste(im_padded, ((h - w) // 2, 0))
93
- return new_image
94
-
95
-
96
- def pil_to_tensor(image: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
97
- output_images = []
98
- output_masks = []
99
- for i in ImageSequence.Iterator(image):
100
- i = ImageOps.exif_transpose(i)
101
- if i.mode == "I":
102
- i = i.point(lambda i: i * (1 / 255))
103
- image = i.convert("RGB")
104
- image = np.array(image).astype(np.float32) / 255.0
105
- image = torch.from_numpy(image)[None,]
106
- if "A" in i.getbands():
107
- mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
108
- mask = 1.0 - torch.from_numpy(mask)
109
- else:
110
- mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
111
- output_images.append(image)
112
- output_masks.append(mask.unsqueeze(0))
113
-
114
- if len(output_images) > 1:
115
- output_image = torch.cat(output_images, dim=0)
116
- output_mask = torch.cat(output_masks, dim=0)
117
- else:
118
- output_image = output_images[0]
119
- output_mask = output_masks[0]
120
-
121
- return (output_image, output_mask)
122
 
123
 
124
  def predict(
125
  prompt: str,
126
  negative_prompt: str,
127
  input_image: Image.Image | None,
 
128
  cond_mode: str,
129
  seed: int,
130
  sampler_name: str,
@@ -133,95 +72,115 @@ def predict(
133
  cfg: float,
134
  denoise: float,
135
  ):
136
- with torch.inference_mode():
137
- cliptextencode_prompt = cliptextencode(
138
- text=prompt,
139
- clip=ckpt[1],
140
- )
141
- cliptextencode_negative_prompt = cliptextencode(
142
- text=negative_prompt,
143
- clip=ckpt[1],
144
- )
145
- emptylatentimage_sample = emptylatentimage_generate(
146
- width=1024, height=1024, batch_size=1
147
- )
148
-
149
- if input_image is not None:
150
- img_tensor = pil_to_tensor(pad_image(input_image).resize((1024, 1024)))
151
- img_latent = vae_encode(pixels=img_tensor[0], vae=ckpt[2])
152
- layereddiffusionapply_sample = ld_cond_apply_layered_diffusion(
153
- config=cond_mode,
154
- weight=1,
155
- model=ckpt[0],
156
- cond=cliptextencode_prompt[0],
157
- uncond=cliptextencode_negative_prompt[0],
158
- latent=img_latent[0],
159
  )
160
- ksampler = ksampler_sample(
161
- steps=steps,
162
- cfg=cfg,
163
- sampler_name=sampler_name,
164
- scheduler=scheduler,
165
- seed=seed,
166
- model=layereddiffusionapply_sample[0],
167
- positive=layereddiffusionapply_sample[1],
168
- negative=layereddiffusionapply_sample[2],
169
- latent_image=emptylatentimage_sample[0],
170
- denoise=denoise,
171
- )
172
-
173
- vaedecode_sample = vae_decode(
174
- samples=ksampler[0],
175
- vae=ckpt[2],
176
  )
177
- layereddiffusiondecode_sample = ld_decode(
178
- sd_version="SDXL",
179
- sub_batch_size=16,
180
- samples=ksampler[0],
181
- images=vaedecode_sample[0],
182
  )
183
 
184
- rgb_img = tensor_to_pil(vaedecode_sample[0])
185
- return flatten([rgb_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- else:
188
- layereddiffusionapply_sample = ld_fg_apply_layered_diffusion(
189
- config="SDXL, Conv Injection", weight=1, model=ckpt[0]
190
- )
191
- ksampler = ksampler_sample(
192
- steps=steps,
193
- cfg=cfg,
194
- sampler_name=sampler_name,
195
- scheduler=scheduler,
196
- seed=seed,
197
- model=layereddiffusionapply_sample[0],
198
- positive=cliptextencode_prompt[0],
199
- negative=cliptextencode_negative_prompt[0],
200
- latent_image=emptylatentimage_sample[0],
201
- denoise=denoise,
202
- )
203
 
204
- vaedecode_sample = vae_decode(
205
- samples=ksampler[0],
206
- vae=ckpt[2],
207
- )
208
- layereddiffusiondecode_sample = ld_decode(
209
- sd_version="SDXL",
210
- sub_batch_size=16,
211
- samples=ksampler[0],
212
- images=vaedecode_sample[0],
213
- )
214
- mask = mask_to_image(mask=layereddiffusiondecode_sample[1])
215
- ld_image = tensor_to_pil(layereddiffusiondecode_sample[0][0])
216
- inverted_mask = invert_mask(mask=layereddiffusiondecode_sample[1])
217
- rgba_img = join_image_with_alpha(
218
- image=layereddiffusiondecode_sample[0], alpha=inverted_mask[0]
219
- )
220
- rgba_img = tensor_to_pil(rgba_img[0])
221
- mask = tensor_to_pil(mask[0])
222
- rgb_img = tensor_to_pil(vaedecode_sample[0])
223
 
224
- return flatten([rgba_img, mask, rgb_img, ld_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  examples = [["An old men sit on a chair looking at the sky"]]
@@ -233,18 +192,18 @@ def flatten(l: List[List[any]]) -> List[any]:
233
 
234
  def predict_examples(prompt, negative_prompt):
235
  return predict(
236
- prompt, negative_prompt, None, None, 0, "euler", "normal", 20, 8.0, 1.0
237
  )
238
 
239
 
240
  css = """
241
  .gradio-container{
242
- max-width: 60rem;
243
  }
244
  """
245
  with gr.Blocks(css=css) as blocks:
246
  gr.Markdown("""# LayerDiffuse (unofficial)
247
-
248
  """)
249
 
250
  with gr.Row():
@@ -253,12 +212,18 @@ with gr.Blocks(css=css) as blocks:
253
  negative_prompt = gr.Text(label="Negative Prompt")
254
  button = gr.Button("Generate")
255
  with gr.Accordion(open=False, label="Input Images (Optional)"):
256
- cond_mode = gr.Radio(
257
- value="SDXL, Foreground",
258
- choices=["SDXL, Foreground", "SDXL, Background"],
259
- info="Whether to use input image as foreground or background",
260
- )
261
- input_image = gr.Image(label="Input Image", type="pil")
 
 
 
 
 
 
262
  with gr.Accordion(open=False, label="Advanced Options"):
263
  seed = gr.Slider(
264
  label="Seed",
@@ -278,8 +243,8 @@ with gr.Blocks(css=css) as blocks:
278
  label="Scheduler",
279
  value=samplers.KSampler.SCHEDULERS[0],
280
  )
281
- steps = gr.Number(
282
- label="Steps", value=20, minimum=1, maximum=10000, step=1
283
  )
284
  cfg = gr.Number(
285
  label="CFG", value=8.0, minimum=0.0, maximum=100.0, step=0.1
@@ -289,14 +254,13 @@ with gr.Blocks(css=css) as blocks:
289
  )
290
 
291
  with gr.Column(scale=1.8):
292
- gallery = gr.Gallery(
293
- columns=[2], rows=[2], object_fit="contain", height="unset"
294
- )
295
 
296
  inputs = [
297
  prompt,
298
  negative_prompt,
299
  input_image,
 
300
  cond_mode,
301
  seed,
302
  sampler_name,
 
1
  import sys
2
  import os
3
  import torch
4
+
5
+ from PIL import Image
 
6
  from typing import List
7
  import numpy as np
8
+ from utils import (
9
+ tensor_to_pil,
10
+ pil_to_tensor,
11
+ pad_image,
12
+ postprocess_image,
13
+ preprocess_image,
14
+ downloadModels,
15
+ )
16
 
17
  sys.path.append(os.path.dirname("./ComfyUI/"))
18
  from ComfyUI.nodes import (
 
34
  LayeredDiffusionCond,
35
  )
36
  import gradio as gr
37
+ from briarmbg import BriaRMBG
38
 
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
+ downloadModels()
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  with torch.inference_mode():
44
  ckpt_load_checkpoint = CheckpointLoaderSimple().load_checkpoint
 
56
  mask_to_image = MaskToImage().mask_to_image
57
  invert_mask = InvertMask().invert
58
  join_image_with_alpha = JoinImageWithAlpha().join_image_with_alpha
59
+ rmbg_model = BriaRMBG.from_pretrained("briaai/RMBG-1.4").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  def predict(
63
  prompt: str,
64
  negative_prompt: str,
65
  input_image: Image.Image | None,
66
+ remove_bg: bool,
67
  cond_mode: str,
68
  seed: int,
69
  sampler_name: str,
 
72
  cfg: float,
73
  denoise: float,
74
  ):
75
+ try:
76
+ with torch.inference_mode():
77
+ cliptextencode_prompt = cliptextencode(
78
+ text=prompt,
79
+ clip=ckpt[1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
+ cliptextencode_negative_prompt = cliptextencode(
82
+ text=negative_prompt,
83
+ clip=ckpt[1],
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
+ emptylatentimage_sample = emptylatentimage_generate(
86
+ width=1024, height=1024, batch_size=1
 
 
 
87
  )
88
 
89
+ if input_image is not None:
90
+ input_image = pad_image(input_image).resize((1024, 1024))
91
+ if remove_bg:
92
+ orig_im_size = input_image.size
93
+ image = preprocess_image(np.array(input_image), [1024, 1024]).to(
94
+ device
95
+ )
96
+
97
+ result = rmbg_model(image)
98
+ # post process
99
+ result_mask_image = postprocess_image(result[0][0], orig_im_size)
100
+
101
+ # save result
102
+ pil_mask = Image.fromarray(result_mask_image)
103
+ no_bg_image = Image.new("RGBA", pil_mask.size, (0, 0, 0, 0))
104
+ no_bg_image.paste(input_image, mask=pil_mask)
105
+ input_image = no_bg_image
106
+
107
+ img_tensor = pil_to_tensor(input_image)
108
+ img_latent = vae_encode(pixels=img_tensor[0], vae=ckpt[2])
109
+ layereddiffusionapply_sample = ld_cond_apply_layered_diffusion(
110
+ config=cond_mode,
111
+ weight=1,
112
+ model=ckpt[0],
113
+ cond=cliptextencode_prompt[0],
114
+ uncond=cliptextencode_negative_prompt[0],
115
+ latent=img_latent[0],
116
+ )
117
+ ksampler = ksampler_sample(
118
+ steps=steps,
119
+ cfg=cfg,
120
+ sampler_name=sampler_name,
121
+ scheduler=scheduler,
122
+ seed=seed,
123
+ model=layereddiffusionapply_sample[0],
124
+ positive=layereddiffusionapply_sample[1],
125
+ negative=layereddiffusionapply_sample[2],
126
+ latent_image=emptylatentimage_sample[0],
127
+ denoise=denoise,
128
+ )
129
 
130
+ vaedecode_sample = vae_decode(
131
+ samples=ksampler[0],
132
+ vae=ckpt[2],
133
+ )
134
+ layereddiffusiondecode_sample = ld_decode(
135
+ sd_version="SDXL",
136
+ sub_batch_size=16,
137
+ samples=ksampler[0],
138
+ images=vaedecode_sample[0],
139
+ )
 
 
 
 
 
 
140
 
141
+ rgb_img = tensor_to_pil(vaedecode_sample[0])
142
+ return flatten([rgb_img])
143
+ else:
144
+ layereddiffusionapply_sample = ld_fg_apply_layered_diffusion(
145
+ config="SDXL, Conv Injection", weight=1, model=ckpt[0]
146
+ )
147
+ ksampler = ksampler_sample(
148
+ steps=steps,
149
+ cfg=cfg,
150
+ sampler_name=sampler_name,
151
+ scheduler=scheduler,
152
+ seed=seed,
153
+ model=layereddiffusionapply_sample[0],
154
+ positive=cliptextencode_prompt[0],
155
+ negative=cliptextencode_negative_prompt[0],
156
+ latent_image=emptylatentimage_sample[0],
157
+ denoise=denoise,
158
+ )
 
159
 
160
+ vaedecode_sample = vae_decode(
161
+ samples=ksampler[0],
162
+ vae=ckpt[2],
163
+ )
164
+ layereddiffusiondecode_sample = ld_decode(
165
+ sd_version="SDXL",
166
+ sub_batch_size=16,
167
+ samples=ksampler[0],
168
+ images=vaedecode_sample[0],
169
+ )
170
+ mask = mask_to_image(mask=layereddiffusiondecode_sample[1])
171
+ ld_image = tensor_to_pil(layereddiffusiondecode_sample[0][0])
172
+ inverted_mask = invert_mask(mask=layereddiffusiondecode_sample[1])
173
+ rgba_img = join_image_with_alpha(
174
+ image=layereddiffusiondecode_sample[0], alpha=inverted_mask[0]
175
+ )
176
+ rgba_img = tensor_to_pil(rgba_img[0])
177
+ mask = tensor_to_pil(mask[0])
178
+ rgb_img = tensor_to_pil(vaedecode_sample[0])
179
+
180
+ return flatten([rgba_img, mask])
181
+ # return flatten([rgba_img, mask, rgb_img, ld_image])
182
+ except Exception as e:
183
+ raise gr.Error(e)
184
 
185
 
186
  examples = [["An old men sit on a chair looking at the sky"]]
 
192
 
193
  def predict_examples(prompt, negative_prompt):
194
  return predict(
195
+ prompt, negative_prompt, None, False, None, 0, "euler", "normal", 20, 8.0, 1.0
196
  )
197
 
198
 
199
  css = """
200
  .gradio-container{
201
+ max-width: 50rem;
202
  }
203
  """
204
  with gr.Blocks(css=css) as blocks:
205
  gr.Markdown("""# LayerDiffuse (unofficial)
206
+ Using ComfyUI building blocks with custom node by [huchenlei](https://github.com/huchenlei/ComfyUI-layerdiffuse)
207
  """)
208
 
209
  with gr.Row():
 
212
  negative_prompt = gr.Text(label="Negative Prompt")
213
  button = gr.Button("Generate")
214
  with gr.Accordion(open=False, label="Input Images (Optional)"):
215
+ with gr.Group():
216
+ cond_mode = gr.Radio(
217
+ value="SDXL, Foreground",
218
+ choices=["SDXL, Foreground", "SDXL, Background"],
219
+ info="Whether to use input image as foreground or background",
220
+ )
221
+ remove_bg = gr.Checkbox(
222
+ info="Remove background using BriaRMBG",
223
+ label="Remove Background",
224
+ value=False,
225
+ )
226
+ input_image = gr.Image(label="Input Image", type="pil")
227
  with gr.Accordion(open=False, label="Advanced Options"):
228
  seed = gr.Slider(
229
  label="Seed",
 
243
  label="Scheduler",
244
  value=samplers.KSampler.SCHEDULERS[0],
245
  )
246
+ steps = gr.Slider(
247
+ label="Steps", value=20, minimum=1, maximum=30, step=1
248
  )
249
  cfg = gr.Number(
250
  label="CFG", value=8.0, minimum=0.0, maximum=100.0, step=0.1
 
254
  )
255
 
256
  with gr.Column(scale=1.8):
257
+ gallery = gr.Gallery(columns=[2], object_fit="contain", height="unset")
 
 
258
 
259
  inputs = [
260
  prompt,
261
  negative_prompt,
262
  input_image,
263
+ remove_bg,
264
  cond_mode,
265
  seed,
266
  sampler_name,
briarmbg.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+
7
+ class REBNCONV(nn.Module):
8
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
9
+ super(REBNCONV, self).__init__()
10
+
11
+ self.conv_s1 = nn.Conv2d(
12
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
13
+ )
14
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
15
+ self.relu_s1 = nn.ReLU(inplace=True)
16
+
17
+ def forward(self, x):
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
27
+
28
+ return src
29
+
30
+
31
+ ### RSU-7 ###
32
+ class RSU7(nn.Module):
33
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
34
+ super(RSU7, self).__init__()
35
+
36
+ self.in_ch = in_ch
37
+ self.mid_ch = mid_ch
38
+ self.out_ch = out_ch
39
+
40
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
41
+
42
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
43
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
56
+
57
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
58
+
59
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
60
+
61
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
64
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
65
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
67
+
68
+ def forward(self, x):
69
+ b, c, h, w = x.shape
70
+
71
+ hx = x
72
+ hxin = self.rebnconvin(hx)
73
+
74
+ hx1 = self.rebnconv1(hxin)
75
+ hx = self.pool1(hx1)
76
+
77
+ hx2 = self.rebnconv2(hx)
78
+ hx = self.pool2(hx2)
79
+
80
+ hx3 = self.rebnconv3(hx)
81
+ hx = self.pool3(hx3)
82
+
83
+ hx4 = self.rebnconv4(hx)
84
+ hx = self.pool4(hx4)
85
+
86
+ hx5 = self.rebnconv5(hx)
87
+ hx = self.pool5(hx5)
88
+
89
+ hx6 = self.rebnconv6(hx)
90
+
91
+ hx7 = self.rebnconv7(hx6)
92
+
93
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
94
+ hx6dup = _upsample_like(hx6d, hx5)
95
+
96
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
97
+ hx5dup = _upsample_like(hx5d, hx4)
98
+
99
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
100
+ hx4dup = _upsample_like(hx4d, hx3)
101
+
102
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
103
+ hx3dup = _upsample_like(hx3d, hx2)
104
+
105
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
106
+ hx2dup = _upsample_like(hx2d, hx1)
107
+
108
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
109
+
110
+ return hx1d + hxin
111
+
112
+
113
+ ### RSU-6 ###
114
+ class RSU6(nn.Module):
115
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
116
+ super(RSU6, self).__init__()
117
+
118
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
119
+
120
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
121
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
122
+
123
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
124
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
125
+
126
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
127
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
128
+
129
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
130
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
131
+
132
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
133
+
134
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
135
+
136
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
137
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
138
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
141
+
142
+ def forward(self, x):
143
+ hx = x
144
+
145
+ hxin = self.rebnconvin(hx)
146
+
147
+ hx1 = self.rebnconv1(hxin)
148
+ hx = self.pool1(hx1)
149
+
150
+ hx2 = self.rebnconv2(hx)
151
+ hx = self.pool2(hx2)
152
+
153
+ hx3 = self.rebnconv3(hx)
154
+ hx = self.pool3(hx3)
155
+
156
+ hx4 = self.rebnconv4(hx)
157
+ hx = self.pool4(hx4)
158
+
159
+ hx5 = self.rebnconv5(hx)
160
+
161
+ hx6 = self.rebnconv6(hx5)
162
+
163
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
164
+ hx5dup = _upsample_like(hx5d, hx4)
165
+
166
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
167
+ hx4dup = _upsample_like(hx4d, hx3)
168
+
169
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
170
+ hx3dup = _upsample_like(hx3d, hx2)
171
+
172
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
173
+ hx2dup = _upsample_like(hx2d, hx1)
174
+
175
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
176
+
177
+ return hx1d + hxin
178
+
179
+
180
+ ### RSU-5 ###
181
+ class RSU5(nn.Module):
182
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
183
+ super(RSU5, self).__init__()
184
+
185
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
186
+
187
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
188
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
189
+
190
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
191
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
192
+
193
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
194
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
195
+
196
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
197
+
198
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
199
+
200
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
201
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
202
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
203
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
204
+
205
+ def forward(self, x):
206
+ hx = x
207
+
208
+ hxin = self.rebnconvin(hx)
209
+
210
+ hx1 = self.rebnconv1(hxin)
211
+ hx = self.pool1(hx1)
212
+
213
+ hx2 = self.rebnconv2(hx)
214
+ hx = self.pool2(hx2)
215
+
216
+ hx3 = self.rebnconv3(hx)
217
+ hx = self.pool3(hx3)
218
+
219
+ hx4 = self.rebnconv4(hx)
220
+
221
+ hx5 = self.rebnconv5(hx4)
222
+
223
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
224
+ hx4dup = _upsample_like(hx4d, hx3)
225
+
226
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
227
+ hx3dup = _upsample_like(hx3d, hx2)
228
+
229
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
230
+ hx2dup = _upsample_like(hx2d, hx1)
231
+
232
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
233
+
234
+ return hx1d + hxin
235
+
236
+
237
+ ### RSU-4 ###
238
+ class RSU4(nn.Module):
239
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
240
+ super(RSU4, self).__init__()
241
+
242
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
243
+
244
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
245
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
246
+
247
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
248
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249
+
250
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
251
+
252
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
253
+
254
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
255
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
256
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
257
+
258
+ def forward(self, x):
259
+ hx = x
260
+
261
+ hxin = self.rebnconvin(hx)
262
+
263
+ hx1 = self.rebnconv1(hxin)
264
+ hx = self.pool1(hx1)
265
+
266
+ hx2 = self.rebnconv2(hx)
267
+ hx = self.pool2(hx2)
268
+
269
+ hx3 = self.rebnconv3(hx)
270
+
271
+ hx4 = self.rebnconv4(hx3)
272
+
273
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
274
+ hx3dup = _upsample_like(hx3d, hx2)
275
+
276
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
277
+ hx2dup = _upsample_like(hx2d, hx1)
278
+
279
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
280
+
281
+ return hx1d + hxin
282
+
283
+
284
+ ### RSU-4F ###
285
+ class RSU4F(nn.Module):
286
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
287
+ super(RSU4F, self).__init__()
288
+
289
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
290
+
291
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
292
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
293
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
294
+
295
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
296
+
297
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
298
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
299
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
300
+
301
+ def forward(self, x):
302
+ hx = x
303
+
304
+ hxin = self.rebnconvin(hx)
305
+
306
+ hx1 = self.rebnconv1(hxin)
307
+ hx2 = self.rebnconv2(hx1)
308
+ hx3 = self.rebnconv3(hx2)
309
+
310
+ hx4 = self.rebnconv4(hx3)
311
+
312
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
313
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
314
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
315
+
316
+ return hx1d + hxin
317
+
318
+
319
+ class myrebnconv(nn.Module):
320
+ def __init__(
321
+ self,
322
+ in_ch=3,
323
+ out_ch=1,
324
+ kernel_size=3,
325
+ stride=1,
326
+ padding=1,
327
+ dilation=1,
328
+ groups=1,
329
+ ):
330
+ super(myrebnconv, self).__init__()
331
+
332
+ self.conv = nn.Conv2d(
333
+ in_ch,
334
+ out_ch,
335
+ kernel_size=kernel_size,
336
+ stride=stride,
337
+ padding=padding,
338
+ dilation=dilation,
339
+ groups=groups,
340
+ )
341
+ self.bn = nn.BatchNorm2d(out_ch)
342
+ self.rl = nn.ReLU(inplace=True)
343
+
344
+ def forward(self, x):
345
+ return self.rl(self.bn(self.conv(x)))
346
+
347
+
348
+ class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
+ def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
350
+ super(BriaRMBG, self).__init__()
351
+ in_ch = config["in_ch"]
352
+ out_ch = config["out_ch"]
353
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
354
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage1 = RSU7(64, 32, 64)
357
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage2 = RSU6(64, 32, 128)
360
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage3 = RSU5(128, 64, 256)
363
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage4 = RSU4(256, 128, 512)
366
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage5 = RSU4F(512, 256, 512)
369
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage6 = RSU4F(512, 256, 512)
372
+
373
+ # decoder
374
+ self.stage5d = RSU4F(1024, 256, 512)
375
+ self.stage4d = RSU4(1024, 128, 256)
376
+ self.stage3d = RSU5(512, 64, 128)
377
+ self.stage2d = RSU6(256, 32, 64)
378
+ self.stage1d = RSU7(128, 16, 64)
379
+
380
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
381
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
382
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
383
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
384
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
385
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
386
+
387
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
388
+
389
+ def forward(self, x):
390
+ hx = x
391
+
392
+ hxin = self.conv_in(hx)
393
+ # hx = self.pool_in(hxin)
394
+
395
+ # stage 1
396
+ hx1 = self.stage1(hxin)
397
+ hx = self.pool12(hx1)
398
+
399
+ # stage 2
400
+ hx2 = self.stage2(hx)
401
+ hx = self.pool23(hx2)
402
+
403
+ # stage 3
404
+ hx3 = self.stage3(hx)
405
+ hx = self.pool34(hx3)
406
+
407
+ # stage 4
408
+ hx4 = self.stage4(hx)
409
+ hx = self.pool45(hx4)
410
+
411
+ # stage 5
412
+ hx5 = self.stage5(hx)
413
+ hx = self.pool56(hx5)
414
+
415
+ # stage 6
416
+ hx6 = self.stage6(hx)
417
+ hx6up = _upsample_like(hx6, hx5)
418
+
419
+ # -------------------- decoder --------------------
420
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
421
+ hx5dup = _upsample_like(hx5d, hx4)
422
+
423
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
424
+ hx4dup = _upsample_like(hx4d, hx3)
425
+
426
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
427
+ hx3dup = _upsample_like(hx3d, hx2)
428
+
429
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
430
+ hx2dup = _upsample_like(hx2d, hx1)
431
+
432
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
433
+
434
+ # side output
435
+ d1 = self.side1(hx1d)
436
+ d1 = _upsample_like(d1, x)
437
+
438
+ d2 = self.side2(hx2d)
439
+ d2 = _upsample_like(d2, x)
440
+
441
+ d3 = self.side3(hx3d)
442
+ d3 = _upsample_like(d3, x)
443
+
444
+ d4 = self.side4(hx4d)
445
+ d4 = _upsample_like(d4, x)
446
+
447
+ d5 = self.side5(hx5d)
448
+ d5 = _upsample_like(d5, x)
449
+
450
+ d6 = self.side6(hx6)
451
+ d6 = _upsample_like(d6, x)
452
+
453
+ return [
454
+ F.sigmoid(d1),
455
+ F.sigmoid(d2),
456
+ F.sigmoid(d3),
457
+ F.sigmoid(d4),
458
+ F.sigmoid(d5),
459
+ F.sigmoid(d6),
460
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms.functional import normalize
6
+ from PIL import Image, ImageOps, ImageSequence
7
+ from typing import List
8
+ from pathlib import Path
9
+ from huggingface_hub import snapshot_download, hf_hub_download
10
+
11
+
12
+ def tensor_to_pil(images: torch.Tensor | List[torch.Tensor]) -> List[Image.Image]:
13
+ if not isinstance(images, list):
14
+ images = [images]
15
+ imgs = []
16
+ for image in images:
17
+ i = 255.0 * image.cpu().numpy()
18
+ img = Image.fromarray(np.clip(np.squeeze(i), 0, 255).astype(np.uint8))
19
+ imgs.append(img)
20
+ return imgs
21
+
22
+
23
+ def pad_image(input_image):
24
+ pad_w, pad_h = (
25
+ np.max(((2, 2), np.ceil(np.array(input_image.size) / 64).astype(int)), axis=0)
26
+ * 64
27
+ - input_image.size
28
+ )
29
+ im_padded = Image.fromarray(
30
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode="edge")
31
+ )
32
+ w, h = im_padded.size
33
+ if w == h:
34
+ return im_padded
35
+ elif w > h:
36
+ new_image = Image.new(im_padded.mode, (w, w), (0, 0, 0))
37
+ new_image.paste(im_padded, (0, (w - h) // 2))
38
+ return new_image
39
+ else:
40
+ new_image = Image.new(im_padded.mode, (h, h), (0, 0, 0))
41
+ new_image.paste(im_padded, ((h - w) // 2, 0))
42
+ return new_image
43
+
44
+
45
+ def pil_to_tensor(image: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
46
+ output_images = []
47
+ output_masks = []
48
+ for i in ImageSequence.Iterator(image):
49
+ i = ImageOps.exif_transpose(i)
50
+ if i.mode == "I":
51
+ i = i.point(lambda i: i * (1 / 255))
52
+ image = i.convert("RGB")
53
+ image = np.array(image).astype(np.float32) / 255.0
54
+ image = torch.from_numpy(image)[None,]
55
+ if "A" in i.getbands():
56
+ mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
57
+ mask = 1.0 - torch.from_numpy(mask)
58
+ else:
59
+ mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
60
+ output_images.append(image)
61
+ output_masks.append(mask.unsqueeze(0))
62
+
63
+ if len(output_images) > 1:
64
+ output_image = torch.cat(output_images, dim=0)
65
+ output_mask = torch.cat(output_masks, dim=0)
66
+ else:
67
+ output_image = output_images[0]
68
+ output_mask = output_masks[0]
69
+
70
+ return (output_image, output_mask)
71
+
72
+
73
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
74
+ if len(im.shape) < 3:
75
+ im = im[:, :, np.newaxis]
76
+ # orig_im_size=im.shape[0:2]
77
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
78
+ im_tensor = F.interpolate(
79
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
80
+ ).type(torch.uint8)
81
+ image = torch.divide(im_tensor, 255.0)
82
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
83
+ return image
84
+
85
+
86
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
87
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
88
+ ma = torch.max(result)
89
+ mi = torch.min(result)
90
+ result = (result - mi) / (ma - mi)
91
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
92
+ im_array = np.squeeze(im_array)
93
+ return im_array
94
+
95
+
96
+ def downloadModels():
97
+ MODEL_PATH = hf_hub_download(
98
+ repo_id="lllyasviel/fav_models",
99
+ subfolder="fav",
100
+ filename="juggernautXL_v8Rundiffusion.safetensors",
101
+ )
102
+ LAYERS_PATH = snapshot_download(
103
+ repo_id="LayerDiffusion/layerdiffusion-v1", allow_patterns="*.safetensors"
104
+ )
105
+ for file in Path(LAYERS_PATH).glob("*.safetensors"):
106
+ target_path = Path(f"./ComfyUI/models/layer_model/{file.name}")
107
+ if not target_path.exists():
108
+ os.symlink(file, target_path)
109
+
110
+ model_target_path = Path(
111
+ "./ComfyUI/models/checkpoints/juggernautXL_v8Rundiffusion.safetensors"
112
+ )
113
+ if not model_target_path.exists():
114
+ os.symlink(MODEL_PATH, model_target_path)