shadowcun commited on
Commit
9ab094a
1 Parent(s): 99e1f07

new version of sadtalker

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +99 -27
  2. checkpoints/mapping_00229-model.pth.tar +1 -1
  3. src/__pycache__/generate_batch.cpython-38.pyc +0 -0
  4. src/__pycache__/generate_facerender_batch.cpython-38.pyc +0 -0
  5. src/__pycache__/test_audio2coeff.cpython-38.pyc +0 -0
  6. src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
  7. src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
  8. src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
  9. src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
  10. src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
  11. src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
  12. src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
  13. src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
  14. src/audio2pose_models/audio2pose.py +4 -4
  15. src/audio2pose_models/audio_encoder.py +7 -7
  16. src/config/similarity_Lm3D_all.mat +0 -0
  17. src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc +0 -0
  18. src/face3d/extract_kp_videos.py +2 -2
  19. src/face3d/extract_kp_videos_safe.py +151 -0
  20. src/face3d/models/__pycache__/__init__.cpython-38.pyc +0 -0
  21. src/face3d/models/__pycache__/base_model.cpython-38.pyc +0 -0
  22. src/face3d/models/__pycache__/networks.cpython-38.pyc +0 -0
  23. src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc +0 -0
  24. src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc +0 -0
  25. src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc +0 -0
  26. src/face3d/util/__pycache__/__init__.cpython-38.pyc +0 -0
  27. src/face3d/util/__pycache__/load_mats.cpython-38.pyc +0 -0
  28. src/face3d/util/__pycache__/preprocess.cpython-38.pyc +0 -0
  29. src/facerender/__pycache__/animate.cpython-38.pyc +0 -0
  30. src/facerender/animate.py +66 -22
  31. src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc +0 -0
  32. src/facerender/modules/__pycache__/generator.cpython-38.pyc +0 -0
  33. src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc +0 -0
  34. src/facerender/modules/__pycache__/make_animation.cpython-38.pyc +0 -0
  35. src/facerender/modules/__pycache__/mapping.cpython-38.pyc +0 -0
  36. src/facerender/modules/__pycache__/util.cpython-38.pyc +0 -0
  37. src/facerender/modules/make_animation.py +4 -4
  38. src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc +0 -0
  39. src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc +0 -0
  40. src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc +0 -0
  41. src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc +0 -0
  42. src/generate_batch.py +25 -20
  43. src/generate_facerender_batch.py +8 -6
  44. src/gradio_demo.py +83 -64
  45. src/src/audio2exp_models/audio2exp.py +41 -0
  46. src/src/audio2exp_models/networks.py +74 -0
  47. src/src/audio2pose_models/audio2pose.py +94 -0
  48. src/src/audio2pose_models/audio_encoder.py +64 -0
  49. src/src/audio2pose_models/cvae.py +149 -0
  50. src/src/audio2pose_models/discriminator.py +76 -0
app.py CHANGED
@@ -8,8 +8,27 @@ from huggingface_hub import snapshot_download
8
  def get_source_image(image):
9
  return image
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def download_model():
12
- REPO_ID = 'vinthony/SadTalker'
13
  snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
14
 
15
  def sadtalker_demo():
@@ -34,33 +53,96 @@ def sadtalker_demo():
34
  with gr.Row().style(equal_height=False):
35
  with gr.Column(variant='panel'):
36
  with gr.Tabs(elem_id="sadtalker_source_image"):
37
- with gr.TabItem('Upload image'):
38
  with gr.Row():
39
- source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
40
-
 
41
  with gr.Tabs(elem_id="sadtalker_driven_audio"):
42
- with gr.TabItem('Upload or Generating from TTS'):
43
- with gr.Column(variant='panel'):
44
- driven_audio = gr.Audio(label="Input audio(.wav/.mp3)", source="upload", type="filepath")
45
-
46
- # with gr.Column(variant='panel'):
47
- # input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="Alternatively, you can genreate the audio from text using @Coqui.ai TTS.")
48
- # tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
49
- # tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  with gr.Column(variant='panel'):
53
  with gr.Tabs(elem_id="sadtalker_checkbox"):
54
  with gr.TabItem('Settings'):
 
55
  with gr.Column(variant='panel'):
56
- preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?")
57
- is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion, works with preprocess `full`)")
58
- enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
59
- submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  with gr.Tabs(elem_id="sadtalker_genearted"):
62
  gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with gr.Row():
