vinthony commited on
Commit
2299694
1 Parent(s): ff4c585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +382 -67
app.py CHANGED
@@ -1,72 +1,387 @@
1
- import os, sys
2
- import tempfile
3
- import gradio as gr
4
- from modules.text2speech import text2speech
5
- from modules.gfpgan_inference import gfpgan
6
- from modules.sadtalker_test import SadTalker
7
-
8
- def get_driven_audio(audio):
9
- if os.path.isfile(audio):
10
- return audio
11
- else:
12
- save_path = tempfile.NamedTemporaryFile(
13
- delete=False,
14
- suffix=("." + "wav"),
15
- )
16
- gen_audio = text2speech(audio, save_path.name)
17
- return gen_audio, gen_audio
18
-
19
- def get_source_image(image):
20
- return image
21
-
22
- def sadtalker_demo(result_dir):
23
-
24
- sad_talker = SadTalker()
25
- with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
26
- with gr.Row().style(equal_height=False):
27
- with gr.Column(variant='panel'):
28
- with gr.Tabs(elem_id="sadtalker_source_image"):
29
- with gr.TabItem('Upload image'):
30
- with gr.Row():
31
- source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
32
-
33
- with gr.Tabs(elem_id="sadtalker_driven_audio"):
34
- with gr.TabItem('Upload audio'):
35
- with gr.Column(variant='panel'):
36
- driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
37
- # submit_audio_1 = gr.Button('Submit', variant='primary')
38
- # submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
39
-
40
- with gr.Tabs(elem_id="sadtalker_checkbox"):
41
- with gr.TabItem('Settings'):
42
- with gr.Column(variant='panel'):
43
- is_still_mode = gr.Checkbox(label="w/ Still Mode")
44
- enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
45
-
46
- with gr.Column(variant='panel'):
47
- gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
48
- gen_text = gr.Textbox(visible=False)
49
- submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
50
-
51
- submit.click(
52
- fn=sad_talker.test,
53
- inputs=[source_image,
54
- driven_audio,
55
- is_still_mode,
56
- enhancer,
57
- gr.Textbox(value=result_dir, visible=False)],
58
- outputs=[gen_video, gen_text]
59
- )
60
-
61
- return sadtalker_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- if __name__ == "__main__":
65
 
66
- current_code_path = sys.argv[0]
67
- current_root_dir = os.path.split(current_code_path)[0]
68
- sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
69
- demo = sadtalker_demo(sadtalker_result_dir)
70
- demo.launch()
71
 
 
72
 
 
 
 
 
 
 
