harlanhong commited on
Commit
10cdcde
1 Parent(s): 1365072
Files changed (2) hide show
  1. app.py +111 -37
  2. demo_dagan.py +83 -84
app.py CHANGED
@@ -2,17 +2,17 @@ import os
2
  import shutil
3
  import gradio as gr
4
  from PIL import Image
5
-
6
  #os.chdir('Restormer')
7
-
8
  # Download sample images
9
 
10
 
11
  examples = [['project/cartoon2.jpg','project/video1.mp4'],
12
- ['project/cartoon3.jpg','project/video2.mp4'],
13
- ['project/celeb1.jpg','project/video1.mp4'],
14
- ['project/celeb2.jpg','project/video2.mp4'],
15
- ]
16
 
17
 
18
  inference_on = ['Full Resolution Image', 'Downsampled Image']
@@ -27,36 +27,110 @@ Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head V
27
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
28
 
29
 
30
- def inference(img, video):
31
  if not os.path.exists('temp'):
32
- os.system('mkdir temp')
33
-
34
- #### Resize the longer edge of the input image
35
- max_res = 256
36
- width, height = img.size
37
- if max(width,height) > max_res:
38
- scale = max_res /max(width,height)
39
- width = int(scale*width)
40
- height = int(scale*height)
41
- img = img.resize((width,height), Image.ANTIALIAS)
42
-
43
- img.save("temp/image.jpg", "JPEG")
44
- video.save('temp/video.mp4')
45
- os.system("python demo_dagan.py --source_image 'temp/image.jpg' --driving_video 'temp/video.mp4/ --output 'temp/rst.mp4'")
46
-
47
- return f'temp/rst.mp4'
48
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  gr.Interface(
50
- inference,
51
- [
52
- gr.inputs.Image(type="filepath", label="Source Image"),
53
- gr.inputs.Video(type='mp4',label="Driving Video"),
54
- ],
55
- gr.outputs.Video(type="mp4", label="Output Video"),
56
- title=title,
57
- description=description,
58
- article=article,
59
- theme ="huggingface",
60
- examples=examples,
61
- allow_flagging=False,
62
- ).launch(debug=False,enable_queue=True)
 
2
  import shutil
3
  import gradio as gr
4
  from PIL import Image
5
+ import subprocess
6
  #os.chdir('Restormer')
7
+ from demo_dagan import *
8
  # Download sample images
9
 
10
 
11
  examples = [['project/cartoon2.jpg','project/video1.mp4'],
12
+ ['project/cartoon3.jpg','project/video2.mp4'],
13
+ ['project/celeb1.jpg','project/video1.mp4'],
14
+ ['project/celeb2.jpg','project/video2.mp4'],
15
+ ]
16
 
17
 
18
  inference_on = ['Full Resolution Image', 'Downsampled Image']
 
27
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
28
 
29
 
30
+ def inference(source_image, video):
31
  if not os.path.exists('temp'):