65
  examples = [
66
  [
@@ -138,16 +220,6 @@ def sadtalker_demo():
138
  fn=sad_talker.test,
139
  cache_examples=os.getenv('SYSTEM') == 'spaces') #
140
 
141
- submit.click(
142
- fn=sad_talker.test,
143
- inputs=[source_image,
144
- driven_audio,
145
- preprocess_type,
146
- is_still_mode,
147
- enhancer],
148
- outputs=[gen_video]
149
- )
150
-
151
  return sadtalker_interface
152
 
153
 
 
8
  def get_source_image(image):
9
  return image
10
 
11
+ try:
12
+ import webui # in webui
13
+ in_webui = True
14
+ except:
15
+ in_webui = False
16
+
17
+
18
+ def toggle_audio_file(choice):
19
+ if choice == False:
20
+ return gr.update(visible=True), gr.update(visible=False)
21
+ else:
22
+ return gr.update(visible=False), gr.update(visible=True)
23
+
24
+ def ref_video_fn(path_of_ref_video):
25
+ if path_of_ref_video is not None:
26
+ return gr.update(value=True)
27
+ else:
28
+ return gr.update(value=False)
29
+
30
  def download_model():
31
+ REPO_ID = 'vinthony/SadTalker-V002rc'
32
  snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
33
 
34
  def sadtalker_demo():
 
53
  with gr.Row().style(equal_height=False):
54
  with gr.Column(variant='panel'):
55
  with gr.Tabs(elem_id="sadtalker_source_image"):
56
+ with gr.TabItem('Source image'):
57
  with gr.Row():
58
+ source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
59
+
60
+
61
  with gr.Tabs(elem_id="sadtalker_driven_audio"):
62
+ with gr.TabItem('Driving Methods'):
63
+ gr.Markdown("Possible driving combinations: <br> 1. Audio only 2. Audio/IDLE Mode + Ref Video(pose, blink, pose+blink) 3. IDLE Mode only 4. Ref Video only (all) ")
64
+
65
+ with gr.Row():
66
+ driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
67
+ driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
68
+
69
+ with gr.Column():
70
+ use_idle_mode = gr.Checkbox(label="Use Idle Animation")
71
+ length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.")
72
+ use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
73
+
74
+ if sys.platform != 'win32' and not in_webui:
75
+ with gr.Accordion('Generate Audio From TTS', open=False):
76
+ from src.utils.text2speech import TTSTalker
77
+ tts_talker = TTSTalker()
78
+ with gr.Column(variant='panel'):
79
+ input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
80
+ tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
81
+ tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
82
+
83
+ with gr.Row():
84
+ ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref").style(width=512)
85
+
86
+ with gr.Column():
87
+ use_ref_video = gr.Checkbox(label="Use Reference Video")
88
+ ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")
89
+
90
+ ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
91
+
92
 
93
  with gr.Column(variant='panel'):
94
  with gr.Tabs(elem_id="sadtalker_checkbox"):
95
  with gr.TabItem('Settings'):
96
+ gr.Markdown("need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
97
  with gr.Column(variant='panel'):
98
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
99
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
100
+ with gr.Row():
101
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
102
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
103
+ blink_every = gr.Checkbox(label="use eye blink", value=True)
104
 
105
+ with gr.Row():
106
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
107
+ preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
108
+
109
+ with gr.Row():
110
+ is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
111
+ facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?")
112
+
113
+ with gr.Row():
114
+ batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=1)
115
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
116
+
117
+ submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
118
+
119
  with gr.Tabs(elem_id="sadtalker_genearted"):
120
  gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
121
 
122
+
123
+
124
+ submit.click(
125
+ fn=sad_talker.test,
126
+ inputs=[source_image,
127
+ driven_audio,
128
+ preprocess_type,
129
+ is_still_mode,
130
+ enhancer,
131
+ batch_size,
132
+ size_of_image,
133
+ pose_style,
134
+ facerender,
135
+ exp_weight,
136
+ use_ref_video,
137
+ ref_video,
138
+ ref_info,
139
+ use_idle_mode,
140
+ length_of_audio,
141
+ blink_every
142
+ ],
143
+ outputs=[gen_video]
144
+ )
145
+
146
  with gr.Row():
147
  examples = [
148
  [
 
220
  fn=sad_talker.test,
221
  cache_examples=os.getenv('SYSTEM') == 'spaces') #
222
 
 
 
 
 
 
 
 
 
 
 
223
  return sadtalker_interface
224
 
225
 
checkpoints/mapping_00229-model.pth.tar CHANGED
@@ -1 +1 @@
1
- ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker-V002rc/blobs/62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
src/__pycache__/generate_batch.cpython-38.pyc CHANGED
Binary files a/src/__pycache__/generate_batch.cpython-38.pyc and b/src/__pycache__/generate_batch.cpython-38.pyc differ
 
src/__pycache__/generate_facerender_batch.cpython-38.pyc CHANGED
Binary files a/src/__pycache__/generate_facerender_batch.cpython-38.pyc and b/src/__pycache__/generate_facerender_batch.cpython-38.pyc differ
 
src/__pycache__/test_audio2coeff.cpython-38.pyc CHANGED
Binary files a/src/__pycache__/test_audio2coeff.cpython-38.pyc and b/src/__pycache__/test_audio2coeff.cpython-38.pyc differ
 
src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc CHANGED
Binary files a/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc and b/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc differ
 
src/audio2exp_models/__pycache__/networks.cpython-38.pyc CHANGED
Binary files a/src/audio2exp_models/__pycache__/networks.cpython-38.pyc and b/src/audio2exp_models/__pycache__/networks.cpython-38.pyc differ
 
src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc CHANGED
Binary files a/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc and b/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc differ
 
src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc CHANGED
Binary files a/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc and b/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc differ
 
src/audio2pose_models/__pycache__/cvae.cpython-38.pyc CHANGED
Binary files a/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc and b/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc differ
 
src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc CHANGED
Binary files a/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc and b/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc differ
 
src/audio2pose_models/__pycache__/networks.cpython-38.pyc CHANGED
Binary files a/src/audio2pose_models/__pycache__/networks.cpython-38.pyc and b/src/audio2pose_models/__pycache__/networks.cpython-38.pyc differ
 
src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc CHANGED
Binary files a/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc and b/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc differ
 
src/audio2pose_models/audio2pose.py CHANGED
@@ -25,8 +25,8 @@ class Audio2Pose(nn.Module):
25
 
26
  batch = {}
27
  coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
- batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
29
- batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
30
  batch['class'] = x['class'].squeeze(0).cuda() # bs
31
  indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
 
@@ -37,8 +37,8 @@ class Audio2Pose(nn.Module):
37
  batch = self.netG(batch)
38
 
39
  pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
- pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
41
- pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
42
 
43
  batch['pose_pred'] = pose_pred
44
  batch['pose_gt'] = pose_gt
 
25
 
26
  batch = {}
27
  coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
  batch['class'] = x['class'].squeeze(0).cuda() # bs
31
  indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
 
 
37
  batch = self.netG(batch)
38
 
39
  pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
 
43
  batch['pose_pred'] = pose_pred
44
  batch['pose_gt'] = pose_gt
src/audio2pose_models/audio_encoder.py CHANGED
@@ -41,14 +41,14 @@ class AudioEncoder(nn.Module):
41
  Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
  Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
 
44
- #### load the pre-trained audio_encoder
45
- wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
- state_dict = self.audio_encoder.state_dict()
47
 
48
- for k,v in wav2lip_state_dict.items():
49
- if 'audio_encoder' in k:
50
- state_dict[k.replace('module.audio_encoder.', '')] = v
51
- self.audio_encoder.load_state_dict(state_dict)
52
 
53
 
54
  def forward(self, audio_sequences):
 
41
  Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
  Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
 
44
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ # state_dict = self.audio_encoder.state_dict()
47
 
48
+ # for k,v in wav2lip_state_dict.items():
49
+ # if 'audio_encoder' in k:
50
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ # self.audio_encoder.load_state_dict(state_dict)
52
 
53
 
54
  def forward(self, audio_sequences):
src/config/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc DELETED
Binary file (3.59 kB)
 
src/face3d/extract_kp_videos.py CHANGED
@@ -13,7 +13,8 @@ from torch.multiprocessing import Pool, Process, set_start_method
13
 
14
  class KeypointExtractor():
15
  def __init__(self, device):
16
- self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cpu')
 
17
 
18
  def extract_keypoint(self, images, name=None, info=True):
19
  if isinstance(images, list):
@@ -40,7 +41,6 @@ class KeypointExtractor():
40
  break
41
  except RuntimeError as e:
42
  if str(e).startswith('CUDA'):
43
- print(e)
44
  print("Warning: out of memory, sleep for 1s")
45
  time.sleep(1)
46
  else:
 
13
 
14
  class KeypointExtractor():
15
  def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
+ device=device)
18
 
19
  def extract_keypoint(self, images, name=None, info=True):
20
  if isinstance(images, list):
 
41
  break
42
  except RuntimeError as e:
43
  if str(e).startswith('CUDA'):
 
44
  print("Warning: out of memory, sleep for 1s")
45
  time.sleep(1)
46
  else:
src/face3d/extract_kp_videos_safe.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from torch.multiprocessing import Pool, Process, set_start_method
12
+
13
+ from facexlib.alignment import landmark_98_to_68
14
+ from facexlib.detection import init_detection_model
15
+
16
+ from facexlib.utils import load_file_from_url
17
+ from facexlib.alignment.awing_arch import FAN
18
+
19
+ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
+ if model_name == 'awing_fan':
21
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
22
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
+ else:
24
+ raise NotImplementedError(f'{model_name} is not implemented.')
25
+
26
+ model_path = load_file_from_url(
27
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
+ model.eval()
30
+ model = model.to(device)
31
+ return model
32
+
33
+
34
+ class KeypointExtractor():
35
+ def __init__(self, device='cuda'):
36
+
37
+ ### gfpgan/weights
38
+ try:
39
+ import webui # in webui
40
+ root_path = 'extensions/SadTalker/gfpgan/weights'
41
+
42
+ except:
43
+ root_path = 'gfpgan/weights'
44
+
45
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
46
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
47
+
48
+ def extract_keypoint(self, images, name=None, info=True):
49
+ if isinstance(images, list):
50
+ keypoints = []
51
+ if info:
52
+ i_range = tqdm(images,desc='landmark Det:')
53
+ else:
54
+ i_range = images
55
+
56
+ for image in i_range:
57
+ current_kp = self.extract_keypoint(image)
58
+ # current_kp = self.detector.get_landmarks(np.array(image))
59
+ if np.mean(current_kp) == -1 and keypoints:
60
+ keypoints.append(keypoints[-1])
61
+ else:
62
+ keypoints.append(current_kp[None])
63
+
64
+ keypoints = np.concatenate(keypoints, 0)
65
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
66
+ return keypoints
67
+ else:
68
+ while True:
69
+ try:
70
+ with torch.no_grad():
71
+ # face detection -> face alignment.
72
+ img = np.array(images)
73
+ bboxes = self.det_net.detect_faces(images, 0.97)
74
+
75
+ bboxes = bboxes[0]
76
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
77
+
78
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
79
+
80
+ #### keypoints to the original location
81
+ keypoints[:,0] += int(bboxes[0])
82
+ keypoints[:,1] += int(bboxes[1])
83
+
84
+ break
85
+ except RuntimeError as e:
86
+ if str(e).startswith('CUDA'):
87
+ print("Warning: out of memory, sleep for 1s")
88
+ time.sleep(1)
89
+ else:
90
+ print(e)
91
+ break
92
+ except TypeError:
93
+ print('No face detected in this image')
94
+ shape = [68, 2]
95
+ keypoints = -1. * np.ones(shape)
96
+ break
97
+ if name is not None:
98
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
99
+ return keypoints
100
+
101
+ def read_video(filename):
102
+ frames = []
103
+ cap = cv2.VideoCapture(filename)
104
+ while cap.isOpened():
105
+ ret, frame = cap.read()
106
+ if ret:
107
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ frame = Image.fromarray(frame)
109
+ frames.append(frame)
110
+ else:
111
+ break
112
+ cap.release()
113
+ return frames
114
+
115
+ def run(data):
116
+ filename, opt, device = data
117
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
118
+ kp_extractor = KeypointExtractor()
119
+ images = read_video(filename)
120
+ name = filename.split('/')[-2:]
121
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
122
+ kp_extractor.extract_keypoint(
123
+ images,
124
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
125
+ )
126
+
127
+ if __name__ == '__main__':
128
+ set_start_method('spawn')
129
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
130
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
131
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
132
+ parser.add_argument('--device_ids', type=str, default='0,1')
133
+ parser.add_argument('--workers', type=int, default=4)
134
+
135
+ opt = parser.parse_args()
136
+ filenames = list()
137
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
138
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
139
+ extensions = VIDEO_EXTENSIONS
140
+
141
+ for ext in extensions:
142
+ os.listdir(f'{opt.input_dir}')
143
+ print(f'{opt.input_dir}/*.{ext}')
144
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
145
+ print('Total number of videos:', len(filenames))
146
+ pool = Pool(opt.workers)
147
+ args_list = cycle([opt])
148
+ device_ids = opt.device_ids.split(",")
149
+ device_ids = cycle(device_ids)
150
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
151
+ None
src/face3d/models/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/src/face3d/models/__pycache__/__init__.cpython-38.pyc and b/src/face3d/models/__pycache__/__init__.cpython-38.pyc differ
 
src/face3d/models/__pycache__/base_model.cpython-38.pyc CHANGED
Binary files a/src/face3d/models/__pycache__/base_model.cpython-38.pyc and b/src/face3d/models/__pycache__/base_model.cpython-38.pyc differ
 
src/face3d/models/__pycache__/networks.cpython-38.pyc CHANGED
Binary files a/src/face3d/models/__pycache__/networks.cpython-38.pyc and b/src/face3d/models/__pycache__/networks.cpython-38.pyc differ
 
src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc differ
 
src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc CHANGED
Binary files a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc differ
 
src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc CHANGED
Binary files a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc differ
 
src/face3d/util/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/src/face3d/util/__pycache__/__init__.cpython-38.pyc and b/src/face3d/util/__pycache__/__init__.cpython-38.pyc differ
 
src/face3d/util/__pycache__/load_mats.cpython-38.pyc CHANGED
Binary files a/src/face3d/util/__pycache__/load_mats.cpython-38.pyc and b/src/face3d/util/__pycache__/load_mats.cpython-38.pyc differ
 
src/face3d/util/__pycache__/preprocess.cpython-38.pyc CHANGED
Binary files a/src/face3d/util/__pycache__/preprocess.cpython-38.pyc and b/src/face3d/util/__pycache__/preprocess.cpython-38.pyc differ
 
src/facerender/__pycache__/animate.cpython-38.pyc CHANGED
Binary files a/src/facerender/__pycache__/animate.cpython-38.pyc and b/src/facerender/__pycache__/animate.cpython-38.pyc differ
 
src/facerender/animate.py CHANGED
@@ -4,11 +4,15 @@ import yaml
4
  import numpy as np
5
  import warnings
6
  from skimage import img_as_ubyte
7
-
 
8
  warnings.filterwarnings('ignore')
9
 
 
10
  import imageio
11
  import torch
 
 
12
 
13
  from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
14
  from src.facerender.modules.mapping import MappingNet
@@ -16,17 +20,21 @@ from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionA
16
  from src.facerender.modules.make_animation import make_animation
17
 
18
  from pydub import AudioSegment
19
- from src.utils.face_enhancer import enhancer as face_enhancer
20
  from src.utils.paste_pic import paste_pic
21
  from src.utils.videoio import save_video_with_watermark
22
 
 
 
 
 
 
23
 
24
  class AnimateFromCoeff():
25
 
26
- def __init__(self, free_view_checkpoint, mapping_checkpoint,
27
- config_path, device):
28
 
29
- with open(config_path) as f:
30
  config = yaml.safe_load(f)
31
 
32
  generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
@@ -37,7 +45,6 @@ class AnimateFromCoeff():
37
  **config['model_params']['common_params'])
38
  mapping = MappingNet(**config['model_params']['mapping_params'])
39
 
40
-
41
  generator.to(device)
42
  kp_extractor.to(device)
43
  he_estimator.to(device)
@@ -51,13 +58,16 @@ class AnimateFromCoeff():
51
  for param in mapping.parameters():
52
  param.requires_grad = False
53
 
54
- if free_view_checkpoint is not None:
55
- self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
 
 
 
56
  else:
57
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
58
 
59
- if mapping_checkpoint is not None:
60
- self.load_cpk_mapping(mapping_checkpoint, mapping=mapping)
61
  else:
62
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
63
 
@@ -73,6 +83,33 @@ class AnimateFromCoeff():
73
 
74
  self.device = device
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
77
  kp_detector=None, he_estimator=None, optimizer_generator=None,
78
  optimizer_discriminator=None, optimizer_kp_detector=None,
@@ -117,7 +154,7 @@ class AnimateFromCoeff():
117
 
118
  return checkpoint['epoch']
119
 
120
- def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
121
 
122
  source_image=x['source_image'].type(torch.FloatTensor)
123
  source_semantics=x['source_semantics'].type(torch.FloatTensor)
@@ -157,14 +194,15 @@ class AnimateFromCoeff():
157
  video.append(image)
158
  result = img_as_ubyte(video)
159
 
160
- ### the generated video is 256x256, so we keep the aspect ratio,
161
  original_size = crop_info[0]
162
  if original_size:
163
- result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ]
164
 