1
+ import pickle
2
+ import time
3
+ import numpy as np
4
+ import scipy, cv2, os, sys, argparse
5
+ from tqdm import tqdm
6
+ import torch
7
+ import librosa
8
+ from networks import define_G
9
+ from pcavs.config.AudioConfig import AudioConfig
10
+
11
+ sys.path.append('spectre')
12
+ from config import cfg as spectre_cfg
13
+ from src.spectre import SPECTRE
14
+
15
+ from audio2mesh_helper import *
16
+ from pcavs.models import create_model, networks
17
+
18
+ torch.manual_seed(0)
19
+ from scipy.signal import savgol_filter
20
+
21
+
22
+ class SimpleWrapperV2(nn.Module):
23
+ def __init__(self, cfg, use_ref=True, exp_dim=53, noload=False) -> None:
24
+ super().__init__()
25
+
26
+ self.audio_encoder = networks.define_A_sync(cfg)
27
+
28
+ self.mapping1 = nn.Linear(512+exp_dim, exp_dim)
29
+ nn.init.constant_(self.mapping1.weight, 0.)
30
+ nn.init.constant_(self.mapping1.bias, 0.)
31
+ self.use_ref = use_ref
32
+
33
+ def forward(self, x, ref, use_tanh=False):
34
+ x = self.audio_encoder.forward_feature(x).view(x.size(0), -1)
35
+ ref_reshape = ref.reshape(x.size(0), -1) #20, -1
36
+
37
+ y = self.mapping1(torch.cat([x, ref_reshape], dim=1))
38
+
39
+ if self.use_ref:
40
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial
41
+ else:
42
+ out = y.reshape(ref.shape[0], ref.shape[1], -1)
43
+
44
+ if use_tanh:
45
+ out[:, :50] = torch.tanh(out[:, :50]) * 3
46
+
47
+ return out
48
+
49
+ class Audio2Mesh(object):
50
+ def __init__(self, args) -> None:
51
+ self.args = args
52
+
53
+ spectre_cfg.model.use_tex = True
54
+ spectre_cfg.model.mask_type = args.mask_type
55
+ spectre_cfg.debug = self.args.debug
56
+ spectre_cfg.model.netA_sync = 'ressesync'
57
+ spectre_cfg.model.gpu_ids = [0]
58
+
59
+ self.spectre = SPECTRE(spectre_cfg)
60
+ self.spectre.eval()
61
+ self.face_tracker = None #FaceTrackerV2() # face landmark detection
62
+ self.mel_step_size = 16
63
+ self.fps = args.fps
64
+ self.Nw = args.tframes
65
+ self.device = self.args.device
66
+ self.image_size = self.args.image_size
67
+
68
+ ### only audio
69
+ args.netA_sync = 'ressesync'
70
+ args.gpu_ids = [0]
71
+ args.exp_dim = 53
72
+ args.use_tanh = False
73
+ args.K = 20
74
+
75
+ self.audio2exp = 'pcavs'
76
+
77
+ #
78
+ self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda()
79
+ self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt'])
80
+
81
+ # 5, 160 = 25fps
82
+ self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160)
83
+
84
+ with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ?
85
+ self.fitting_coeffs = pickle.load(f, encoding='bytes')
86
+
87
+ self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']}
88
+
89
+ #### find the close month
90
+ exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1)
91
+ ssss, sorted_indices = torch.sort(exp_tensors)
92
+ self.exp_id = sorted_indices[0].item()
93
+
94
+ if '.ts' in args.render_path:
95
+ self.render = torch.jit.load(args.render_path).cuda()
96
+ self.trt = True
97
+ else:
98
+ self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda()
99
+ self.render.load_state_dict(torch.load(args.render_path))
100
+ self.trt = False
101
+
102
+ print('loaded cached images...')
103
+
104
+ @torch.no_grad()
105
+ def cg2real(self, rendedimages, start_frame=0):
106
+
107
+ ## load original image and the mask
108
+ self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\
109
+ resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
110
+ self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\
111
+ resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
112
+
113
+ self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.)
114
+ self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.)
115
+
116
+ ## padding the rended_imgs
117
+ paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous()
118
+ paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous()
119
+ paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous()
120
+
121
+ # paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask)
122
+ padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor)
123
+ padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
124
+ paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
125
+ paded_tensor = (paded_tensor-0.5)*2
126
+
127
+ result = []
128
+ for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'):
129
+ list_A = []
130
+ list_R = []
131
+ list_M = []
132
+ for i in range(self.args.renderbs):
133
+ idx = index + i
134
+ if idx+self.Nw > len(padded_input):
135
+ list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
136
+ list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
137
+ list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
138
+ else:
139
+ list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
140
+ list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
141
+ list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
142
+
143
+ list_A = torch.cat(list_A)
144
+ list_R = torch.cat(list_R)
145
+ list_M = torch.cat(list_M)
146
+
147
+ idx = (self.Nw//2) * 3
148
+ mask = list_M[:, idx:idx+3]
149
+
150
+ # list_A = padded_input
151
+ mask = maskErosion(mask, offY=self.args.mask)
152
+ list_A = list_A * (1 - mask[:,0:1])
153
+ A = torch.cat([list_A, list_R], 1)
154
+
155
+ if self.trt:
156
+ B = self.render(A.half().cuda())
157
+ elif self.args.netR == 'unet_256':
158
+ # import pdb; pdb.set_trace()
159
+ idx = (self.Nw//2) * 3
160
+ mask = list_M[:, idx:idx+3].cuda()
161
+ mask = maskErosion(mask, offY=self.args.mask)
162
+ B0 = list_A[:, idx:idx+3].cuda()
163
+ B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
164
+ elif self.args.netR == 's2am':
165
+ # import pdb; pdb.set_trace()
166
+ idx = (self.Nw//2) * 3
167
+ mask = list_M[:, idx:idx+3].cuda()
168
+ mask = maskErosion(mask, offY=self.args.mask)
169
+ B0 = list_A[:, idx:idx+3].cuda()
170
+ B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
171
+ else:
172
+ B = self.render(A.cuda())
173
+
174
+ result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1
175
+
176
+ return torch.cat(result)[:len(rendedimages)]
177
+
178
+ @torch.no_grad()
179
+ def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20):
180
+
181
+ xlen = vertices.shape[0]
182
+ all_shape_images = []
183
+ landmark2d = []
184
+
185
+ #### find the most larger pose 51 in the coeffs.
186
+ max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1))
187
+
188
+ for i in tqdm(range(0, xlen, XK)):
189
+
190
+ if i + XK > xlen:
191
+ XK = xlen - i
192
+
193
+ codedictdecoder = {}
194
+ codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda())
195
+ codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda()
196
+ codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK
197
+ codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda()
198
+ codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
199
+ codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
200
+ codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda()
201
+
202
+ codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
203
+ codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
204
+
205
+ sub_vertices = vertices[i:i+XK].cuda()
206
+
207
+ opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False)
208
+
209
+ landmark2d.append(opdict['landmarks2d'].cpu())
210
+
211
+ all_shape_images.append(opdict['rendered_images'].cpu())
212
+
213
+ rendedimages = torch.cat(all_shape_images)
214
+
215
+ lmk2d = torch.cat(landmark2d)
216
+
217
+ return rendedimages, lmk2d
218
+
219
+
220
+ @torch.no_grad()
221
+ def run_spectre_v3(self, wav=None, ds_features=None, L=20):
222
+
223
+ wav = audio_normalize(wav)
224
+ all_mel = self.audio.melspectrogram(wav).astype(np.float32).T
225
+ frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm
226
+ audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame)
227
+
228
+ vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1]
229
+ vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1]
230
+
231
+ ref = torch.cat([vid_exps.view(1, 50), vid_poses[:, 3:].view(1, 3)], dim=-1)
232
+ ref = ref[...,:self.args.exp_dim]
233
+
234
+ K = 20
235
+ xlens = len(audio_inds) # len(self.coeffs_dict['exp'])
236
+
237
+ exps = []
238
+ for i in tqdm(range(0, xlens, K), desc='S2 DECODER:'+ str(xlens) + ' '):
239
+
240
+ mels = []
241
+ for j in range(K):
242
+ if i + j < xlens:
243
+ idx = i+j # //3 * 3
244
+ mel = load_spectrogram(all_mel, audio_inds[idx], self.audio.num_frames_per_clip * self.audio.num_bins_per_frame).cuda()
245
+ mel = mel.view(-1, 1, 80, self.audio.num_frames_per_clip * self.audio.num_bins_per_frame)
246
+ mels.append(mel)
247
+ else:
248
+ break
249
+
250
+ mels = torch.cat(mels, dim=0)
251
+ new_exp = self.avmodel(mels, ref.repeat(mels.shape[0], 1, 1).cuda(), self.args.use_tanh) # exp 53
252
+ exps+= [new_exp.view(-1, 53)]
253
+
254
+ all_exps = torch.cat(exps,axis=0)
255
+
256
+ return all_exps
257
+
258
+ @torch.no_grad()
259
+ def test_model(self, wav_path):
260
+
261
+ sys.path.append('../FaceFormer')
262
+ from faceformer import Faceformer
263
+ from transformers import Wav2Vec2FeatureExtractor,Wav2Vec2Processor
264
+ from faceformer import PeriodicPositionalEncoding, init_biased_mask
265
+
266
+ #build model
267
+ self.args.train_subjects = " ".join(["A"]*8) # suitable for pre-trained faceformer checkpoint
268
+ model = Faceformer(self.args)
269
+ model.load_state_dict(torch.load('/apdcephfs/private_shadowcun/Avatar2dFF/medias/videos/c8/mask5000_l2/6_model.pth')) # ../packages/pretrained/28_ff_model.pth
270
+ model = model.to(torch.device(self.device))
271
+ model.eval()
272
+
273
+ # hacking for long audio generation
274
+ model.PPE = PeriodicPositionalEncoding(self.args.feature_dim, period = self.args.period, max_seq_len=6000).cuda()
275
+ model.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 6000, period=self.args.period).cuda()
276
+
277
+ train_subjects_list = ["A"] * 8
278
+
279
+ one_hot_labels = np.eye(len(train_subjects_list))
280
+ one_hot = one_hot_labels[0]
281
+ one_hot = np.reshape(one_hot,(-1,one_hot.shape[0]))
282
+ one_hot = torch.FloatTensor(one_hot).to(device=self.device)
283
+
284
+ vertices_npy = np.load(self.args.source_dir + '/mesh_pose0.npy')
285
+ vertices_npy = np.array(vertices_npy).reshape(-1, 5023*3)
286
+
287
+ temp = vertices_npy[33] # 829
288
+
289
+ template = temp.reshape((-1))
290
+ template = np.reshape(template,(-1,template.shape[0]))
291
+ template = torch.FloatTensor(template).to(device=self.device)
292
+
293
+ speech_array, sampling_rate = librosa.load(os.path.join(wav_path), sr=16000)
294
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
295
+ audio_feature = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
296
+ audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0]))
297
+ audio_feature = torch.FloatTensor(audio_feature).to(device=self.device)
298
 
