face-to-all-api / app.py
jbilcke-hf's picture
jbilcke-hf HF staff
Update app.py
7c58fd1 verified
import gradio as gr
import torch
torch.jit.script = lambda f: f
import timm
import time
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
import lora
import copy
import json
import gc
import random
from urllib.parse import quote
import gdown
import os
import diffusers
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler
import cv2
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import base64
import re
from insightface.app import FaceAnalysis
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
from controlnet_aux import ZoeDetector
from compel import Compel, ReturnedEmbeddingsType
#import spaces
#from gradio_imageslider import ImageSlider
# Regex pattern to match data URI scheme
data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
def readb64(b64):
# Remove any data URI scheme prefix with regex
b64 = data_uri_pattern.sub("", b64)
# Decode and open the image with PIL
img = Image.open(BytesIO(base64.b64decode(b64)))
return img
# convert from PIL to base64
def writeb64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
b64image = base64.b64encode(buffered.getvalue())
b64image_str = b64image.decode("utf-8")
return b64image_str
with open("sdxl_loras.json", "r") as file:
data = json.load(file)
sdxl_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
"likes": item.get("likes", 0),
"downloads": item.get("downloads", 0),
"is_nc": item.get("is_nc", False),
"new": item.get("new", False),
}
for item in data
]
with open("defaults_data.json", "r") as file:
lora_defaults = json.load(file)
def getLoraByRepoName(repo_name):
# Loop through each lora in sdxl_loras_raw
for lora in sdxl_loras_raw:
if lora["repo"] == repo_name:
# Return the lora if the repo name matches
return lora
# If no match is found, return the first lora in the array
return sdxl_loras_raw[0] if sdxl_loras_raw else None
# Return the default values specific to this particular
def getLoraDefaultsByRepoName(repo_name):
# Loop through each lora in sdxl_loras_raw
for lora_defs in lora_defaults:
if lora_defs["model"] == repo_name:
# Return the lora if the repo name matches
return lora_defs
# If no match is found, return None
return None
device = "cuda"
state_dicts = {}
for item in sdxl_loras_raw:
saved_name = hf_hub_download(item["repo"], item["weights"])
if not saved_name.endswith('.safetensors'):
state_dict = torch.load(saved_name)
else:
state_dict = load_file(saved_name)
state_dicts[item["repo"]] = {
"saved_name": saved_name,
"state_dict": state_dict
}
sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
# download models
hf_hub_download(
repo_id="InstantX/InstantID",
filename="ControlNetModel/config.json",
local_dir="/data/checkpoints",
)
hf_hub_download(
repo_id="InstantX/InstantID",
filename="ControlNetModel/diffusion_pytorch_model.safetensors",
local_dir="/data/checkpoints",
)
hf_hub_download(
repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
)
hf_hub_download(
repo_id="latent-consistency/lcm-lora-sdxl",
filename="pytorch_lora_weights.safetensors",
local_dir="/data/checkpoints",
)
# download antelopev2
if not os.path.exists("/data/antelopev2.zip"):
gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True)
os.system("unzip /data/antelopev2.zip -d /data/models/")
app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
# prepare models under ./checkpoints
face_adapter = f'/data/checkpoints/ip-adapter.bin'
controlnet_path = f'/data/checkpoints/ControlNetModel'
# load IdentityNet
st = time.time()
identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
et = time.time()
elapsed_time = et - st
print('Loading ControlNet took: ', elapsed_time, 'seconds')
st = time.time()
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
et = time.time()
elapsed_time = et - st
print('Loading VAE took: ', elapsed_time, 'seconds')
st = time.time()
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("rubbrband/albedobaseXL_v21",
vae=vae,
controlnet=[identitynet, zoedepthnet],
torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter_instantid(face_adapter)
pipe.set_ip_adapter_scale(0.8)
et = time.time()
elapsed_time = et - st
print('Loading pipeline took: ', elapsed_time, 'seconds')
st = time.time()
compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
et = time.time()
elapsed_time = et - st
print('Loading Compel took: ', elapsed_time, 'seconds')
st = time.time()
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
et = time.time()
elapsed_time = et - st
print('Loading Zoe took: ', elapsed_time, 'seconds')
zoe.to(device)
pipe.to(device)
last_lora = ""
last_fused = False
def center_crop_image_as_square(img):
square_size = min(img.size)
left = (img.width - square_size) / 2
top = (img.height - square_size) / 2
right = (img.width + square_size) / 2
bottom = (img.height + square_size) / 2
img_cropped = img.crop((left, top, right, bottom))
return img_cropped
def merge_incompatible_lora(full_path_lora, lora_scale):
for weights_file in [full_path_lora]:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = lora_scale
lora_model, weights_sd = lora.create_network_from_weights(
multiplier,
full_path_lora,
pipe.vae,
pipe.text_encoder,
pipe.unet,
for_inference=True,
)
lora_model.merge_to(
pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
)
del weights_sd
del lora_model
#@spaces.GPU
def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, lora, full_path_lora, lora_scale, st):
et = time.time()
elapsed_time = et - st
print('Getting into the decorated function took: ', elapsed_time, 'seconds')
global last_fused, last_lora
print("Last LoRA: ", last_lora)
print("Current LoRA: ", lora["repo"])
print("Last fused: ", last_fused)
#prepare face zoe
st = time.time()
with torch.no_grad():
image_zoe = zoe(face_image)
width, height = face_kps.size
images = [face_kps, image_zoe.resize((height, width))]
et = time.time()
elapsed_time = et - st
print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
if last_lora != lora["repo"]:
if(last_fused):
st = time.time()
pipe.unfuse_lora()
pipe.unload_lora_weights()
et = time.time()
elapsed_time = et - st
print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
st = time.time()
pipe.load_lora_weights(full_path_lora)
pipe.fuse_lora(lora_scale)
et = time.time()
elapsed_time = et - st
print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
last_fused = True
if(lora["is_pivotal"]):
#Add the textual inversion embeddings from pivotal tuning models
text_embedding_name = lora["text_embedding_weights"]
embedding_path = hf_hub_download(repo_id=lora["repo"], filename=text_embedding_name, repo_type="model")
state_dict_embedding = load_file(embedding_path)
pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
print("Processing prompt...")
st = time.time()
conditioning, pooled = compel(prompt)
print("Processing prompt...")
st = time.time()
conditioning, pooled = compel(prompt)
if(negative):
negative_conditioning, negative_pooled = compel(negative)
else:
negative_conditioning, negative_pooled = None, None
et = time.time()
elapsed_time = et - st
print('Prompt processing took: ', elapsed_time, 'seconds')
print("Processing image...")
st = time.time()
image = pipe(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=negative_conditioning,
negative_pooled_prompt_embeds=negative_pooled,
width=1024,
height=1024,
image_embeds=face_emb,
image=face_image,
strength=1-image_strength,
control_image=images,
num_inference_steps=20,
guidance_scale = guidance_scale,
controlnet_conditioning_scale=[face_strength, depth_control_scale],
).images[0]
et = time.time()
elapsed_time = et - st
print('Image processing took: ', elapsed_time, 'seconds')
last_lora = lora["repo"]
return image
def run_lora(face_image, prompt, negative, lora_weight, face_strength, image_strength, guidance_scale, depth_control_scale, lora_repo_name):
# get the lora and its default values
lora = getLoraByRepoName(lora_repo_name)
default_values = getLoraDefaultsByRepoName(lora_repo_name)
if not lora_repo_name:
raise gr.Error("You must input a LoRA repo name")
st = time.time()
face_image = readb64(face_image)
face_image = center_crop_image_as_square(face_image)
# this is temporary, just to debug
# return writeb64(face_image)
try:
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
face_emb = face_info['embedding']
face_kps = draw_kps(face_image, face_info['kps'])
except:
raise gr.Error("No face found in your image. Only face images work here. Try again")
et = time.time()
elapsed_time = et - st
print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
st = time.time()
if default_values:
prompt_full = default_values.get("prompt", None)
if(prompt_full):
prompt = prompt_full.replace("<subject>", prompt)
print("Prompt:", prompt)
if(prompt == ""):
prompt = "a person"
if negative == "":
negative = None
weight_name = lora["weights"]
full_path_lora = state_dicts[lora["repo"]]["saved_name"]
#loaded_state_dict = copy.deepcopy(state_dicts[lora_repo_name]["state_dict"])
cross_attention_kwargs = None
et = time.time()
elapsed_time = et - st
print('Small content processing took: ', elapsed_time, 'seconds')
st = time.time()
image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, lora, full_path_lora, lora_weight, st)
image_base64 = writeb64(image)
return image_base64
with gr.Blocks() as demo:
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically generate an image from a face.</p>
<p style="color: black;">Interested in using it through an UI? Please use the <a href="https://huggingface.co/spaces/multimodalart/face-to-all" target="_blank">original space</a>, thank you!</p>
</div>
</div>""")
input_image_base64 = gr.Text()
lora_repo_name = gr.Text(label="name of the LoRA repo nape on HF")
prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
negative = gr.Textbox(label="Negative Prompt")
# initial value was 0.9
lora_weight = gr.Slider(0, 10, value=6, step=0.1, label="LoRA weight")
# initial value was 0.85
face_strength = gr.Slider(0, 1, value=0.75, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
# initial value was 0.15
image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
# initial value was 7
guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
# initial value was 1
depth_control_scale = gr.Slider(0, 4, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
button = gr.Button(value="Generate")
output_image_base64 = gr.Text()
button.click(
fn=run_lora,
inputs=[
input_image_base64,
prompt,
negative,
lora_weight,
face_strength,
image_strength,
guidance_scale,
depth_control_scale,
lora_repo_name
],
outputs=output_image_base64,
api_name='run',
)
demo.queue(max_size=20)
demo.launch()