165
  video_name = x['video_name'] + '.mp4'
166
  path = os.path.join(video_save_dir, 'temp_'+video_name)
167
- imageio.mimsave(path, result, fps=float(25))
 
168
 
169
  av_path = os.path.join(video_save_dir, video_name)
170
  return_path = av_path
@@ -173,22 +211,23 @@ class AnimateFromCoeff():
173
  audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
174
  new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
175
  start_time = 0
176
- sound = AudioSegment.from_mp3(audio_path)
 
177
  frames = frame_num
178
  end_time = start_time + frames*1/25*1000
179
  word1=sound.set_frame_rate(16000)
180
  word = word1[start_time:end_time]
181
  word.export(new_audio_path, format="wav")
182
 
183
- save_video_with_watermark(path, new_audio_path, av_path, watermark= None)
184
- print(f'The generated video is named {video_name} in {video_save_dir}')
185
 
186
- if preprocess.lower() == 'full':
187
  # only add watermark to the full image.
188
  video_name_full = x['video_name'] + '_full.mp4'
189
  full_video_path = os.path.join(video_save_dir, video_name_full)
190
  return_path = full_video_path
191
- paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
192
  print(f'The generated video is named {video_save_dir}/{video_name_full}')
193
  else:
194
  full_video_path = av_path
