|
|
|
import os |
|
import json |
|
import torch |
|
import random |
|
|
|
import gradio as gr |
|
from glob import glob |
|
from omegaconf import OmegaConf |
|
from datetime import datetime |
|
from safetensors import safe_open |
|
|
|
from diffusers import AutoencoderKL |
|
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
from animatediff.models.unet import UNet3DConditionModel |
|
from animatediff.pipelines.pipeline_animation import AnimationPipeline |
|
from animatediff.utils.util import save_videos_grid |
|
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint |
|
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora |
|
|
|
|
|
sample_idx = 0 |
|
scheduler_dict = { |
|
"Euler": EulerDiscreteScheduler, |
|
"PNDM": PNDMScheduler, |
|
"DDIM": DDIMScheduler, |
|
} |
|
|
|
css = """ |
|
.toolbutton { |
|
margin-buttom: 0em 0em 0em 0em; |
|
max-width: 2.5em; |
|
min-width: 2.5em !important; |
|
height: 2.5em; |
|
} |
|
""" |
|
|
|
class AnimateController: |
|
def __init__(self): |
|
|
|
|
|
self.basedir = os.getcwd() |
|
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") |
|
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
|
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") |
|
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
|
self.savedir_sample = os.path.join(self.savedir, "sample") |
|
os.makedirs(self.savedir, exist_ok=True) |
|
|
|
self.stable_diffusion_list = [] |
|
self.motion_module_list = [] |
|
self.personalized_model_list = [] |
|
|
|
self.refresh_stable_diffusion() |
|
self.refresh_motion_module() |
|
self.refresh_personalized_model() |
|
|
|
|
|
self.tokenizer = None |
|
self.text_encoder = None |
|
self.vae = None |
|
self.unet = None |
|
self.pipeline = None |
|
self.lora_model_state_dict = {} |
|
|
|
self.inference_config = OmegaConf.load("configs/inference/inference.yaml") |
|
|
|
def refresh_stable_diffusion(self): |
|
self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/")) |
|
|
|
def refresh_motion_module(self): |
|
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) |
|
self.motion_module_list = [os.path.basename(p) for p in motion_module_list] |
|
|
|
def refresh_personalized_model(self): |
|
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) |
|
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] |
|
|
|
def update_stable_diffusion(self, stable_diffusion_dropdown): |
|
self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer") |
|
self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda() |
|
self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda() |
|
self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() |
|
return gr.Dropdown.update() |
|
|
|
def update_motion_module(self, motion_module_dropdown): |
|
if self.unet is None: |
|
gr.Info(f"Please select a pretrained model path.") |
|
return gr.Dropdown.update(value=None) |
|
else: |
|
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) |
|
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") |
|
missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) |
|
assert len(unexpected) == 0 |
|
return gr.Dropdown.update() |
|
|
|
def update_base_model(self, base_model_dropdown): |
|
if self.unet is None: |
|
gr.Info(f"Please select a pretrained model path.") |
|
return gr.Dropdown.update(value=None) |
|
else: |
|
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) |
|
base_model_state_dict = {} |
|
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
base_model_state_dict[key] = f.get_tensor(key) |
|
|
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config) |
|
self.vae.load_state_dict(converted_vae_checkpoint) |
|
|
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config) |
|
self.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
|
|
self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) |
|
return gr.Dropdown.update() |
|
|
|
def update_lora_model(self, lora_model_dropdown): |
|
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) |
|
self.lora_model_state_dict = {} |
|
if lora_model_dropdown == "none": pass |
|
else: |
|
with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
self.lora_model_state_dict[key] = f.get_tensor(key) |
|
return gr.Dropdown.update() |
|
|
|
def animate( |
|
self, |
|
stable_diffusion_dropdown, |
|
motion_module_dropdown, |
|
base_model_dropdown, |
|
lora_alpha_slider, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
sampler_dropdown, |
|
sample_step_slider, |
|
width_slider, |
|
length_slider, |
|
height_slider, |
|
cfg_scale_slider, |
|
seed_textbox |
|
): |
|
if self.unet is None: |
|
raise gr.Error(f"Please select a pretrained model path.") |
|
if motion_module_dropdown == "": |
|
raise gr.Error(f"Please select a motion module.") |
|
if base_model_dropdown == "": |
|
raise gr.Error(f"Please select a base DreamBooth model.") |
|
|
|
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() |
|
|
|
pipeline = AnimationPipeline( |
|
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, |
|
scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) |
|
).to("cuda") |
|
|
|
if self.lora_model_state_dict != {}: |
|
pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) |
|
|
|
pipeline.to("cuda") |
|
|
|
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) |
|
else: torch.seed() |
|
seed = torch.initial_seed() |
|
|
|
sample = pipeline( |
|
prompt_textbox, |
|
negative_prompt = negative_prompt_textbox, |
|
num_inference_steps = sample_step_slider, |
|
guidance_scale = cfg_scale_slider, |
|
width = width_slider, |
|
height = height_slider, |
|
video_length = length_slider, |
|
).videos |
|
|
|
save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") |
|
save_videos_grid(sample, save_sample_path) |
|
|
|
sample_config = { |
|
"prompt": prompt_textbox, |
|
"n_prompt": negative_prompt_textbox, |
|
"sampler": sampler_dropdown, |
|
"num_inference_steps": sample_step_slider, |
|
"guidance_scale": cfg_scale_slider, |
|
"width": width_slider, |
|
"height": height_slider, |
|
"video_length": length_slider, |
|
"seed": seed |
|
} |
|
json_str = json.dumps(sample_config, indent=4) |
|
with open(os.path.join(self.savedir, "logs.json"), "a") as f: |
|
f.write(json_str) |
|
f.write("\n\n") |
|
|
|
return gr.Video.update(value=save_sample_path) |
|
|
|
|
|
controller = AnimateController() |
|
|
|
|
|
def ui(): |
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) |
|
Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br> |
|
[Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/) |
|
""" |
|
) |
|
with gr.Column(variant="panel"): |
|
gr.Markdown( |
|
""" |
|
### 1. Model checkpoints (select pretrained model path first). |
|
""" |
|
) |
|
with gr.Row(): |
|
stable_diffusion_dropdown = gr.Dropdown( |
|
label="Pretrained Model Path", |
|
choices=controller.stable_diffusion_list, |
|
interactive=True, |
|
) |
|
stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) |
|
|
|
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") |
|
def update_stable_diffusion(): |
|
controller.refresh_stable_diffusion() |
|
return gr.Dropdown.update(choices=controller.stable_diffusion_list) |
|
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown]) |
|
|
|
with gr.Row(): |
|
motion_module_dropdown = gr.Dropdown( |
|
label="Select motion module", |
|
choices=controller.motion_module_list, |
|
interactive=True, |
|
) |
|
motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown]) |
|
|
|
motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") |
|
def update_motion_module(): |
|
controller.refresh_motion_module() |
|
return gr.Dropdown.update(choices=controller.motion_module_list) |
|
motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown]) |
|
|
|
base_model_dropdown = gr.Dropdown( |
|
label="Select base Dreambooth model (required)", |
|
choices=controller.personalized_model_list, |
|
interactive=True, |
|
) |
|
base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) |
|
|
|
lora_model_dropdown = gr.Dropdown( |
|
label="Select LoRA model (optional)", |
|
choices=["none"] + controller.personalized_model_list, |
|
value="none", |
|
interactive=True, |
|
) |
|
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown]) |
|
|
|
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) |
|
|
|
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") |
|
def update_personalized_model(): |
|
controller.refresh_personalized_model() |
|
return [ |
|
gr.Dropdown.update(choices=controller.personalized_model_list), |
|
gr.Dropdown.update(choices=["none"] + controller.personalized_model_list) |
|
] |
|
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) |
|
|
|
with gr.Column(variant="panel"): |
|
gr.Markdown( |
|
""" |
|
### 2. Configs for AnimateDiff. |
|
""" |
|
) |
|
|
|
prompt_textbox = gr.Textbox(label="Prompt", lines=2) |
|
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2) |
|
|
|
with gr.Row().style(equal_height=False): |
|
with gr.Column(): |
|
with gr.Row(): |
|
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) |
|
sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1) |
|
|
|
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) |
|
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) |
|
length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1) |
|
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20) |
|
|
|
with gr.Row(): |
|
seed_textbox = gr.Textbox(label="Seed", value=-1) |
|
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") |
|
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) |
|
|
|
generate_button = gr.Button(value="Generate", variant='primary') |
|
|
|
result_video = gr.Video(label="Generated Animation", interactive=False) |
|
|
|
generate_button.click( |
|
fn=controller.animate, |
|
inputs=[ |
|
stable_diffusion_dropdown, |
|
motion_module_dropdown, |
|
base_model_dropdown, |
|
lora_alpha_slider, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
sampler_dropdown, |
|
sample_step_slider, |
|
width_slider, |
|
length_slider, |
|
height_slider, |
|
cfg_scale_slider, |
|
seed_textbox, |
|
], |
|
outputs=[result_video] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = ui() |
|
demo.launch(share=True) |
|
|