from __future__ import annotations import datetime import os import pathlib import shlex import shutil import subprocess import sys import gradio as gr import slugify import torch import huggingface_hub from huggingface_hub import HfApi from omegaconf import OmegaConf ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero' SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) class Runner: def __init__(self, hf_token: str | None = None): self.hf_token = hf_token self.checkpoint_dir = pathlib.Path('checkpoints') self.checkpoint_dir.mkdir(exist_ok=True) def download_base_model(self, base_model_id: str, token=None) -> str: model_dir = self.checkpoint_dir / base_model_id if not model_dir.exists(): org_name = base_model_id.split('/')[0] org_dir = self.checkpoint_dir / org_name org_dir.mkdir(exist_ok=True) print(f'https://huggingface.co/{base_model_id}') if token == None: subprocess.run(shlex.split( f'git clone https://huggingface.co/{base_model_id}'), cwd=org_dir) return model_dir.as_posix() else: temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token) print(temp_path, org_dir) # subprocess.run(shlex.split(f'mv {temp_path} {model_dir.as_posix()}')) # return model_dir.as_posix() return temp_path def join_model_library_org(self, token: str) -> None: subprocess.run( shlex.split( f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}' )) def run_vid2vid_zero( self, model_path: str, input_video: str, prompt: str, n_sample_frames: int, sample_start_idx: int, sample_frame_rate: int, validation_prompt: str, guidance_scale: float, resolution: str, seed: int, remove_gpu_after_running: bool, input_token: str = None, ) -> str: if not torch.cuda.is_available(): raise gr.Error('CUDA is not available.') if input_video is None: raise gr.Error('You need to upload a video.') if not prompt: raise gr.Error('The input prompt is missing.') if not validation_prompt: raise gr.Error('The validation prompt is missing.') resolution = int(resolution) n_sample_frames = int(n_sample_frames) sample_start_idx = int(sample_start_idx) sample_frame_rate = int(sample_frame_rate) repo_dir = pathlib.Path(__file__).parent prompt_path = prompt.replace(' ', '_') output_dir = repo_dir / 'outputs' / prompt_path output_dir.mkdir(parents=True, exist_ok=True) config = OmegaConf.load('configs/black-swan.yaml') # config.pretrained_model_path = self.download_base_model(model_path, token=input_token) config.pretrained_model_path = "checkpoints/stable-diffusion-v1-4" # TODO config.output_dir = output_dir.as_posix() config.input_data.video_path = input_video.name # type: ignore config.input_data.prompt = prompt config.input_data.n_sample_frames = n_sample_frames config.input_data.width = resolution config.input_data.height = resolution config.input_data.sample_start_idx = sample_start_idx config.input_data.sample_frame_rate = sample_frame_rate config.validation_data.prompts = [validation_prompt] config.validation_data.video_length = 8 config.validation_data.width = resolution config.validation_data.height = resolution config.validation_data.num_inference_steps = 50 config.validation_data.guidance_scale = guidance_scale config.input_batch_size = 1 config.seed = seed config_path = output_dir / 'config.yaml' with open(config_path, 'w') as f: OmegaConf.save(config, f) command = f'accelerate launch test_vid2vid_zero.py --config {config_path}' subprocess.run(shlex.split(command)) output_video_path = os.path.join(output_dir, "sample-all.mp4") print(f"video path for gradio: {output_video_path}") message = 'Running completed!' print(message) if remove_gpu_after_running: space_id = os.getenv('SPACE_ID') if space_id: api = HfApi( token=self.hf_token if self.hf_token else input_token) api.request_space_hardware(repo_id=space_id, hardware='cpu-basic') return output_video_path