multimodalart HF staff commited on
Commit
29dc892
1 Parent(s): 69b2b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -52
app.py CHANGED
@@ -5,6 +5,7 @@ from controlnet_flux import FluxControlNetModel
5
  from transformer_flux import FluxTransformer2DModel
6
  from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
7
  from PIL import Image, ImageDraw
 
8
  import spaces
9
 
10
  # Load models
@@ -21,78 +22,370 @@ pipe = FluxControlNetInpaintingPipeline.from_pretrained(
21
  pipe.transformer.to(torch.bfloat16)
22
  pipe.controlnet.to(torch.bfloat16)
23
 
24
- def prepare_image_and_mask(image, width, height, overlap_percentage):
25
- # Resize the input image to fit within the target size
26
- image.thumbnail((width, height), Image.LANCZOS)
27
-
28
- # Create a new white background image of the target size
29
- background = Image.new('RGB', (width, height), (255, 255, 255))
30
-
31
- # Paste the resized image onto the background
32
- offset = ((width - image.width) // 2, (height - image.height) // 2)
33
- background.paste(image, offset)
34
-
35
- # Create a mask
36
- mask = Image.new('L', (width, height), 255)
37
- draw = ImageDraw.Draw(mask)
38
 
39
- # Calculate the overlap area
40
- overlap_x = int(image.width * overlap_percentage / 100)
41
- overlap_y = int(image.height * overlap_percentage / 100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Draw the mask (black area is where we want to inpaint)
44
- draw.rectangle([
45
- (offset[0] + overlap_x, offset[1] + overlap_y),
46
- (offset[0] + image.width - overlap_x, offset[1] + image.height - overlap_y)
 
 
 
 
 
 
 
 
 
47
  ], fill=0)
48
-
49
  return background, mask
50
 
51
  @spaces.GPU
52
- def inpaint(image, prompt, width, height, overlap_percentage, num_inference_steps, guidance_scale):
53
- # Prepare image and mask
54
- image, mask = prepare_image_and_mask(image, width, height, overlap_percentage)
55
 
56
- # Set up generator for reproducibility
 
 
 
 
 
 
 
57
  generator = torch.Generator(device="cuda").manual_seed(42)
58
-
59
- # Run inpainting
60
  result = pipe(
61
- prompt=prompt,
62
  height=height,
63
  width=width,
64
- control_image=image,
65
  control_mask=mask,
66
  num_inference_steps=num_inference_steps,
67
  generator=generator,
68
  controlnet_conditioning_scale=0.9,
69
- guidance_scale=guidance_scale,
70
  negative_prompt="",
71
- true_guidance_scale=guidance_scale
72
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- return result
75
-
76
- # Gradio interface
77
- with gr.Blocks() as demo:
78
- gr.Markdown("# FLUX Outpainting Demo")
79
- with gr.Row():
80
- with gr.Column():
81
- input_image = gr.Image(type="pil", label="Input Image")
82
- prompt_input = gr.Textbox(label="Prompt")
83
- width_slider = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=768)
84
- height_slider = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=768)
85
- overlap_slider = gr.Slider(label="Overlap Percentage", minimum=0, maximum=50, step=1, value=10)
86
- steps_slider = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
87
- guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=3.5)
88
- run_button = gr.Button("Generate")
89
- with gr.Column():
90
- output_image = gr.Image(label="Output Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  run_button.click(
 
 
 
 
93
  fn=inpaint,
94
- inputs=[input_image, prompt_input, width_slider, height_slider, overlap_slider, steps_slider, guidance_slider],
95
- outputs=output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
 
98
- demo.launch()
 
5
  from transformer_flux import FluxTransformer2DModel
6
  from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
7
  from PIL import Image, ImageDraw
8
+ import numpy as np
9
  import spaces
10
 
11
  # Load models
 
22
  pipe.transformer.to(torch.bfloat16)
23
  pipe.controlnet.to(torch.bfloat16)
24
 
25
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
26
+ if alignment in ("Left", "Right") and source_width >= target_width:
27
+ return False
28
+ if alignment in ("Top", "Bottom") and source_height >= target_height:
29
+ return False
30
+ return True
31
+
32
+ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
33
+ target_size = (width, height)
34
+
35
+ # Calculate the scaling factor to fit the image within the target size
36
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
37
+ new_width = int(image.width * scale_factor)
38
+ new_height = int(image.height * scale_factor)
39
 
40
+ # Resize the source image to fit within target size
41
+ source = image.resize((new_width, new_height), Image.LANCZOS)
42
+
43
+ # Apply resize option using percentages
44
+ if resize_option == "Full":
45
+ resize_percentage = 100
46
+ elif resize_option == "50%":
47
+ resize_percentage = 50
48
+ elif resize_option == "33%":
49
+ resize_percentage = 33
50
+ elif resize_option == "25%":
51
+ resize_percentage = 25
52
+ else: # Custom
53
+ resize_percentage = custom_resize_percentage
54
+
55
+ # Calculate new dimensions based on percentage
56
+ resize_factor = resize_percentage / 100
57
+ new_width = int(source.width * resize_factor)
58
+ new_height = int(source.height * resize_factor)
59
+
60
+ # Ensure minimum size of 64 pixels
61
+ new_width = max(new_width, 64)
62
+ new_height = max(new_height, 64)
63
+
64
+ # Resize the image
65
+ source = source.resize((new_width, new_height), Image.LANCZOS)
66
+
67
+ # Calculate the overlap in pixels based on the percentage
68
+ overlap_x = int(new_width * (overlap_percentage / 100))
69
+ overlap_y = int(new_height * (overlap_percentage / 100))
70
+
71
+ # Ensure minimum overlap of 1 pixel
72
+ overlap_x = max(overlap_x, 1)
73
+ overlap_y = max(overlap_y, 1)
74
+
75
+ # Calculate margins based on alignment
76
+ if alignment == "Middle":
77
+ margin_x = (target_size[0] - new_width) // 2
78
+ margin_y = (target_size[1] - new_height) // 2
79
+ elif alignment == "Left":
80
+ margin_x = 0
81
+ margin_y = (target_size[1] - new_height) // 2
82
+ elif alignment == "Right":
83
+ margin_x = target_size[0] - new_width
84
+ margin_y = (target_size[1] - new_height) // 2
85
+ elif alignment == "Top":
86
+ margin_x = (target_size[0] - new_width) // 2
87
+ margin_y = 0
88
+ elif alignment == "Bottom":
89
+ margin_x = (target_size[0] - new_width) // 2
90
+ margin_y = target_size[1] - new_height
91
+
92
+ # Adjust margins to eliminate gaps
93
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
94
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
95
+
96
+ # Create a new background image and paste the resized source image
97
+ background = Image.new('RGB', target_size, (255, 255, 255))
98
+ background.paste(source, (margin_x, margin_y))
99
+
100
+ # Create the mask
101
+ mask = Image.new('L', target_size, 255)
102
+ mask_draw = ImageDraw.Draw(mask)
103
+
104
+ # Calculate overlap areas
105
+ white_gaps_patch = 2
106
+
107
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
108
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
109
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
110
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
111
 
112
+ if alignment == "Left":
113
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
114
+ elif alignment == "Right":
115
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
116
+ elif alignment == "Top":
117
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
118
+ elif alignment == "Bottom":
119
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
120
+
121
+ # Draw the mask
122
+ mask_draw.rectangle([
123
+ (left_overlap, top_overlap),
124
+ (right_overlap, bottom_overlap)
125
  ], fill=0)
126
+
127
  return background, mask
128
 
129
  @spaces.GPU
130
+ def inpaint(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
131
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
 
132
 
133
+ if not can_expand(background.width, background.height, width, height, alignment):
134
+ alignment = "Middle"
135
+
136
+ cnet_image = background.copy()
137
+ cnet_image.paste(0, (0, 0), mask)
138
+
139
+ final_prompt = f"{prompt_input} , high quality, 4k"
140
+
141
  generator = torch.Generator(device="cuda").manual_seed(42)
142
+
 
143
  result = pipe(
144
+ prompt=final_prompt,
145
  height=height,
146
  width=width,
147
+ control_image=cnet_image,
148
  control_mask=mask,
149
  num_inference_steps=num_inference_steps,
150
  generator=generator,
151
  controlnet_conditioning_scale=0.9,
152
+ guidance_scale=3.5,
153
  negative_prompt="",
154
+ true_guidance_scale=3.5
155
  ).images[0]
156
+
157
+ result = result.convert("RGBA")
158
+ cnet_image.paste(result, (0, 0), mask)
159
+
160
+ return background, cnet_image
161
+
162
+ def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
163
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
164
+
165
+ preview = background.copy().convert('RGBA')
166
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64))
167
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
168
+ red_mask.paste(red_overlay, (0, 0), mask)
169
+ preview = Image.alpha_composite(preview, red_mask)
170
 
171
+ return preview
172
+
173
+ def clear_result():
174
+ return gr.update(value=None)
175
+
176
+ def preload_presets(target_ratio, ui_width, ui_height):
177
+ if target_ratio == "9:16":
178
+ return 720, 1280, gr.update()
179
+ elif target_ratio == "16:9":
180
+ return 1280, 720, gr.update()
181
+ elif target_ratio == "1:1":
182
+ return 1024, 1024, gr.update()
183
+ elif target_ratio == "Custom":
184
+ return ui_width, ui_height, gr.update(open=True)
185
+
186
+ def select_the_right_preset(user_width, user_height):
187
+ if user_width == 720 and user_height == 1280:
188
+ return "9:16"
189
+ elif user_width == 1280 and user_height == 720:
190
+ return "16:9"
191
+ elif user_width == 1024 and user_height == 1024:
192
+ return "1:1"
193
+ else:
194
+ return "Custom"
195
+
196
+ def toggle_custom_resize_slider(resize_option):
197
+ return gr.update(visible=(resize_option == "Custom"))
198
+
199
+ def update_history(new_image, history):
200
+ if history is None:
201
+ history = []
202
+ history.insert(0, new_image)
203
+ return history
204
+
205
+ css = """
206
+ .gradio-container {
207
+ width: 1200px !important;
208
+ }
209
+ """
210
+
211
+ title = """<h1 align="center">FLUX Image Outpaint</h1>
212
+ <div align="center">Drop an image you would like to extend, pick your expected ratio and hit Generate.</div>
213
+ """
214
+
215
+ with gr.Blocks(css=css) as demo:
216
+ with gr.Column():
217
+ gr.HTML(title)
218
+
219
+ with gr.Row():
220
+ with gr.Column():
221
+ input_image = gr.Image(
222
+ type="pil",
223
+ label="Input Image"
224
+ )
225
+
226
+ with gr.Row():
227
+ with gr.Column(scale=2):
228
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
229
+ with gr.Column(scale=1):
230
+ run_button = gr.Button("Generate")
231
+
232
+ with gr.Row():
233
+ target_ratio = gr.Radio(
234
+ label="Expected Ratio",
235
+ choices=["9:16", "16:9", "1:1", "Custom"],
236
+ value="9:16",
237
+ scale=2
238
+ )
239
+
240
+ alignment_dropdown = gr.Dropdown(
241
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
242
+ value="Middle",
243
+ label="Alignment"
244
+ )
245
+
246
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
247
+ with gr.Column():
248
+ with gr.Row():
249
+ width_slider = gr.Slider(
250
+ label="Target Width",
251
+ minimum=720,
252
+ maximum=1536,
253
+ step=8,
254
+ value=720,
255
+ )
256
+ height_slider = gr.Slider(
257
+ label="Target Height",
258
+ minimum=720,
259
+ maximum=1536,
260
+ step=8,
261
+ value=1280,
262
+ )
263
+
264
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
265
+ with gr.Group():
266
+ overlap_percentage = gr.Slider(
267
+ label="Mask overlap (%)",
268
+ minimum=1,
269
+ maximum=50,
270
+ value=10,
271
+ step=1
272
+ )
273
+ with gr.Row():
274
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
275
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
276
+ with gr.Row():
277
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
278
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
279
+ with gr.Row():
280
+ resize_option = gr.Radio(
281
+ label="Resize input image",
282
+ choices=["Full", "50%", "33%", "25%", "Custom"],
283
+ value="Full"
284
+ )
285
+ custom_resize_percentage = gr.Slider(
286
+ label="Custom resize (%)",
287
+ minimum=1,
288
+ maximum=100,
289
+ step=1,
290
+ value=50,
291
+ visible=False
292
+ )
293
+
294
+ with gr.Column():
295
+ preview_button = gr.Button("Preview alignment and mask")
296
+
297
+ with gr.Column():
298
+ result = gr.Image(
299
+ interactive=False,
300
+ label="Generated Image",
301
+ )
302
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
303
+
304
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
305
+ preview_image = gr.Image(label="Preview")
306
+
307
+ def use_output_as_input(output_image):
308
+ return gr.update(value=output_image[1])
309
+
310
+ use_as_input_button.click(
311
+ fn=use_output_as_input,
312
+ inputs=[result],
313
+ outputs=[input_image]
314
+ )
315
+
316
+ target_ratio.change(
317
+ fn=preload_presets,
318
+ inputs=[target_ratio, width_slider, height_slider],
319
+ outputs=[width_slider, height_slider, settings_panel],
320
+ queue=False
321
+ )
322
+
323
+ width_slider.change(
324
+ fn=select_the_right_preset,
325
+ inputs=[width_slider, height_slider],
326
+ outputs=[target_ratio],
327
+ queue=False
328
+ )
329
+
330
+ height_slider.change(
331
+ fn=select_the_right_preset,
332
+ inputs=[width_slider, height_slider],
333
+ outputs=[target_ratio],
334
+ queue=False
335
+ )
336
+
337
+ resize_option.change(
338
+ fn=toggle_custom_resize_slider,
339
+ inputs=[resize_option],
340
+ outputs=[custom_resize_percentage],
341
+ queue=False
342
+ )
343
 
344
  run_button.click(
345
+ fn=clear_result,
346
+ inputs=None,
347
+ outputs=result,
348
+ ).then(
349
  fn=inpaint,
350
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
351
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
352
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
353
+ outputs=result,
354
+ ).then(
355
+ fn=lambda x, history: update_history(x[1], history),
356
+ inputs=[result, history_gallery],
357
+ outputs=history_gallery,
358
+ ).then(
359
+ fn=lambda: gr.update(visible=True),
360
+ inputs=None,
361
+ outputs=use_as_input_button,
362
+ )
363
+
364
+ prompt_input.submit(
365
+ fn=clear_result,
366
+ inputs=None,
367
+ outputs=result,
368
+ ).then(
369
+ fn=inpaint,
370
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
371
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
372
+ outputs=result,
373
+ ).then(
374
+ fn=lambda x, history: update_history(x[1], history),
375
+ inputs=[result, history_gallery],
376
+ outputs=history_gallery,
377
+ ).then(
378
+ fn=lambda: gr.update(visible=True),
379
+ inputs=None,
380
+ outputs=use_as_input_button,
381
+ )
382
+
383
+ preview_button.click(
384
+ fn=preview_image_and_mask,
385
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
386
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
387
+ outputs=preview_image,
388
+ queue=False
389
  )
390
 
391
+ demo.queue(max_size=12).launch(share=False)