ControlNet / app.py
RamAnanth1's picture
Update app.py
29e7450
raw
history blame
7.79 kB
import cv2
import einops
import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import UniPCMultistepScheduler
from PIL import Image
from controlnet_aux import OpenposeDetector
# Constants
low_threshold = 100
high_threshold = 200
# Models
controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet_canny, safety_checker=None, torch_dtype=torch.float16
)
pipe_canny.scheduler = UniPCMultistepScheduler.from_config(pipe_canny.scheduler.config)
# This command loads the individual model components on GPU on-demand. So, we don't
# need to explicitly call pipe.to("cuda").
pipe_canny.enable_model_cpu_offload()
pipe_canny.enable_xformers_memory_efficient_attention()
# Generator seed,
generator = torch.manual_seed(0)
pose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
controlnet_pose = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
)
pipe_pose = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet_pose, safety_checker=None, torch_dtype=torch.float16
)
pipe_pose.scheduler = UniPCMultistepScheduler.from_config(pipe_pose.scheduler.config)
# This command loads the individual model components on GPU on-demand. So, we don't
# need to explicitly call pipe.to("cuda").
pipe_pose.enable_model_cpu_offload()
# xformers
pipe_pose.enable_xformers_memory_efficient_attention()
def get_canny_filter(image):
if not isinstance(image, np.ndarray):
image = np.array(image)
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:
, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def get_pose(image):
return pose_model(image)
def process(input_image, prompt, input_control):
# TODO: Add other control tasks
if input_control == "Scribble":
return process_canny(input_image, prompt)
elif input_control == "Pose":
return process_pose(input_image, prompt)
return process_canny(input_image, prompt)
def process_canny(input_image, prompt):
canny_image = get_canny_filter(input_image)
output = pipe_canny(
prompt,
canny_image,
generator=generator,
num_images_per_prompt=1,
num_inference_steps=20,
)
return [canny_image,output.images[0]]
def process_pose(input_image, prompt):
pose_image = get_pose(input_image)
output = pipe_pose(
prompt,
pose_image,
generator=generator,
num_images_per_prompt=1,
num_inference_steps=20,
)
return [pose_image,output.images[0]]
block = gr.Blocks().queue()
control_task_list = [
"Canny Edge Map",
"Scribble",
"Pose"
]
with block:
gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This is an unofficial demo for ControlNet, which is a neural network structure to control diffusion models by adding extra conditions such as canny edge detection. The demo is based on the <a href="https://github.com/lllyasviel/ControlNet" style="text-decoration: underline;" target="_blank"> Github </a> implementation.
</p>
''')
gr.HTML("<p>You can duplicate this Space to run it privately without a queue and load additional checkpoints. : <a style='display:inline-block' href='https://huggingface.co/spaces/RamAnanth1/ControlNet?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a> </p>")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="numpy")
input_control = gr.Dropdown(control_task_list, value="Scribble", label="Control Task")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
eta = gr.Slider(label="eta (DDIM)", minimum=0.0,maximum =1.0, value=0.0, step=0.1)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
examples_list = [
[
"bird.png",
"bird",
"Canny Edge Map"
],
# [
# "turtle.png",
# "turtle",
# "Scribble",
# "best quality, extremely detailed",
# 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
# 1,
# 512,
# 20,
# 9.0,
# 123490213,
# 0.0,
# 100,
# 200
# ],
# [
# "pose1.png",
# "Chef in the Kitchen",
# "Pose",
# "best quality, extremely detailed",
# 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
# 1,
# 512,
# 20,
# 9.0,
# 123490213,
# 0.0,
# 100,
# 200
# ]
]
examples = gr.Examples(examples=examples_list,inputs = [input_image, prompt, input_control], outputs = [result_gallery], cache_examples = True, fn = process)
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=RamAnanth1.ControlNet)")
block.launch(debug = True)