299
+ prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3)
300
+
301
+ return prediction.squeeze()
302
+
303
+ @torch.no_grad()
304
+ def run(self, face, audio, start_frame=0):
305
+
306
+ wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80
307
+ wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav)
308
+ _, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps)
309
+
310
+ ##### audio-guided, only use the jaw movement
311
+ all_exps = self.run_spectre_v3(wav)
312
+
313
+ # #### temp. interpolation
314
+ all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
315
+ all_exps = all_exps.permute([0,2,1]).squeeze(0)
316
+
317
+ # run faceformer for face mesh generation
318
+ predicted_vertices = self.test_model(audio)
319
+ predicted_vertices = predicted_vertices.view(-1, 5023*3)
320
+
321
+ #### temp. interpolation
322
+ predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
323
+ predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3)
324
+
325
+ all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT
326
+
327
+ rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True)
328
+ debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr)
329
+
330
+ # cg2real
331
+ debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr)
332
+
333
+ exit()
334
+
335
+
336
+
337
+ if __name__ == '__main__':
338
+ parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing')
339
+ parser.add_argument('--face', default='examples', type=str, help='')
340
+ parser.add_argument('--audio', default='examples', type=str, help='')
341
+ parser.add_argument('--source_dir', default='examples', type=str,help='TODO')
342
+ parser.add_argument('--result_dir', default='examples', type=str,help='TODO')
343
+ parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs')
344
+ parser.add_argument('--result_tag', default='result', type=str,help='TODO')
345
+ parser.add_argument('--netR', default='unet_256', type=str,help='TODO')
346
+ parser.add_argument('--render_path', default='', type=str,help='TODO')
347
+ parser.add_argument('--ngf', default=16, type=int,help='TODO')
348
+ parser.add_argument('--fps', default=20, type=int,help='TODO')
349
+ parser.add_argument('--mask', default=100, type=int,help='TODO')
350
+ parser.add_argument('--mask_type', default='v3', type=str,help='TODO')
351
+ parser.add_argument('--image_size', default=256, type=int,help='TODO')
352
+ parser.add_argument('--input_nc', default=21, type=int,help='TODO')
353
+ parser.add_argument('--output_nc', default=3, type=int,help='TODO')
354
+ parser.add_argument('--renderbs', default=16, type=int,help='TODO')
355
+ parser.add_argument('--tframes', default=1, type=int,help='TODO')
356
+ parser.add_argument('--debug', action='store_true')
357
+ parser.add_argument('--enhance', action='store_true')
358
+ parser.add_argument('--phone', action='store_true')
359
+
360
+ #### faceformer
361
+ parser.add_argument("--model_name", type=str, default="VOCA")
362
+ parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI')
363
+ parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI')
364
+ parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI')
365
+ parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI')
366
+ parser.add_argument("--device", type=str, default="cuda")
367
+ parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ")
368
+ parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA")
369
+ parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects')
370
+ parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects')
371
+ parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background')
372
+ parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates')
373
+ parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology')
374
 
375
+ opt = parser.parse_args()
376
 
377
+ opt.img_size = 96
378
+ opt.static = True
379
+ opt.device = torch.device("cuda")
 
 
380
 
381
+ a2m = Audio2Mesh(opt)
382
 
383
+ print('link start!')
384
+ t = time.time()
385
+ # 02780
386
+ a2m.run(opt.face, opt.audio, 0)
387
+ print(time.time() - t)