import io import base64 import os import numpy as np import torch from torch import autocast from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline from PIL import Image from PIL import ImageOps import gradio as gr import base64 import skimage import skimage.measure from utils import * try: cuda_available = torch.cuda.is_available() except: cuda_available = False finally: if cuda_available: device = "cuda" else: device = "cpu" if device != "cuda": import contextlib autocast = contextlib.nullcontext def load_html(): body, canvaspy = "", "" with open("index.html", encoding="utf8") as f: body = f.read() with open("canvas.py", encoding="utf8") as f: canvaspy = f.read() body = body.replace("- paths:\n", "") body = body.replace(" - ./canvas.py\n", "") body = body.replace("from canvas import InfCanvas", canvaspy) return body def test(x): x = load_html() return f"""""" DEBUG_MODE = False try: SAMPLING_MODE = Image.Resampling.LANCZOS except Exception as e: SAMPLING_MODE = Image.LANCZOS try: contain_func = ImageOps.contain except Exception as e: def contain_func(image, size, method=SAMPLING_MODE): # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain im_ratio = image.width / image.height dest_ratio = size[0] / size[1] if im_ratio != dest_ratio: if im_ratio > dest_ratio: new_height = int(image.height / image.width * size[0]) if new_height != size[1]: size = (size[0], new_height) else: new_width = int(image.width / image.height * size[1]) if new_width != size[0]: size = (new_width, size[1]) return image.resize(size, resample=method) PAINT_SELECTION = "✥" IMAGE_SELECTION = "🖼️" BRUSH_SELECTION = "🖌️" blocks = gr.Blocks() model = {} model["width"] = 1500 model["height"] = 600 model["sel_size"] = 256 def get_token(): token = "" token = os.environ.get("hftoken", token) return token def save_token(token): return def get_model(token=""): if "text2img" not in model: if device=="cuda": text2img = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=token, ).to(device) else: text2img = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token=token, ).to(device) model["safety_checker"] = text2img.safety_checker inpaint = StableDiffusionInpaintPipeline( vae=text2img.vae, text_encoder=text2img.text_encoder, tokenizer=text2img.tokenizer, unet=text2img.unet, scheduler=text2img.scheduler, safety_checker=text2img.safety_checker, feature_extractor=text2img.feature_extractor, ).to(device) save_token(token) try: total_memory = torch.cuda.get_device_properties(0).total_memory // ( 1024 ** 3 ) if total_memory <= 5: inpaint.enable_attention_slicing() except: pass model["text2img"] = text2img model["inpaint"] = inpaint return model["text2img"], model["inpaint"] def run_outpaint( sel_buffer_str, prompt_text, strength, guidance, step, resize_check, fill_mode, enable_safety, state, ): base64_str = "base64" if not cuda_available: data = base64.b64decode(str(sel_buffer_str)) pil = Image.open(io.BytesIO(data)) sel_buffer = np.array(pil) sel_buffer[:, :, 3]=255 sel_buffer[:, :, 0]=255 out_pil = Image.fromarray(sel_buffer) out_buffer = io.BytesIO() out_pil.save(out_buffer, format="PNG") out_buffer.seek(0) base64_bytes = base64.b64encode(out_buffer.read()) base64_str = base64_bytes.decode("ascii") return ( gr.update(label=str(state + 1), value=base64_str,), gr.update(label="Prompt"), state + 1, ) if True: text2img, inpaint = get_model() if enable_safety: text2img.safety_checker = model["safety_checker"] inpaint.safety_checker = model["safety_checker"] else: text2img.safety_checker = lambda images, **kwargs: (images, False) inpaint.safety_checker = lambda images, **kwargs: (images, False) data = base64.b64decode(str(sel_buffer_str)) pil = Image.open(io.BytesIO(data)) # base.output.clear_output() # base.read_selection_from_buffer() sel_buffer = np.array(pil) img = sel_buffer[:, :, 0:3] mask = sel_buffer[:, :, -1] process_size = 512 if resize_check else model["sel_size"] if mask.sum() > 0: img, mask = functbl[fill_mode](img, mask) init_image = Image.fromarray(img) mask = 255 - mask mask = skimage.measure.block_reduce(mask, (8, 8), np.max) mask = mask.repeat(8, axis=0).repeat(8, axis=1) mask_image = Image.fromarray(mask) # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) with autocast("cuda"): images = inpaint( prompt=prompt_text, init_image=init_image.resize( (process_size, process_size), resample=SAMPLING_MODE ), mask_image=mask_image.resize((process_size, process_size)), strength=strength, num_inference_steps=step, guidance_scale=guidance, )["sample"] else: with autocast("cuda"): images = text2img( prompt=prompt_text, height=process_size, width=process_size, )["sample"] out = sel_buffer.copy() out[:, :, 0:3] = np.array( images[0].resize( (model["sel_size"], model["sel_size"]), resample=SAMPLING_MODE, ) ) out[:, :, -1] = 255 out_pil = Image.fromarray(out) out_buffer = io.BytesIO() out_pil.save(out_buffer, format="PNG") out_buffer.seek(0) base64_bytes = base64.b64encode(out_buffer.read()) base64_str = base64_bytes.decode("ascii") return ( gr.update(label=str(state + 1), value=base64_str,), gr.update(label="Prompt"), state + 1, ) def load_js(name): if name in ["export", "commit", "undo"]: return f""" function (x) {{ let frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow; frame.postMessage(["click","{name}"], "*"); return x; }} """ ret = "" with open(f"./js/{name}.js", "r") as f: ret = f.read() return ret upload_button_js = load_js("upload") outpaint_button_js = load_js("outpaint") proceed_button_js = load_js("proceed") mode_js = load_js("mode") setup_button_js = load_js("setup") if not cuda_available: get_model = lambda x:x get_model(get_token()) with blocks as demo: # title title = gr.Markdown( """ **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) """ ) # frame frame = gr.HTML(test(2), visible=True) # setup # with gr.Row(): # token = gr.Textbox( # label="Huggingface token", # value="", # placeholder="Input your token here", # ) # canvas_width = gr.Number( # label="Canvas width", value=1024, precision=0, elem_id="canvas_width" # ) # canvas_height = gr.Number( # label="Canvas height", value=600, precision=0, elem_id="canvas_height" # ) # selection_size = gr.Number( # label="Selection box size", value=256, precision=0, elem_id="selection_size" # ) # setup_button = gr.Button("Start (may take a while)", variant="primary") with gr.Row(): with gr.Column(scale=3, min_width=270): # canvas control canvas_control = gr.Radio( label="Control", choices=[PAINT_SELECTION, IMAGE_SELECTION, BRUSH_SELECTION], value=PAINT_SELECTION, elem_id="control", ) with gr.Box(): with gr.Group(): run_button = gr.Button(value="Outpaint") export_button = gr.Button(value="Export") commit_button = gr.Button(value="✓") retry_button = gr.Button(value="⟳") undo_button = gr.Button(value="↶") with gr.Column(scale=3, min_width=270): sd_prompt = gr.Textbox( label="Prompt", placeholder="input your prompt here", lines=4 ) with gr.Column(scale=2, min_width=150): with gr.Box(): sd_resize = gr.Checkbox(label="Resize input to 515x512", value=True) safety_check = gr.Checkbox(label="Enable Safety Checker", value=True) sd_strength = gr.Slider( label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01 ) with gr.Column(scale=1, min_width=150): sd_step = gr.Number(label="Step", value=50, precision=0) sd_guidance = gr.Number(label="Guidance", value=7.5) with gr.Row(): with gr.Column(scale=4, min_width=600): init_mode = gr.Radio( label="Init mode", choices=[ "patchmatch", "edge_pad", "cv2_ns", "cv2_telea", "gaussian", "perlin", ], value="patchmatch", type="value", ) proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE) # sd pipeline parameters with gr.Accordion("Upload image", open=False): image_box = gr.Image(image_mode="RGBA", source="upload", type="pil") upload_button = gr.Button( "Upload" ) model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0") model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input") upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0") model_output_state = gr.State(value=0) upload_output_state = gr.State(value=0) # canvas_state = gr.State({"width":1024,"height":600,"selection_size":384}) def upload_func(image, state): pil = image.convert("RGBA") w, h = pil.size if w > model["width"] - 100 or h > model["height"] - 100: pil = contain_func(pil, (model["width"] - 100, model["height"] - 100)) out_buffer = io.BytesIO() pil.save(out_buffer, format="PNG") out_buffer.seek(0) base64_bytes = base64.b64encode(out_buffer.read()) base64_str = base64_bytes.decode("ascii") return ( gr.update(label=str(state + 1), value=base64_str), state + 1, ) upload_button.click( fn=upload_func, inputs=[image_box, upload_output_state], outputs=[upload_output, upload_output_state], _js=upload_button_js, queue=False ) def setup_func(token_val, width, height, size): model["width"] = width model["height"] = height model["sel_size"] = size try: get_model(token_val) except Exception as e: return {token: gr.update(value="Invalid token!")} return { token: gr.update(visible=False), canvas_width: gr.update(visible=False), canvas_height: gr.update(visible=False), selection_size: gr.update(visible=False), setup_button: gr.update(visible=False), frame: gr.update(visible=True), upload_button: gr.update(value="Upload"), } # setup_button.click( # fn=setup_func, # inputs=[token, canvas_width, canvas_height, selection_size], # outputs=[ # token, # canvas_width, # canvas_height, # selection_size, # setup_button, # frame, # upload_button, # ], # _js=setup_button_js, # ) run_button.click( fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js, ) retry_button.click( fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js, ) proceed_button.click( fn=run_outpaint, inputs=[ model_input, sd_prompt, sd_strength, sd_guidance, sd_step, sd_resize, init_mode, safety_check, model_output_state, ], outputs=[model_output, sd_prompt, model_output_state], _js=proceed_button_js, ) export_button.click( fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("export") ) commit_button.click( fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("commit") ) undo_button.click( fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("undo") ) canvas_control.change( fn=None, inputs=[canvas_control], outputs=[canvas_control], _js=mode_js, ) demo.launch()