|
import huggingface_hub |
|
|
|
huggingface_hub.snapshot_download( |
|
repo_id='h94/IP-Adapter', |
|
allow_patterns=[ |
|
'models/**', |
|
'sdxl_models/**', |
|
], |
|
local_dir='./', |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
import gradio as gr |
|
from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel |
|
from rembg import remove |
|
from PIL import Image |
|
import torch |
|
from ip_adapter import IPAdapterXL |
|
from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images |
|
from PIL import Image, ImageChops, ImageEnhance |
|
import numpy as np |
|
|
|
import os |
|
import glob |
|
import torch |
|
import cv2 |
|
import argparse |
|
|
|
import DPT.util.io |
|
|
|
from torchvision.transforms import Compose |
|
|
|
from DPT.dpt.models import DPTDepthModel |
|
from DPT.dpt.midas_net import MidasNet_large |
|
from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet |
|
|
|
""" |
|
Get ZeST Ready |
|
""" |
|
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
|
image_encoder_path = "models/image_encoder" |
|
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" |
|
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" |
|
device = "cuda" |
|
torch.cuda.empty_cache() |
|
|
|
|
|
controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device) |
|
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( |
|
base_model_path, |
|
controlnet=controlnet, |
|
use_safetensors=True, |
|
torch_dtype=torch.float16, |
|
add_watermarker=False, |
|
).to(device) |
|
pipe.unet = register_cross_attention_hook(pipe.unet) |
|
|
|
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device) |
|
|
|
|
|
""" |
|
Get Depth Model Ready |
|
""" |
|
model_path = "DPT/weights/dpt_hybrid-midas-501f0c75.pt" |
|
net_w = net_h = 384 |
|
model = DPTDepthModel( |
|
path=model_path, |
|
backbone="vitb_rn50_384", |
|
non_negative=True, |
|
enable_attention_hooks=False, |
|
) |
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
|
transform = Compose( |
|
[ |
|
Resize( |
|
net_w, |
|
net_h, |
|
resize_target=None, |
|
keep_aspect_ratio=True, |
|
ensure_multiple_of=32, |
|
resize_method="minimal", |
|
image_interpolation_method=cv2.INTER_CUBIC, |
|
), |
|
normalization, |
|
PrepareForNet(), |
|
] |
|
) |
|
|
|
model.eval() |
|
|
|
|
|
def greet(input_image, material_exemplar): |
|
|
|
""" |
|
Compute depth map from input_image |
|
""" |
|
|
|
img = np.array(input_image) |
|
|
|
img_input = transform({"image": img})["image"] |
|
|
|
|
|
with torch.no_grad(): |
|
sample = torch.from_numpy(img_input).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
prediction = model.forward(sample) |
|
prediction = ( |
|
torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=img.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
.squeeze() |
|
.cpu() |
|
.numpy() |
|
) |
|
|
|
depth_min = prediction.min() |
|
depth_max = prediction.max() |
|
bits = 2 |
|
max_val = (2 ** (8 * bits)) - 1 |
|
|
|
if depth_max - depth_min > np.finfo("float").eps: |
|
out = max_val * (prediction - depth_min) / (depth_max - depth_min) |
|
else: |
|
out = np.zeros(prediction.shape, dtype=depth.dtype) |
|
|
|
out = (out / 256).astype('uint8') |
|
depth_map = Image.fromarray(out).resize((1024, 1024)) |
|
|
|
|
|
""" |
|
Process foreground decolored image |
|
""" |
|
rm_bg = remove(input_image) |
|
target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB') |
|
mask_target_img = ImageChops.lighter(input_image, target_mask) |
|
invert_target_mask = ImageChops.invert(target_mask) |
|
gray_target_image = input_image.convert('L').convert('RGB') |
|
gray_target_image = ImageEnhance.Brightness(gray_target_image) |
|
factor = 1.0 |
|
gray_target_image = gray_target_image.enhance(factor) |
|
grayscale_img = ImageChops.darker(gray_target_image, target_mask) |
|
img_black_mask = ImageChops.darker(input_image, invert_target_mask) |
|
grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img) |
|
init_img = grayscale_init_img |
|
|
|
""" |
|
Process material exemplar and resize all images |
|
""" |
|
ip_image = material_exemplar.resize((1024, 1024)) |
|
init_img = init_img.resize((1024,1024)) |
|
mask = target_mask.resize((1024, 1024)) |
|
|
|
|
|
num_samples = 1 |
|
images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42) |
|
|
|
return images[0] |
|
|
|
css = """ |
|
#col-container{ |
|
margin: 0 auto; |
|
max-width: 960px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown(""" |
|
# ZeST: Zero-Shot Material Transfer from a Single Image |
|
Upload two images -- input image and material exemplar. ZeST extracts the material from the exemplar and cast it onto the input image following the original lighting cues. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
input_image = gr.Image(type="pil", label="input image") |
|
input_image2 = gr.Image(type="pil", label = "material examplar") |
|
submit_btn = gr.Button("Submit") |
|
gr.Examples( |
|
examples = [["demo_assets/input_imgs/pumpkin.png", "demo_assets/material_exemplars/cup_glaze.png"]], |
|
inputs = [input_image, input_image2] |
|
) |
|
with gr.Column(): |
|
output_image = gr.Image(label="transfer result") |
|
submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image]) |
|
|
|
demo.queue().launch() |
|
|