Spaces:
Running
on
A10G
Running
on
A10G
Update app.py
Browse files
app.py
CHANGED
@@ -1,72 +1,387 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
from
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
demo = sadtalker_demo(sadtalker_result_dir)
|
70 |
-
demo.launch()
|
71 |
|
|
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import scipy, cv2, os, sys, argparse
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
from networks import define_G
|
9 |
+
from pcavs.config.AudioConfig import AudioConfig
|
10 |
+
|
11 |
+
sys.path.append('spectre')
|
12 |
+
from config import cfg as spectre_cfg
|
13 |
+
from src.spectre import SPECTRE
|
14 |
+
|
15 |
+
from audio2mesh_helper import *
|
16 |
+
from pcavs.models import create_model, networks
|
17 |
+
|
18 |
+
torch.manual_seed(0)
|
19 |
+
from scipy.signal import savgol_filter
|
20 |
+
|
21 |
+
|
22 |
+
class SimpleWrapperV2(nn.Module):
|
23 |
+
def __init__(self, cfg, use_ref=True, exp_dim=53, noload=False) -> None:
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.audio_encoder = networks.define_A_sync(cfg)
|
27 |
+
|
28 |
+
self.mapping1 = nn.Linear(512+exp_dim, exp_dim)
|
29 |
+
nn.init.constant_(self.mapping1.weight, 0.)
|
30 |
+
nn.init.constant_(self.mapping1.bias, 0.)
|
31 |
+
self.use_ref = use_ref
|
32 |
+
|
33 |
+
def forward(self, x, ref, use_tanh=False):
|
34 |
+
x = self.audio_encoder.forward_feature(x).view(x.size(0), -1)
|
35 |
+
ref_reshape = ref.reshape(x.size(0), -1) #20, -1
|
36 |
+
|
37 |
+
y = self.mapping1(torch.cat([x, ref_reshape], dim=1))
|
38 |
+
|
39 |
+
if self.use_ref:
|
40 |
+
out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial
|
41 |
+
else:
|
42 |
+
out = y.reshape(ref.shape[0], ref.shape[1], -1)
|
43 |
+
|
44 |
+
if use_tanh:
|
45 |
+
out[:, :50] = torch.tanh(out[:, :50]) * 3
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
class Audio2Mesh(object):
|
50 |
+
def __init__(self, args) -> None:
|
51 |
+
self.args = args
|
52 |
+
|
53 |
+
spectre_cfg.model.use_tex = True
|
54 |
+
spectre_cfg.model.mask_type = args.mask_type
|
55 |
+
spectre_cfg.debug = self.args.debug
|
56 |
+
spectre_cfg.model.netA_sync = 'ressesync'
|
57 |
+
spectre_cfg.model.gpu_ids = [0]
|
58 |
+
|
59 |
+
self.spectre = SPECTRE(spectre_cfg)
|
60 |
+
self.spectre.eval()
|
61 |
+
self.face_tracker = None #FaceTrackerV2() # face landmark detection
|
62 |
+
self.mel_step_size = 16
|
63 |
+
self.fps = args.fps
|
64 |
+
self.Nw = args.tframes
|
65 |
+
self.device = self.args.device
|
66 |
+
self.image_size = self.args.image_size
|
67 |
+
|
68 |
+
### only audio
|
69 |
+
args.netA_sync = 'ressesync'
|
70 |
+
args.gpu_ids = [0]
|
71 |
+
args.exp_dim = 53
|
72 |
+
args.use_tanh = False
|
73 |
+
args.K = 20
|
74 |
+
|
75 |
+
self.audio2exp = 'pcavs'
|
76 |
+
|
77 |
+
#
|
78 |
+
self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda()
|
79 |
+
self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt'])
|
80 |
+
|
81 |
+
# 5, 160 = 25fps
|
82 |
+
self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160)
|
83 |
+
|
84 |
+
with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ?
|
85 |
+
self.fitting_coeffs = pickle.load(f, encoding='bytes')
|
86 |
+
|
87 |
+
self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']}
|
88 |
+
|
89 |
+
#### find the close month
|
90 |
+
exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1)
|
91 |
+
ssss, sorted_indices = torch.sort(exp_tensors)
|
92 |
+
self.exp_id = sorted_indices[0].item()
|
93 |
+
|
94 |
+
if '.ts' in args.render_path:
|
95 |
+
self.render = torch.jit.load(args.render_path).cuda()
|
96 |
+
self.trt = True
|
97 |
+
else:
|
98 |
+
self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda()
|
99 |
+
self.render.load_state_dict(torch.load(args.render_path))
|
100 |
+
self.trt = False
|
101 |
+
|
102 |
+
print('loaded cached images...')
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def cg2real(self, rendedimages, start_frame=0):
|
106 |
+
|
107 |
+
## load original image and the mask
|
108 |
+
self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\
|
109 |
+
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
|
110 |
+
self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\
|
111 |
+
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
|
112 |
+
|
113 |
+
self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.)
|
114 |
+
self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.)
|
115 |
+
|
116 |
+
## padding the rended_imgs
|
117 |
+
paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous()
|
118 |
+
paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous()
|
119 |
+
paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous()
|
120 |
+
|
121 |
+
# paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask)
|
122 |
+
padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor)
|
123 |
+
padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
|
124 |
+
paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
|
125 |
+
paded_tensor = (paded_tensor-0.5)*2
|
126 |
+
|
127 |
+
result = []
|
128 |
+
for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'):
|
129 |
+
list_A = []
|
130 |
+
list_R = []
|
131 |
+
list_M = []
|
132 |
+
for i in range(self.args.renderbs):
|
133 |
+
idx = index + i
|
134 |
+
if idx+self.Nw > len(padded_input):
|
135 |
+
list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
136 |
+
list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
137 |
+
list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
138 |
+
else:
|
139 |
+
list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
140 |
+
list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
141 |
+
list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
142 |
+
|
143 |
+
list_A = torch.cat(list_A)
|
144 |
+
list_R = torch.cat(list_R)
|
145 |
+
list_M = torch.cat(list_M)
|
146 |
+
|
147 |
+
idx = (self.Nw//2) * 3
|
148 |
+
mask = list_M[:, idx:idx+3]
|
149 |
+
|
150 |
+
# list_A = padded_input
|
151 |
+
mask = maskErosion(mask, offY=self.args.mask)
|
152 |
+
list_A = list_A * (1 - mask[:,0:1])
|
153 |
+
A = torch.cat([list_A, list_R], 1)
|
154 |
+
|
155 |
+
if self.trt:
|
156 |
+
B = self.render(A.half().cuda())
|
157 |
+
elif self.args.netR == 'unet_256':
|
158 |
+
# import pdb; pdb.set_trace()
|
159 |
+
idx = (self.Nw//2) * 3
|
160 |
+
mask = list_M[:, idx:idx+3].cuda()
|
161 |
+
mask = maskErosion(mask, offY=self.args.mask)
|
162 |
+
B0 = list_A[:, idx:idx+3].cuda()
|
163 |
+
B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
|
164 |
+
elif self.args.netR == 's2am':
|
165 |
+
# import pdb; pdb.set_trace()
|
166 |
+
idx = (self.Nw//2) * 3
|
167 |
+
mask = list_M[:, idx:idx+3].cuda()
|
168 |
+
mask = maskErosion(mask, offY=self.args.mask)
|
169 |
+
B0 = list_A[:, idx:idx+3].cuda()
|
170 |
+
B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
|
171 |
+
else:
|
172 |
+
B = self.render(A.cuda())
|
173 |
+
|
174 |
+
result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1
|
175 |
+
|
176 |
+
return torch.cat(result)[:len(rendedimages)]
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20):
|
180 |
+
|
181 |
+
xlen = vertices.shape[0]
|
182 |
+
all_shape_images = []
|
183 |
+
landmark2d = []
|
184 |
+
|
185 |
+
#### find the most larger pose 51 in the coeffs.
|
186 |
+
max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1))
|
187 |
+
|
188 |
+
for i in tqdm(range(0, xlen, XK)):
|
189 |
+
|
190 |
+
if i + XK > xlen:
|
191 |
+
XK = xlen - i
|
192 |
+
|
193 |
+
codedictdecoder = {}
|
194 |
+
codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda())
|
195 |
+
codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda()
|
196 |
+
codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK
|
197 |
+
codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda()
|
198 |
+
codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
|
199 |
+
codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
|
200 |
+
codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda()
|
201 |
+
|
202 |
+
codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
|
203 |
+
codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
|
204 |
+
|
205 |
+
sub_vertices = vertices[i:i+XK].cuda()
|
206 |
+
|
207 |
+
opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False)
|
208 |
+
|
209 |
+
landmark2d.append(opdict['landmarks2d'].cpu())
|
210 |
+
|
211 |
+
all_shape_images.append(opdict['rendered_images'].cpu())
|
212 |
+
|
213 |
+
rendedimages = torch.cat(all_shape_images)
|
214 |
+
|
215 |
+
lmk2d = torch.cat(landmark2d)
|
216 |
+
|
217 |
+
return rendedimages, lmk2d
|
218 |
+
|
219 |
+
|
220 |
+
@torch.no_grad()
|
221 |
+
def run_spectre_v3(self, wav=None, ds_features=None, L=20):
|
222 |
+
|
223 |
+
wav = audio_normalize(wav)
|
224 |
+
all_mel = self.audio.melspectrogram(wav).astype(np.float32).T
|
225 |
+
frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm
|
226 |
+
audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame)
|
227 |
+
|
228 |
+
vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1]
|
229 |
+
vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1]
|
230 |
+
|
231 |
+
ref = torch.cat([vid_exps.view(1, 50), vid_poses[:, 3:].view(1, 3)], dim=-1)
|
232 |
+
ref = ref[...,:self.args.exp_dim]
|
233 |
+
|
234 |
+
K = 20
|
235 |
+
xlens = len(audio_inds) # len(self.coeffs_dict['exp'])
|
236 |
+
|
237 |
+
exps = []
|
238 |
+
for i in tqdm(range(0, xlens, K), desc='S2 DECODER:'+ str(xlens) + ' '):
|
239 |
+
|
240 |
+
mels = []
|
241 |
+
for j in range(K):
|
242 |
+
if i + j < xlens:
|
243 |
+
idx = i+j # //3 * 3
|
244 |
+
mel = load_spectrogram(all_mel, audio_inds[idx], self.audio.num_frames_per_clip * self.audio.num_bins_per_frame).cuda()
|
245 |
+
mel = mel.view(-1, 1, 80, self.audio.num_frames_per_clip * self.audio.num_bins_per_frame)
|
246 |
+
mels.append(mel)
|
247 |
+
else:
|
248 |
+
break
|
249 |
+
|
250 |
+
mels = torch.cat(mels, dim=0)
|
251 |
+
new_exp = self.avmodel(mels, ref.repeat(mels.shape[0], 1, 1).cuda(), self.args.use_tanh) # exp 53
|
252 |
+
exps+= [new_exp.view(-1, 53)]
|
253 |
+
|
254 |
+
all_exps = torch.cat(exps,axis=0)
|
255 |
+
|
256 |
+
return all_exps
|
257 |
+
|
258 |
+
@torch.no_grad()
|
259 |
+
def test_model(self, wav_path):
|
260 |
+
|
261 |
+
sys.path.append('../FaceFormer')
|
262 |
+
from faceformer import Faceformer
|
263 |
+
from transformers import Wav2Vec2FeatureExtractor,Wav2Vec2Processor
|
264 |
+
from faceformer import PeriodicPositionalEncoding, init_biased_mask
|
265 |
+
|
266 |
+
#build model
|
267 |
+
self.args.train_subjects = " ".join(["A"]*8) # suitable for pre-trained faceformer checkpoint
|
268 |
+
model = Faceformer(self.args)
|
269 |
+
model.load_state_dict(torch.load('/apdcephfs/private_shadowcun/Avatar2dFF/medias/videos/c8/mask5000_l2/6_model.pth')) # ../packages/pretrained/28_ff_model.pth
|
270 |
+
model = model.to(torch.device(self.device))
|
271 |
+
model.eval()
|
272 |
+
|
273 |
+
# hacking for long audio generation
|
274 |
+
model.PPE = PeriodicPositionalEncoding(self.args.feature_dim, period = self.args.period, max_seq_len=6000).cuda()
|
275 |
+
model.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 6000, period=self.args.period).cuda()
|
276 |
+
|
277 |
+
train_subjects_list = ["A"] * 8
|
278 |
+
|
279 |
+
one_hot_labels = np.eye(len(train_subjects_list))
|
280 |
+
one_hot = one_hot_labels[0]
|
281 |
+
one_hot = np.reshape(one_hot,(-1,one_hot.shape[0]))
|
282 |
+
one_hot = torch.FloatTensor(one_hot).to(device=self.device)
|
283 |
+
|
284 |
+
vertices_npy = np.load(self.args.source_dir + '/mesh_pose0.npy')
|
285 |
+
vertices_npy = np.array(vertices_npy).reshape(-1, 5023*3)
|
286 |
+
|
287 |
+
temp = vertices_npy[33] # 829
|
288 |
+
|
289 |
+
template = temp.reshape((-1))
|
290 |
+
template = np.reshape(template,(-1,template.shape[0]))
|
291 |
+
template = torch.FloatTensor(template).to(device=self.device)
|
292 |
+
|
293 |
+
speech_array, sampling_rate = librosa.load(os.path.join(wav_path), sr=16000)
|
294 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
295 |
+
audio_feature = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
|
296 |
+
audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0]))
|
297 |
+
audio_feature = torch.FloatTensor(audio_feature).to(device=self.device)
|
298 |
|
299 |
+
prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3)
|
300 |
+
|
301 |
+
return prediction.squeeze()
|
302 |
+
|
303 |
+
@torch.no_grad()
|
304 |
+
def run(self, face, audio, start_frame=0):
|
305 |
+
|
306 |
+
wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80
|
307 |
+
wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav)
|
308 |
+
_, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps)
|
309 |
+
|
310 |
+
##### audio-guided, only use the jaw movement
|
311 |
+
all_exps = self.run_spectre_v3(wav)
|
312 |
+
|
313 |
+
# #### temp. interpolation
|
314 |
+
all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
|
315 |
+
all_exps = all_exps.permute([0,2,1]).squeeze(0)
|
316 |
+
|
317 |
+
# run faceformer for face mesh generation
|
318 |
+
predicted_vertices = self.test_model(audio)
|
319 |
+
predicted_vertices = predicted_vertices.view(-1, 5023*3)
|
320 |
+
|
321 |
+
#### temp. interpolation
|
322 |
+
predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
|
323 |
+
predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3)
|
324 |
+
|
325 |
+
all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT
|
326 |
+
|
327 |
+
rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True)
|
328 |
+
debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr)
|
329 |
+
|
330 |
+
# cg2real
|
331 |
+
debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr)
|
332 |
+
|
333 |
+
exit()
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == '__main__':
|
338 |
+
parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing')
|
339 |
+
parser.add_argument('--face', default='examples', type=str, help='')
|
340 |
+
parser.add_argument('--audio', default='examples', type=str, help='')
|
341 |
+
parser.add_argument('--source_dir', default='examples', type=str,help='TODO')
|
342 |
+
parser.add_argument('--result_dir', default='examples', type=str,help='TODO')
|
343 |
+
parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs')
|
344 |
+
parser.add_argument('--result_tag', default='result', type=str,help='TODO')
|
345 |
+
parser.add_argument('--netR', default='unet_256', type=str,help='TODO')
|
346 |
+
parser.add_argument('--render_path', default='', type=str,help='TODO')
|
347 |
+
parser.add_argument('--ngf', default=16, type=int,help='TODO')
|
348 |
+
parser.add_argument('--fps', default=20, type=int,help='TODO')
|
349 |
+
parser.add_argument('--mask', default=100, type=int,help='TODO')
|
350 |
+
parser.add_argument('--mask_type', default='v3', type=str,help='TODO')
|
351 |
+
parser.add_argument('--image_size', default=256, type=int,help='TODO')
|
352 |
+
parser.add_argument('--input_nc', default=21, type=int,help='TODO')
|
353 |
+
parser.add_argument('--output_nc', default=3, type=int,help='TODO')
|
354 |
+
parser.add_argument('--renderbs', default=16, type=int,help='TODO')
|
355 |
+
parser.add_argument('--tframes', default=1, type=int,help='TODO')
|
356 |
+
parser.add_argument('--debug', action='store_true')
|
357 |
+
parser.add_argument('--enhance', action='store_true')
|
358 |
+
parser.add_argument('--phone', action='store_true')
|
359 |
+
|
360 |
+
#### faceformer
|
361 |
+
parser.add_argument("--model_name", type=str, default="VOCA")
|
362 |
+
parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI')
|
363 |
+
parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI')
|
364 |
+
parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI')
|
365 |
+
parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI')
|
366 |
+
parser.add_argument("--device", type=str, default="cuda")
|
367 |
+
parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ")
|
368 |
+
parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA")
|
369 |
+
parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects')
|
370 |
+
parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects')
|
371 |
+
parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background')
|
372 |
+
parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates')
|
373 |
+
parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology')
|
374 |
|
375 |
+
opt = parser.parse_args()
|
376 |
|
377 |
+
opt.img_size = 96
|
378 |
+
opt.static = True
|
379 |
+
opt.device = torch.device("cuda")
|
|
|
|
|
380 |
|
381 |
+
a2m = Audio2Mesh(opt)
|
382 |
|
383 |
+
print('link start!')
|
384 |
+
t = time.time()
|
385 |
+
# 02780
|
386 |
+
a2m.run(opt.face, opt.audio, 0)
|
387 |
+
print(time.time() - t)
|