Spaces:
Runtime error
Runtime error
# This file is adapted from gradio_*.py in https://github.com/lllyasviel/ControlNet/tree/f4748e3630d8141d7765e2bd9b1e348f47847707 | |
# The original license file is LICENSE.ControlNet in this repo. | |
from __future__ import annotations | |
import gc | |
import pathlib | |
import sys | |
import cv2 | |
import numpy as np | |
import PIL.Image | |
import torch | |
from diffusers import (ControlNetModel, DiffusionPipeline, | |
StableDiffusionControlNetPipeline, | |
UniPCMultistepScheduler) | |
repo_dir = pathlib.Path(__file__).parent | |
submodule_dir = repo_dir / 'ControlNet' | |
sys.path.append(submodule_dir.as_posix()) | |
from annotator.mlsd import apply_mlsd | |
from annotator.uniformer import apply_uniformer | |
from annotator.util import HWC3, resize_image | |
CONTROLNET_MODEL_IDS = { | |
'hough': 'lllyasviel/sd-controlnet-mlsd', | |
} | |
def download_all_controlnet_weights() -> None: | |
for model_id in CONTROLNET_MODEL_IDS.values(): | |
ControlNetModel.from_pretrained(model_id) | |
class Model: | |
def __init__(self, | |
base_model_id: str = 'runwayml/stable-diffusion-v1-5', | |
task_name: str = 'hough'): | |
self.device = torch.device( | |
'cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.base_model_id = '' | |
self.task_name = '' | |
self.pipe = self.load_pipe(base_model_id, task_name) | |
def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline: | |
if base_model_id == self.base_model_id and task_name == self.task_name and hasattr( | |
self, 'pipe'): | |
return self.pipe | |
model_id = CONTROLNET_MODEL_IDS[task_name] | |
controlnet = ControlNetModel.from_pretrained(model_id, | |
torch_dtype=torch.float16) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
base_model_id, | |
safety_checker=None, | |
controlnet=controlnet, | |
torch_dtype=torch.float16) | |
pipe.scheduler = UniPCMultistepScheduler.from_config( | |
pipe.scheduler.config) | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.to(self.device) | |
torch.cuda.empty_cache() | |
gc.collect() | |
self.base_model_id = base_model_id | |
self.task_name = task_name | |
return pipe | |
def set_base_model(self, base_model_id: str) -> str: | |
if not base_model_id or base_model_id == self.base_model_id: | |
return self.base_model_id | |
del self.pipe | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
self.pipe = self.load_pipe(base_model_id, self.task_name) | |
except Exception: | |
self.pipe = self.load_pipe(self.base_model_id, self.task_name) | |
return self.base_model_id | |
def load_controlnet_weight(self, task_name: str) -> None: | |
if task_name == self.task_name: | |
return | |
del self.pipe.controlnet | |
torch.cuda.empty_cache() | |
gc.collect() | |
model_id = CONTROLNET_MODEL_IDS[task_name] | |
controlnet = ControlNetModel.from_pretrained(model_id, | |
torch_dtype=torch.float16) | |
controlnet.to(self.device) | |
torch.cuda.empty_cache() | |
gc.collect() | |
self.pipe.controlnet = controlnet | |
self.task_name = task_name | |
def get_prompt(self, prompt: str, additional_prompt: str) -> str: | |
if not prompt: | |
prompt = additional_prompt | |
else: | |
prompt = f'{prompt}, {additional_prompt}' | |
return prompt | |
def run_pipe( | |
self, | |
prompt: str, | |
negative_prompt: str, | |
control_image: PIL.Image.Image, | |
num_images: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
) -> list[PIL.Image.Image]: | |
if seed == -1: | |
seed = np.random.randint(0, np.iinfo(np.int64).max) | |
generator = torch.Generator().manual_seed(seed) | |
return self.pipe(prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_images, | |
num_inference_steps=num_steps, | |
generator=generator, | |
image=control_image).images | |
def preprocess_hough( | |
input_image: np.ndarray, | |
image_resolution: int, | |
detect_resolution: int, | |
value_threshold: float, | |
distance_threshold: float, | |
) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
input_image = HWC3(input_image) | |
control_image = apply_mlsd( | |
resize_image(input_image, detect_resolution), value_threshold, | |
distance_threshold) | |
control_image = HWC3(control_image) | |
image = resize_image(input_image, image_resolution) | |
H, W = image.shape[:2] | |
control_image = cv2.resize(control_image, (W, H), | |
interpolation=cv2.INTER_NEAREST) | |
vis_control_image = 255 - cv2.dilate( | |
control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1) | |
return PIL.Image.fromarray(control_image), PIL.Image.fromarray( | |
vis_control_image) | |
def process_hough( | |
self, | |
input_image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
detect_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
value_threshold: float, | |
distance_threshold: float, | |
) -> list[PIL.Image.Image]: | |
control_image, vis_control_image = self.preprocess_hough( | |
input_image=input_image, | |
image_resolution=image_resolution, | |
detect_resolution=detect_resolution, | |
value_threshold=value_threshold, | |
distance_threshold=distance_threshold, | |
) | |
self.load_controlnet_weight('hough') | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return results | |