@@ -199,10 +238,15 @@ class AnimateFromCoeff():
199
  enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
200
  av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
201
  return_path = av_path_enhancer
202
- enhanced_images = face_enhancer(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
203
- imageio.mimsave(enhanced_path, enhanced_images, fps=float(25))
 
 
 
 
 
204
 
205
- save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= None)
206
  print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
207
  os.remove(enhanced_path)
208
 
 
4
  import numpy as np
5
  import warnings
6
  from skimage import img_as_ubyte
7
+ import safetensors
8
+ import safetensors.torch
9
  warnings.filterwarnings('ignore')
10
 
11
+
12
  import imageio
13
  import torch
14
+ import torchvision
15
+
16
 
17
  from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
18
  from src.facerender.modules.mapping import MappingNet
 
20
  from src.facerender.modules.make_animation import make_animation
21
 
22
  from pydub import AudioSegment
23
+ from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
24
  from src.utils.paste_pic import paste_pic
25
  from src.utils.videoio import save_video_with_watermark
26
 
27
+ try:
28
+ import webui # in webui
29
+ in_webui = True
30
+ except:
31
+ in_webui = False
32
 
33
  class AnimateFromCoeff():
34
 
35
+ def __init__(self, sadtalker_path, device):
 
36
 
37
+ with open(sadtalker_path['facerender_yaml']) as f:
38
  config = yaml.safe_load(f)
