File size: 6,072 Bytes
bcec73a
 
 
 
10cdcde
bcec73a
10cdcde
bcec73a
c45e94d
 
 
 
 
 
 
 
 
 
 
 
bcec73a
d726718
10cdcde
 
 
 
bcec73a
 
 
 
 
 
 
 
 
 
 
 
 
 
10cdcde
bcec73a
10cdcde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcec73a
10cdcde
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import shutil
import gradio as gr
from PIL import Image
import subprocess
#os.chdir('Restormer')
from demo_dagan import *
# Download sample images
import torch
import torch.nn.functional as F
import os
from skimage import img_as_ubyte
import imageio
from skimage.transform import resize
import numpy as np
import modules.generator as G
import modules.keypoint_detector as KPD
import yaml
from collections import OrderedDict
import depth

examples = [['project/cartoon2.jpg','project/video1.mp4'],
						['project/cartoon3.jpg','project/video2.mp4'],
						['project/celeb1.jpg','project/video1.mp4'],
						['project/celeb2.jpg','project/video2.mp4'],
						]


inference_on = ['Full Resolution Image', 'Downsampled Image']

title = "DaGAN"
description = """
Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n 
"""
##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3)  Motion Deblurring, and (4) Image Deraining. 
##To use it, simply upload your own image, or click one of the examples provided below.

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"


def inference(source_image, video):
    if not os.path.exists('temp'):
        os.system('mkdir temp')
    cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4"
    subprocess.run(cmd.split())
    driving_video = "video_input.mp4"
    output = "rst.mp4"
    with open("config/vox-adv-256.yaml") as f:
        config = yaml.load(f)
    generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
    config['model_params']['common_params']['num_channels'] = 4
    kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    g_checkpoint = torch.load("generator.pt", map_location=device)
    kp_checkpoint = torch.load("kp_detector.pt", map_location=device)

    ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
    generator.load_state_dict(ckp_generator)
    ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
    kp_detector.load_state_dict(ckp_kp_detector)

    depth_encoder = depth.ResnetEncoder(18, False)
    depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
    loaded_dict_enc = torch.load('encoder.pth')
    loaded_dict_dec = torch.load('depth.pth')
    filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
    depth_encoder.load_state_dict(filtered_dict_enc)
    ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
    depth_decoder.load_state_dict(ckp_depth_decoder)
    depth_encoder.eval()
    depth_decoder.eval()
            
    # device = torch.device('cpu')
    # stx()

    generator = generator.to(device)
    kp_detector = kp_detector.to(device)
    depth_encoder = depth_encoder.to(device)
    depth_decoder = depth_decoder.to(device)

    generator.eval()
    kp_detector.eval()
    depth_encoder.eval()
    depth_decoder.eval()

    img_multiple_of = 8

    with torch.inference_mode():
        if torch.cuda.is_available():
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
        source_image = imageio.imread(source_image)
        reader = imageio.get_reader(driving_video)
        fps = reader.get_meta_data()['fps']
        driving_video = []
        try:
            for im in reader:
                driving_video.append(im)
        except RuntimeError:
            pass
        reader.close()

        source_image = resize(source_image, (256, 256))[..., :3]
        driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]



        i = find_best_frame(source_image, driving_video)
        print ("Best frame: " + str(i))
        driving_forward = driving_video[i:]
        driving_backward = driving_video[:(i+1)][::-1]
        sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
        sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
        predictions = predictions_backward[::-1] + predictions_forward[1:]
        sources = sources_backward[::-1] + sources_forward[1:]
        drivings = drivings_backward[::-1] + drivings_forward[1:]
        depth_gray = depth_backward[::-1] + depth_forward[1:]

        imageio.mimsave(output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
        imageio.mimsave("gray.mp4", depth_gray, fps=fps)
        # merge the gray video
        animation = np.array(imageio.mimread(output,memtest=False))
        gray = np.array(imageio.mimread("gray.mp4",memtest=False))

        src_dst = animation[:,:,:512,:]
        animate = animation[:,:,512:,:]
        merge = np.concatenate((src_dst,gray,animate),2)
        imageio.mimsave(output, merge, fps=fps)

    return output
		
gr.Interface(
		inference,
		[
				gr.inputs.Image(type="filepath", label="Source Image"),
				gr.inputs.Video(type='mp4',label="Driving Video"),
		],
		gr.outputs.Video(type="mp4", label="Output Video"),
		title=title,
		description=description,
		article=article,
		theme ="huggingface",
		examples=examples,
		allow_flagging=False,
		).launch(debug=False,enable_queue=True)