32
+ os.system('mkdir temp')
33
+ cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4"
34
+ subprocess.run(cmd.split())
35
+ driving_video = "video_input.mp4"
36
+ output = "rst.mp4"
37
+ with open("config/vox-adv-256.yaml") as f:
38
+ config = yaml.load(f)
39
+ generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
40
+ config['model_params']['common_params']['num_channels'] = 4
41
+ kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+
44
+
45
+ g_checkpoint = torch.load("generator.pt", map_location=device)
46
+ kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
47
+
48
+ ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
49
+ generator.load_state_dict(ckp_generator)
50
+ ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
51
+ kp_detector.load_state_dict(ckp_kp_detector)
52
+
53
+ depth_encoder = depth.ResnetEncoder(18, False)
54
+ depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
55
+ loaded_dict_enc = torch.load('encoder.pth')
56
+ loaded_dict_dec = torch.load('depth.pth')
57
+ filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
58
+ depth_encoder.load_state_dict(filtered_dict_enc)
59
+ ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
60
+ depth_decoder.load_state_dict(ckp_depth_decoder)
61
+ depth_encoder.eval()
62
+ depth_decoder.eval()
63
+
64
+ # device = torch.device('cpu')
65
+ # stx()
66
+
67
+ generator = generator.to(device)
68
+ kp_detector = kp_detector.to(device)
69
+ depth_encoder = depth_encoder.to(device)
70
+ depth_decoder = depth_decoder.to(device)
71
+
72
+ generator.eval()
73
+ kp_detector.eval()
74
+ depth_encoder.eval()
75
+ depth_decoder.eval()
76
+
77
+ img_multiple_of = 8
78
+
79
+ with torch.inference_mode():
80
+ if torch.cuda.is_available():
81
+ torch.cuda.ipc_collect()
82
+ torch.cuda.empty_cache()
83
+ source_image = imageio.imread(source_image)
84
+ reader = imageio.get_reader(driving_video)
85
+ fps = reader.get_meta_data()['fps']
86
+ driving_video = []
87
+ try:
88
+ for im in reader:
89
+ driving_video.append(im)
90
+ except RuntimeError:
91
+ pass
92
+ reader.close()
93
+
94
+ source_image = resize(source_image, (256, 256))[..., :3]
95
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
96
+
97
+
98
+
99
+ i = find_best_frame(source_image, driving_video)
100
+ print ("Best frame: " + str(i))
101
+ driving_forward = driving_video[i:]
102
+ driving_backward = driving_video[:(i+1)][::-1]
103
+ sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
104
+ sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
105
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
106
+ sources = sources_backward[::-1] + sources_forward[1:]
107
+ drivings = drivings_backward[::-1] + drivings_forward[1:]
108
+ depth_gray = depth_backward[::-1] + depth_forward[1:]
109
+
110
+ imageio.mimsave(output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
111
+ imageio.mimsave("gray.mp4", depth_gray, fps=fps)
112
+ # merge the gray video
113
+ animation = np.array(imageio.mimread(output,memtest=False))
114
+ gray = np.array(imageio.mimread("gray.mp4",memtest=False))
115
+
116
+ src_dst = animation[:,:,:512,:]
117
+ animate = animation[:,:,512:,:]
118
+ merge = np.concatenate((src_dst,gray,animate),2)
119
+ imageio.mimsave(output, merge, fps=fps)
120
+
121
+ return output
122
+
123
  gr.Interface(
124
+ inference,
125
+ [
126
+ gr.inputs.Image(type="filepath", label="Source Image"),
127
+ gr.inputs.Video(type='mp4',label="Driving Video"),
128
+ ],
129
+ gr.outputs.Video(type="mp4", label="Output Video"),
130
+ title=title,
131
+ description=description,
132
+ article=article,
133
+ theme ="huggingface",
134
+ examples=examples,
135
+ allow_flagging=False,
136
+ ).launch(debug=False,enable_queue=True)
demo_dagan.py CHANGED
@@ -25,7 +25,7 @@ parser.add_argument('--driving_video', default='./temp/driving.mp4', type=str, h
25
  parser.add_argument('--output', default='./temp/result.mp4', type=str, help='Directory for driving video')
26
 
27
 
28
- args = parser.parse_args()
29
  def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
30
  use_relative_movement=False, use_relative_jacobian=False):
31
  if adapt_movement_scale:
@@ -71,7 +71,6 @@ def find_best_frame(source, driving, cpu=False):
71
  frame_num = i
72
  return frame_num
73
 
74
-
75
  def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
76
  sources = []
77
  drivings = []
@@ -121,88 +120,88 @@ def make_animation(source_image, driving_video, generator, kp_detector, relative
121
  predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
122
  depth_gray.append(gray_driving)
123
  return sources, drivings, predictions,depth_gray
124
- with open("config/vox-adv-256.yaml") as f:
125
- config = yaml.load(f)
126
- generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
127
- config['model_params']['common_params']['num_channels'] = 4
128
- kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
129
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
130
-
131
-
132
- g_checkpoint = torch.load("generator.pt", map_location=device)
133
- kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
134
-
135
- ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
136
- generator.load_state_dict(ckp_generator)
137
- ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
138
- kp_detector.load_state_dict(ckp_kp_detector)
139
-
140
- depth_encoder = depth.ResnetEncoder(18, False)
141
- depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
142
- loaded_dict_enc = torch.load('encoder.pth')
143
- loaded_dict_dec = torch.load('depth.pth')
144
- filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
145
- depth_encoder.load_state_dict(filtered_dict_enc)
146
- ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
147
- depth_decoder.load_state_dict(ckp_depth_decoder)
148
- depth_encoder.eval()
149
- depth_decoder.eval()
150
 
151
- # device = torch.device('cpu')
152
- # stx()
153
-
154
- generator = generator.to(device)
155
- kp_detector = kp_detector.to(device)
156
- depth_encoder = depth_encoder.to(device)
157
- depth_decoder = depth_decoder.to(device)
158
-
159
- generator.eval()
160
- kp_detector.eval()
161
- depth_encoder.eval()
162
- depth_decoder.eval()
163
-
164
- img_multiple_of = 8
165
-
166
- with torch.inference_mode():
167
- if torch.cuda.is_available():
168
- torch.cuda.ipc_collect()
169
- torch.cuda.empty_cache()
170
- source_image = imageio.imread(args.source_image)
171
- reader = imageio.get_reader(args.driving_video)
172
- fps = reader.get_meta_data()['fps']
173
- driving_video = []
174
- try:
175
- for im in reader:
176
- driving_video.append(im)
177
- except RuntimeError:
178
- pass
179
- reader.close()
180
-
181
- source_image = resize(source_image, (256, 256))[..., :3]
182
- driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
183
-
184
-
185
-
186
- i = find_best_frame(source_image, driving_video)
187
- print ("Best frame: " + str(i))
188
- driving_forward = driving_video[i:]
189
- driving_backward = driving_video[:(i+1)][::-1]
190
- sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
191
- sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
192
- predictions = predictions_backward[::-1] + predictions_forward[1:]
193
- sources = sources_backward[::-1] + sources_forward[1:]
194
- drivings = drivings_backward[::-1] + drivings_forward[1:]
195
- depth_gray = depth_backward[::-1] + depth_forward[1:]
196
-
197
- imageio.mimsave(args.output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
198
- imageio.mimsave("gray.mp4", depth_gray, fps=fps)
199
- # merge the gray video
200
- animation = np.array(imageio.mimread(args.output,memtest=False))
201
- gray = np.array(imageio.mimread("gray.mp4",memtest=False))
202
-
203
- src_dst = animation[:,:,:512,:]
204
- animate = animation[:,:,512:,:]
205
- merge = np.concatenate((src_dst,gray,animate),2)
206
- imageio.mimsave(args.output, merge, fps=fps)
207
 
208
  # print(f"\nRestored images are saved at {out_dir}")
 
25
  parser.add_argument('--output', default='./temp/result.mp4', type=str, help='Directory for driving video')
26
 
27
 
28
+ # args = parser.parse_args()
29
  def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
30
  use_relative_movement=False, use_relative_jacobian=False):
31
  if adapt_movement_scale:
 
71
  frame_num = i
72
  return frame_num
73
 
 
74
  def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
75
  sources = []
76
  drivings = []
 
120
  predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
121
  depth_gray.append(gray_driving)
122
  return sources, drivings, predictions,depth_gray
123
+ # with open("config/vox-adv-256.yaml") as f:
124
+ # config = yaml.load(f)
125
+ # generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
126
+ # config['model_params']['common_params']['num_channels'] = 4
127
+ # kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
128
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
129
+
130
+
131
+ # g_checkpoint = torch.load("generator.pt", map_location=device)
132
+ # kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
133
+
134
+ # ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
135
+ # generator.load_state_dict(ckp_generator)
136
+ # ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
137
+ # kp_detector.load_state_dict(ckp_kp_detector)
138
+
139
+ # depth_encoder = depth.ResnetEncoder(18, False)
140
+ # depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
141
+ # loaded_dict_enc = torch.load('encoder.pth')
142
+ # loaded_dict_dec = torch.load('depth.pth')
143
+ # filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
144
+ # depth_encoder.load_state_dict(filtered_dict_enc)
145
+ # ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
146
+ # depth_decoder.load_state_dict(ckp_depth_decoder)
147
+ # depth_encoder.eval()
148
+ # depth_decoder.eval()
149
 
150
+ # # device = torch.device('cpu')
151
+ # # stx()
152
+
153
+ # generator = generator.to(device)
154
+ # kp_detector = kp_detector.to(device)
155
+ # depth_encoder = depth_encoder.to(device)
156
+ # depth_decoder = depth_decoder.to(device)
157
+
158
+ # generator.eval()
159
+ # kp_detector.eval()
160
+ # depth_encoder.eval()
161
+ # depth_decoder.eval()
162
+
163
+ # img_multiple_of = 8
164
+
165
+ # with torch.inference_mode():
166
+ # if torch.cuda.is_available():
167
+ # torch.cuda.ipc_collect()
168
+ # torch.cuda.empty_cache()
169
+ # source_image = imageio.imread(args.source_image)
170
+ # reader = imageio.get_reader(args.driving_video)
171
+ # fps = reader.get_meta_data()['fps']
172
+ # driving_video = []
173
+ # try:
174
+ # for im in reader:
175
+ # driving_video.append(im)
176
+ # except RuntimeError:
177
+ # pass
178
+ # reader.close()
179
+
180
+ # source_image = resize(source_image, (256, 256))[..., :3]
181
+ # driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
182
+
183
+
184
+
185
+ # i = find_best_frame(source_image, driving_video)
186
+ # print ("Best frame: " + str(i))
187
+ # driving_forward = driving_video[i:]
188
+ # driving_backward = driving_video[:(i+1)][::-1]
189
+ # sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
190
+ # sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
191
+ # predictions = predictions_backward[::-1] + predictions_forward[1:]
192
+ # sources = sources_backward[::-1] + sources_forward[1:]
193
+ # drivings = drivings_backward[::-1] + drivings_forward[1:]
194
+ # depth_gray = depth_backward[::-1] + depth_forward[1:]
195
+
196
+ # imageio.mimsave(args.output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
197
+ # imageio.mimsave("gray.mp4", depth_gray, fps=fps)
198
+ # # merge the gray video
199
+ # animation = np.array(imageio.mimread(args.output,memtest=False))
200
+ # gray = np.array(imageio.mimread("gray.mp4",memtest=False))
201
+
202
+ # src_dst = animation[:,:,:512,:]
203
+ # animate = animation[:,:,512:,:]
204
+ # merge = np.concatenate((src_dst,gray,animate),2)
205
+ # imageio.mimsave(args.output, merge, fps=fps)
206
 
207
  # print(f"\nRestored images are saved at {out_dir}")