DaGAN / app.py
harlanhong
first
c45e94d
raw
history blame
6.07 kB
import os
import shutil
import gradio as gr
from PIL import Image
import subprocess
#os.chdir('Restormer')
from demo_dagan import *
# Download sample images
import torch
import torch.nn.functional as F
import os
from skimage import img_as_ubyte
import imageio
from skimage.transform import resize
import numpy as np
import modules.generator as G
import modules.keypoint_detector as KPD
import yaml
from collections import OrderedDict
import depth
examples = [['project/cartoon2.jpg','project/video1.mp4'],
['project/cartoon3.jpg','project/video2.mp4'],
['project/celeb1.jpg','project/video1.mp4'],
['project/celeb2.jpg','project/video2.mp4'],
]
inference_on = ['Full Resolution Image', 'Downsampled Image']
title = "DaGAN"
description = """
Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n
"""
##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3) Motion Deblurring, and (4) Image Deraining.
##To use it, simply upload your own image, or click one of the examples provided below.
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
def inference(source_image, video):
if not os.path.exists('temp'):
os.system('mkdir temp')
cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4"
subprocess.run(cmd.split())
driving_video = "video_input.mp4"
output = "rst.mp4"
with open("config/vox-adv-256.yaml") as f:
config = yaml.load(f)
generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
config['model_params']['common_params']['num_channels'] = 4
kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g_checkpoint = torch.load("generator.pt", map_location=device)
kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
generator.load_state_dict(ckp_generator)
ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
kp_detector.load_state_dict(ckp_kp_detector)
depth_encoder = depth.ResnetEncoder(18, False)
depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
loaded_dict_enc = torch.load('encoder.pth')
loaded_dict_dec = torch.load('depth.pth')
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
depth_encoder.load_state_dict(filtered_dict_enc)
ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
depth_decoder.load_state_dict(ckp_depth_decoder)
depth_encoder.eval()
depth_decoder.eval()
# device = torch.device('cpu')
# stx()
generator = generator.to(device)
kp_detector = kp_detector.to(device)
depth_encoder = depth_encoder.to(device)
depth_decoder = depth_decoder.to(device)
generator.eval()
kp_detector.eval()
depth_encoder.eval()
depth_decoder.eval()
img_multiple_of = 8
with torch.inference_mode():
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
source_image = imageio.imread(source_image)
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
i = find_best_frame(source_image, driving_video)
print ("Best frame: " + str(i))
driving_forward = driving_video[i:]
driving_backward = driving_video[:(i+1)][::-1]
sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
predictions = predictions_backward[::-1] + predictions_forward[1:]
sources = sources_backward[::-1] + sources_forward[1:]
drivings = drivings_backward[::-1] + drivings_forward[1:]
depth_gray = depth_backward[::-1] + depth_forward[1:]
imageio.mimsave(output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
imageio.mimsave("gray.mp4", depth_gray, fps=fps)
# merge the gray video
animation = np.array(imageio.mimread(output,memtest=False))
gray = np.array(imageio.mimread("gray.mp4",memtest=False))
src_dst = animation[:,:,:512,:]
animate = animation[:,:,512:,:]
merge = np.concatenate((src_dst,gray,animate),2)
imageio.mimsave(output, merge, fps=fps)
return output
gr.Interface(
inference,
[
gr.inputs.Image(type="filepath", label="Source Image"),
gr.inputs.Video(type='mp4',label="Driving Video"),
],
gr.outputs.Video(type="mp4", label="Output Video"),
title=title,
description=description,
article=article,
theme ="huggingface",
examples=examples,
allow_flagging=False,
).launch(debug=False,enable_queue=True)