39
 
40
  generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
 
45
  **config['model_params']['common_params'])
46
  mapping = MappingNet(**config['model_params']['mapping_params'])
47
 
 
48
  generator.to(device)
49
  kp_extractor.to(device)
50
  he_estimator.to(device)
 
58
  for param in mapping.parameters():
59
  param.requires_grad = False
60
 
61
+ if sadtalker_path is not None:
62
+ if 'checkpoint' in sadtalker_path: # use safe tensor
63
+ self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
64
+ else:
65
+ self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
66
  else:
67
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
68
 
69
+ if sadtalker_path['mappingnet_checkpoint'] is not None:
70
+ self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
71
  else:
72
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
73
 
 
83
 
84
  self.device = device
85
 
86
+ def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
87
+ kp_detector=None, he_estimator=None,
88
+ device="cpu"):
89
+
90
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
91
+
92
+ if generator is not None:
93
+ x_generator = {}
94
+ for k,v in checkpoint.items():
95
+ if 'generator' in k:
96
+ x_generator[k.replace('generator.', '')] = v
97
+ generator.load_state_dict(x_generator)
98
+ if kp_detector is not None:
99
+ x_generator = {}
100
+ for k,v in checkpoint.items():
101
+ if 'kp_extractor' in k:
102
+ x_generator[k.replace('kp_extractor.', '')] = v
103
+ kp_detector.load_state_dict(x_generator)
104
+ if he_estimator is not None:
105
+ x_generator = {}
106
+ for k,v in checkpoint.items():
107
+ if 'he_estimator' in k:
108
+ x_generator[k.replace('he_estimator.', '')] = v
109
+ he_estimator.load_state_dict(x_generator)
110
+
111
+ return None
112
+
113
  def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
114
  kp_detector=None, he_estimator=None, optimizer_generator=None,
115
  optimizer_discriminator=None, optimizer_kp_detector=None,
 
154
 
155
  return checkpoint['epoch']
156
 
157
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
158
 
159
  source_image=x['source_image'].type(torch.FloatTensor)
160
  source_semantics=x['source_semantics'].type(torch.FloatTensor)
 
194
  video.append(image)
195
  result = img_as_ubyte(video)
196
 
197
+ ### the generated video is 256x256, so we keep the aspect ratio,
198
  original_size = crop_info[0]
199
  if original_size:
200
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
201
 
202
  video_name = x['video_name'] + '.mp4'
203
  path = os.path.join(video_save_dir, 'temp_'+video_name)
204
+
205
+ imageio.mimsave(path, result, fps=float(25))
206
 
207
  av_path = os.path.join(video_save_dir, video_name)
208
  return_path = av_path
 
211
  audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
212
  new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
213
  start_time = 0
214
+ # cog will not keep the .mp3 filename
215
+ sound = AudioSegment.from_file(audio_path)
216
  frames = frame_num
217
  end_time = start_time + frames*1/25*1000
218
  word1=sound.set_frame_rate(16000)
219
  word = word1[start_time:end_time]
220
  word.export(new_audio_path, format="wav")
221
 
222
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
223
+ print(f'The generated video is named {video_save_dir}/{video_name}')
224
 
225
+ if 'full' in preprocess.lower():
226
  # only add watermark to the full image.
227
  video_name_full = x['video_name'] + '_full.mp4'
228
  full_video_path = os.path.join(video_save_dir, video_name_full)
229
  return_path = full_video_path
230
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
231
  print(f'The generated video is named {video_save_dir}/{video_name_full}')
232
  else:
233
  full_video_path = av_path
 
238
  enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
239
  av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
240
  return_path = av_path_enhancer
241
+
242
+ try:
243
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
244
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
245
+ except:
246
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
247
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
248
 
249
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
250
  print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
251
  os.remove(enhanced_path)
252
 
src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc CHANGED
Binary files a/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc and b/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc differ
 
src/facerender/modules/__pycache__/generator.cpython-38.pyc CHANGED
Binary files a/src/facerender/modules/__pycache__/generator.cpython-38.pyc and b/src/facerender/modules/__pycache__/generator.cpython-38.pyc differ
 
src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc CHANGED
Binary files a/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc and b/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc differ
 
