Update backend to the new Inpainting model

#6
by multimodalart HF staff - opened
Files changed (1) hide show
  1. app.py +422 -427
app.py CHANGED
@@ -1,427 +1,422 @@
1
- import io
2
- import base64
3
- import os
4
-
5
- import numpy as np
6
- import torch
7
- from torch import autocast
8
- from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
9
- from PIL import Image
10
- from PIL import ImageOps
11
- import gradio as gr
12
- import base64
13
- import skimage
14
- import skimage.measure
15
- from utils import *
16
-
17
- try:
18
- cuda_available = torch.cuda.is_available()
19
- except:
20
- cuda_available = False
21
- finally:
22
- if cuda_available:
23
- device = "cuda"
24
- else:
25
- device = "cpu"
26
-
27
- if device != "cuda":
28
- import contextlib
29
- autocast = contextlib.nullcontext
30
-
31
- def load_html():
32
- body, canvaspy = "", ""
33
- with open("index.html", encoding="utf8") as f:
34
- body = f.read()
35
- with open("canvas.py", encoding="utf8") as f:
36
- canvaspy = f.read()
37
- body = body.replace("- paths:\n", "")
38
- body = body.replace(" - ./canvas.py\n", "")
39
- body = body.replace("from canvas import InfCanvas", canvaspy)
40
- return body
41
-
42
-
43
- def test(x):
44
- x = load_html()
45
- return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
46
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
47
- allow-scripts allow-same-origin allow-popups
48
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
49
- allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
50
-
51
-
52
- DEBUG_MODE = False
53
-
54
- try:
55
- SAMPLING_MODE = Image.Resampling.LANCZOS
56
- except Exception as e:
57
- SAMPLING_MODE = Image.LANCZOS
58
-
59
- try:
60
- contain_func = ImageOps.contain
61
- except Exception as e:
62
-
63
- def contain_func(image, size, method=SAMPLING_MODE):
64
- # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
65
- im_ratio = image.width / image.height
66
- dest_ratio = size[0] / size[1]
67
- if im_ratio != dest_ratio:
68
- if im_ratio > dest_ratio:
69
- new_height = int(image.height / image.width * size[0])
70
- if new_height != size[1]:
71
- size = (size[0], new_height)
72
- else:
73
- new_width = int(image.width / image.height * size[1])
74
- if new_width != size[0]:
75
- size = (new_width, size[1])
76
- return image.resize(size, resample=method)
77
-
78
-
79
- PAINT_SELECTION = "✥"
80
- IMAGE_SELECTION = "🖼️"
81
- BRUSH_SELECTION = "🖌️"
82
- blocks = gr.Blocks()
83
- model = {}
84
- model["width"] = 1500
85
- model["height"] = 600
86
- model["sel_size"] = 256
87
-
88
- def get_token():
89
- token = ""
90
- token = os.environ.get("hftoken", token)
91
- return token
92
-
93
-
94
- def save_token(token):
95
- return
96
-
97
-
98
- def get_model(token=""):
99
- if "text2img" not in model:
100
- if device=="cuda":
101
- text2img = StableDiffusionPipeline.from_pretrained(
102
- "CompVis/stable-diffusion-v1-4",
103
- revision="fp16",
104
- torch_dtype=torch.float16,
105
- use_auth_token=token,
106
- ).to(device)
107
- else:
108
- text2img = StableDiffusionPipeline.from_pretrained(
109
- "CompVis/stable-diffusion-v1-4",
110
- use_auth_token=token,
111
- ).to(device)
112
- model["safety_checker"] = text2img.safety_checker
113
- inpaint = StableDiffusionInpaintPipeline(
114
- vae=text2img.vae,
115
- text_encoder=text2img.text_encoder,
116
- tokenizer=text2img.tokenizer,
117
- unet=text2img.unet,
118
- scheduler=text2img.scheduler,
119
- safety_checker=text2img.safety_checker,
120
- feature_extractor=text2img.feature_extractor,
121
- ).to(device)
122
- save_token(token)
123
- try:
124
- total_memory = torch.cuda.get_device_properties(0).total_memory // (
125
- 1024 ** 3
126
- )
127
- if total_memory <= 5:
128
- inpaint.enable_attention_slicing()
129
- except:
130
- pass
131
- model["text2img"] = text2img
132
- model["inpaint"] = inpaint
133
- return model["text2img"], model["inpaint"]
134
-
135
-
136
- def run_outpaint(
137
- sel_buffer_str,
138
- prompt_text,
139
- strength,
140
- guidance,
141
- step,
142
- resize_check,
143
- fill_mode,
144
- enable_safety,
145
- state,
146
- ):
147
- base64_str = "base64"
148
- if not cuda_available:
149
- data = base64.b64decode(str(sel_buffer_str))
150
- pil = Image.open(io.BytesIO(data))
151
- sel_buffer = np.array(pil)
152
- sel_buffer[:, :, 3]=255
153
- sel_buffer[:, :, 0]=255
154
- out_pil = Image.fromarray(sel_buffer)
155
- out_buffer = io.BytesIO()
156
- out_pil.save(out_buffer, format="PNG")
157
- out_buffer.seek(0)
158
- base64_bytes = base64.b64encode(out_buffer.read())
159
- base64_str = base64_bytes.decode("ascii")
160
- return (
161
- gr.update(label=str(state + 1), value=base64_str,),
162
- gr.update(label="Prompt"),
163
- state + 1,
164
- )
165
- if True:
166
- text2img, inpaint = get_model()
167
- if enable_safety:
168
- text2img.safety_checker = model["safety_checker"]
169
- inpaint.safety_checker = model["safety_checker"]
170
- else:
171
- text2img.safety_checker = lambda images, **kwargs: (images, False)
172
- inpaint.safety_checker = lambda images, **kwargs: (images, False)
173
- data = base64.b64decode(str(sel_buffer_str))
174
- pil = Image.open(io.BytesIO(data))
175
- # base.output.clear_output()
176
- # base.read_selection_from_buffer()
177
- sel_buffer = np.array(pil)
178
- img = sel_buffer[:, :, 0:3]
179
- mask = sel_buffer[:, :, -1]
180
- process_size = 512 if resize_check else model["sel_size"]
181
- if mask.sum() > 0:
182
- img, mask = functbl[fill_mode](img, mask)
183
- init_image = Image.fromarray(img)
184
- mask = 255 - mask
185
- mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
186
- mask = mask.repeat(8, axis=0).repeat(8, axis=1)
187
- mask_image = Image.fromarray(mask)
188
- # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
189
- with autocast("cuda"):
190
- images = inpaint(
191
- prompt=prompt_text,
192
- init_image=init_image.resize(
193
- (process_size, process_size), resample=SAMPLING_MODE
194
- ),
195
- mask_image=mask_image.resize((process_size, process_size)),
196
- strength=strength,
197
- num_inference_steps=step,
198
- guidance_scale=guidance,
199
- )["sample"]
200
- else:
201
- with autocast("cuda"):
202
- images = text2img(
203
- prompt=prompt_text, height=process_size, width=process_size,
204
- )["sample"]
205
- out = sel_buffer.copy()
206
- out[:, :, 0:3] = np.array(
207
- images[0].resize(
208
- (model["sel_size"], model["sel_size"]), resample=SAMPLING_MODE,
209
- )
210
- )
211
- out[:, :, -1] = 255
212
- out_pil = Image.fromarray(out)
213
- out_buffer = io.BytesIO()
214
- out_pil.save(out_buffer, format="PNG")
215
- out_buffer.seek(0)
216
- base64_bytes = base64.b64encode(out_buffer.read())
217
- base64_str = base64_bytes.decode("ascii")
218
- return (
219
- gr.update(label=str(state + 1), value=base64_str,),
220
- gr.update(label="Prompt"),
221
- state + 1,
222
- )
223
-
224
-
225
- def load_js(name):
226
- if name in ["export", "commit", "undo"]:
227
- return f"""
228
- function (x)
229
- {{
230
- let frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow;
231
- frame.postMessage(["click","{name}"], "*");
232
- return x;
233
- }}
234
- """
235
- ret = ""
236
- with open(f"./js/{name}.js", "r") as f:
237
- ret = f.read()
238
- return ret
239
-
240
-
241
- upload_button_js = load_js("upload")
242
- outpaint_button_js = load_js("outpaint")
243
- proceed_button_js = load_js("proceed")
244
- mode_js = load_js("mode")
245
- setup_button_js = load_js("setup")
246
- if not cuda_available:
247
- get_model = lambda x:x
248
- get_model(get_token())
249
-
250
- with blocks as demo:
251
- # title
252
- title = gr.Markdown(
253
- """
254
- **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
255
- """
256
- )
257
- # frame
258
- frame = gr.HTML(test(2), visible=True)
259
- # setup
260
- # with gr.Row():
261
- # token = gr.Textbox(
262
- # label="Huggingface token",
263
- # value="",
264
- # placeholder="Input your token here",
265
- # )
266
- # canvas_width = gr.Number(
267
- # label="Canvas width", value=1024, precision=0, elem_id="canvas_width"
268
- # )
269
- # canvas_height = gr.Number(
270
- # label="Canvas height", value=600, precision=0, elem_id="canvas_height"
271
- # )
272
- # selection_size = gr.Number(
273
- # label="Selection box size", value=256, precision=0, elem_id="selection_size"
274
- # )
275
- # setup_button = gr.Button("Start (may take a while)", variant="primary")
276
- with gr.Row():
277
- with gr.Column(scale=3, min_width=270):
278
- # canvas control
279
- canvas_control = gr.Radio(
280
- label="Control",
281
- choices=[PAINT_SELECTION, IMAGE_SELECTION, BRUSH_SELECTION],
282
- value=PAINT_SELECTION,
283
- elem_id="control",
284
- )
285
- with gr.Box():
286
- with gr.Group():
287
- run_button = gr.Button(value="Outpaint")
288
- export_button = gr.Button(value="Export")
289
- commit_button = gr.Button(value="")
290
- retry_button = gr.Button(value="⟳")
291
- undo_button = gr.Button(value="↶")
292
- with gr.Column(scale=3, min_width=270):
293
- sd_prompt = gr.Textbox(
294
- label="Prompt", placeholder="input your prompt here", lines=4
295
- )
296
- with gr.Column(scale=2, min_width=150):
297
- with gr.Box():
298
- sd_resize = gr.Checkbox(label="Resize input to 515x512", value=True)
299
- safety_check = gr.Checkbox(label="Enable Safety Checker", value=True)
300
- sd_strength = gr.Slider(
301
- label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01
302
- )
303
- with gr.Column(scale=1, min_width=150):
304
- sd_step = gr.Number(label="Step", value=50, precision=0)
305
- sd_guidance = gr.Number(label="Guidance", value=7.5)
306
- with gr.Row():
307
- with gr.Column(scale=4, min_width=600):
308
- init_mode = gr.Radio(
309
- label="Init mode",
310
- choices=[
311
- "patchmatch",
312
- "edge_pad",
313
- "cv2_ns",
314
- "cv2_telea",
315
- "gaussian",
316
- "perlin",
317
- ],
318
- value="patchmatch",
319
- type="value",
320
- )
321
-
322
- proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
323
- # sd pipeline parameters
324
- with gr.Accordion("Upload image", open=False):
325
- image_box = gr.Image(image_mode="RGBA", source="upload", type="pil")
326
- upload_button = gr.Button(
327
- "Upload"
328
- )
329
- model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
330
- model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
331
- upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
332
- model_output_state = gr.State(value=0)
333
- upload_output_state = gr.State(value=0)
334
- # canvas_state = gr.State({"width":1024,"height":600,"selection_size":384})
335
-
336
- def upload_func(image, state):
337
- pil = image.convert("RGBA")
338
- w, h = pil.size
339
- if w > model["width"] - 100 or h > model["height"] - 100:
340
- pil = contain_func(pil, (model["width"] - 100, model["height"] - 100))
341
- out_buffer = io.BytesIO()
342
- pil.save(out_buffer, format="PNG")
343
- out_buffer.seek(0)
344
- base64_bytes = base64.b64encode(out_buffer.read())
345
- base64_str = base64_bytes.decode("ascii")
346
- return (
347
- gr.update(label=str(state + 1), value=base64_str),
348
- state + 1,
349
- )
350
-
351
- upload_button.click(
352
- fn=upload_func,
353
- inputs=[image_box, upload_output_state],
354
- outputs=[upload_output, upload_output_state],
355
- _js=upload_button_js,
356
- queue=False
357
- )
358
-
359
- def setup_func(token_val, width, height, size):
360
- model["width"] = width
361
- model["height"] = height
362
- model["sel_size"] = size
363
- try:
364
- get_model(token_val)
365
- except Exception as e:
366
- return {token: gr.update(value="Invalid token!")}
367
- return {
368
- token: gr.update(visible=False),
369
- canvas_width: gr.update(visible=False),
370
- canvas_height: gr.update(visible=False),
371
- selection_size: gr.update(visible=False),
372
- setup_button: gr.update(visible=False),
373
- frame: gr.update(visible=True),
374
- upload_button: gr.update(value="Upload"),
375
- }
376
-
377
- # setup_button.click(
378
- # fn=setup_func,
379
- # inputs=[token, canvas_width, canvas_height, selection_size],
380
- # outputs=[
381
- # token,
382
- # canvas_width,
383
- # canvas_height,
384
- # selection_size,
385
- # setup_button,
386
- # frame,
387
- # upload_button,
388
- # ],
389
- # _js=setup_button_js,
390
- # )
391
- run_button.click(
392
- fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
393
- )
394
- retry_button.click(
395
- fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
396
- )
397
- proceed_button.click(
398
- fn=run_outpaint,
399
- inputs=[
400
- model_input,
401
- sd_prompt,
402
- sd_strength,
403
- sd_guidance,
404
- sd_step,
405
- sd_resize,
406
- init_mode,
407
- safety_check,
408
- model_output_state,
409
- ],
410
- outputs=[model_output, sd_prompt, model_output_state],
411
- _js=proceed_button_js,
412
- )
413
- export_button.click(
414
- fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("export")
415
- )
416
- commit_button.click(
417
- fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("commit")
418
- )
419
- undo_button.click(
420
- fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("undo")
421
- )
422
- canvas_control.change(
423
- fn=None, inputs=[canvas_control], outputs=[canvas_control], _js=mode_js,
424
- )
425
-
426
- demo.launch()
427
-
 
