Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
import spaces | |
import gradio as gr | |
import random | |
import tempfile | |
import time | |
from easydict import EasyDict | |
import numpy as np | |
import torch | |
from dav.pipelines import DAVPipeline | |
from dav.models import UNetSpatioTemporalRopeConditionModel | |
from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler | |
from dav.utils import img_utils | |
def seed_all(seed: int = 0): | |
""" | |
Set random seeds for reproducibility. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
examples = [ | |
["demos/wooly_mammoth.mp4", 3, 32, 16, 16, 6, 960], | |
] | |
def load_models(model_base, device): | |
vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae") | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
model_base, subfolder="scheduler" | |
) | |
unet = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
model_base, subfolder="unet" | |
) | |
unet_interp = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
model_base, subfolder="unet_interp" | |
) | |
pipe = DAVPipeline( | |
vae=vae, | |
unet=unet, | |
unet_interp=unet_interp, | |
scheduler=scheduler, | |
) | |
pipe = pipe.to(device) | |
return pipe | |
model_base = "hhyangcs/depth-any-video" | |
device_type = "cuda" | |
device = torch.device(device_type) | |
pipe = load_models(model_base, device) | |
def infer_depth( | |
file: str, | |
denoise_steps: int = 3, | |
num_frames: int = 32, | |
decode_chunk_size: int = 16, | |
num_interp_frames: int = 16, | |
num_overlap_frames: int = 6, | |
max_resolution: int = 1024, | |
seed: int = 66, | |
output_dir: str = "./outputs", | |
): | |
seed_all(seed) | |
max_frames = (num_interp_frames + 2 - num_overlap_frames) * (num_frames // 2) | |
image, fps = img_utils.read_video(file, max_frames=max_frames) | |
image = img_utils.imresize_max(image, max_resolution) | |
image = img_utils.imcrop_multi(image) | |
image_tensor = np.ascontiguousarray( | |
[_img.transpose(2, 0, 1) / 255.0 for _img in image] | |
) | |
image_tensor = torch.from_numpy(image_tensor).to(device) | |
print(f"==> video name: {file}, frames shape: {image_tensor.shape}") | |
with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.float16): | |
pipe_out = pipe( | |
image_tensor, | |
num_frames=num_frames, | |
num_overlap_frames=num_overlap_frames, | |
num_interp_frames=num_interp_frames, | |
decode_chunk_size=decode_chunk_size, | |
num_inference_steps=denoise_steps, | |
) | |
disparity = pipe_out.disparity | |
disparity_colored = pipe_out.disparity_colored | |
image = pipe_out.image | |
# (N, H, 2 * W, 3) | |
merged = np.concatenate( | |
[ | |
image, | |
disparity_colored, | |
], | |
axis=2, | |
) | |
file_name = os.path.splitext(os.path.basename(file))[0] | |
os.makedirs(output_dir, exist_ok=True) | |
output_path = os.path.join(output_dir, f"{file_name}_depth.mp4") | |
img_utils.write_video( | |
output_path, | |
merged, | |
fps, | |
) | |
# clear the cache for the next video | |
gc.collect() | |
torch.cuda.empty_cache() | |
return output_path | |
def construct_demo(): | |
with gr.Blocks(analytics_enabled=False) as depthanyvideo_iface: | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
input_video = gr.Video(label="Input Video") | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=True): | |
output_video = gr.Video( | |
label="Ouput Video Depth", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
scale=1, | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=False): | |
with gr.Accordion("Advanced Settings", open=False): | |
denoise_steps = gr.Slider( | |
label="Denoise Steps", | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
) | |
num_frames = gr.Slider( | |
label="Number of Key Frames", | |
minimum=16, | |
maximum=32, | |
value=24, | |
step=2, | |
) | |
decode_chunk_size = gr.Slider( | |
label="Decode Chunk Size", | |
minimum=8, | |
maximum=32, | |
value=16, | |
step=1, | |
) | |
num_interp_frames = gr.Slider( | |
label="Number of Interpolation Frames", | |
minimum=8, | |
maximum=32, | |
value=16, | |
step=1, | |
) | |
num_overlap_frames = gr.Slider( | |
label="Number of Overlap Frames", | |
minimum=2, | |
maximum=10, | |
value=6, | |
step=1, | |
) | |
max_resolution = gr.Slider( | |
label="Maximum Resolution", | |
minimum=512, | |
maximum=2048, | |
value=1024, | |
step=32, | |
) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(scale=2): | |
pass | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
input_video, | |
denoise_steps, | |
num_frames, | |
decode_chunk_size, | |
num_interp_frames, | |
num_overlap_frames, | |
max_resolution, | |
], | |
outputs=output_video, | |
fn=infer_depth, | |
cache_examples="lazy", | |
) | |
generate_btn.click( | |
fn=infer_depth, | |
inputs=[ | |
input_video, | |
denoise_steps, | |
num_frames, | |
decode_chunk_size, | |
num_interp_frames, | |
num_overlap_frames, | |
max_resolution, | |
], | |
outputs=output_video, | |
) | |
return depthanyvideo_iface | |
demo = construct_demo() | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch(share=True) | |