src/facerender/modules/__pycache__/make_animation.cpython-38.pyc CHANGED
Binary files a/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc and b/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc differ
 
src/facerender/modules/__pycache__/mapping.cpython-38.pyc CHANGED
Binary files a/src/facerender/modules/__pycache__/mapping.cpython-38.pyc and b/src/facerender/modules/__pycache__/mapping.cpython-38.pyc differ
 
src/facerender/modules/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/src/facerender/modules/__pycache__/util.cpython-38.pyc and b/src/facerender/modules/__pycache__/util.cpython-38.pyc differ
 
src/facerender/modules/make_animation.py CHANGED
@@ -29,7 +29,7 @@ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale
29
  def headpose_pred_to_degree(pred):
30
  device = pred.device
31
  idx_tensor = [idx for idx in range(66)]
32
- idx_tensor = torch.FloatTensor(idx_tensor).to(device)
33
  pred = F.softmax(pred)
34
  degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
35
  return degree
@@ -102,7 +102,7 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
102
  def make_animation(source_image, source_semantics, target_semantics,
103
  generator, kp_detector, he_estimator, mapping,
104
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
- use_exp=True):
106
  with torch.no_grad():
107
  predictions = []
108
 
@@ -111,6 +111,8 @@ def make_animation(source_image, source_semantics, target_semantics,
111
  kp_source = keypoint_transformation(kp_canonical, he_source)
112
 
113
  for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
 
 
114
  target_semantics_frame = target_semantics[:, frame_idx]
115
  he_driving = mapping(target_semantics_frame)
116
  if yaw_c_seq is not None:
@@ -122,8 +124,6 @@ def make_animation(source_image, source_semantics, target_semantics,
122
 
123
  kp_driving = keypoint_transformation(kp_canonical, he_driving)
124
 
125
- #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
126
- #kp_driving_initial=kp_driving_initial)
127
  kp_norm = kp_driving
128
  out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
129
  '''
 
29
  def headpose_pred_to_degree(pred):
30
  device = pred.device
31
  idx_tensor = [idx for idx in range(66)]
32
+ idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
33
  pred = F.softmax(pred)
34
  degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
35
  return degree
 
102
  def make_animation(source_image, source_semantics, target_semantics,
103
  generator, kp_detector, he_estimator, mapping,
104
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
+ use_exp=True, use_half=False):
106
  with torch.no_grad():
107
  predictions = []
108
 
 
111
  kp_source = keypoint_transformation(kp_canonical, he_source)
112
 
113
  for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
114
+ # still check the dimension
115
+ # print(target_semantics.shape, source_semantics.shape)
116
  target_semantics_frame = target_semantics[:, frame_idx]
117
  he_driving = mapping(target_semantics_frame)
118
  if yaw_c_seq is not None:
 
124
 
125
  kp_driving = keypoint_transformation(kp_canonical, he_driving)
126
 
 
 
127
  kp_norm = kp_driving
128
  out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
129
  '''
src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc differ
 
src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc differ
 
src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc differ
 
src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc CHANGED
Binary files a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc differ
 
src/generate_batch.py CHANGED
@@ -48,7 +48,7 @@ def generate_blink_seq_randomly(num_frames):
48
  break
49
  return ratio
50
 
51
- def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False):
52
 
53
  syncnet_mel_step_size = 16
54
  fps = 25
@@ -56,22 +56,27 @@ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, stil
56
  pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
57
  audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
58
 
59
- wav = audio.load_wav(audio_path, 16000)
60
- wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
61
- wav = crop_pad_audio(wav, wav_length)
62
- orig_mel = audio.melspectrogram(wav).T
63
- spec = orig_mel.copy() # nframes 80
64
- indiv_mels = []
65
-
66
- for i in tqdm(range(num_frames), 'mel:'):
67
- start_frame_num = i-2
68
- start_idx = int(80. * (start_frame_num / float(fps)))
69
- end_idx = start_idx + syncnet_mel_step_size
70
- seq = list(range(start_idx, end_idx))
71
- seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
72
- m = spec[seq, :]
73
- indiv_mels.append(m.T)
74
- indiv_mels = np.asarray(indiv_mels) # T 80 16
 
 
 
 
 
75
 
76
  ratio = generate_blink_seq_randomly(num_frames) # T
77
  source_semantics_path = first_coeff_path
@@ -96,10 +101,10 @@ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, stil
96
 
97
  indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
98
 
99
- if still:
100
- ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.) # bs T
101
  else:
102
- ratio = torch.FloatTensor(ratio).unsqueeze(0)
103
  # bs T
104
  ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70
105
 
 
48
  break
49
  return ratio
50
 
51
+ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
52
 
53
  syncnet_mel_step_size = 16
54
  fps = 25
 
56
  pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
57
  audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
58
 
59
+
60
+ if idlemode:
61
+ num_frames = int(length_of_audio * 25)
62
+ indiv_mels = np.zeros((num_frames, 80, 16))
63
+ else:
64
+ wav = audio.load_wav(audio_path, 16000)
65
+ wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
66
+ wav = crop_pad_audio(wav, wav_length)
67
+ orig_mel = audio.melspectrogram(wav).T
68
+ spec = orig_mel.copy() # nframes 80
69
+ indiv_mels = []
70
+
71
+ for i in tqdm(range(num_frames), 'mel:'):
72
+ start_frame_num = i-2
73
+ start_idx = int(80. * (start_frame_num / float(fps)))
74
+ end_idx = start_idx + syncnet_mel_step_size
75
+ seq = list(range(start_idx, end_idx))
76
+ seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
77
+ m = spec[seq, :]
78
+ indiv_mels.append(m.T)
79
+ indiv_mels = np.asarray(indiv_mels) # T 80 16
80
 
81
  ratio = generate_blink_seq_randomly(num_frames) # T
82
  source_semantics_path = first_coeff_path
 
101
 
102
  indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