1
+ import io
2
+ import base64
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import autocast
8
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
9
+ from PIL import Image
10
+ from PIL import ImageOps
11
+ import gradio as gr
12
+ import base64
13
+ import skimage
14
+ import skimage.measure
15
+ from utils import *
16
+
17
+ try:
18
+ cuda_available = torch.cuda.is_available()
19
+ except:
20
+ cuda_available = False
21
+ finally:
22
+ if cuda_available:
23
+ device = "cuda"
24
+ else:
25
+ device = "cpu"
26
+
27
+ if device != "cuda":
28
+ import contextlib
29
+ autocast = contextlib.nullcontext
30
+
31
+ def load_html():
32
+ body, canvaspy = "", ""
33
+ with open("index.html", encoding="utf8") as f:
34
+ body = f.read()
35
+ with open("canvas.py", encoding="utf8") as f:
36
+ canvaspy = f.read()
37
+ body = body.replace("- paths:\n", "")
38
+ body = body.replace(" - ./canvas.py\n", "")
39
+ body = body.replace("from canvas import InfCanvas", canvaspy)
40
+ return body
41
+
42
+
43
+ def test(x):
44
+ x = load_html()
45
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
46
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
47
+ allow-scripts allow-same-origin allow-popups
48
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
49
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
50
+
51
+
52
+ DEBUG_MODE = False
53
+
54
+ try:
55
+ SAMPLING_MODE = Image.Resampling.LANCZOS
56
+ except Exception as e:
57
+ SAMPLING_MODE = Image.LANCZOS
58
+
59
+ try:
60
+ contain_func = ImageOps.contain
61
+ except Exception as e:
62
+
63
+ def contain_func(image, size, method=SAMPLING_MODE):
64
+ # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
65
+ im_ratio = image.width / image.height
66
+ dest_ratio = size[0] / size[1]
67
+ if im_ratio != dest_ratio:
68
+ if im_ratio > dest_ratio:
69
+ new_height = int(image.height / image.width * size[0])
70
+ if new_height != size[1]:
71
+ size = (size[0], new_height)
72
+ else:
73
+ new_width = int(image.width / image.height * size[1])
74
+ if new_width != size[0]:
75
+ size = (new_width, size[1])
76
+ return image.resize(size, resample=method)
77
+
78
+
79
+ PAINT_SELECTION = "✥"
80
+ IMAGE_SELECTION = "🖼️"
81
+ BRUSH_SELECTION = "🖌️"
82
+ blocks = gr.Blocks()
83
+ model = {}
84
+ model["width"] = 1500
85
+ model["height"] = 600
86
+ model["sel_size"] = 256
87
+
88
+ def get_token():
89
+ token = ""
90
+ token = os.environ.get("hftoken", token)
91
+ return token
92
+
93
+
94
+ def save_token(token):
95
+ return
96
+
97
+
98
+ def get_model(token=""):
99
+ if "text2img" not in model:
100
+ if device=="cuda":
101
+ text2img = StableDiffusionPipeline.from_pretrained(
102
+ "CompVis/stable-diffusion-v1-4",
103
+ revision="fp16",
104
+ torch_dtype=torch.float16,
105
+ use_auth_token=token,
106
+ ).to(device)
107
+ else:
108
+ text2img = StableDiffusionPipeline.from_pretrained(
109
+ "CompVis/stable-diffusion-v1-4",
110
+ use_auth_token=token,
111
+ ).to(device)
112
+ model["safety_checker"] = text2img.safety_checker
113
+ inpaint = DiffusionPipeline.from_pretrained(
114
+ "runwayml/stable-diffusion-inpainting",
115
+ use_auth_token=token,
116
+ ).to(device)
117
+ save_token(token)
118
+ try:
119
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
120
+ 1024 ** 3
121
+ )
122
+ if total_memory <= 5:
123
+ inpaint.enable_attention_slicing()
124
+ except:
125
+ pass
126
+ model["text2img"] = text2img
127
+ model["inpaint"] = inpaint
128
+ return model["text2img"], model["inpaint"]
129
+
130
+
131
+ def run_outpaint(
132
+ sel_buffer_str,
133
+ prompt_text,
134
+ strength,
135
+ guidance,
136
+ step,
137
+ resize_check,
138
+ fill_mode,
139
+ enable_safety,
140
+ state,
141
+ ):
142
+ base64_str = "base64"
143
+ if not cuda_available:
144
+ data = base64.b64decode(str(sel_buffer_str))
145
+ pil = Image.open(io.BytesIO(data))
146
+ sel_buffer = np.array(pil)
147
+ sel_buffer[:, :, 3]=255
148
+ sel_buffer[:, :, 0]=255
149
+ out_pil = Image.fromarray(sel_buffer)
150
+ out_buffer = io.BytesIO()
151
+ out_pil.save(out_buffer, format="PNG")
152
+ out_buffer.seek(0)
153
+ base64_bytes = base64.b64encode(out_buffer.read())
154
+ base64_str = base64_bytes.decode("ascii")
155
+ return (
156
+ gr.update(label=str(state + 1), value=base64_str,),
157
+ gr.update(label="Prompt"),
158
+ state + 1,
159
+ )
160
+ if True:
161
+ text2img, inpaint = get_model()
162
+ if enable_safety:
163
+ text2img.safety_checker = model["safety_checker"]
164
+ inpaint.safety_checker = model["safety_checker"]
165
+ else:
166
+ text2img.safety_checker = lambda images, **kwargs: (images, False)
167
+ inpaint.safety_checker = lambda images, **kwargs: (images, False)
168
+ data = base64.b64decode(str(sel_buffer_str))
169
+ pil = Image.open(io.BytesIO(data))
170
+ # base.output.clear_output()
171
+ # base.read_selection_from_buffer()
172
+ sel_buffer = np.array(pil)
173
+ img = sel_buffer[:, :, 0:3]
174
+ mask = sel_buffer[:, :, -1]
175
+ process_size = 512 if resize_check else model["sel_size"]
176
+ if mask.sum() > 0:
177
+ img, mask = functbl[fill_mode](img, mask)
178
+ init_image = Image.fromarray(img)
179
+ mask = 255 - mask
180
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
181
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
182
+ mask_image = Image.fromarray(mask)
183
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
184
+ with autocast("cuda"):
185
+ images = inpaint(
186
+ prompt=prompt_text,
187
+ image=init_image.resize(
188
+ (process_size, process_size), resample=SAMPLING_MODE
189
+ ),
190
+ mask_image=mask_image.resize((process_size, process_size)),
191
+ strength=strength,
192
+ num_inference_steps=step,
193
+ guidance_scale=guidance,
194
+ )["sample"]
195
+ else:
196
+ with autocast("cuda"):
197
+ images = text2img(
198
+ prompt=prompt_text, height=process_size, width=process_size,
199
+ )["sample"]
200
+ out = sel_buffer.copy()
201
+ out[:, :, 0:3] = np.array(
202
+ images[0].resize(
203
+ (model["sel_size"], model["sel_size"]), resample=SAMPLING_MODE,
204
+ )
205
+ )
206
+ out[:, :, -1] = 255
207
+ out_pil = Image.fromarray(out)
208
+ out_buffer = io.BytesIO()
209
+ out_pil.save(out_buffer, format="PNG")
210
+ out_buffer.seek(0)
211
+ base64_bytes = base64.b64encode(out_buffer.read())
212
+ base64_str = base64_bytes.decode("ascii")
213
+ return (
214
+ gr.update(label=str(state + 1), value=base64_str,),
215
+ gr.update(label="Prompt"),
216
+ state + 1,
217
+ )
218
+
219
+
220
+ def load_js(name):
221
+ if name in ["export", "commit", "undo"]:
222
+ return f"""
223
+ function (x)
224
+ {{
225
+ let frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow;
226
+ frame.postMessage(["click","{name}"], "*");
227
+ return x;
228
+ }}
229
+ """
230
+ ret = ""
231
+ with open(f"./js/{name}.js", "r") as f:
232
+ ret = f.read()
233
+ return ret
234
+
235
+
236
+ upload_button_js = load_js("upload")
237
+ outpaint_button_js = load_js("outpaint")
238
+ proceed_button_js = load_js("proceed")
239
+ mode_js = load_js("mode")
240
+ setup_button_js = load_js("setup")
241
+ if not cuda_available:
242
+ get_model = lambda x:x
243
+ get_model(get_token())
244
+
245
+ with blocks as demo:
246
+ # title
247
+ title = gr.Markdown(
248
+ """
249
+ **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
250
+ """
251
+ )
252
+ # frame
253
+ frame = gr.HTML(test(2), visible=True)
254
+ # setup
255
+ # with gr.Row():
256
+ # token = gr.Textbox(
257
+ # label="Huggingface token",
258
+ # value="",
259
+ # placeholder="Input your token here",
260
+ # )
261
+ # canvas_width = gr.Number(
262
+ # label="Canvas width", value=1024, precision=0, elem_id="canvas_width"
263
+ # )
264
+ # canvas_height = gr.Number(
265
+ # label="Canvas height", value=600, precision=0, elem_id="canvas_height"
266
+ # )
267
+ # selection_size = gr.Number(
268
+ # label="Selection box size", value=256, precision=0, elem_id="selection_size"
269
+ # )
270
+ # setup_button = gr.Button("Start (may take a while)", variant="primary")
271
+ with gr.Row():
272
+ with gr.Column(scale=3, min_width=270):
273
+ # canvas control
274
+ canvas_control = gr.Radio(
275
+ label="Control",
276
+ choices=[PAINT_SELECTION, IMAGE_SELECTION, BRUSH_SELECTION],
277
+ value=PAINT_SELECTION,
278
+ elem_id="control",
279
+ )
280
+ with gr.Box():
281
+ with gr.Group():
282
+ run_button = gr.Button(value="Outpaint")
283
+ export_button = gr.Button(value="Export")
284
+ commit_button = gr.Button(value="✓")
285
+ retry_button = gr.Button(value="⟳")
286
+ undo_button = gr.Button(value="↶")
287
+ with gr.Column(scale=3, min_width=270):
288
+ sd_prompt = gr.Textbox(
289
+ label="Prompt", placeholder="input your prompt here", lines=4
290
+ )
291
+ with gr.Column(scale=2, min_width=150):
292
+ with gr.Box():
293
+ sd_resize = gr.Checkbox(label="Resize input to 515x512", value=True)
294
+ safety_check = gr.Checkbox(label="Enable Safety Checker", value=True)
295
+ sd_strength = gr.Slider(
296
+ label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01
297
+ )
298
+ with gr.Column(scale=1, min_width=150):
299
+ sd_step = gr.Number(label="Step", value=50, precision=0)
300
+ sd_guidance = gr.Number(label="Guidance", value=7.5)
301
+ with gr.Row():
302
+ with gr.Column(scale=4, min_width=600):
303
+ init_mode = gr.Radio(
304
+ label="Init mode",
305
+ choices=[
306
+ "patchmatch",
307
+ "edge_pad",
308
+ "cv2_ns",
309
+ "cv2_telea",
310
+ "gaussian",
311
+ "perlin",
312
+ ],
313
+ value="patchmatch",
314
+ type="value",
315
+ )
316
+
317
+ proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
318
+ # sd pipeline parameters
319
+ with gr.Accordion("Upload image", open=False):
320
+ image_box = gr.Image(image_mode="RGBA", source="upload", type="pil")
321
+ upload_button = gr.Button(
322
+ "Upload"
323
+ )
324
+ model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
325
+ model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
326
+ upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
327
+ model_output_state = gr.State(value=0)
328
+ upload_output_state = gr.State(value=0)
329
+ # canvas_state = gr.State({"width":1024,"height":600,"selection_size":384})
330
+
331
+ def upload_func(image, state):
332
+ pil = image.convert("RGBA")
333
+ w, h = pil.size
334
+ if w > model["width"] - 100 or h > model["height"] - 100:
335
+ pil = contain_func(pil, (model["width"] - 100, model["height"] - 100))
336
+ out_buffer = io.BytesIO()
337
+ pil.save(out_buffer, format="PNG")
338
+ out_buffer.seek(0)
339
+ base64_bytes = base64.b64encode(out_buffer.read())
340
+ base64_str = base64_bytes.decode("ascii")
341
+ return (
342
+ gr.update(label=str(state + 1), value=base64_str),
343
+ state + 1,
344
+ )
345
+
346
+ upload_button.click(
347
+ fn=upload_func,
348
+ inputs=[image_box, upload_output_state],
349
+ outputs=[upload_output, upload_output_state],
350
+ _js=upload_button_js,
351
+ queue=False
352
+ )
353
+
354
+ def setup_func(token_val, width, height, size):
355
+ model["width"] = width
356
+ model["height"] = height
357
+ model["sel_size"] = size
358
+ try:
359
+ get_model(token_val)
360
+ except Exception as e:
361
+ return {token: gr.update(value="Invalid token!")}
362
+ return {
363
+ token: gr.update(visible=False),
364
+ canvas_width: gr.update(visible=False),
365
+ canvas_height: gr.update(visible=False),
366
+ selection_size: gr.update(visible=False),
367
+ setup_button: gr.update(visible=False),
368
+ frame: gr.update(visible=True),
369
+ upload_button: gr.update(value="Upload"),
370
+ }
371
+
372
+ # setup_button.click(
373
+ # fn=setup_func,
374
+ # inputs=[token, canvas_width, canvas_height, selection_size],
375
+ # outputs=[
376
+ # token,
377
+ # canvas_width,
378
+ # canvas_height,
379
+ # selection_size,
380
+ # setup_button,
381
+ # frame,
382
+ # upload_button,
383
+ # ],
384
+ # _js=setup_button_js,
385
+ # )
386
+ run_button.click(
387
+ fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
388
+ )
389
+ retry_button.click(
390
+ fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
391
+ )
392
+ proceed_button.click(
393
+ fn=run_outpaint,
394
+ inputs=[
395
+ model_input,
396
+ sd_prompt,
397
+ sd_strength,
398
+ sd_guidance,
399
+ sd_step,
400
+ sd_resize,
401
+ init_mode,
402
+ safety_check,
403
+ model_output_state,
404
+ ],
405
+ outputs=[model_output, sd_prompt, model_output_state],
406
+ _js=proceed_button_js,
407
+ )
408
+ export_button.click(
409
+ fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("export")
410
+ )
411
+ commit_button.click(
412
+ fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("commit")
413
+ )
414
+ undo_button.click(
415
+ fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("undo")
416
+ )
417
+ canvas_control.change(
418
+ fn=None, inputs=[canvas_control], outputs=[canvas_control], _js=mode_js,
419
+ )
420
+
421
+ demo.launch()
422
+