lnyan's picture
Update app
89ee8cb
raw
history blame
14.8 kB
import io
import base64
import os
import numpy as np
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
from PIL import Image
from PIL import ImageOps
import gradio as gr
import base64
import skimage
import skimage.measure
from utils import *
try:
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
finally:
if cuda_available:
device = "cuda"
else:
device = "cpu"
if device != "cuda":
import contextlib
autocast = contextlib.nullcontext
def load_html():
body, canvaspy = "", ""
with open("index.html", encoding="utf8") as f:
body = f.read()
with open("canvas.py", encoding="utf8") as f:
canvaspy = f.read()
body = body.replace("- paths:\n", "")
body = body.replace(" - ./canvas.py\n", "")
body = body.replace("from canvas import InfCanvas", canvaspy)
return body
def test(x):
x = load_html()
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
DEBUG_MODE = False
try:
SAMPLING_MODE = Image.Resampling.LANCZOS
except Exception as e:
SAMPLING_MODE = Image.LANCZOS
try:
contain_func = ImageOps.contain
except Exception as e:
def contain_func(image, size, method=SAMPLING_MODE):
# from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
im_ratio = image.width / image.height
dest_ratio = size[0] / size[1]
if im_ratio != dest_ratio:
if im_ratio > dest_ratio:
new_height = int(image.height / image.width * size[0])
if new_height != size[1]:
size = (size[0], new_height)
else:
new_width = int(image.width / image.height * size[1])
if new_width != size[0]:
size = (new_width, size[1])
return image.resize(size, resample=method)
PAINT_SELECTION = "✥"
IMAGE_SELECTION = "🖼️"
BRUSH_SELECTION = "🖌️"
blocks = gr.Blocks()
model = {}
model["width"] = 1500
model["height"] = 600
model["sel_size"] = 256
def get_token():
token = ""
token = os.environ.get("hftoken", token)
return token
def save_token(token):
return
def get_model(token=""):
if "text2img" not in model:
if device=="cuda":
text2img = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=token,
).to(device)
else:
text2img = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=token,
).to(device)
model["safety_checker"] = text2img.safety_checker
inpaint = StableDiffusionInpaintPipeline(
vae=text2img.vae,
text_encoder=text2img.text_encoder,
tokenizer=text2img.tokenizer,
unet=text2img.unet,
scheduler=text2img.scheduler,
safety_checker=text2img.safety_checker,
feature_extractor=text2img.feature_extractor,
).to(device)
save_token(token)
try:
total_memory = torch.cuda.get_device_properties(0).total_memory // (
1024 ** 3
)
if total_memory <= 5:
inpaint.enable_attention_slicing()
except:
pass
model["text2img"] = text2img
model["inpaint"] = inpaint
return model["text2img"], model["inpaint"]
def run_outpaint(
sel_buffer_str,
prompt_text,
strength,
guidance,
step,
resize_check,
fill_mode,
enable_safety,
state,
):
base64_str = "base64"
if not cuda_available:
data = base64.b64decode(str(sel_buffer_str))
pil = Image.open(io.BytesIO(data))
sel_buffer = np.array(pil)
sel_buffer[:, :, 3]=255
sel_buffer[:, :, 0]=255
out_pil = Image.fromarray(sel_buffer)
out_buffer = io.BytesIO()
out_pil.save(out_buffer, format="PNG")
out_buffer.seek(0)
base64_bytes = base64.b64encode(out_buffer.read())
base64_str = base64_bytes.decode("ascii")
return (
gr.update(label=str(state + 1), value=base64_str,),
gr.update(label="Prompt"),
state + 1,
)
if True:
text2img, inpaint = get_model()
if enable_safety:
text2img.safety_checker = model["safety_checker"]
inpaint.safety_checker = model["safety_checker"]
else:
text2img.safety_checker = lambda images, **kwargs: (images, False)
inpaint.safety_checker = lambda images, **kwargs: (images, False)
data = base64.b64decode(str(sel_buffer_str))
pil = Image.open(io.BytesIO(data))
# base.output.clear_output()
# base.read_selection_from_buffer()
sel_buffer = np.array(pil)
img = sel_buffer[:, :, 0:3]
mask = sel_buffer[:, :, -1]
process_size = 512 if resize_check else model["sel_size"]
if mask.sum() > 0:
img, mask = functbl[fill_mode](img, mask)
init_image = Image.fromarray(img)
mask = 255 - mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
mask_image = Image.fromarray(mask)
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
with autocast("cuda"):
images = inpaint(
prompt=prompt_text,
init_image=init_image.resize(
(process_size, process_size), resample=SAMPLING_MODE
),
mask_image=mask_image.resize((process_size, process_size)),
strength=strength,
num_inference_steps=step,
guidance_scale=guidance,
)["sample"]
else:
with autocast("cuda"):
images = text2img(
prompt=prompt_text, height=process_size, width=process_size,
)["sample"]
out = sel_buffer.copy()
out[:, :, 0:3] = np.array(
images[0].resize(
(model["sel_size"], model["sel_size"]), resample=SAMPLING_MODE,
)
)
out[:, :, -1] = 255
out_pil = Image.fromarray(out)
out_buffer = io.BytesIO()
out_pil.save(out_buffer, format="PNG")
out_buffer.seek(0)
base64_bytes = base64.b64encode(out_buffer.read())
base64_str = base64_bytes.decode("ascii")
return (
gr.update(label=str(state + 1), value=base64_str,),
gr.update(label="Prompt"),
state + 1,
)
def load_js(name):
if name in ["export", "commit", "undo"]:
return f"""
function (x)
{{
let frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow;
frame.postMessage(["click","{name}"], "*");
return x;
}}
"""
ret = ""
with open(f"./js/{name}.js", "r") as f:
ret = f.read()
return ret
upload_button_js = load_js("upload")
outpaint_button_js = load_js("outpaint")
proceed_button_js = load_js("proceed")
mode_js = load_js("mode")
setup_button_js = load_js("setup")
if not cuda_available:
get_model = lambda x:x
get_model(get_token())
with blocks as demo:
# title
title = gr.Markdown(
"""
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
"""
)
# frame
frame = gr.HTML(test(2), visible=True)
# setup
# with gr.Row():
# token = gr.Textbox(
# label="Huggingface token",
# value="",
# placeholder="Input your token here",
# )
# canvas_width = gr.Number(
# label="Canvas width", value=1024, precision=0, elem_id="canvas_width"
# )
# canvas_height = gr.Number(
# label="Canvas height", value=600, precision=0, elem_id="canvas_height"
# )
# selection_size = gr.Number(
# label="Selection box size", value=256, precision=0, elem_id="selection_size"
# )
# setup_button = gr.Button("Start (may take a while)", variant="primary")
with gr.Row():
with gr.Column(scale=3, min_width=270):
# canvas control
canvas_control = gr.Radio(
label="Control",
choices=[PAINT_SELECTION, IMAGE_SELECTION, BRUSH_SELECTION],
value=PAINT_SELECTION,
elem_id="control",
)
with gr.Box():
with gr.Group():
run_button = gr.Button(value="Outpaint")
export_button = gr.Button(value="Export")
commit_button = gr.Button(value="✓")
retry_button = gr.Button(value="⟳")
undo_button = gr.Button(value="↶")
with gr.Column(scale=3, min_width=270):
sd_prompt = gr.Textbox(
label="Prompt", placeholder="input your prompt here", lines=4
)
with gr.Column(scale=2, min_width=150):
with gr.Box():
sd_resize = gr.Checkbox(label="Resize input to 515x512", value=True)
safety_check = gr.Checkbox(label="Enable Safety Checker", value=True)
sd_strength = gr.Slider(
label="Strength", minimum=0.0, maximum=1.0, value=0.75, step=0.01
)
with gr.Column(scale=1, min_width=150):
sd_step = gr.Number(label="Step", value=50, precision=0)
sd_guidance = gr.Number(label="Guidance", value=7.5)
with gr.Row():
with gr.Column(scale=4, min_width=600):
init_mode = gr.Radio(
label="Init mode",
choices=[
"patchmatch",
"edge_pad",
"cv2_ns",
"cv2_telea",
"gaussian",
"perlin",
],
value="patchmatch",
type="value",
)
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
# sd pipeline parameters
with gr.Accordion("Upload image", open=False):
image_box = gr.Image(image_mode="RGBA", source="upload", type="pil")
upload_button = gr.Button(
"Upload"
)
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
model_output_state = gr.State(value=0)
upload_output_state = gr.State(value=0)
# canvas_state = gr.State({"width":1024,"height":600,"selection_size":384})
def upload_func(image, state):
pil = image.convert("RGBA")
w, h = pil.size
if w > model["width"] - 100 or h > model["height"] - 100:
pil = contain_func(pil, (model["width"] - 100, model["height"] - 100))
out_buffer = io.BytesIO()
pil.save(out_buffer, format="PNG")
out_buffer.seek(0)
base64_bytes = base64.b64encode(out_buffer.read())
base64_str = base64_bytes.decode("ascii")
return (
gr.update(label=str(state + 1), value=base64_str),
state + 1,
)
upload_button.click(
fn=upload_func,
inputs=[image_box, upload_output_state],
outputs=[upload_output, upload_output_state],
_js=upload_button_js,
queue=False
)
def setup_func(token_val, width, height, size):
model["width"] = width
model["height"] = height
model["sel_size"] = size
try:
get_model(token_val)
except Exception as e:
return {token: gr.update(value="Invalid token!")}
return {
token: gr.update(visible=False),
canvas_width: gr.update(visible=False),
canvas_height: gr.update(visible=False),
selection_size: gr.update(visible=False),
setup_button: gr.update(visible=False),
frame: gr.update(visible=True),
upload_button: gr.update(value="Upload"),
}
# setup_button.click(
# fn=setup_func,
# inputs=[token, canvas_width, canvas_height, selection_size],
# outputs=[
# token,
# canvas_width,
# canvas_height,
# selection_size,
# setup_button,
# frame,
# upload_button,
# ],
# _js=setup_button_js,
# )
run_button.click(
fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
)
retry_button.click(
fn=None, inputs=[run_button], outputs=[run_button], _js=outpaint_button_js,
)
proceed_button.click(
fn=run_outpaint,
inputs=[
model_input,
sd_prompt,
sd_strength,
sd_guidance,
sd_step,
sd_resize,
init_mode,
safety_check,
model_output_state,
],
outputs=[model_output, sd_prompt, model_output_state],
_js=proceed_button_js,
)
export_button.click(
fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("export")
)
commit_button.click(
fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("commit")
)
undo_button.click(
fn=None, inputs=[export_button], outputs=[export_button], _js=load_js("undo")
)
canvas_control.change(
fn=None, inputs=[canvas_control], outputs=[canvas_control], _js=mode_js,
)
demo.launch()