103
 
104
+ if use_blink:
105
+ ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T
106
  else:
107
+ ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
108
  # bs T
109
  ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70
110
 
src/generate_facerender_batch.py CHANGED
@@ -7,7 +7,7 @@ import scipy.io as scio
7
 
8
  def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
9
  batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
10
- expression_scale=1.0, still_mode = False, preprocess='crop'):
11
 
12
  semantic_radius = 13
13
  video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
@@ -18,18 +18,22 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
18
  img1 = Image.open(pic_path)
19
  source_image = np.array(img1)
20
  source_image = img_as_float32(source_image)
21
- source_image = transform.resize(source_image, (256, 256, 3))
22
  source_image = source_image.transpose((2, 0, 1))
23
  source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
24
  source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
25
  data['source_image'] = source_image_ts
26
 
27
  source_semantics_dict = scio.loadmat(first_coeff_path)
 
28
 
29
- if preprocess.lower() != 'full':
30
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
 
 
31
  else:
32
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
 
33
 
34
  source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
35
  source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
@@ -37,11 +41,9 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
37
  data['source_semantics'] = source_semantics_ts
38
 
39
  # target
40
- generated_dict = scio.loadmat(coeff_path)
41
- generated_3dmm = generated_dict['coeff_3dmm']
42
  generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
43
 
44
- if preprocess.lower() == 'full':
45
  generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
46
 
47
  if still_mode:
 
7
 
8
  def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
9
  batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
10
+ expression_scale=1.0, still_mode = False, preprocess='crop', size = 256):
11
 
12
  semantic_radius = 13
13
  video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
 
18
  img1 = Image.open(pic_path)
19
  source_image = np.array(img1)
20
  source_image = img_as_float32(source_image)
21
+ source_image = transform.resize(source_image, (size, size, 3))
22
  source_image = source_image.transpose((2, 0, 1))
23
  source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
24
  source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
25
  data['source_image'] = source_image_ts
26
 
27
  source_semantics_dict = scio.loadmat(first_coeff_path)
28
+ generated_dict = scio.loadmat(coeff_path)
29
 
30
+ if 'full' not in preprocess.lower():
31
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
32
+ generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
33
+
34
  else:
35
  source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
36
+ generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
37
 
38
  source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
39
  source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
 
41
  data['source_semantics'] = source_semantics_ts
42
 
43
  # target
 
 
44
  generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
45
 
46
+ if 'full' in preprocess.lower():
47
  generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
48
 
49
  if still_mode:
src/gradio_demo.py CHANGED
@@ -6,8 +6,11 @@ from src.facerender.animate import AnimateFromCoeff
6
  from src.generate_batch import get_data
7
  from src.generate_facerender_batch import get_facerender_data
8
 
 
 
9
  from pydub import AudioSegment
10
 
 
11
  def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
12
  mp3_file = AudioSegment.from_file(file=mp3_filename)
13
  mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
@@ -28,55 +31,24 @@ class SadTalker():
28
 
29
  self.checkpoint_path = checkpoint_path
30
  self.config_path = config_path
31
-
32
- self.path_of_lm_croper = os.path.join( checkpoint_path, 'shape_predictor_68_face_landmarks.dat')
33
- self.path_of_net_recon_model = os.path.join( checkpoint_path, 'epoch_20.pth')
34
- self.dir_of_BFM_fitting = os.path.join( checkpoint_path, 'BFM_Fitting')
35
- self.wav2lip_checkpoint = os.path.join( checkpoint_path, 'wav2lip.pth')
36
-
37
- self.audio2pose_checkpoint = os.path.join( checkpoint_path, 'auido2pose_00140-model.pth')
38
- self.audio2pose_yaml_path = os.path.join( config_path, 'auido2pose.yaml')
39
-
40
- self.audio2exp_checkpoint = os.path.join( checkpoint_path, 'auido2exp_00300-model.pth')
41
- self.audio2exp_yaml_path = os.path.join( config_path, 'auido2exp.yaml')
42
-
43
- self.free_view_checkpoint = os.path.join( checkpoint_path, 'facevid2vid_00189-model.pth.tar')
44
-
45
- self.lazy_load = lazy_load
46
-
47
- if not self.lazy_load:
48
- #init model
49
- print(self.path_of_lm_croper)
50
- self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
51
-
52
- print(self.audio2pose_checkpoint)
53
- self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
54
- self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
55
-
56
- def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, result_dir='./results/'):
57
-
58
- ### crop: only model,
59
-
60
- if self.lazy_load:
61
- #init model
62
- print(self.path_of_lm_croper)
63
- self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
64
-
65
- print(self.audio2pose_checkpoint)
66
- self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
67
- self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
68
-
69
- if preprocess == 'full':
70
- self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00109-model.pth.tar')
71
- self.facerender_yaml_path = os.path.join(self.config_path, 'facerender_still.yaml')
72
- else:
73
- self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00229-model.pth.tar')
74
- self.facerender_yaml_path = os.path.join(self.config_path, 'facerender.yaml')
75
-
76
- print(self.mapping_checkpoint)
77
- print(self.free_view_checkpoint)
78
- self.animate_from_coeff = AnimateFromCoeff(self.free_view_checkpoint, self.mapping_checkpoint,
79
- self.facerender_yaml_path, self.device)
80
 
81
  time_tag = str(uuid.uuid4())
82
  save_dir = os.path.join(result_dir, time_tag)
@@ -89,7 +61,7 @@ class SadTalker():
89
  pic_path = os.path.join(input_dir, os.path.basename(source_image))
90
  shutil.move(source_image, input_dir)
91
 
92
- if os.path.isfile(driven_audio):
93
  audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
94
 
95
  #### mp3 to wav
@@ -98,37 +70,84 @@ class SadTalker():
98
  audio_path = audio_path.replace('.mp3', '.wav')
