Spaces:
Sleeping
Sleeping
shadowcun
commited on
Commit
•
9ab094a
1
Parent(s):
99e1f07
new version of sadtalker
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +99 -27
- checkpoints/mapping_00229-model.pth.tar +1 -1
- src/__pycache__/generate_batch.cpython-38.pyc +0 -0
- src/__pycache__/generate_facerender_batch.cpython-38.pyc +0 -0
- src/__pycache__/test_audio2coeff.cpython-38.pyc +0 -0
- src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
- src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
- src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
- src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
- src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
- src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
- src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
- src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
- src/audio2pose_models/audio2pose.py +4 -4
- src/audio2pose_models/audio_encoder.py +7 -7
- src/config/similarity_Lm3D_all.mat +0 -0
- src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc +0 -0
- src/face3d/extract_kp_videos.py +2 -2
- src/face3d/extract_kp_videos_safe.py +151 -0
- src/face3d/models/__pycache__/__init__.cpython-38.pyc +0 -0
- src/face3d/models/__pycache__/base_model.cpython-38.pyc +0 -0
- src/face3d/models/__pycache__/networks.cpython-38.pyc +0 -0
- src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc +0 -0
- src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc +0 -0
- src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc +0 -0
- src/face3d/util/__pycache__/__init__.cpython-38.pyc +0 -0
- src/face3d/util/__pycache__/load_mats.cpython-38.pyc +0 -0
- src/face3d/util/__pycache__/preprocess.cpython-38.pyc +0 -0
- src/facerender/__pycache__/animate.cpython-38.pyc +0 -0
- src/facerender/animate.py +66 -22
- src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc +0 -0
- src/facerender/modules/__pycache__/generator.cpython-38.pyc +0 -0
- src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc +0 -0
- src/facerender/modules/__pycache__/make_animation.cpython-38.pyc +0 -0
- src/facerender/modules/__pycache__/mapping.cpython-38.pyc +0 -0
- src/facerender/modules/__pycache__/util.cpython-38.pyc +0 -0
- src/facerender/modules/make_animation.py +4 -4
- src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc +0 -0
- src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc +0 -0
- src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc +0 -0
- src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc +0 -0
- src/generate_batch.py +25 -20
- src/generate_facerender_batch.py +8 -6
- src/gradio_demo.py +83 -64
- src/src/audio2exp_models/audio2exp.py +41 -0
- src/src/audio2exp_models/networks.py +74 -0
- src/src/audio2pose_models/audio2pose.py +94 -0
- src/src/audio2pose_models/audio_encoder.py +64 -0
- src/src/audio2pose_models/cvae.py +149 -0
- src/src/audio2pose_models/discriminator.py +76 -0
app.py
CHANGED
@@ -8,8 +8,27 @@ from huggingface_hub import snapshot_download
|
|
8 |
def get_source_image(image):
|
9 |
return image
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def download_model():
|
12 |
-
REPO_ID = 'vinthony/SadTalker'
|
13 |
snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
|
14 |
|
15 |
def sadtalker_demo():
|
@@ -34,33 +53,96 @@ def sadtalker_demo():
|
|
34 |
with gr.Row().style(equal_height=False):
|
35 |
with gr.Column(variant='panel'):
|
36 |
with gr.Tabs(elem_id="sadtalker_source_image"):
|
37 |
-
with gr.TabItem('
|
38 |
with gr.Row():
|
39 |
-
source_image = gr.Image(label="Source image", source="upload", type="filepath").style(
|
40 |
-
|
|
|
41 |
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
42 |
-
with gr.TabItem('
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
with gr.Column(variant='panel'):
|
53 |
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
54 |
with gr.TabItem('Settings'):
|
|
|
55 |
with gr.Column(variant='panel'):
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
with gr.Tabs(elem_id="sadtalker_genearted"):
|
62 |
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
with gr.Row():
|
65 |
examples = [
|
66 |
[
|
@@ -138,16 +220,6 @@ def sadtalker_demo():
|
|
138 |
fn=sad_talker.test,
|
139 |
cache_examples=os.getenv('SYSTEM') == 'spaces') #
|
140 |
|
141 |
-
submit.click(
|
142 |
-
fn=sad_talker.test,
|
143 |
-
inputs=[source_image,
|
144 |
-
driven_audio,
|
145 |
-
preprocess_type,
|
146 |
-
is_still_mode,
|
147 |
-
enhancer],
|
148 |
-
outputs=[gen_video]
|
149 |
-
)
|
150 |
-
|
151 |
return sadtalker_interface
|
152 |
|
153 |
|
|
|
8 |
def get_source_image(image):
|
9 |
return image
|
10 |
|
11 |
+
try:
|
12 |
+
import webui # in webui
|
13 |
+
in_webui = True
|
14 |
+
except:
|
15 |
+
in_webui = False
|
16 |
+
|
17 |
+
|
18 |
+
def toggle_audio_file(choice):
|
19 |
+
if choice == False:
|
20 |
+
return gr.update(visible=True), gr.update(visible=False)
|
21 |
+
else:
|
22 |
+
return gr.update(visible=False), gr.update(visible=True)
|
23 |
+
|
24 |
+
def ref_video_fn(path_of_ref_video):
|
25 |
+
if path_of_ref_video is not None:
|
26 |
+
return gr.update(value=True)
|
27 |
+
else:
|
28 |
+
return gr.update(value=False)
|
29 |
+
|
30 |
def download_model():
|
31 |
+
REPO_ID = 'vinthony/SadTalker-V002rc'
|
32 |
snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
|
33 |
|
34 |
def sadtalker_demo():
|
|
|
53 |
with gr.Row().style(equal_height=False):
|
54 |
with gr.Column(variant='panel'):
|
55 |
with gr.Tabs(elem_id="sadtalker_source_image"):
|
56 |
+
with gr.TabItem('Source image'):
|
57 |
with gr.Row():
|
58 |
+
source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
|
59 |
+
|
60 |
+
|
61 |
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
62 |
+
with gr.TabItem('Driving Methods'):
|
63 |
+
gr.Markdown("Possible driving combinations: <br> 1. Audio only 2. Audio/IDLE Mode + Ref Video(pose, blink, pose+blink) 3. IDLE Mode only 4. Ref Video only (all) ")
|
64 |
+
|
65 |
+
with gr.Row():
|
66 |
+
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
|
67 |
+
driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
|
68 |
+
|
69 |
+
with gr.Column():
|
70 |
+
use_idle_mode = gr.Checkbox(label="Use Idle Animation")
|
71 |
+
length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.")
|
72 |
+
use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
|
73 |
+
|
74 |
+
if sys.platform != 'win32' and not in_webui:
|
75 |
+
with gr.Accordion('Generate Audio From TTS', open=False):
|
76 |
+
from src.utils.text2speech import TTSTalker
|
77 |
+
tts_talker = TTSTalker()
|
78 |
+
with gr.Column(variant='panel'):
|
79 |
+
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
|
80 |
+
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
|
81 |
+
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref").style(width=512)
|
85 |
+
|
86 |
+
with gr.Column():
|
87 |
+
use_ref_video = gr.Checkbox(label="Use Reference Video")
|
88 |
+
ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")
|
89 |
+
|
90 |
+
ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
|
91 |
+
|
92 |
|
93 |
with gr.Column(variant='panel'):
|
94 |
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
95 |
with gr.TabItem('Settings'):
|
96 |
+
gr.Markdown("need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
|
97 |
with gr.Column(variant='panel'):
|
98 |
+
# width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
|
99 |
+
# height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
|
100 |
+
with gr.Row():
|
101 |
+
pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
|
102 |
+
exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
|
103 |
+
blink_every = gr.Checkbox(label="use eye blink", value=True)
|
104 |
|
105 |
+
with gr.Row():
|
106 |
+
size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
|
107 |
+
preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
|
111 |
+
facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?")
|
112 |
+
|
113 |
+
with gr.Row():
|
114 |
+
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=1)
|
115 |
+
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
|
116 |
+
|
117 |
+
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
|
118 |
+
|
119 |
with gr.Tabs(elem_id="sadtalker_genearted"):
|
120 |
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
|
121 |
|
122 |
+
|
123 |
+
|
124 |
+
submit.click(
|
125 |
+
fn=sad_talker.test,
|
126 |
+
inputs=[source_image,
|
127 |
+
driven_audio,
|
128 |
+
preprocess_type,
|
129 |
+
is_still_mode,
|
130 |
+
enhancer,
|
131 |
+
batch_size,
|
132 |
+
size_of_image,
|
133 |
+
pose_style,
|
134 |
+
facerender,
|
135 |
+
exp_weight,
|
136 |
+
use_ref_video,
|
137 |
+
ref_video,
|
138 |
+
ref_info,
|
139 |
+
use_idle_mode,
|
140 |
+
length_of_audio,
|
141 |
+
blink_every
|
142 |
+
],
|
143 |
+
outputs=[gen_video]
|
144 |
+
)
|
145 |
+
|
146 |
with gr.Row():
|
147 |
examples = [
|
148 |
[
|
|
|
220 |
fn=sad_talker.test,
|
221 |
cache_examples=os.getenv('SYSTEM') == 'spaces') #
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
return sadtalker_interface
|
224 |
|
225 |
|
checkpoints/mapping_00229-model.pth.tar
CHANGED
@@ -1 +1 @@
|
|
1 |
-
../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
|
|
|
1 |
+
../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker-V002rc/blobs/62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
|
src/__pycache__/generate_batch.cpython-38.pyc
CHANGED
Binary files a/src/__pycache__/generate_batch.cpython-38.pyc and b/src/__pycache__/generate_batch.cpython-38.pyc differ
|
|
src/__pycache__/generate_facerender_batch.cpython-38.pyc
CHANGED
Binary files a/src/__pycache__/generate_facerender_batch.cpython-38.pyc and b/src/__pycache__/generate_facerender_batch.cpython-38.pyc differ
|
|
src/__pycache__/test_audio2coeff.cpython-38.pyc
CHANGED
Binary files a/src/__pycache__/test_audio2coeff.cpython-38.pyc and b/src/__pycache__/test_audio2coeff.cpython-38.pyc differ
|
|
src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc
CHANGED
Binary files a/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc and b/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc differ
|
|
src/audio2exp_models/__pycache__/networks.cpython-38.pyc
CHANGED
Binary files a/src/audio2exp_models/__pycache__/networks.cpython-38.pyc and b/src/audio2exp_models/__pycache__/networks.cpython-38.pyc differ
|
|
src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc
CHANGED
Binary files a/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc and b/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc differ
|
|
src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc
CHANGED
Binary files a/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc and b/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc differ
|
|
src/audio2pose_models/__pycache__/cvae.cpython-38.pyc
CHANGED
Binary files a/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc and b/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc differ
|
|
src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc
CHANGED
Binary files a/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc and b/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc differ
|
|
src/audio2pose_models/__pycache__/networks.cpython-38.pyc
CHANGED
Binary files a/src/audio2pose_models/__pycache__/networks.cpython-38.pyc and b/src/audio2pose_models/__pycache__/networks.cpython-38.pyc differ
|
|
src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc
CHANGED
Binary files a/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc and b/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc differ
|
|
src/audio2pose_models/audio2pose.py
CHANGED
@@ -25,8 +25,8 @@ class Audio2Pose(nn.Module):
|
|
25 |
|
26 |
batch = {}
|
27 |
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
|
28 |
-
batch['pose_motion_gt'] = coeff_gt[:, 1:,
|
29 |
-
batch['ref'] = coeff_gt[:, 0,
|
30 |
batch['class'] = x['class'].squeeze(0).cuda() # bs
|
31 |
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
|
32 |
|
@@ -37,8 +37,8 @@ class Audio2Pose(nn.Module):
|
|
37 |
batch = self.netG(batch)
|
38 |
|
39 |
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
|
40 |
-
pose_gt = coeff_gt[:, 1:,
|
41 |
-
pose_pred = coeff_gt[:, :1,
|
42 |
|
43 |
batch['pose_pred'] = pose_pred
|
44 |
batch['pose_gt'] = pose_gt
|
|
|
25 |
|
26 |
batch = {}
|
27 |
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
|
28 |
+
batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
|
29 |
+
batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
|
30 |
batch['class'] = x['class'].squeeze(0).cuda() # bs
|
31 |
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
|
32 |
|
|
|
37 |
batch = self.netG(batch)
|
38 |
|
39 |
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
|
40 |
+
pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
|
41 |
+
pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
|
42 |
|
43 |
batch['pose_pred'] = pose_pred
|
44 |
batch['pose_gt'] = pose_gt
|
src/audio2pose_models/audio_encoder.py
CHANGED
@@ -41,14 +41,14 @@ class AudioEncoder(nn.Module):
|
|
41 |
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
42 |
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
|
44 |
-
#### load the pre-trained audio_encoder
|
45 |
-
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
|
46 |
-
state_dict = self.audio_encoder.state_dict()
|
47 |
|
48 |
-
for k,v in wav2lip_state_dict.items():
|
49 |
-
|
50 |
-
|
51 |
-
self.audio_encoder.load_state_dict(state_dict)
|
52 |
|
53 |
|
54 |
def forward(self, audio_sequences):
|
|
|
41 |
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
42 |
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
|
44 |
+
#### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
|
45 |
+
# wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
|
46 |
+
# state_dict = self.audio_encoder.state_dict()
|
47 |
|
48 |
+
# for k,v in wav2lip_state_dict.items():
|
49 |
+
# if 'audio_encoder' in k:
|
50 |
+
# state_dict[k.replace('module.audio_encoder.', '')] = v
|
51 |
+
# self.audio_encoder.load_state_dict(state_dict)
|
52 |
|
53 |
|
54 |
def forward(self, audio_sequences):
|
src/config/similarity_Lm3D_all.mat
ADDED
Binary file (994 Bytes). View file
|
|
src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc
DELETED
Binary file (3.59 kB)
|
|
src/face3d/extract_kp_videos.py
CHANGED
@@ -13,7 +13,8 @@ from torch.multiprocessing import Pool, Process, set_start_method
|
|
13 |
|
14 |
class KeypointExtractor():
|
15 |
def __init__(self, device):
|
16 |
-
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
|
|
|
17 |
|
18 |
def extract_keypoint(self, images, name=None, info=True):
|
19 |
if isinstance(images, list):
|
@@ -40,7 +41,6 @@ class KeypointExtractor():
|
|
40 |
break
|
41 |
except RuntimeError as e:
|
42 |
if str(e).startswith('CUDA'):
|
43 |
-
print(e)
|
44 |
print("Warning: out of memory, sleep for 1s")
|
45 |
time.sleep(1)
|
46 |
else:
|
|
|
13 |
|
14 |
class KeypointExtractor():
|
15 |
def __init__(self, device):
|
16 |
+
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
|
17 |
+
device=device)
|
18 |
|
19 |
def extract_keypoint(self, images, name=None, info=True):
|
20 |
if isinstance(images, list):
|
|
|
41 |
break
|
42 |
except RuntimeError as e:
|
43 |
if str(e).startswith('CUDA'):
|
|
|
44 |
print("Warning: out of memory, sleep for 1s")
|
45 |
time.sleep(1)
|
46 |
else:
|
src/face3d/extract_kp_videos_safe.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from itertools import cycle
|
11 |
+
from torch.multiprocessing import Pool, Process, set_start_method
|
12 |
+
|
13 |
+
from facexlib.alignment import landmark_98_to_68
|
14 |
+
from facexlib.detection import init_detection_model
|
15 |
+
|
16 |
+
from facexlib.utils import load_file_from_url
|
17 |
+
from facexlib.alignment.awing_arch import FAN
|
18 |
+
|
19 |
+
def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
|
20 |
+
if model_name == 'awing_fan':
|
21 |
+
model = FAN(num_modules=4, num_landmarks=98, device=device)
|
22 |
+
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f'{model_name} is not implemented.')
|
25 |
+
|
26 |
+
model_path = load_file_from_url(
|
27 |
+
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
28 |
+
model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
|
29 |
+
model.eval()
|
30 |
+
model = model.to(device)
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
class KeypointExtractor():
|
35 |
+
def __init__(self, device='cuda'):
|
36 |
+
|
37 |
+
### gfpgan/weights
|
38 |
+
try:
|
39 |
+
import webui # in webui
|
40 |
+
root_path = 'extensions/SadTalker/gfpgan/weights'
|
41 |
+
|
42 |
+
except:
|
43 |
+
root_path = 'gfpgan/weights'
|
44 |
+
|
45 |
+
self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
|
46 |
+
self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
|
47 |
+
|
48 |
+
def extract_keypoint(self, images, name=None, info=True):
|
49 |
+
if isinstance(images, list):
|
50 |
+
keypoints = []
|
51 |
+
if info:
|
52 |
+
i_range = tqdm(images,desc='landmark Det:')
|
53 |
+
else:
|
54 |
+
i_range = images
|
55 |
+
|
56 |
+
for image in i_range:
|
57 |
+
current_kp = self.extract_keypoint(image)
|
58 |
+
# current_kp = self.detector.get_landmarks(np.array(image))
|
59 |
+
if np.mean(current_kp) == -1 and keypoints:
|
60 |
+
keypoints.append(keypoints[-1])
|
61 |
+
else:
|
62 |
+
keypoints.append(current_kp[None])
|
63 |
+
|
64 |
+
keypoints = np.concatenate(keypoints, 0)
|
65 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
66 |
+
return keypoints
|
67 |
+
else:
|
68 |
+
while True:
|
69 |
+
try:
|
70 |
+
with torch.no_grad():
|
71 |
+
# face detection -> face alignment.
|
72 |
+
img = np.array(images)
|
73 |
+
bboxes = self.det_net.detect_faces(images, 0.97)
|
74 |
+
|
75 |
+
bboxes = bboxes[0]
|
76 |
+
img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
|
77 |
+
|
78 |
+
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
|
79 |
+
|
80 |
+
#### keypoints to the original location
|
81 |
+
keypoints[:,0] += int(bboxes[0])
|
82 |
+
keypoints[:,1] += int(bboxes[1])
|
83 |
+
|
84 |
+
break
|
85 |
+
except RuntimeError as e:
|
86 |
+
if str(e).startswith('CUDA'):
|
87 |
+
print("Warning: out of memory, sleep for 1s")
|
88 |
+
time.sleep(1)
|
89 |
+
else:
|
90 |
+
print(e)
|
91 |
+
break
|
92 |
+
except TypeError:
|
93 |
+
print('No face detected in this image')
|
94 |
+
shape = [68, 2]
|
95 |
+
keypoints = -1. * np.ones(shape)
|
96 |
+
break
|
97 |
+
if name is not None:
|
98 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
99 |
+
return keypoints
|
100 |
+
|
101 |
+
def read_video(filename):
|
102 |
+
frames = []
|
103 |
+
cap = cv2.VideoCapture(filename)
|
104 |
+
while cap.isOpened():
|
105 |
+
ret, frame = cap.read()
|
106 |
+
if ret:
|
107 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
108 |
+
frame = Image.fromarray(frame)
|
109 |
+
frames.append(frame)
|
110 |
+
else:
|
111 |
+
break
|
112 |
+
cap.release()
|
113 |
+
return frames
|
114 |
+
|
115 |
+
def run(data):
|
116 |
+
filename, opt, device = data
|
117 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
118 |
+
kp_extractor = KeypointExtractor()
|
119 |
+
images = read_video(filename)
|
120 |
+
name = filename.split('/')[-2:]
|
121 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
122 |
+
kp_extractor.extract_keypoint(
|
123 |
+
images,
|
124 |
+
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
125 |
+
)
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
set_start_method('spawn')
|
129 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
130 |
+
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
131 |
+
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
132 |
+
parser.add_argument('--device_ids', type=str, default='0,1')
|
133 |
+
parser.add_argument('--workers', type=int, default=4)
|
134 |
+
|
135 |
+
opt = parser.parse_args()
|
136 |
+
filenames = list()
|
137 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
138 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
139 |
+
extensions = VIDEO_EXTENSIONS
|
140 |
+
|
141 |
+
for ext in extensions:
|
142 |
+
os.listdir(f'{opt.input_dir}')
|
143 |
+
print(f'{opt.input_dir}/*.{ext}')
|
144 |
+
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
145 |
+
print('Total number of videos:', len(filenames))
|
146 |
+
pool = Pool(opt.workers)
|
147 |
+
args_list = cycle([opt])
|
148 |
+
device_ids = opt.device_ids.split(",")
|
149 |
+
device_ids = cycle(device_ids)
|
150 |
+
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
151 |
+
None
|
src/face3d/models/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/src/face3d/models/__pycache__/__init__.cpython-38.pyc and b/src/face3d/models/__pycache__/__init__.cpython-38.pyc differ
|
|
src/face3d/models/__pycache__/base_model.cpython-38.pyc
CHANGED
Binary files a/src/face3d/models/__pycache__/base_model.cpython-38.pyc and b/src/face3d/models/__pycache__/base_model.cpython-38.pyc differ
|
|
src/face3d/models/__pycache__/networks.cpython-38.pyc
CHANGED
Binary files a/src/face3d/models/__pycache__/networks.cpython-38.pyc and b/src/face3d/models/__pycache__/networks.cpython-38.pyc differ
|
|
src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc differ
|
|
src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc
CHANGED
Binary files a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc differ
|
|
src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc
CHANGED
Binary files a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc differ
|
|
src/face3d/util/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/src/face3d/util/__pycache__/__init__.cpython-38.pyc and b/src/face3d/util/__pycache__/__init__.cpython-38.pyc differ
|
|
src/face3d/util/__pycache__/load_mats.cpython-38.pyc
CHANGED
Binary files a/src/face3d/util/__pycache__/load_mats.cpython-38.pyc and b/src/face3d/util/__pycache__/load_mats.cpython-38.pyc differ
|
|
src/face3d/util/__pycache__/preprocess.cpython-38.pyc
CHANGED
Binary files a/src/face3d/util/__pycache__/preprocess.cpython-38.pyc and b/src/face3d/util/__pycache__/preprocess.cpython-38.pyc differ
|
|
src/facerender/__pycache__/animate.cpython-38.pyc
CHANGED
Binary files a/src/facerender/__pycache__/animate.cpython-38.pyc and b/src/facerender/__pycache__/animate.cpython-38.pyc differ
|
|
src/facerender/animate.py
CHANGED
@@ -4,11 +4,15 @@ import yaml
|
|
4 |
import numpy as np
|
5 |
import warnings
|
6 |
from skimage import img_as_ubyte
|
7 |
-
|
|
|
8 |
warnings.filterwarnings('ignore')
|
9 |
|
|
|
10 |
import imageio
|
11 |
import torch
|
|
|
|
|
12 |
|
13 |
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
|
14 |
from src.facerender.modules.mapping import MappingNet
|
@@ -16,17 +20,21 @@ from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionA
|
|
16 |
from src.facerender.modules.make_animation import make_animation
|
17 |
|
18 |
from pydub import AudioSegment
|
19 |
-
from src.utils.face_enhancer import
|
20 |
from src.utils.paste_pic import paste_pic
|
21 |
from src.utils.videoio import save_video_with_watermark
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class AnimateFromCoeff():
|
25 |
|
26 |
-
def __init__(self,
|
27 |
-
config_path, device):
|
28 |
|
29 |
-
with open(
|
30 |
config = yaml.safe_load(f)
|
31 |
|
32 |
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
|
@@ -37,7 +45,6 @@ class AnimateFromCoeff():
|
|
37 |
**config['model_params']['common_params'])
|
38 |
mapping = MappingNet(**config['model_params']['mapping_params'])
|
39 |
|
40 |
-
|
41 |
generator.to(device)
|
42 |
kp_extractor.to(device)
|
43 |
he_estimator.to(device)
|
@@ -51,13 +58,16 @@ class AnimateFromCoeff():
|
|
51 |
for param in mapping.parameters():
|
52 |
param.requires_grad = False
|
53 |
|
54 |
-
if
|
55 |
-
|
|
|
|
|
|
|
56 |
else:
|
57 |
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
|
58 |
|
59 |
-
if
|
60 |
-
self.load_cpk_mapping(
|
61 |
else:
|
62 |
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
|
63 |
|
@@ -73,6 +83,33 @@ class AnimateFromCoeff():
|
|
73 |
|
74 |
self.device = device
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
|
77 |
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
78 |
optimizer_discriminator=None, optimizer_kp_detector=None,
|
@@ -117,7 +154,7 @@ class AnimateFromCoeff():
|
|
117 |
|
118 |
return checkpoint['epoch']
|
119 |
|
120 |
-
def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
|
121 |
|
122 |
source_image=x['source_image'].type(torch.FloatTensor)
|
123 |
source_semantics=x['source_semantics'].type(torch.FloatTensor)
|
@@ -157,14 +194,15 @@ class AnimateFromCoeff():
|
|
157 |
video.append(image)
|
158 |
result = img_as_ubyte(video)
|
159 |
|
160 |
-
### the generated video is 256x256, so we
|
161 |
original_size = crop_info[0]
|
162 |
if original_size:
|
163 |
-
result = [ cv2.resize(result_i,(
|
164 |
|
165 |
video_name = x['video_name'] + '.mp4'
|
166 |
path = os.path.join(video_save_dir, 'temp_'+video_name)
|
167 |
-
|
|
|
168 |
|
169 |
av_path = os.path.join(video_save_dir, video_name)
|
170 |
return_path = av_path
|
@@ -173,22 +211,23 @@ class AnimateFromCoeff():
|
|
173 |
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
|
174 |
new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
|
175 |
start_time = 0
|
176 |
-
|
|
|
177 |
frames = frame_num
|
178 |
end_time = start_time + frames*1/25*1000
|
179 |
word1=sound.set_frame_rate(16000)
|
180 |
word = word1[start_time:end_time]
|
181 |
word.export(new_audio_path, format="wav")
|
182 |
|
183 |
-
save_video_with_watermark(path, new_audio_path, av_path, watermark=
|
184 |
-
print(f'The generated video is named {
|
185 |
|
186 |
-
if preprocess.lower()
|
187 |
# only add watermark to the full image.
|
188 |
video_name_full = x['video_name'] + '_full.mp4'
|
189 |
full_video_path = os.path.join(video_save_dir, video_name_full)
|
190 |
return_path = full_video_path
|
191 |
-
paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
|
192 |
print(f'The generated video is named {video_save_dir}/{video_name_full}')
|
193 |
else:
|
194 |
full_video_path = av_path
|
@@ -199,10 +238,15 @@ class AnimateFromCoeff():
|
|
199 |
enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
|
200 |
av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
|
201 |
return_path = av_path_enhancer
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark=
|
206 |
print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
|
207 |
os.remove(enhanced_path)
|
208 |
|
|
|
4 |
import numpy as np
|
5 |
import warnings
|
6 |
from skimage import img_as_ubyte
|
7 |
+
import safetensors
|
8 |
+
import safetensors.torch
|
9 |
warnings.filterwarnings('ignore')
|
10 |
|
11 |
+
|
12 |
import imageio
|
13 |
import torch
|
14 |
+
import torchvision
|
15 |
+
|
16 |
|
17 |
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
|
18 |
from src.facerender.modules.mapping import MappingNet
|
|
|
20 |
from src.facerender.modules.make_animation import make_animation
|
21 |
|
22 |
from pydub import AudioSegment
|
23 |
+
from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
|
24 |
from src.utils.paste_pic import paste_pic
|
25 |
from src.utils.videoio import save_video_with_watermark
|
26 |
|
27 |
+
try:
|
28 |
+
import webui # in webui
|
29 |
+
in_webui = True
|
30 |
+
except:
|
31 |
+
in_webui = False
|
32 |
|
33 |
class AnimateFromCoeff():
|
34 |
|
35 |
+
def __init__(self, sadtalker_path, device):
|
|
|
36 |
|
37 |
+
with open(sadtalker_path['facerender_yaml']) as f:
|
38 |
config = yaml.safe_load(f)
|
39 |
|
40 |
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
|
|
|
45 |
**config['model_params']['common_params'])
|
46 |
mapping = MappingNet(**config['model_params']['mapping_params'])
|
47 |
|
|
|
48 |
generator.to(device)
|
49 |
kp_extractor.to(device)
|
50 |
he_estimator.to(device)
|
|
|
58 |
for param in mapping.parameters():
|
59 |
param.requires_grad = False
|
60 |
|
61 |
+
if sadtalker_path is not None:
|
62 |
+
if 'checkpoint' in sadtalker_path: # use safe tensor
|
63 |
+
self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
|
64 |
+
else:
|
65 |
+
self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
|
66 |
else:
|
67 |
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
|
68 |
|
69 |
+
if sadtalker_path['mappingnet_checkpoint'] is not None:
|
70 |
+
self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
|
71 |
else:
|
72 |
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
|
73 |
|
|
|
83 |
|
84 |
self.device = device
|
85 |
|
86 |
+
def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
|
87 |
+
kp_detector=None, he_estimator=None,
|
88 |
+
device="cpu"):
|
89 |
+
|
90 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path)
|
91 |
+
|
92 |
+
if generator is not None:
|
93 |
+
x_generator = {}
|
94 |
+
for k,v in checkpoint.items():
|
95 |
+
if 'generator' in k:
|
96 |
+
x_generator[k.replace('generator.', '')] = v
|
97 |
+
generator.load_state_dict(x_generator)
|
98 |
+
if kp_detector is not None:
|
99 |
+
x_generator = {}
|
100 |
+
for k,v in checkpoint.items():
|
101 |
+
if 'kp_extractor' in k:
|
102 |
+
x_generator[k.replace('kp_extractor.', '')] = v
|
103 |
+
kp_detector.load_state_dict(x_generator)
|
104 |
+
if he_estimator is not None:
|
105 |
+
x_generator = {}
|
106 |
+
for k,v in checkpoint.items():
|
107 |
+
if 'he_estimator' in k:
|
108 |
+
x_generator[k.replace('he_estimator.', '')] = v
|
109 |
+
he_estimator.load_state_dict(x_generator)
|
110 |
+
|
111 |
+
return None
|
112 |
+
|
113 |
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
|
114 |
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
115 |
optimizer_discriminator=None, optimizer_kp_detector=None,
|
|
|
154 |
|
155 |
return checkpoint['epoch']
|
156 |
|
157 |
+
def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
|
158 |
|
159 |
source_image=x['source_image'].type(torch.FloatTensor)
|
160 |
source_semantics=x['source_semantics'].type(torch.FloatTensor)
|
|
|
194 |
video.append(image)
|
195 |
result = img_as_ubyte(video)
|
196 |
|
197 |
+
### the generated video is 256x256, so we keep the aspect ratio,
|
198 |
original_size = crop_info[0]
|
199 |
if original_size:
|
200 |
+
result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
|
201 |
|
202 |
video_name = x['video_name'] + '.mp4'
|
203 |
path = os.path.join(video_save_dir, 'temp_'+video_name)
|
204 |
+
|
205 |
+
imageio.mimsave(path, result, fps=float(25))
|
206 |
|
207 |
av_path = os.path.join(video_save_dir, video_name)
|
208 |
return_path = av_path
|
|
|
211 |
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
|
212 |
new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
|
213 |
start_time = 0
|
214 |
+
# cog will not keep the .mp3 filename
|
215 |
+
sound = AudioSegment.from_file(audio_path)
|
216 |
frames = frame_num
|
217 |
end_time = start_time + frames*1/25*1000
|
218 |
word1=sound.set_frame_rate(16000)
|
219 |
word = word1[start_time:end_time]
|
220 |
word.export(new_audio_path, format="wav")
|
221 |
|
222 |
+
save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
|
223 |
+
print(f'The generated video is named {video_save_dir}/{video_name}')
|
224 |
|
225 |
+
if 'full' in preprocess.lower():
|
226 |
# only add watermark to the full image.
|
227 |
video_name_full = x['video_name'] + '_full.mp4'
|
228 |
full_video_path = os.path.join(video_save_dir, video_name_full)
|
229 |
return_path = full_video_path
|
230 |
+
paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
|
231 |
print(f'The generated video is named {video_save_dir}/{video_name_full}')
|
232 |
else:
|
233 |
full_video_path = av_path
|
|
|
238 |
enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
|
239 |
av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
|
240 |
return_path = av_path_enhancer
|
241 |
+
|
242 |
+
try:
|
243 |
+
enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
|
244 |
+
imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
|
245 |
+
except:
|
246 |
+
enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
|
247 |
+
imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
|
248 |
|
249 |
+
save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
|
250 |
print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
|
251 |
os.remove(enhanced_path)
|
252 |
|
src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc
CHANGED
Binary files a/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc and b/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc differ
|
|
src/facerender/modules/__pycache__/generator.cpython-38.pyc
CHANGED
Binary files a/src/facerender/modules/__pycache__/generator.cpython-38.pyc and b/src/facerender/modules/__pycache__/generator.cpython-38.pyc differ
|
|
src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc
CHANGED
Binary files a/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc and b/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc differ
|
|
src/facerender/modules/__pycache__/make_animation.cpython-38.pyc
CHANGED
Binary files a/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc and b/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc differ
|
|
src/facerender/modules/__pycache__/mapping.cpython-38.pyc
CHANGED
Binary files a/src/facerender/modules/__pycache__/mapping.cpython-38.pyc and b/src/facerender/modules/__pycache__/mapping.cpython-38.pyc differ
|
|
src/facerender/modules/__pycache__/util.cpython-38.pyc
CHANGED
Binary files a/src/facerender/modules/__pycache__/util.cpython-38.pyc and b/src/facerender/modules/__pycache__/util.cpython-38.pyc differ
|
|
src/facerender/modules/make_animation.py
CHANGED
@@ -29,7 +29,7 @@ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale
|
|
29 |
def headpose_pred_to_degree(pred):
|
30 |
device = pred.device
|
31 |
idx_tensor = [idx for idx in range(66)]
|
32 |
-
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
|
33 |
pred = F.softmax(pred)
|
34 |
degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
|
35 |
return degree
|
@@ -102,7 +102,7 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
|
|
102 |
def make_animation(source_image, source_semantics, target_semantics,
|
103 |
generator, kp_detector, he_estimator, mapping,
|
104 |
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
105 |
-
use_exp=True):
|
106 |
with torch.no_grad():
|
107 |
predictions = []
|
108 |
|
@@ -111,6 +111,8 @@ def make_animation(source_image, source_semantics, target_semantics,
|
|
111 |
kp_source = keypoint_transformation(kp_canonical, he_source)
|
112 |
|
113 |
for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
|
|
|
|
|
114 |
target_semantics_frame = target_semantics[:, frame_idx]
|
115 |
he_driving = mapping(target_semantics_frame)
|
116 |
if yaw_c_seq is not None:
|
@@ -122,8 +124,6 @@ def make_animation(source_image, source_semantics, target_semantics,
|
|
122 |
|
123 |
kp_driving = keypoint_transformation(kp_canonical, he_driving)
|
124 |
|
125 |
-
#kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
126 |
-
#kp_driving_initial=kp_driving_initial)
|
127 |
kp_norm = kp_driving
|
128 |
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
|
129 |
'''
|
|
|
29 |
def headpose_pred_to_degree(pred):
|
30 |
device = pred.device
|
31 |
idx_tensor = [idx for idx in range(66)]
|
32 |
+
idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
|
33 |
pred = F.softmax(pred)
|
34 |
degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
|
35 |
return degree
|
|
|
102 |
def make_animation(source_image, source_semantics, target_semantics,
|
103 |
generator, kp_detector, he_estimator, mapping,
|
104 |
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
105 |
+
use_exp=True, use_half=False):
|
106 |
with torch.no_grad():
|
107 |
predictions = []
|
108 |
|
|
|
111 |
kp_source = keypoint_transformation(kp_canonical, he_source)
|
112 |
|
113 |
for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
|
114 |
+
# still check the dimension
|
115 |
+
# print(target_semantics.shape, source_semantics.shape)
|
116 |
target_semantics_frame = target_semantics[:, frame_idx]
|
117 |
he_driving = mapping(target_semantics_frame)
|
118 |
if yaw_c_seq is not None:
|
|
|
124 |
|
125 |
kp_driving = keypoint_transformation(kp_canonical, he_driving)
|
126 |
|
|
|
|
|
127 |
kp_norm = kp_driving
|
128 |
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
|
129 |
'''
|
src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc differ
|
|
src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc
CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc differ
|
|
src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc
CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc differ
|
|
src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc
CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc differ
|
|
src/generate_batch.py
CHANGED
@@ -48,7 +48,7 @@ def generate_blink_seq_randomly(num_frames):
|
|
48 |
break
|
49 |
return ratio
|
50 |
|
51 |
-
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False):
|
52 |
|
53 |
syncnet_mel_step_size = 16
|
54 |
fps = 25
|
@@ -56,22 +56,27 @@ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, stil
|
|
56 |
pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
|
57 |
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
ratio = generate_blink_seq_randomly(num_frames) # T
|
77 |
source_semantics_path = first_coeff_path
|
@@ -96,10 +101,10 @@ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, stil
|
|
96 |
|
97 |
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
|
98 |
|
99 |
-
if
|
100 |
-
ratio = torch.FloatTensor(ratio).unsqueeze(0)
|
101 |
else:
|
102 |
-
ratio = torch.FloatTensor(ratio).unsqueeze(0)
|
103 |
# bs T
|
104 |
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70
|
105 |
|
|
|
48 |
break
|
49 |
return ratio
|
50 |
|
51 |
+
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
|
52 |
|
53 |
syncnet_mel_step_size = 16
|
54 |
fps = 25
|
|
|
56 |
pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
|
57 |
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
|
58 |
|
59 |
+
|
60 |
+
if idlemode:
|
61 |
+
num_frames = int(length_of_audio * 25)
|
62 |
+
indiv_mels = np.zeros((num_frames, 80, 16))
|
63 |
+
else:
|
64 |
+
wav = audio.load_wav(audio_path, 16000)
|
65 |
+
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
|
66 |
+
wav = crop_pad_audio(wav, wav_length)
|
67 |
+
orig_mel = audio.melspectrogram(wav).T
|
68 |
+
spec = orig_mel.copy() # nframes 80
|
69 |
+
indiv_mels = []
|
70 |
+
|
71 |
+
for i in tqdm(range(num_frames), 'mel:'):
|
72 |
+
start_frame_num = i-2
|
73 |
+
start_idx = int(80. * (start_frame_num / float(fps)))
|
74 |
+
end_idx = start_idx + syncnet_mel_step_size
|
75 |
+
seq = list(range(start_idx, end_idx))
|
76 |
+
seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
|
77 |
+
m = spec[seq, :]
|
78 |
+
indiv_mels.append(m.T)
|
79 |
+
indiv_mels = np.asarray(indiv_mels) # T 80 16
|
80 |
|
81 |
ratio = generate_blink_seq_randomly(num_frames) # T
|
82 |
source_semantics_path = first_coeff_path
|
|
|
101 |
|
102 |
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
|
103 |
|
104 |
+
if use_blink:
|
105 |
+
ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T
|
106 |
else:
|
107 |
+
ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
|
108 |
# bs T
|
109 |
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70
|
110 |
|
src/generate_facerender_batch.py
CHANGED
@@ -7,7 +7,7 @@ import scipy.io as scio
|
|
7 |
|
8 |
def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
|
9 |
batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
|
10 |
-
expression_scale=1.0, still_mode = False, preprocess='crop'):
|
11 |
|
12 |
semantic_radius = 13
|
13 |
video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
|
@@ -18,18 +18,22 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
|
|
18 |
img1 = Image.open(pic_path)
|
19 |
source_image = np.array(img1)
|
20 |
source_image = img_as_float32(source_image)
|
21 |
-
source_image = transform.resize(source_image, (
|
22 |
source_image = source_image.transpose((2, 0, 1))
|
23 |
source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
|
24 |
source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
|
25 |
data['source_image'] = source_image_ts
|
26 |
|
27 |
source_semantics_dict = scio.loadmat(first_coeff_path)
|
|
|
28 |
|
29 |
-
if preprocess.lower()
|
30 |
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
|
|
|
|
|
31 |
else:
|
32 |
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
|
|
|
33 |
|
34 |
source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
|
35 |
source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
|
@@ -37,11 +41,9 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
|
|
37 |
data['source_semantics'] = source_semantics_ts
|
38 |
|
39 |
# target
|
40 |
-
generated_dict = scio.loadmat(coeff_path)
|
41 |
-
generated_3dmm = generated_dict['coeff_3dmm']
|
42 |
generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
|
43 |
|
44 |
-
if preprocess.lower()
|
45 |
generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
|
46 |
|
47 |
if still_mode:
|
|
|
7 |
|
8 |
def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
|
9 |
batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
|
10 |
+
expression_scale=1.0, still_mode = False, preprocess='crop', size = 256):
|
11 |
|
12 |
semantic_radius = 13
|
13 |
video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
|
|
|
18 |
img1 = Image.open(pic_path)
|
19 |
source_image = np.array(img1)
|
20 |
source_image = img_as_float32(source_image)
|
21 |
+
source_image = transform.resize(source_image, (size, size, 3))
|
22 |
source_image = source_image.transpose((2, 0, 1))
|
23 |
source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
|
24 |
source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
|
25 |
data['source_image'] = source_image_ts
|
26 |
|
27 |
source_semantics_dict = scio.loadmat(first_coeff_path)
|
28 |
+
generated_dict = scio.loadmat(coeff_path)
|
29 |
|
30 |
+
if 'full' not in preprocess.lower():
|
31 |
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
|
32 |
+
generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
|
33 |
+
|
34 |
else:
|
35 |
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
|
36 |
+
generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
|
37 |
|
38 |
source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
|
39 |
source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
|
|
|
41 |
data['source_semantics'] = source_semantics_ts
|
42 |
|
43 |
# target
|
|
|
|
|
44 |
generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
|
45 |
|
46 |
+
if 'full' in preprocess.lower():
|
47 |
generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
|
48 |
|
49 |
if still_mode:
|
src/gradio_demo.py
CHANGED
@@ -6,8 +6,11 @@ from src.facerender.animate import AnimateFromCoeff
|
|
6 |
from src.generate_batch import get_data
|
7 |
from src.generate_facerender_batch import get_facerender_data
|
8 |
|
|
|
|
|
9 |
from pydub import AudioSegment
|
10 |
|
|
|
11 |
def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
|
12 |
mp3_file = AudioSegment.from_file(file=mp3_filename)
|
13 |
mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
|
@@ -28,55 +31,24 @@ class SadTalker():
|
|
28 |
|
29 |
self.checkpoint_path = checkpoint_path
|
30 |
self.config_path = config_path
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
self.
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
print(self.path_of_lm_croper)
|
50 |
-
self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
|
51 |
-
|
52 |
-
print(self.audio2pose_checkpoint)
|
53 |
-
self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
|
54 |
-
self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
|
55 |
-
|
56 |
-
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, result_dir='./results/'):
|
57 |
-
|
58 |
-
### crop: only model,
|
59 |
-
|
60 |
-
if self.lazy_load:
|
61 |
-
#init model
|
62 |
-
print(self.path_of_lm_croper)
|
63 |
-
self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
|
64 |
-
|
65 |
-
print(self.audio2pose_checkpoint)
|
66 |
-
self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
|
67 |
-
self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
|
68 |
-
|
69 |
-
if preprocess == 'full':
|
70 |
-
self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00109-model.pth.tar')
|
71 |
-
self.facerender_yaml_path = os.path.join(self.config_path, 'facerender_still.yaml')
|
72 |
-
else:
|
73 |
-
self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00229-model.pth.tar')
|
74 |
-
self.facerender_yaml_path = os.path.join(self.config_path, 'facerender.yaml')
|
75 |
-
|
76 |
-
print(self.mapping_checkpoint)
|
77 |
-
print(self.free_view_checkpoint)
|
78 |
-
self.animate_from_coeff = AnimateFromCoeff(self.free_view_checkpoint, self.mapping_checkpoint,
|
79 |
-
self.facerender_yaml_path, self.device)
|
80 |
|
81 |
time_tag = str(uuid.uuid4())
|
82 |
save_dir = os.path.join(result_dir, time_tag)
|
@@ -89,7 +61,7 @@ class SadTalker():
|
|
89 |
pic_path = os.path.join(input_dir, os.path.basename(source_image))
|
90 |
shutil.move(source_image, input_dir)
|
91 |
|
92 |
-
if os.path.isfile(driven_audio):
|
93 |
audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
|
94 |
|
95 |
#### mp3 to wav
|
@@ -98,37 +70,84 @@ class SadTalker():
|
|
98 |
audio_path = audio_path.replace('.mp3', '.wav')
|
99 |
else:
|
100 |
shutil.move(driven_audio, input_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
else:
|
102 |
-
|
|
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
os.makedirs(save_dir, exist_ok=True)
|
106 |
-
|
107 |
#crop image and extract 3dmm from image
|
108 |
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
|
109 |
os.makedirs(first_frame_dir, exist_ok=True)
|
110 |
-
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir,preprocess)
|
111 |
|
112 |
if first_coeff_path is None:
|
113 |
raise AttributeError("No face is detected")
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
#audio2ceoff
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
#coeff2video
|
119 |
-
batch_size =
|
120 |
-
|
121 |
-
return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess)
|
122 |
video_name = data['video_name']
|
123 |
print(f'The generated video is named {video_name} in {save_dir}')
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
del self.animate_from_coeff
|
129 |
|
130 |
-
torch.cuda.
|
131 |
-
|
|
|
|
|
132 |
import gc; gc.collect()
|
133 |
|
134 |
return return_path
|
|
|
6 |
from src.generate_batch import get_data
|
7 |
from src.generate_facerender_batch import get_facerender_data
|
8 |
|
9 |
+
from src.utils.init_path import init_path
|
10 |
+
|
11 |
from pydub import AudioSegment
|
12 |
|
13 |
+
|
14 |
def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
|
15 |
mp3_file = AudioSegment.from_file(file=mp3_filename)
|
16 |
mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
|
|
|
31 |
|
32 |
self.checkpoint_path = checkpoint_path
|
33 |
self.config_path = config_path
|
34 |
+
|
35 |
+
|
36 |
+
def test(self, source_image, driven_audio, preprocess='crop',
|
37 |
+
still_mode=False, use_enhancer=False, batch_size=1, size=256,
|
38 |
+
pose_style = 0, exp_scale=1.0,
|
39 |
+
use_ref_video = False,
|
40 |
+
ref_video = None,
|
41 |
+
ref_info = None,
|
42 |
+
use_idle_mode = False,
|
43 |
+
length_of_audio = 0, use_blink=True,
|
44 |
+
result_dir='./results/'):
|
45 |
+
|
46 |
+
self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
|
47 |
+
print(self.sadtalker_paths)
|
48 |
+
|
49 |
+
self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
|
50 |
+
self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
|
51 |
+
self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
time_tag = str(uuid.uuid4())
|
54 |
save_dir = os.path.join(result_dir, time_tag)
|
|
|
61 |
pic_path = os.path.join(input_dir, os.path.basename(source_image))
|
62 |
shutil.move(source_image, input_dir)
|
63 |
|
64 |
+
if driven_audio is not None and os.path.isfile(driven_audio):
|
65 |
audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
|
66 |
|
67 |
#### mp3 to wav
|
|
|
70 |
audio_path = audio_path.replace('.mp3', '.wav')
|
71 |
else:
|
72 |
shutil.move(driven_audio, input_dir)
|
73 |
+
|
74 |
+
elif use_idle_mode:
|
75 |
+
audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
|
76 |
+
from pydub import AudioSegment
|
77 |
+
one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
|
78 |
+
one_sec_segment.export(audio_path, format="wav")
|
79 |
else:
|
80 |
+
print(use_ref_video, ref_info)
|
81 |
+
assert use_ref_video == True and ref_info == 'all'
|
82 |
|
83 |
+
if use_ref_video and ref_info == 'all': # full ref mode
|
84 |
+
ref_video_videoname = os.path.basename(ref_video)
|
85 |
+
audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
|
86 |
+
print('new audiopath:',audio_path)
|
87 |
+
# if ref_video contains audio, set the audio from ref_video.
|
88 |
+
cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
|
89 |
+
os.system(cmd)
|
90 |
|
91 |
os.makedirs(save_dir, exist_ok=True)
|
92 |
+
|
93 |
#crop image and extract 3dmm from image
|
94 |
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
|
95 |
os.makedirs(first_frame_dir, exist_ok=True)
|
96 |
+
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
|
97 |
|
98 |
if first_coeff_path is None:
|
99 |
raise AttributeError("No face is detected")
|
100 |
|
101 |
+
if use_ref_video:
|
102 |
+
print('using ref video for genreation')
|
103 |
+
ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
|
104 |
+
ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
|
105 |
+
os.makedirs(ref_video_frame_dir, exist_ok=True)
|
106 |
+
print('3DMM Extraction for the reference video providing pose')
|
107 |
+
ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
|
108 |
+
else:
|
109 |
+
ref_video_coeff_path = None
|
110 |
+
|
111 |
+
if use_ref_video:
|
112 |
+
if ref_info == 'pose':
|
113 |
+
ref_pose_coeff_path = ref_video_coeff_path
|
114 |
+
ref_eyeblink_coeff_path = None
|
115 |
+
elif ref_info == 'blink':
|
116 |
+
ref_pose_coeff_path = None
|
117 |
+
ref_eyeblink_coeff_path = ref_video_coeff_path
|
118 |
+
elif ref_info == 'pose+blink':
|
119 |
+
ref_pose_coeff_path = ref_video_coeff_path
|
120 |
+
ref_eyeblink_coeff_path = ref_video_coeff_path
|
121 |
+
elif ref_info == 'all':
|
122 |
+
ref_pose_coeff_path = None
|
123 |
+
ref_eyeblink_coeff_path = None
|
124 |
+
else:
|
125 |
+
raise('error in refinfo')
|
126 |
+
else:
|
127 |
+
ref_pose_coeff_path = None
|
128 |
+
ref_eyeblink_coeff_path = None
|
129 |
+
|
130 |
#audio2ceoff
|
131 |
+
if use_ref_video and ref_info == 'all':
|
132 |
+
coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
133 |
+
else:
|
134 |
+
batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
|
135 |
+
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
136 |
+
|
137 |
#coeff2video
|
138 |
+
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale)
|
139 |
+
return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
|
|
|
140 |
video_name = data['video_name']
|
141 |
print(f'The generated video is named {video_name} in {save_dir}')
|
142 |
|
143 |
+
del self.preprocess_model
|
144 |
+
del self.audio_to_coeff
|
145 |
+
del self.animate_from_coeff
|
|
|
146 |
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
torch.cuda.empty_cache()
|
149 |
+
torch.cuda.synchronize()
|
150 |
+
|
151 |
import gc; gc.collect()
|
152 |
|
153 |
return return_path
|
src/src/audio2exp_models/audio2exp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Audio2Exp(nn.Module):
|
7 |
+
def __init__(self, netG, cfg, device, prepare_training_loss=False):
|
8 |
+
super(Audio2Exp, self).__init__()
|
9 |
+
self.cfg = cfg
|
10 |
+
self.device = device
|
11 |
+
self.netG = netG.to(device)
|
12 |
+
|
13 |
+
def test(self, batch):
|
14 |
+
|
15 |
+
mel_input = batch['indiv_mels'] # bs T 1 80 16
|
16 |
+
bs = mel_input.shape[0]
|
17 |
+
T = mel_input.shape[1]
|
18 |
+
|
19 |
+
exp_coeff_pred = []
|
20 |
+
|
21 |
+
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
|
22 |
+
|
23 |
+
current_mel_input = mel_input[:,i:i+10]
|
24 |
+
|
25 |
+
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
|
26 |
+
ref = batch['ref'][:, :, :64][:, i:i+10]
|
27 |
+
ratio = batch['ratio_gt'][:, i:i+10] #bs T
|
28 |
+
|
29 |
+
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
|
30 |
+
|
31 |
+
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
|
32 |
+
|
33 |
+
exp_coeff_pred += [curr_exp_coeff_pred]
|
34 |
+
|
35 |
+
# BS x T x 64
|
36 |
+
results_dict = {
|
37 |
+
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
|
38 |
+
}
|
39 |
+
return results_dict
|
40 |
+
|
41 |
+
|
src/src/audio2exp_models/networks.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
self.use_act = use_act
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
out = self.conv_block(x)
|
18 |
+
if self.residual:
|
19 |
+
out += x
|
20 |
+
|
21 |
+
if self.use_act:
|
22 |
+
return self.act(out)
|
23 |
+
else:
|
24 |
+
return out
|
25 |
+
|
26 |
+
class SimpleWrapperV2(nn.Module):
|
27 |
+
def __init__(self) -> None:
|
28 |
+
super().__init__()
|
29 |
+
self.audio_encoder = nn.Sequential(
|
30 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
31 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
32 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
33 |
+
|
34 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
35 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
36 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
37 |
+
|
38 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
39 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
40 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
41 |
+
|
42 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
43 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
44 |
+
|
45 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
46 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
47 |
+
)
|
48 |
+
|
49 |
+
#### load the pre-trained audio_encoder
|
50 |
+
#self.audio_encoder = self.audio_encoder.to(device)
|
51 |
+
'''
|
52 |
+
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
|
53 |
+
state_dict = self.audio_encoder.state_dict()
|
54 |
+
|
55 |
+
for k,v in wav2lip_state_dict.items():
|
56 |
+
if 'audio_encoder' in k:
|
57 |
+
print('init:', k)
|
58 |
+
state_dict[k.replace('module.audio_encoder.', '')] = v
|
59 |
+
self.audio_encoder.load_state_dict(state_dict)
|
60 |
+
'''
|
61 |
+
|
62 |
+
self.mapping1 = nn.Linear(512+64+1, 64)
|
63 |
+
#self.mapping2 = nn.Linear(30, 64)
|
64 |
+
#nn.init.constant_(self.mapping1.weight, 0.)
|
65 |
+
nn.init.constant_(self.mapping1.bias, 0.)
|
66 |
+
|
67 |
+
def forward(self, x, ref, ratio):
|
68 |
+
x = self.audio_encoder(x).view(x.size(0), -1)
|
69 |
+
ref_reshape = ref.reshape(x.size(0), -1)
|
70 |
+
ratio = ratio.reshape(x.size(0), -1)
|
71 |
+
|
72 |
+
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
|
73 |
+
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
|
74 |
+
return out
|
src/src/audio2pose_models/audio2pose.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from src.audio2pose_models.cvae import CVAE
|
4 |
+
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
|
5 |
+
from src.audio2pose_models.audio_encoder import AudioEncoder
|
6 |
+
|
7 |
+
class Audio2Pose(nn.Module):
|
8 |
+
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
|
9 |
+
super().__init__()
|
10 |
+
self.cfg = cfg
|
11 |
+
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
12 |
+
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
|
16 |
+
self.audio_encoder.eval()
|
17 |
+
for param in self.audio_encoder.parameters():
|
18 |
+
param.requires_grad = False
|
19 |
+
|
20 |
+
self.netG = CVAE(cfg)
|
21 |
+
self.netD_motion = PoseSequenceDiscriminator(cfg)
|
22 |
+
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
|
26 |
+
batch = {}
|
27 |
+
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
|
28 |
+
batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
|
29 |
+
batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
|
30 |
+
batch['class'] = x['class'].squeeze(0).cuda() # bs
|
31 |
+
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
|
32 |
+
|
33 |
+
# forward
|
34 |
+
audio_emb_list = []
|
35 |
+
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
|
36 |
+
batch['audio_emb'] = audio_emb
|
37 |
+
batch = self.netG(batch)
|
38 |
+
|
39 |
+
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
|
40 |
+
pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
|
41 |
+
pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
|
42 |
+
|
43 |
+
batch['pose_pred'] = pose_pred
|
44 |
+
batch['pose_gt'] = pose_gt
|
45 |
+
|
46 |
+
return batch
|
47 |
+
|
48 |
+
def test(self, x):
|
49 |
+
|
50 |
+
batch = {}
|
51 |
+
ref = x['ref'] #bs 1 70
|
52 |
+
batch['ref'] = x['ref'][:,0,-6:]
|
53 |
+
batch['class'] = x['class']
|
54 |
+
bs = ref.shape[0]
|
55 |
+
|
56 |
+
indiv_mels= x['indiv_mels'] # bs T 1 80 16
|
57 |
+
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
|
58 |
+
num_frames = x['num_frames']
|
59 |
+
num_frames = int(num_frames) - 1
|
60 |
+
|
61 |
+
#
|
62 |
+
div = num_frames//self.seq_len
|
63 |
+
re = num_frames%self.seq_len
|
64 |
+
audio_emb_list = []
|
65 |
+
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
|
66 |
+
device=batch['ref'].device)]
|
67 |
+
|
68 |
+
for i in range(div):
|
69 |
+
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
70 |
+
batch['z'] = z
|
71 |
+
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
|
72 |
+
batch['audio_emb'] = audio_emb
|
73 |
+
batch = self.netG.test(batch)
|
74 |
+
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
|
75 |
+
|
76 |
+
if re != 0:
|
77 |
+
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
78 |
+
batch['z'] = z
|
79 |
+
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
|
80 |
+
if audio_emb.shape[1] != self.seq_len:
|
81 |
+
pad_dim = self.seq_len-audio_emb.shape[1]
|
82 |
+
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
|
83 |
+
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
|
84 |
+
batch['audio_emb'] = audio_emb
|
85 |
+
batch = self.netG.test(batch)
|
86 |
+
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
|
87 |
+
|
88 |
+
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
|
89 |
+
batch['pose_motion_pred'] = pose_motion_pred
|
90 |
+
|
91 |
+
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
|
92 |
+
|
93 |
+
batch['pose_pred'] = pose_pred
|
94 |
+
return batch
|
src/src/audio2pose_models/audio_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out = self.conv_block(x)
|
17 |
+
if self.residual:
|
18 |
+
out += x
|
19 |
+
return self.act(out)
|
20 |
+
|
21 |
+
class AudioEncoder(nn.Module):
|
22 |
+
def __init__(self, wav2lip_checkpoint, device):
|
23 |
+
super(AudioEncoder, self).__init__()
|
24 |
+
|
25 |
+
self.audio_encoder = nn.Sequential(
|
26 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
27 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
28 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
29 |
+
|
30 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
31 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
32 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
33 |
+
|
34 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
35 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
36 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
37 |
+
|
38 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
39 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
40 |
+
|
41 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
42 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
+
|
44 |
+
#### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
|
45 |
+
# wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
|
46 |
+
# state_dict = self.audio_encoder.state_dict()
|
47 |
+
|
48 |
+
# for k,v in wav2lip_state_dict.items():
|
49 |
+
# if 'audio_encoder' in k:
|
50 |
+
# state_dict[k.replace('module.audio_encoder.', '')] = v
|
51 |
+
# self.audio_encoder.load_state_dict(state_dict)
|
52 |
+
|
53 |
+
|
54 |
+
def forward(self, audio_sequences):
|
55 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
56 |
+
B = audio_sequences.size(0)
|
57 |
+
|
58 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
59 |
+
|
60 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
61 |
+
dim = audio_embedding.shape[1]
|
62 |
+
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
|
63 |
+
|
64 |
+
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
|
src/src/audio2pose_models/cvae.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from src.audio2pose_models.res_unet import ResUnet
|
5 |
+
|
6 |
+
def class2onehot(idx, class_num):
|
7 |
+
|
8 |
+
assert torch.max(idx).item() < class_num
|
9 |
+
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
|
10 |
+
onehot.scatter_(1, idx, 1)
|
11 |
+
return onehot
|
12 |
+
|
13 |
+
class CVAE(nn.Module):
|
14 |
+
def __init__(self, cfg):
|
15 |
+
super().__init__()
|
16 |
+
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
|
17 |
+
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
|
18 |
+
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
|
19 |
+
num_classes = cfg.DATASET.NUM_CLASSES
|
20 |
+
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
|
21 |
+
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
|
22 |
+
seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
23 |
+
|
24 |
+
self.latent_size = latent_size
|
25 |
+
|
26 |
+
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
|
27 |
+
audio_emb_in_size, audio_emb_out_size, seq_len)
|
28 |
+
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
|
29 |
+
audio_emb_in_size, audio_emb_out_size, seq_len)
|
30 |
+
def reparameterize(self, mu, logvar):
|
31 |
+
std = torch.exp(0.5 * logvar)
|
32 |
+
eps = torch.randn_like(std)
|
33 |
+
return mu + eps * std
|
34 |
+
|
35 |
+
def forward(self, batch):
|
36 |
+
batch = self.encoder(batch)
|
37 |
+
mu = batch['mu']
|
38 |
+
logvar = batch['logvar']
|
39 |
+
z = self.reparameterize(mu, logvar)
|
40 |
+
batch['z'] = z
|
41 |
+
return self.decoder(batch)
|
42 |
+
|
43 |
+
def test(self, batch):
|
44 |
+
'''
|
45 |
+
class_id = batch['class']
|
46 |
+
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
|
47 |
+
batch['z'] = z
|
48 |
+
'''
|
49 |
+
return self.decoder(batch)
|
50 |
+
|
51 |
+
class ENCODER(nn.Module):
|
52 |
+
def __init__(self, layer_sizes, latent_size, num_classes,
|
53 |
+
audio_emb_in_size, audio_emb_out_size, seq_len):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.resunet = ResUnet()
|
57 |
+
self.num_classes = num_classes
|
58 |
+
self.seq_len = seq_len
|
59 |
+
|
60 |
+
self.MLP = nn.Sequential()
|
61 |
+
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
|
62 |
+
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
|
63 |
+
self.MLP.add_module(
|
64 |
+
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
65 |
+
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
66 |
+
|
67 |
+
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
|
68 |
+
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
|
69 |
+
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
70 |
+
|
71 |
+
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
72 |
+
|
73 |
+
def forward(self, batch):
|
74 |
+
class_id = batch['class']
|
75 |
+
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
|
76 |
+
ref = batch['ref'] #bs 6
|
77 |
+
bs = pose_motion_gt.shape[0]
|
78 |
+
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
79 |
+
|
80 |
+
#pose encode
|
81 |
+
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
|
82 |
+
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
|
83 |
+
|
84 |
+
#audio mapping
|
85 |
+
print(audio_in.shape)
|
86 |
+
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
87 |
+
audio_out = audio_out.reshape(bs, -1)
|
88 |
+
|
89 |
+
class_bias = self.classbias[class_id] #bs latent_size
|
90 |
+
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
|
91 |
+
x_out = self.MLP(x_in)
|
92 |
+
|
93 |
+
mu = self.linear_means(x_out)
|
94 |
+
logvar = self.linear_means(x_out) #bs latent_size
|
95 |
+
|
96 |
+
batch.update({'mu':mu, 'logvar':logvar})
|
97 |
+
return batch
|
98 |
+
|
99 |
+
class DECODER(nn.Module):
|
100 |
+
def __init__(self, layer_sizes, latent_size, num_classes,
|
101 |
+
audio_emb_in_size, audio_emb_out_size, seq_len):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.resunet = ResUnet()
|
105 |
+
self.num_classes = num_classes
|
106 |
+
self.seq_len = seq_len
|
107 |
+
|
108 |
+
self.MLP = nn.Sequential()
|
109 |
+
input_size = latent_size + seq_len*audio_emb_out_size + 6
|
110 |
+
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
|
111 |
+
self.MLP.add_module(
|
112 |
+
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
113 |
+
if i+1 < len(layer_sizes):
|
114 |
+
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
115 |
+
else:
|
116 |
+
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
|
117 |
+
|
118 |
+
self.pose_linear = nn.Linear(6, 6)
|
119 |
+
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
120 |
+
|
121 |
+
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
122 |
+
|
123 |
+
def forward(self, batch):
|
124 |
+
|
125 |
+
z = batch['z'] #bs latent_size
|
126 |
+
bs = z.shape[0]
|
127 |
+
class_id = batch['class']
|
128 |
+
ref = batch['ref'] #bs 6
|
129 |
+
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
130 |
+
#print('audio_in: ', audio_in[:, :, :10])
|
131 |
+
|
132 |
+
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
133 |
+
#print('audio_out: ', audio_out[:, :, :10])
|
134 |
+
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
|
135 |
+
class_bias = self.classbias[class_id] #bs latent_size
|
136 |
+
|
137 |
+
z = z + class_bias
|
138 |
+
x_in = torch.cat([ref, z, audio_out], dim=-1)
|
139 |
+
x_out = self.MLP(x_in) # bs layer_sizes[-1]
|
140 |
+
x_out = x_out.reshape((bs, self.seq_len, -1))
|
141 |
+
|
142 |
+
#print('x_out: ', x_out)
|
143 |
+
|
144 |
+
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
|
145 |
+
|
146 |
+
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
|
147 |
+
|
148 |
+
batch.update({'pose_motion_pred':pose_motion_pred})
|
149 |
+
return batch
|
src/src/audio2pose_models/discriminator.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class ConvNormRelu(nn.Module):
|
6 |
+
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
|
7 |
+
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
|
8 |
+
super().__init__()
|
9 |
+
if kernel_size is None:
|
10 |
+
if downsample:
|
11 |
+
kernel_size, stride, padding = 4, 2, 1
|
12 |
+
else:
|
13 |
+
kernel_size, stride, padding = 3, 1, 1
|
14 |
+
|
15 |
+
if conv_type == '2d':
|
16 |
+
self.conv = nn.Conv2d(
|
17 |
+
in_channels,
|
18 |
+
out_channels,
|
19 |
+
kernel_size,
|
20 |
+
stride,
|
21 |
+
padding,
|
22 |
+
bias=False,
|
23 |
+
)
|
24 |
+
if norm == 'BN':
|
25 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
26 |
+
elif norm == 'IN':
|
27 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
elif conv_type == '1d':
|
31 |
+
self.conv = nn.Conv1d(
|
32 |
+
in_channels,
|
33 |
+
out_channels,
|
34 |
+
kernel_size,
|
35 |
+
stride,
|
36 |
+
padding,
|
37 |
+
bias=False,
|
38 |
+
)
|
39 |
+
if norm == 'BN':
|
40 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
41 |
+
elif norm == 'IN':
|
42 |
+
self.norm = nn.InstanceNorm1d(out_channels)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError
|
45 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
46 |
+
|
47 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = self.conv(x)
|
51 |
+
if isinstance(self.norm, nn.InstanceNorm1d):
|
52 |
+
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
|
53 |
+
else:
|
54 |
+
x = self.norm(x)
|
55 |
+
x = self.act(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class PoseSequenceDiscriminator(nn.Module):
|
60 |
+
def __init__(self, cfg):
|
61 |
+
super().__init__()
|
62 |
+
self.cfg = cfg
|
63 |
+
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
|
64 |
+
|
65 |
+
self.seq = nn.Sequential(
|
66 |
+
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
|
67 |
+
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
|
68 |
+
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
|
69 |
+
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
|
74 |
+
x = self.seq(x)
|
75 |
+
x = x.squeeze(1)
|
76 |
+
return x
|