99
  else:
100
  shutil.move(driven_audio, input_dir)
 
 
 
 
 
 
101
  else:
102
- raise AttributeError("error audio")
 
103
 
 
 
 
 
 
 
 
104
 
105
  os.makedirs(save_dir, exist_ok=True)
106
- pose_style = 0
107
  #crop image and extract 3dmm from image
108
  first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
109
  os.makedirs(first_frame_dir, exist_ok=True)
110
- first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir,preprocess)
111
 
112
  if first_coeff_path is None:
113
  raise AttributeError("No face is detected")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  #audio2ceoff
116
- batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=None, still=still_mode) # longer audio?
117
- coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
 
 
 
 
118
  #coeff2video
119
- batch_size = 8
120
- data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess)
121
- return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess)
122
  video_name = data['video_name']
123
  print(f'The generated video is named {video_name} in {save_dir}')
124
 
125
- if self.lazy_load:
126
- del self.preprocess_model
127
- del self.audio_to_coeff
128
- del self.animate_from_coeff
129
 
130
- torch.cuda.empty_cache()
131
- torch.cuda.synchronize()
 
 
132
  import gc; gc.collect()
133
 
134
  return return_path
 
6
  from src.generate_batch import get_data
7
  from src.generate_facerender_batch import get_facerender_data
8
 
9
+ from src.utils.init_path import init_path
10
+
11
  from pydub import AudioSegment
12
 
13
+
14
  def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
15
  mp3_file = AudioSegment.from_file(file=mp3_filename)
16
  mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
 
31
 
32
  self.checkpoint_path = checkpoint_path
33
  self.config_path = config_path
34
+
35
+
36
+ def test(self, source_image, driven_audio, preprocess='crop',
37
+ still_mode=False, use_enhancer=False, batch_size=1, size=256,
38
+ pose_style = 0, exp_scale=1.0,
39
+ use_ref_video = False,
40
+ ref_video = None,
41
+ ref_info = None,
42
+ use_idle_mode = False,
43
+ length_of_audio = 0, use_blink=True,
44
+ result_dir='./results/'):
45
+
46
+ self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
47
+ print(self.sadtalker_paths)
48
+
49
+ self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
50
+ self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
51
+ self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  time_tag = str(uuid.uuid4())
54
  save_dir = os.path.join(result_dir, time_tag)
 
61
  pic_path = os.path.join(input_dir, os.path.basename(source_image))
62
  shutil.move(source_image, input_dir)
63
 
64
+ if driven_audio is not None and os.path.isfile(driven_audio):
65
  audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
66
 
67
  #### mp3 to wav
 
70
  audio_path = audio_path.replace('.mp3', '.wav')
71
  else:
72
  shutil.move(driven_audio, input_dir)
73
+
74
+ elif use_idle_mode:
75
+ audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
76
+ from pydub import AudioSegment
77
+ one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
78
+ one_sec_segment.export(audio_path, format="wav")
79
  else:
80
+ print(use_ref_video, ref_info)
81
+ assert use_ref_video == True and ref_info == 'all'
82
 
83
+ if use_ref_video and ref_info == 'all': # full ref mode
84
+ ref_video_videoname = os.path.basename(ref_video)
85
+ audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
86
+ print('new audiopath:',audio_path)
87
+ # if ref_video contains audio, set the audio from ref_video.
88
+ cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
89
+ os.system(cmd)
90
 
91
  os.makedirs(save_dir, exist_ok=True)
92
+
93
  #crop image and extract 3dmm from image
94
  first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
95
  os.makedirs(first_frame_dir, exist_ok=True)
96
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
97
 
98
  if first_coeff_path is None:
99
  raise AttributeError("No face is detected")
100
 
101
+ if use_ref_video:
102
+ print('using ref video for genreation')
103
+ ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
104
+ ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
105
+ os.makedirs(ref_video_frame_dir, exist_ok=True)
106
+ print('3DMM Extraction for the reference video providing pose')
107
+ ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
108
+ else:
109
+ ref_video_coeff_path = None
110
+
111
+ if use_ref_video:
112
+ if ref_info == 'pose':
113
+ ref_pose_coeff_path = ref_video_coeff_path
114
+ ref_eyeblink_coeff_path = None
115
+ elif ref_info == 'blink':
116
+ ref_pose_coeff_path = None
117
+ ref_eyeblink_coeff_path = ref_video_coeff_path
118
+ elif ref_info == 'pose+blink':
119
+ ref_pose_coeff_path = ref_video_coeff_path
120
+ ref_eyeblink_coeff_path = ref_video_coeff_path
121
+ elif ref_info == 'all':
122
+ ref_pose_coeff_path = None
123
+ ref_eyeblink_coeff_path = None
124
+ else:
125
+ raise('error in refinfo')
126
+ else:
127
+ ref_pose_coeff_path = None
128
+ ref_eyeblink_coeff_path = None
129
+
130
  #audio2ceoff
131
+ if use_ref_video and ref_info == 'all':
132
+ coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
133
+ else:
134
+ batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
135
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
136
+
137
  #coeff2video
138
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale)
139
+ return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
 
140
  video_name = data['video_name']
141
  print(f'The generated video is named {video_name} in {save_dir}')
142
 
143
+ del self.preprocess_model
144
+ del self.audio_to_coeff
145
+ del self.animate_from_coeff
 
146
 
147
+ if torch.cuda.is_available():
148
+ torch.cuda.empty_cache()
149
+ torch.cuda.synchronize()
150
+
151
  import gc; gc.collect()
152
 
153
  return return_path
src/src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
src/src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
src/src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch
src/src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ # state_dict = self.audio_encoder.state_dict()
47
+
48
+ # for k,v in wav2lip_state_dict.items():
49
+ # if 'audio_encoder' in k:
50
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ # self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
src/src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
src/src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x