mrfakename
commited on
Commit
•
635f007
0
Parent(s):
Initial Commit
Browse files- .gitattributes +37 -0
- .gitignore +163 -0
- Configs/config.yml +116 -0
- Configs/config_ft.yml +111 -0
- Configs/config_libritts.yml +113 -0
- Data/OOD_texts.txt +3 -0
- Data/train_list.txt +0 -0
- Data/val_list.txt +100 -0
- LICENSE +28 -0
- Modules/__init__.py +1 -0
- Modules/diffusion/__init__.py +1 -0
- Modules/diffusion/diffusion.py +92 -0
- Modules/diffusion/modules.py +700 -0
- Modules/diffusion/sampler.py +685 -0
- Modules/diffusion/utils.py +83 -0
- Modules/discriminators.py +267 -0
- Modules/hifigan.py +643 -0
- Modules/istftnet.py +720 -0
- Modules/slmadv.py +256 -0
- Modules/utils.py +14 -0
- README.md +15 -0
- Utils/ASR/__init__.py +1 -0
- Utils/ASR/config.yml +29 -0
- Utils/ASR/epoch_00080.pth +3 -0
- Utils/ASR/layers.py +455 -0
- Utils/ASR/models.py +217 -0
- Utils/JDC/__init__.py +1 -0
- Utils/JDC/bst.t7 +3 -0
- Utils/JDC/model.py +212 -0
- Utils/PLBERT/config.yml +30 -0
- Utils/PLBERT/step_1000000.t7 +3 -0
- Utils/PLBERT/util.py +49 -0
- Utils/__init__.py +1 -0
- _run.py +371 -0
- app.py +46 -0
- losses.py +303 -0
- meldataset.py +294 -0
- models.py +881 -0
- optimizers.py +86 -0
- reference_audio.zip +3 -0
- requirements.txt +21 -0
- styletts2importable.py +361 -0
- text_utils.py +28 -0
- train_finetune.py +839 -0
- train_first.py +540 -0
- train_second.py +958 -0
- utils.py +89 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.t7 filter=lfs diff=lfs merge=lfs -text
|
25 |
+
OOD_texts.txt filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
28 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
# Byte-compiled / optimized / DLL files
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
|
163 |
+
voice
|
Configs/config.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Models/LJSpeech"
|
2 |
+
first_stage_path: "first_stage.pth"
|
3 |
+
save_freq: 2
|
4 |
+
log_interval: 10
|
5 |
+
device: "cuda"
|
6 |
+
epochs_1st: 200 # number of epochs for first stage training (pre-training)
|
7 |
+
epochs_2nd: 100 # number of peochs for second stage training (joint training)
|
8 |
+
batch_size: 16
|
9 |
+
max_len: 400 # maximum number of frames
|
10 |
+
pretrained_model: ""
|
11 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
12 |
+
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
F0_path: "Utils/JDC/bst.t7"
|
15 |
+
ASR_config: "Utils/ASR/config.yml"
|
16 |
+
ASR_path: "Utils/ASR/epoch_00080.pth"
|
17 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
18 |
+
|
19 |
+
data_params:
|
20 |
+
train_data: "Data/train_list.txt"
|
21 |
+
val_data: "Data/val_list.txt"
|
22 |
+
root_path: "/local/LJSpeech-1.1/wavs"
|
23 |
+
OOD_data: "Data/OOD_texts.txt"
|
24 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
25 |
+
|
26 |
+
preprocess_params:
|
27 |
+
sr: 24000
|
28 |
+
spect_params:
|
29 |
+
n_fft: 2048
|
30 |
+
win_length: 1200
|
31 |
+
hop_length: 300
|
32 |
+
|
33 |
+
model_params:
|
34 |
+
multispeaker: false
|
35 |
+
|
36 |
+
dim_in: 64
|
37 |
+
hidden_dim: 512
|
38 |
+
max_conv_dim: 512
|
39 |
+
n_layer: 3
|
40 |
+
n_mels: 80
|
41 |
+
|
42 |
+
n_token: 178 # number of phoneme tokens
|
43 |
+
max_dur: 50 # maximum duration of a single phoneme
|
44 |
+
style_dim: 128 # style vector size
|
45 |
+
|
46 |
+
dropout: 0.2
|
47 |
+
|
48 |
+
# config for decoder
|
49 |
+
decoder:
|
50 |
+
type: 'istftnet' # either hifigan or istftnet
|
51 |
+
resblock_kernel_sizes: [3,7,11]
|
52 |
+
upsample_rates : [10, 6]
|
53 |
+
upsample_initial_channel: 512
|
54 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
55 |
+
upsample_kernel_sizes: [20, 12]
|
56 |
+
gen_istft_n_fft: 20
|
57 |
+
gen_istft_hop_size: 5
|
58 |
+
|
59 |
+
# speech language model config
|
60 |
+
slm:
|
61 |
+
model: 'microsoft/wavlm-base-plus'
|
62 |
+
sr: 16000 # sampling rate of SLM
|
63 |
+
hidden: 768 # hidden size of SLM
|
64 |
+
nlayers: 13 # number of layers of SLM
|
65 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
66 |
+
|
67 |
+
# style diffusion model config
|
68 |
+
diffusion:
|
69 |
+
embedding_mask_proba: 0.1
|
70 |
+
# transformer config
|
71 |
+
transformer:
|
72 |
+
num_layers: 3
|
73 |
+
num_heads: 8
|
74 |
+
head_features: 64
|
75 |
+
multiplier: 2
|
76 |
+
|
77 |
+
# diffusion distribution config
|
78 |
+
dist:
|
79 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
80 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
81 |
+
mean: -3.0
|
82 |
+
std: 1.0
|
83 |
+
|
84 |
+
loss_params:
|
85 |
+
lambda_mel: 5. # mel reconstruction loss
|
86 |
+
lambda_gen: 1. # generator loss
|
87 |
+
lambda_slm: 1. # slm feature matching loss
|
88 |
+
|
89 |
+
lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
|
90 |
+
lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
|
91 |
+
TMA_epoch: 50 # TMA starting epoch (1st stage)
|
92 |
+
|
93 |
+
lambda_F0: 1. # F0 reconstruction loss (2nd stage)
|
94 |
+
lambda_norm: 1. # norm reconstruction loss (2nd stage)
|
95 |
+
lambda_dur: 1. # duration loss (2nd stage)
|
96 |
+
lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
|
97 |
+
lambda_sty: 1. # style reconstruction loss (2nd stage)
|
98 |
+
lambda_diff: 1. # score matching loss (2nd stage)
|
99 |
+
|
100 |
+
diff_epoch: 20 # style diffusion starting epoch (2nd stage)
|
101 |
+
joint_epoch: 50 # joint training starting epoch (2nd stage)
|
102 |
+
|
103 |
+
optimizer_params:
|
104 |
+
lr: 0.0001 # general learning rate
|
105 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
106 |
+
ft_lr: 0.00001 # learning rate for acoustic modules
|
107 |
+
|
108 |
+
slmadv_params:
|
109 |
+
min_len: 400 # minimum length of samples
|
110 |
+
max_len: 500 # maximum length of samples
|
111 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
112 |
+
iter: 10 # update the discriminator every this iterations of generator update
|
113 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
114 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
115 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
116 |
+
|
Configs/config_ft.yml
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Models/LJSpeech"
|
2 |
+
save_freq: 5
|
3 |
+
log_interval: 10
|
4 |
+
device: "cuda"
|
5 |
+
epochs: 50 # number of finetuning epoch (1 hour of data)
|
6 |
+
batch_size: 8
|
7 |
+
max_len: 400 # maximum number of frames
|
8 |
+
pretrained_model: "Models/LibriTTS/epochs_2nd_00020.pth"
|
9 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
10 |
+
load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
|
11 |
+
|
12 |
+
F0_path: "Utils/JDC/bst.t7"
|
13 |
+
ASR_config: "Utils/ASR/config.yml"
|
14 |
+
ASR_path: "Utils/ASR/epoch_00080.pth"
|
15 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
16 |
+
|
17 |
+
data_params:
|
18 |
+
train_data: "Data/train_list.txt"
|
19 |
+
val_data: "Data/val_list.txt"
|
20 |
+
root_path: "/local/LJSpeech-1.1/wavs"
|
21 |
+
OOD_data: "Data/OOD_texts.txt"
|
22 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
23 |
+
|
24 |
+
preprocess_params:
|
25 |
+
sr: 24000
|
26 |
+
spect_params:
|
27 |
+
n_fft: 2048
|
28 |
+
win_length: 1200
|
29 |
+
hop_length: 300
|
30 |
+
|
31 |
+
model_params:
|
32 |
+
multispeaker: true
|
33 |
+
|
34 |
+
dim_in: 64
|
35 |
+
hidden_dim: 512
|
36 |
+
max_conv_dim: 512
|
37 |
+
n_layer: 3
|
38 |
+
n_mels: 80
|
39 |
+
|
40 |
+
n_token: 178 # number of phoneme tokens
|
41 |
+
max_dur: 50 # maximum duration of a single phoneme
|
42 |
+
style_dim: 128 # style vector size
|
43 |
+
|
44 |
+
dropout: 0.2
|
45 |
+
|
46 |
+
# config for decoder
|
47 |
+
decoder:
|
48 |
+
type: 'hifigan' # either hifigan or istftnet
|
49 |
+
resblock_kernel_sizes: [3,7,11]
|
50 |
+
upsample_rates : [10,5,3,2]
|
51 |
+
upsample_initial_channel: 512
|
52 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
53 |
+
upsample_kernel_sizes: [20,10,6,4]
|
54 |
+
|
55 |
+
# speech language model config
|
56 |
+
slm:
|
57 |
+
model: 'microsoft/wavlm-base-plus'
|
58 |
+
sr: 16000 # sampling rate of SLM
|
59 |
+
hidden: 768 # hidden size of SLM
|
60 |
+
nlayers: 13 # number of layers of SLM
|
61 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
62 |
+
|
63 |
+
# style diffusion model config
|
64 |
+
diffusion:
|
65 |
+
embedding_mask_proba: 0.1
|
66 |
+
# transformer config
|
67 |
+
transformer:
|
68 |
+
num_layers: 3
|
69 |
+
num_heads: 8
|
70 |
+
head_features: 64
|
71 |
+
multiplier: 2
|
72 |
+
|
73 |
+
# diffusion distribution config
|
74 |
+
dist:
|
75 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
76 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
77 |
+
mean: -3.0
|
78 |
+
std: 1.0
|
79 |
+
|
80 |
+
loss_params:
|
81 |
+
lambda_mel: 5. # mel reconstruction loss
|
82 |
+
lambda_gen: 1. # generator loss
|
83 |
+
lambda_slm: 1. # slm feature matching loss
|
84 |
+
|
85 |
+
lambda_mono: 1. # monotonic alignment loss (TMA)
|
86 |
+
lambda_s2s: 1. # sequence-to-sequence loss (TMA)
|
87 |
+
|
88 |
+
lambda_F0: 1. # F0 reconstruction loss
|
89 |
+
lambda_norm: 1. # norm reconstruction loss
|
90 |
+
lambda_dur: 1. # duration loss
|
91 |
+
lambda_ce: 20. # duration predictor probability output CE loss
|
92 |
+
lambda_sty: 1. # style reconstruction loss
|
93 |
+
lambda_diff: 1. # score matching loss
|
94 |
+
|
95 |
+
diff_epoch: 10 # style diffusion starting epoch
|
96 |
+
joint_epoch: 30 # joint training starting epoch
|
97 |
+
|
98 |
+
optimizer_params:
|
99 |
+
lr: 0.0001 # general learning rate
|
100 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
101 |
+
ft_lr: 0.0001 # learning rate for acoustic modules
|
102 |
+
|
103 |
+
slmadv_params:
|
104 |
+
min_len: 400 # minimum length of samples
|
105 |
+
max_len: 500 # maximum length of samples
|
106 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
107 |
+
iter: 10 # update the discriminator every this iterations of generator update
|
108 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
109 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
110 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
111 |
+
|
Configs/config_libritts.yml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Models/LibriTTS"
|
2 |
+
first_stage_path: "first_stage.pth"
|
3 |
+
save_freq: 1
|
4 |
+
log_interval: 10
|
5 |
+
device: "cuda"
|
6 |
+
epochs_1st: 50 # number of epochs for first stage training (pre-training)
|
7 |
+
epochs_2nd: 30 # number of peochs for second stage training (joint training)
|
8 |
+
batch_size: 16
|
9 |
+
max_len: 300 # maximum number of frames
|
10 |
+
pretrained_model: ""
|
11 |
+
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
|
12 |
+
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
F0_path: "Utils/JDC/bst.t7"
|
15 |
+
ASR_config: "Utils/ASR/config.yml"
|
16 |
+
ASR_path: "Utils/ASR/epoch_00080.pth"
|
17 |
+
PLBERT_dir: 'Utils/PLBERT/'
|
18 |
+
|
19 |
+
data_params:
|
20 |
+
train_data: "Data/train_list.txt"
|
21 |
+
val_data: "Data/val_list.txt"
|
22 |
+
root_path: ""
|
23 |
+
OOD_data: "Data/OOD_texts.txt"
|
24 |
+
min_length: 50 # sample until texts with this size are obtained for OOD texts
|
25 |
+
|
26 |
+
preprocess_params:
|
27 |
+
sr: 24000
|
28 |
+
spect_params:
|
29 |
+
n_fft: 2048
|
30 |
+
win_length: 1200
|
31 |
+
hop_length: 300
|
32 |
+
|
33 |
+
model_params:
|
34 |
+
multispeaker: true
|
35 |
+
|
36 |
+
dim_in: 64
|
37 |
+
hidden_dim: 512
|
38 |
+
max_conv_dim: 512
|
39 |
+
n_layer: 3
|
40 |
+
n_mels: 80
|
41 |
+
|
42 |
+
n_token: 178 # number of phoneme tokens
|
43 |
+
max_dur: 50 # maximum duration of a single phoneme
|
44 |
+
style_dim: 128 # style vector size
|
45 |
+
|
46 |
+
dropout: 0.2
|
47 |
+
|
48 |
+
# config for decoder
|
49 |
+
decoder:
|
50 |
+
type: 'hifigan' # either hifigan or istftnet
|
51 |
+
resblock_kernel_sizes: [3,7,11]
|
52 |
+
upsample_rates : [10,5,3,2]
|
53 |
+
upsample_initial_channel: 512
|
54 |
+
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
|
55 |
+
upsample_kernel_sizes: [20,10,6,4]
|
56 |
+
|
57 |
+
# speech language model config
|
58 |
+
slm:
|
59 |
+
model: 'microsoft/wavlm-base-plus'
|
60 |
+
sr: 16000 # sampling rate of SLM
|
61 |
+
hidden: 768 # hidden size of SLM
|
62 |
+
nlayers: 13 # number of layers of SLM
|
63 |
+
initial_channel: 64 # initial channels of SLM discriminator head
|
64 |
+
|
65 |
+
# style diffusion model config
|
66 |
+
diffusion:
|
67 |
+
embedding_mask_proba: 0.1
|
68 |
+
# transformer config
|
69 |
+
transformer:
|
70 |
+
num_layers: 3
|
71 |
+
num_heads: 8
|
72 |
+
head_features: 64
|
73 |
+
multiplier: 2
|
74 |
+
|
75 |
+
# diffusion distribution config
|
76 |
+
dist:
|
77 |
+
sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
|
78 |
+
estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
|
79 |
+
mean: -3.0
|
80 |
+
std: 1.0
|
81 |
+
|
82 |
+
loss_params:
|
83 |
+
lambda_mel: 5. # mel reconstruction loss
|
84 |
+
lambda_gen: 1. # generator loss
|
85 |
+
lambda_slm: 1. # slm feature matching loss
|
86 |
+
|
87 |
+
lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
|
88 |
+
lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
|
89 |
+
TMA_epoch: 5 # TMA starting epoch (1st stage)
|
90 |
+
|
91 |
+
lambda_F0: 1. # F0 reconstruction loss (2nd stage)
|
92 |
+
lambda_norm: 1. # norm reconstruction loss (2nd stage)
|
93 |
+
lambda_dur: 1. # duration loss (2nd stage)
|
94 |
+
lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
|
95 |
+
lambda_sty: 1. # style reconstruction loss (2nd stage)
|
96 |
+
lambda_diff: 1. # score matching loss (2nd stage)
|
97 |
+
|
98 |
+
diff_epoch: 10 # style diffusion starting epoch (2nd stage)
|
99 |
+
joint_epoch: 15 # joint training starting epoch (2nd stage)
|
100 |
+
|
101 |
+
optimizer_params:
|
102 |
+
lr: 0.0001 # general learning rate
|
103 |
+
bert_lr: 0.00001 # learning rate for PLBERT
|
104 |
+
ft_lr: 0.00001 # learning rate for acoustic modules
|
105 |
+
|
106 |
+
slmadv_params:
|
107 |
+
min_len: 400 # minimum length of samples
|
108 |
+
max_len: 500 # maximum length of samples
|
109 |
+
batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
|
110 |
+
iter: 20 # update the discriminator every this iterations of generator update
|
111 |
+
thresh: 5 # gradient norm above which the gradient is scaled
|
112 |
+
scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
|
113 |
+
sig: 1.5 # sigma for differentiable duration modeling
|
Data/OOD_texts.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e0989ef6a9873b711befefcbe60660ced7a65532359277f766f4db504c558a72
|
3 |
+
size 31758898
|
Data/train_list.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Data/val_list.txt
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹᵻɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wʌt ðeɪ hˈɪɹ ænd wʌt ðeɪ ɹˈiːd .|0
|
2 |
+
LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː , ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt , tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ , ænd ˈɔːl ðə fˈɜːnɪtʃɚ , aɪ wʊd biː mˈæd æz hˈɛl , tˈuː .|0
|
3 |
+
LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹᵻpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪŋkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn ˈeɪtiːn θˈɜːɾi fˈaɪv .|0
|
4 |
+
LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹᵻspˈɛkt :|0
|
5 |
+
LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹᵻspˈɛkt wʌz tə θɹˈoʊ ðə ɹᵻspˌɑːnsəbˈɪlɪɾi ˌɔn ˈʌðɚz .|0
|
6 |
+
LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛl ɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌŋkənvˈɪktᵻd pɹˈɪzənɚ , ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt , ænd stˈɪl ʌŋkəntˈæmᵻnˌeɪɾᵻd ,|0
|
7 |
+
LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔː stˈeɪʃənɚz . hɪz ɚɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz .|0
|
8 |
+
LJ047-0044.wav|ˈɑːswəld wʌz , haʊˈɛvɚ , wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz . hiː dᵻnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz|0
|
9 |
+
LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ . tʃˈɑːɹlz dʒˈeɪ . kˈæɹɪkˌoʊ , ɐ ɹˈɛzᵻdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi .|0
|
10 |
+
LJ048-0194.wav|dˈʊɹɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛnti tˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd .|0
|
11 |
+
LJ049-0026.wav|ˌɔn əkˈeɪʒən ðə sˈiːkɹᵻt sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt .|0
|
12 |
+
LJ004-0152.wav|ɔːlðˈoʊ æt mˈɪstɚ . bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən , ðə fˈɜːst stˈɛp təwˈɔːɹdz ɹᵻfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˈɛvəntˌiːn sˈɛvənti fˈoːɹ .|0
|
13 |
+
LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni , ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsᵻsɚɹi tə dˈɑːlɚ mˌeɪk ɐn ɛɡzˈæmpəl.dˈɑːlɚ|0
|
14 |
+
LJ043-0002.wav|ðə wˈɔːɹəŋ kəmˈɪʃən ɹᵻpˈoːɹt . baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɔnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi . tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld :|0
|
15 |
+
LJ009-0114.wav|mˈɪstɚ . wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dᵻskɹˈaɪbɪŋ ɐnˈʌðɚ ɹᵻlˈɪdʒəs sˈɜːvɪs , wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪˌeɪtli biː ɪnsˈɜːɾᵻd hˈɪɹ .|0
|
16 |
+
LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk .|0
|
17 |
+
LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd . ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzᵻz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə|0
|
18 |
+
LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp , hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪliˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən|0
|
19 |
+
LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl , kwˈoʊt , wiː hæd ɐ mˈoʊɾɚkˌeɪd wɛɹˈɛvɚ kplˈʌsplʌs wˌɪtʃ hɐdbɪn bˌɪn hˈeɪstili sˈʌmənd fɚðə ðə pˈɜːpəs wiː wˈɛnt , ˈɛnd kwˈoʊt .|0
|
20 |
+
LJ031-0070.wav|dˈɑːktɚ . klˈɑːɹk , hˌuː mˈoʊst klˈoʊsli əbzˈɜːvd ðə hˈɛd wˈuːnd ,|0
|
21 |
+
LJ034-0198.wav|jˈuːɪnz , hˌuː wʌz ɔnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstᵻfˌaɪd ðæt hiː kʊd nˌɑːt dᵻskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ .|0
|
22 |
+
LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt , tʊ ɐ smˈɔːl ɛkstˈɛnt ,|0
|
23 |
+
LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɔnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsᵻsɚɹi .|0
|
24 |
+
LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd|0
|
25 |
+
LJ005-0014.wav|spˈiːkɪŋ ˌɔn ɐ dᵻbˈeɪt ˌɔn pɹˈɪzən mˈæɾɚz , hiː dᵻklˈɛɹd ðˈæt|0
|
26 |
+
LJ012-0161.wav|hiː wʌz ɹᵻpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ .|0
|
27 |
+
LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹᵻpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹᵻfˈɜːd tuː|0
|
28 |
+
LJ019-0257.wav|hˈɪɹ ðə tɹˈɛd wˈiːl wʌz ɪn jˈuːs , ðɛɹ sˈɛljʊlɚ kɹˈæŋks , ɔːɹ hˈɑːɹd lˈeɪbɚ məʃˈiːnz .|0
|
29 |
+
LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɔn .|0
|
30 |
+
LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɔnðə kˈoːɹt ;|0
|
31 |
+
LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz , nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz . aɪ hæv hæd ɪnˈʌf .|0
|
32 |
+
LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp .|0
|
33 |
+
LJ046-0058.wav|dˈʊɹɹɪŋ hɪz pɹˈɛzɪdənsi , fɹˈæŋklɪn dˈiː . ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹɪd dʒˈɜːniz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹɪd fˈɪfti θˈaʊzənd mˈaɪlz .|0
|
34 |
+
LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ , ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv .|0
|
35 |
+
LJ002-0043.wav|lˈɔŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾi sˈɪks fˈiːt , sˈɪks twˈɛnti θɹˈiː fˈiːt , ænd ðɪ ˈeɪtθ ˈeɪtiːn ,|0
|
36 |
+
LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən .|0
|
37 |
+
LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hæd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹᵻpɹˈiːv , ænd wɪðˌɪn ɐ fjˈuː ˈaʊɚz ʌv ˌɛksɪkjˈuːʃən .|0
|
38 |
+
LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹᵻt sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹᵻlˈiːst ɔːɹ ɛskˈeɪps .|0
|
39 |
+
LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ , ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt .|0
|
40 |
+
LJ042-0096.wav|ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt|0
|
41 |
+
LJ049-0050.wav|hˈɪl hæd bˈoʊθ fˈiːt ɔnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mˈɪsɪz . kˈɛnədi .|0
|
42 |
+
LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt , nˈuːɡeɪt ɹᵻsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntiz ,|0
|
43 |
+
LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs , ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsəŋ ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd .|0
|
44 |
+
LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd .|0
|
45 |
+
LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kəŋklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɔnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld .|0
|
46 |
+
LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən .|0
|
47 |
+
LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt , ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ᵻlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm ?|0
|
48 |
+
LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪɾ ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz .|0
|
49 |
+
LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪst ænd ɹᵻpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɹiz ʌvðə sˈɪɾi ʌv lˈʌndən ,|0
|
50 |
+
LJ028-0275.wav|æt lˈæst , ɪnðə twˈɛntiəθ mˈʌnθ ,|0
|
51 |
+
LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋ plˈeɪs wɪð ɐ tɹˈæp dˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd .|0
|
52 |
+
LJ011-0096.wav|hiː mˈæɹid ɐ lˈeɪdi ˈɔːlsoʊ bᵻlˈɔŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz , hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃʊn , wˈɪtʃ , ænd hɪz ˈoʊn mˈʌni , hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm ,|0
|
53 |
+
LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː . kɹˈeɪɡ , ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti ,|0
|
54 |
+
LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz , ɡɹˈeɪt lˈɔɪɚz , ɡˈʌvɚnɚz ʌv pɹˈɪzənz , ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː .|0
|
55 |
+
LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst , ɐ səspˈɪʃəs sˈɜːkəmstˌæns , æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ .|0
|
56 |
+
LJ027-0141.wav|ɪz klˈoʊsli ɹᵻpɹədˈuːst ɪnðə lˈaɪf hˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ . ɔːɹ , ɪn ˈʌðɚ wˈɜːdz ,|0
|
57 |
+
LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi , ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz .|0
|
58 |
+
LJ031-0202.wav|mˈɪsɪz . kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hæd sˈɜːvd ɪnðə nˈeɪvi .|0
|
59 |
+
LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈo��pt fɔːɹ pˈiəɹɪəd ʌv pˈiːs ,|0
|
60 |
+
LJ016-0288.wav|dˈɑːlɚ mˈuːlɚ , mˈuːlɚ , hiːz ðə mˈæn , dˈɑːlɚ tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz , wˌɪtʃ wʌz ɹᵻsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz .|0
|
61 |
+
LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ , wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdili dɪstˈɪŋɡwɪʃ ðə fˈɔls fɹʌmðə tɹˈuː ,|0
|
62 |
+
LJ018-0081.wav|hɪz dᵻfˈɛns bˌiːɪŋ ðæt hiː hæd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd , bˌʌt ðˈæt , ɔnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hæd ɹˈɔŋd hˌɪm ,|0
|
63 |
+
LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪŋkɹiːs ɪnðə pˈeɪɹoʊlz , ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts|0
|
64 |
+
LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp , bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd , ænd ðə mˈæn wʌz kˈæɹid bˈæk tə dʒˈeɪl .|0
|
65 |
+
LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz , ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz .|0
|
66 |
+
LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən , ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl .|0
|
67 |
+
LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs , ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts , ðə hˈaʊskiːpɚ ðˈɛɹ .|0
|
68 |
+
LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛnti tˈuː , nˈaɪntiːn sˈɪksti θɹˈiː , fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈɪɹiəl fɚðə pˌiːˌɑːɹɹˈɛs dʒˈɛnɚɹəl fˈaɪlz|0
|
69 |
+
LJ017-0044.wav|ænd ðə dˈiːpɪst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm , ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn , ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ .|0
|
70 |
+
LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ , ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn , ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ .|0
|
71 |
+
LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɚɹˈɛstᵻd ˌɔn səspˈɪʃən , ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd ;|0
|
72 |
+
LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn , bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd , ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sᵻvˈɪɹli .|0
|
73 |
+
LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹihˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ . ɔːlðˈoʊ ɪɾ ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt ,|0
|
74 |
+
LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm .|0
|
75 |
+
LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹᵻkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɔŋ ɪn səspˈɛns .|0
|
76 |
+
LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dᵻfˈɜːd .|0
|
77 |
+
LJ047-0148.wav|ˌɔn ɑːktˈoʊbɚ twˈɛnti fˈaɪv ,|0
|
78 |
+
LJ008-0111.wav|ðeɪ ˈɛntɚd ɐ dˈɑːlɚ stˈoʊŋ kˈoʊld ɹˈuːm , dˈɑːlɚɹ ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ .|0
|
79 |
+
LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstᵻfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld .|0
|
80 |
+
LJ037-0234.wav|mˈɪsɪz . mˈɛɹi bɹˈɑːk , ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən , wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl ,|0
|
81 |
+
LJ040-0002.wav|tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld : bˈækɡɹaʊnd ænd pˈɑːsᵻbəl mˈoʊɾɪvz , pˈɑːɹt wˌʌn .|0
|
82 |
+
LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstᵻfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bᵻkˈʌmɪŋ ɪnvˈɑːlvd|0
|
83 |
+
LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɔn wˈɑːtʃᵻz , wɜː kˈɛɹfəli ɹᵻmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz .|0
|
84 |
+
LJ012-0250.wav|ɔnðə sˈɛvənθ dʒuːlˈaɪ , ˈeɪtiːn θˈɜːɾi sˈɛvən ,|0
|
85 |
+
LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈiːnɚz tə wˈɜːk baɪ ðə dʒˈɑːb .|0
|
86 |
+
LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən .|0
|
87 |
+
LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ᵻsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi .|0
|
88 |
+
LJ031-0134.wav|ˌɔn wˈʌn əkˈeɪʒən mˈɪsɪz . dʒˈɑːnsən , ɐkˈʌmpənid baɪ tˈuː sˈiːkɹᵻt sˈɜːvɪs ˈeɪdʒənts , lˈɛft ðə ɹˈuːm tə sˈiː mˈɪsɪz . kˈɛnədi ænd mˈɪsɪz . kˈɑːnæli .|0
|
89 |
+
LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn .|0
|
90 |
+
LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd , ˈoʊpənd , ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts .|0
|
91 |
+
LJ034-0160.wav|ˌɔn bɹˈɛnənz sˈʌbsᵻkwənt sˈɜːʔn̩ aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl .|0
|
92 |
+
LJ038-0199.wav|ᵻlˈɛvən . ɪf aɪɐm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ ,|0
|
93 |
+
LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈæd fɔːɹ hˌɪm , ænd ɹᵻmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm .|0
|
94 |
+
LJ033-0047.wav|aɪ nˈoʊɾɪst wɛn aɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɔn , ˈɛnd kwˈoʊt ,|0
|
95 |
+
LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ .|0
|
96 |
+
LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli .|0
|
97 |
+
LJ003-0111.wav|hiː wʌz ɪŋ kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː , ˈɛnd kwˈoʊt . ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɹɪˈɔsɪɾi .|0
|
98 |
+
LJ008-0258.wav|lˈɛt mˌiː ɹᵻtɹˈeɪs maɪ stˈɛps , ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz ,|0
|
99 |
+
LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæŋ kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt , mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs , fˈɔːɹt wˈɜːθ , sˌæn æntˈoʊnɪˌoʊ , ænd hjˈuːstən .|0
|
100 |
+
LJ004-0045.wav|mˈɪstɚ . stˈɜːdʒᵻz bˈoːɹn , sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ , sˌɜː dʒˈeɪmz skˈɑːɹlɪt , ænd wˈɪljəm wˈɪlbɚfˌoːɹs .|0
|
LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE FOR STYLETTS2:
|
2 |
+
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Aaron (Yinghao) Li
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
|
25 |
+
|
26 |
+
LICENSE FOR DEMO PAGE:
|
27 |
+
|
28 |
+
COPYRIGHT 2023 MRFAKENAME. ALL RIGHTS RESERVED.
|
Modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Modules/diffusion/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Modules/diffusion/diffusion.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
from random import randint
|
3 |
+
from typing import Any, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .utils import *
|
11 |
+
from .sampler import *
|
12 |
+
|
13 |
+
"""
|
14 |
+
Diffusion Classes (generic for 1d data)
|
15 |
+
"""
|
16 |
+
|
17 |
+
|
18 |
+
class Model1d(nn.Module):
|
19 |
+
def __init__(self, unet_type: str = "base", **kwargs):
|
20 |
+
super().__init__()
|
21 |
+
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
|
22 |
+
self.unet = None
|
23 |
+
self.diffusion = None
|
24 |
+
|
25 |
+
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
26 |
+
return self.diffusion(x, **kwargs)
|
27 |
+
|
28 |
+
def sample(self, *args, **kwargs) -> Tensor:
|
29 |
+
return self.diffusion.sample(*args, **kwargs)
|
30 |
+
|
31 |
+
|
32 |
+
"""
|
33 |
+
Audio Diffusion Classes (specific for 1d audio data)
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
def get_default_model_kwargs():
|
38 |
+
return dict(
|
39 |
+
channels=128,
|
40 |
+
patch_size=16,
|
41 |
+
multipliers=[1, 2, 4, 4, 4, 4, 4],
|
42 |
+
factors=[4, 4, 4, 2, 2, 2],
|
43 |
+
num_blocks=[2, 2, 2, 2, 2, 2],
|
44 |
+
attentions=[0, 0, 0, 1, 1, 1, 1],
|
45 |
+
attention_heads=8,
|
46 |
+
attention_features=64,
|
47 |
+
attention_multiplier=2,
|
48 |
+
attention_use_rel_pos=False,
|
49 |
+
diffusion_type="v",
|
50 |
+
diffusion_sigma_distribution=UniformDistribution(),
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def get_default_sampling_kwargs():
|
55 |
+
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
|
56 |
+
|
57 |
+
|
58 |
+
class AudioDiffusionModel(Model1d):
|
59 |
+
def __init__(self, **kwargs):
|
60 |
+
super().__init__(**{**get_default_model_kwargs(), **kwargs})
|
61 |
+
|
62 |
+
def sample(self, *args, **kwargs):
|
63 |
+
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
|
64 |
+
|
65 |
+
|
66 |
+
class AudioDiffusionConditional(Model1d):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
embedding_features: int,
|
70 |
+
embedding_max_length: int,
|
71 |
+
embedding_mask_proba: float = 0.1,
|
72 |
+
**kwargs,
|
73 |
+
):
|
74 |
+
self.embedding_mask_proba = embedding_mask_proba
|
75 |
+
default_kwargs = dict(
|
76 |
+
**get_default_model_kwargs(),
|
77 |
+
unet_type="cfg",
|
78 |
+
context_embedding_features=embedding_features,
|
79 |
+
context_embedding_max_length=embedding_max_length,
|
80 |
+
)
|
81 |
+
super().__init__(**{**default_kwargs, **kwargs})
|
82 |
+
|
83 |
+
def forward(self, *args, **kwargs):
|
84 |
+
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
85 |
+
return super().forward(*args, **{**default_kwargs, **kwargs})
|
86 |
+
|
87 |
+
def sample(self, *args, **kwargs):
|
88 |
+
default_kwargs = dict(
|
89 |
+
**get_default_sampling_kwargs(),
|
90 |
+
embedding_scale=5.0,
|
91 |
+
)
|
92 |
+
return super().sample(*args, **{**default_kwargs, **kwargs})
|
Modules/diffusion/modules.py
ADDED
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import floor, log, pi
|
2 |
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
from .utils import *
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange, reduce, repeat
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
from einops_exts import rearrange_many
|
11 |
+
from torch import Tensor, einsum
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
Utils
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
class AdaLayerNorm(nn.Module):
|
20 |
+
def __init__(self, style_dim, channels, eps=1e-5):
|
21 |
+
super().__init__()
|
22 |
+
self.channels = channels
|
23 |
+
self.eps = eps
|
24 |
+
|
25 |
+
self.fc = nn.Linear(style_dim, channels * 2)
|
26 |
+
|
27 |
+
def forward(self, x, s):
|
28 |
+
x = x.transpose(-1, -2)
|
29 |
+
x = x.transpose(1, -1)
|
30 |
+
|
31 |
+
h = self.fc(s)
|
32 |
+
h = h.view(h.size(0), h.size(1), 1)
|
33 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
34 |
+
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
35 |
+
|
36 |
+
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
37 |
+
x = (1 + gamma) * x + beta
|
38 |
+
return x.transpose(1, -1).transpose(-1, -2)
|
39 |
+
|
40 |
+
|
41 |
+
class StyleTransformer1d(nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
num_layers: int,
|
45 |
+
channels: int,
|
46 |
+
num_heads: int,
|
47 |
+
head_features: int,
|
48 |
+
multiplier: int,
|
49 |
+
use_context_time: bool = True,
|
50 |
+
use_rel_pos: bool = False,
|
51 |
+
context_features_multiplier: int = 1,
|
52 |
+
rel_pos_num_buckets: Optional[int] = None,
|
53 |
+
rel_pos_max_distance: Optional[int] = None,
|
54 |
+
context_features: Optional[int] = None,
|
55 |
+
context_embedding_features: Optional[int] = None,
|
56 |
+
embedding_max_length: int = 512,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.blocks = nn.ModuleList(
|
61 |
+
[
|
62 |
+
StyleTransformerBlock(
|
63 |
+
features=channels + context_embedding_features,
|
64 |
+
head_features=head_features,
|
65 |
+
num_heads=num_heads,
|
66 |
+
multiplier=multiplier,
|
67 |
+
style_dim=context_features,
|
68 |
+
use_rel_pos=use_rel_pos,
|
69 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
70 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
71 |
+
)
|
72 |
+
for i in range(num_layers)
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
self.to_out = nn.Sequential(
|
77 |
+
Rearrange("b t c -> b c t"),
|
78 |
+
nn.Conv1d(
|
79 |
+
in_channels=channels + context_embedding_features,
|
80 |
+
out_channels=channels,
|
81 |
+
kernel_size=1,
|
82 |
+
),
|
83 |
+
)
|
84 |
+
|
85 |
+
use_context_features = exists(context_features)
|
86 |
+
self.use_context_features = use_context_features
|
87 |
+
self.use_context_time = use_context_time
|
88 |
+
|
89 |
+
if use_context_time or use_context_features:
|
90 |
+
context_mapping_features = channels + context_embedding_features
|
91 |
+
|
92 |
+
self.to_mapping = nn.Sequential(
|
93 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
94 |
+
nn.GELU(),
|
95 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
96 |
+
nn.GELU(),
|
97 |
+
)
|
98 |
+
|
99 |
+
if use_context_time:
|
100 |
+
assert exists(context_mapping_features)
|
101 |
+
self.to_time = nn.Sequential(
|
102 |
+
TimePositionalEmbedding(
|
103 |
+
dim=channels, out_features=context_mapping_features
|
104 |
+
),
|
105 |
+
nn.GELU(),
|
106 |
+
)
|
107 |
+
|
108 |
+
if use_context_features:
|
109 |
+
assert exists(context_features) and exists(context_mapping_features)
|
110 |
+
self.to_features = nn.Sequential(
|
111 |
+
nn.Linear(
|
112 |
+
in_features=context_features, out_features=context_mapping_features
|
113 |
+
),
|
114 |
+
nn.GELU(),
|
115 |
+
)
|
116 |
+
|
117 |
+
self.fixed_embedding = FixedEmbedding(
|
118 |
+
max_length=embedding_max_length, features=context_embedding_features
|
119 |
+
)
|
120 |
+
|
121 |
+
def get_mapping(
|
122 |
+
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
123 |
+
) -> Optional[Tensor]:
|
124 |
+
"""Combines context time features and features into mapping"""
|
125 |
+
items, mapping = [], None
|
126 |
+
# Compute time features
|
127 |
+
if self.use_context_time:
|
128 |
+
assert_message = "use_context_time=True but no time features provided"
|
129 |
+
assert exists(time), assert_message
|
130 |
+
items += [self.to_time(time)]
|
131 |
+
# Compute features
|
132 |
+
if self.use_context_features:
|
133 |
+
assert_message = "context_features exists but no features provided"
|
134 |
+
assert exists(features), assert_message
|
135 |
+
items += [self.to_features(features)]
|
136 |
+
|
137 |
+
# Compute joint mapping
|
138 |
+
if self.use_context_time or self.use_context_features:
|
139 |
+
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
140 |
+
mapping = self.to_mapping(mapping)
|
141 |
+
|
142 |
+
return mapping
|
143 |
+
|
144 |
+
def run(self, x, time, embedding, features):
|
145 |
+
mapping = self.get_mapping(time, features)
|
146 |
+
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
147 |
+
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
148 |
+
|
149 |
+
for block in self.blocks:
|
150 |
+
x = x + mapping
|
151 |
+
x = block(x, features)
|
152 |
+
|
153 |
+
x = x.mean(axis=1).unsqueeze(1)
|
154 |
+
x = self.to_out(x)
|
155 |
+
x = x.transpose(-1, -2)
|
156 |
+
|
157 |
+
return x
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
x: Tensor,
|
162 |
+
time: Tensor,
|
163 |
+
embedding_mask_proba: float = 0.0,
|
164 |
+
embedding: Optional[Tensor] = None,
|
165 |
+
features: Optional[Tensor] = None,
|
166 |
+
embedding_scale: float = 1.0,
|
167 |
+
) -> Tensor:
|
168 |
+
b, device = embedding.shape[0], embedding.device
|
169 |
+
fixed_embedding = self.fixed_embedding(embedding)
|
170 |
+
if embedding_mask_proba > 0.0:
|
171 |
+
# Randomly mask embedding
|
172 |
+
batch_mask = rand_bool(
|
173 |
+
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
174 |
+
)
|
175 |
+
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
176 |
+
|
177 |
+
if embedding_scale != 1.0:
|
178 |
+
# Compute both normal and fixed embedding outputs
|
179 |
+
out = self.run(x, time, embedding=embedding, features=features)
|
180 |
+
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
181 |
+
# Scale conditional output using classifier-free guidance
|
182 |
+
return out_masked + (out - out_masked) * embedding_scale
|
183 |
+
else:
|
184 |
+
return self.run(x, time, embedding=embedding, features=features)
|
185 |
+
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
class StyleTransformerBlock(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
features: int,
|
193 |
+
num_heads: int,
|
194 |
+
head_features: int,
|
195 |
+
style_dim: int,
|
196 |
+
multiplier: int,
|
197 |
+
use_rel_pos: bool,
|
198 |
+
rel_pos_num_buckets: Optional[int] = None,
|
199 |
+
rel_pos_max_distance: Optional[int] = None,
|
200 |
+
context_features: Optional[int] = None,
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
|
204 |
+
self.use_cross_attention = exists(context_features) and context_features > 0
|
205 |
+
|
206 |
+
self.attention = StyleAttention(
|
207 |
+
features=features,
|
208 |
+
style_dim=style_dim,
|
209 |
+
num_heads=num_heads,
|
210 |
+
head_features=head_features,
|
211 |
+
use_rel_pos=use_rel_pos,
|
212 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
213 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
214 |
+
)
|
215 |
+
|
216 |
+
if self.use_cross_attention:
|
217 |
+
self.cross_attention = StyleAttention(
|
218 |
+
features=features,
|
219 |
+
style_dim=style_dim,
|
220 |
+
num_heads=num_heads,
|
221 |
+
head_features=head_features,
|
222 |
+
context_features=context_features,
|
223 |
+
use_rel_pos=use_rel_pos,
|
224 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
225 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
226 |
+
)
|
227 |
+
|
228 |
+
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
229 |
+
|
230 |
+
def forward(
|
231 |
+
self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None
|
232 |
+
) -> Tensor:
|
233 |
+
x = self.attention(x, s) + x
|
234 |
+
if self.use_cross_attention:
|
235 |
+
x = self.cross_attention(x, s, context=context) + x
|
236 |
+
x = self.feed_forward(x) + x
|
237 |
+
return x
|
238 |
+
|
239 |
+
|
240 |
+
class StyleAttention(nn.Module):
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
features: int,
|
244 |
+
*,
|
245 |
+
style_dim: int,
|
246 |
+
head_features: int,
|
247 |
+
num_heads: int,
|
248 |
+
context_features: Optional[int] = None,
|
249 |
+
use_rel_pos: bool,
|
250 |
+
rel_pos_num_buckets: Optional[int] = None,
|
251 |
+
rel_pos_max_distance: Optional[int] = None,
|
252 |
+
):
|
253 |
+
super().__init__()
|
254 |
+
self.context_features = context_features
|
255 |
+
mid_features = head_features * num_heads
|
256 |
+
context_features = default(context_features, features)
|
257 |
+
|
258 |
+
self.norm = AdaLayerNorm(style_dim, features)
|
259 |
+
self.norm_context = AdaLayerNorm(style_dim, context_features)
|
260 |
+
self.to_q = nn.Linear(
|
261 |
+
in_features=features, out_features=mid_features, bias=False
|
262 |
+
)
|
263 |
+
self.to_kv = nn.Linear(
|
264 |
+
in_features=context_features, out_features=mid_features * 2, bias=False
|
265 |
+
)
|
266 |
+
self.attention = AttentionBase(
|
267 |
+
features,
|
268 |
+
num_heads=num_heads,
|
269 |
+
head_features=head_features,
|
270 |
+
use_rel_pos=use_rel_pos,
|
271 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
272 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
273 |
+
)
|
274 |
+
|
275 |
+
def forward(
|
276 |
+
self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None
|
277 |
+
) -> Tensor:
|
278 |
+
assert_message = "You must provide a context when using context_features"
|
279 |
+
assert not self.context_features or exists(context), assert_message
|
280 |
+
# Use context if provided
|
281 |
+
context = default(context, x)
|
282 |
+
# Normalize then compute q from input and k,v from context
|
283 |
+
x, context = self.norm(x, s), self.norm_context(context, s)
|
284 |
+
|
285 |
+
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
286 |
+
# Compute and return attention
|
287 |
+
return self.attention(q, k, v)
|
288 |
+
|
289 |
+
|
290 |
+
class Transformer1d(nn.Module):
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
num_layers: int,
|
294 |
+
channels: int,
|
295 |
+
num_heads: int,
|
296 |
+
head_features: int,
|
297 |
+
multiplier: int,
|
298 |
+
use_context_time: bool = True,
|
299 |
+
use_rel_pos: bool = False,
|
300 |
+
context_features_multiplier: int = 1,
|
301 |
+
rel_pos_num_buckets: Optional[int] = None,
|
302 |
+
rel_pos_max_distance: Optional[int] = None,
|
303 |
+
context_features: Optional[int] = None,
|
304 |
+
context_embedding_features: Optional[int] = None,
|
305 |
+
embedding_max_length: int = 512,
|
306 |
+
):
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
self.blocks = nn.ModuleList(
|
310 |
+
[
|
311 |
+
TransformerBlock(
|
312 |
+
features=channels + context_embedding_features,
|
313 |
+
head_features=head_features,
|
314 |
+
num_heads=num_heads,
|
315 |
+
multiplier=multiplier,
|
316 |
+
use_rel_pos=use_rel_pos,
|
317 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
318 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
319 |
+
)
|
320 |
+
for i in range(num_layers)
|
321 |
+
]
|
322 |
+
)
|
323 |
+
|
324 |
+
self.to_out = nn.Sequential(
|
325 |
+
Rearrange("b t c -> b c t"),
|
326 |
+
nn.Conv1d(
|
327 |
+
in_channels=channels + context_embedding_features,
|
328 |
+
out_channels=channels,
|
329 |
+
kernel_size=1,
|
330 |
+
),
|
331 |
+
)
|
332 |
+
|
333 |
+
use_context_features = exists(context_features)
|
334 |
+
self.use_context_features = use_context_features
|
335 |
+
self.use_context_time = use_context_time
|
336 |
+
|
337 |
+
if use_context_time or use_context_features:
|
338 |
+
context_mapping_features = channels + context_embedding_features
|
339 |
+
|
340 |
+
self.to_mapping = nn.Sequential(
|
341 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
342 |
+
nn.GELU(),
|
343 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
344 |
+
nn.GELU(),
|
345 |
+
)
|
346 |
+
|
347 |
+
if use_context_time:
|
348 |
+
assert exists(context_mapping_features)
|
349 |
+
self.to_time = nn.Sequential(
|
350 |
+
TimePositionalEmbedding(
|
351 |
+
dim=channels, out_features=context_mapping_features
|
352 |
+
),
|
353 |
+
nn.GELU(),
|
354 |
+
)
|
355 |
+
|
356 |
+
if use_context_features:
|
357 |
+
assert exists(context_features) and exists(context_mapping_features)
|
358 |
+
self.to_features = nn.Sequential(
|
359 |
+
nn.Linear(
|
360 |
+
in_features=context_features, out_features=context_mapping_features
|
361 |
+
),
|
362 |
+
nn.GELU(),
|
363 |
+
)
|
364 |
+
|
365 |
+
self.fixed_embedding = FixedEmbedding(
|
366 |
+
max_length=embedding_max_length, features=context_embedding_features
|
367 |
+
)
|
368 |
+
|
369 |
+
def get_mapping(
|
370 |
+
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
371 |
+
) -> Optional[Tensor]:
|
372 |
+
"""Combines context time features and features into mapping"""
|
373 |
+
items, mapping = [], None
|
374 |
+
# Compute time features
|
375 |
+
if self.use_context_time:
|
376 |
+
assert_message = "use_context_time=True but no time features provided"
|
377 |
+
assert exists(time), assert_message
|
378 |
+
items += [self.to_time(time)]
|
379 |
+
# Compute features
|
380 |
+
if self.use_context_features:
|
381 |
+
assert_message = "context_features exists but no features provided"
|
382 |
+
assert exists(features), assert_message
|
383 |
+
items += [self.to_features(features)]
|
384 |
+
|
385 |
+
# Compute joint mapping
|
386 |
+
if self.use_context_time or self.use_context_features:
|
387 |
+
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
388 |
+
mapping = self.to_mapping(mapping)
|
389 |
+
|
390 |
+
return mapping
|
391 |
+
|
392 |
+
def run(self, x, time, embedding, features):
|
393 |
+
mapping = self.get_mapping(time, features)
|
394 |
+
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
395 |
+
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
396 |
+
|
397 |
+
for block in self.blocks:
|
398 |
+
x = x + mapping
|
399 |
+
x = block(x)
|
400 |
+
|
401 |
+
x = x.mean(axis=1).unsqueeze(1)
|
402 |
+
x = self.to_out(x)
|
403 |
+
x = x.transpose(-1, -2)
|
404 |
+
|
405 |
+
return x
|
406 |
+
|
407 |
+
def forward(
|
408 |
+
self,
|
409 |
+
x: Tensor,
|
410 |
+
time: Tensor,
|
411 |
+
embedding_mask_proba: float = 0.0,
|
412 |
+
embedding: Optional[Tensor] = None,
|
413 |
+
features: Optional[Tensor] = None,
|
414 |
+
embedding_scale: float = 1.0,
|
415 |
+
) -> Tensor:
|
416 |
+
b, device = embedding.shape[0], embedding.device
|
417 |
+
fixed_embedding = self.fixed_embedding(embedding)
|
418 |
+
if embedding_mask_proba > 0.0:
|
419 |
+
# Randomly mask embedding
|
420 |
+
batch_mask = rand_bool(
|
421 |
+
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
422 |
+
)
|
423 |
+
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
424 |
+
|
425 |
+
if embedding_scale != 1.0:
|
426 |
+
# Compute both normal and fixed embedding outputs
|
427 |
+
out = self.run(x, time, embedding=embedding, features=features)
|
428 |
+
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
429 |
+
# Scale conditional output using classifier-free guidance
|
430 |
+
return out_masked + (out - out_masked) * embedding_scale
|
431 |
+
else:
|
432 |
+
return self.run(x, time, embedding=embedding, features=features)
|
433 |
+
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
"""
|
438 |
+
Attention Components
|
439 |
+
"""
|
440 |
+
|
441 |
+
|
442 |
+
class RelativePositionBias(nn.Module):
|
443 |
+
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
|
444 |
+
super().__init__()
|
445 |
+
self.num_buckets = num_buckets
|
446 |
+
self.max_distance = max_distance
|
447 |
+
self.num_heads = num_heads
|
448 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
449 |
+
|
450 |
+
@staticmethod
|
451 |
+
def _relative_position_bucket(
|
452 |
+
relative_position: Tensor, num_buckets: int, max_distance: int
|
453 |
+
):
|
454 |
+
num_buckets //= 2
|
455 |
+
ret = (relative_position >= 0).to(torch.long) * num_buckets
|
456 |
+
n = torch.abs(relative_position)
|
457 |
+
|
458 |
+
max_exact = num_buckets // 2
|
459 |
+
is_small = n < max_exact
|
460 |
+
|
461 |
+
val_if_large = (
|
462 |
+
max_exact
|
463 |
+
+ (
|
464 |
+
torch.log(n.float() / max_exact)
|
465 |
+
/ log(max_distance / max_exact)
|
466 |
+
* (num_buckets - max_exact)
|
467 |
+
).long()
|
468 |
+
)
|
469 |
+
val_if_large = torch.min(
|
470 |
+
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
|
471 |
+
)
|
472 |
+
|
473 |
+
ret += torch.where(is_small, n, val_if_large)
|
474 |
+
return ret
|
475 |
+
|
476 |
+
def forward(self, num_queries: int, num_keys: int) -> Tensor:
|
477 |
+
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
|
478 |
+
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
|
479 |
+
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
480 |
+
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
|
481 |
+
|
482 |
+
relative_position_bucket = self._relative_position_bucket(
|
483 |
+
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
|
484 |
+
)
|
485 |
+
|
486 |
+
bias = self.relative_attention_bias(relative_position_bucket)
|
487 |
+
bias = rearrange(bias, "m n h -> 1 h m n")
|
488 |
+
return bias
|
489 |
+
|
490 |
+
|
491 |
+
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
492 |
+
mid_features = features * multiplier
|
493 |
+
return nn.Sequential(
|
494 |
+
nn.Linear(in_features=features, out_features=mid_features),
|
495 |
+
nn.GELU(),
|
496 |
+
nn.Linear(in_features=mid_features, out_features=features),
|
497 |
+
)
|
498 |
+
|
499 |
+
|
500 |
+
class AttentionBase(nn.Module):
|
501 |
+
def __init__(
|
502 |
+
self,
|
503 |
+
features: int,
|
504 |
+
*,
|
505 |
+
head_features: int,
|
506 |
+
num_heads: int,
|
507 |
+
use_rel_pos: bool,
|
508 |
+
out_features: Optional[int] = None,
|
509 |
+
rel_pos_num_buckets: Optional[int] = None,
|
510 |
+
rel_pos_max_distance: Optional[int] = None,
|
511 |
+
):
|
512 |
+
super().__init__()
|
513 |
+
self.scale = head_features**-0.5
|
514 |
+
self.num_heads = num_heads
|
515 |
+
self.use_rel_pos = use_rel_pos
|
516 |
+
mid_features = head_features * num_heads
|
517 |
+
|
518 |
+
if use_rel_pos:
|
519 |
+
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
|
520 |
+
self.rel_pos = RelativePositionBias(
|
521 |
+
num_buckets=rel_pos_num_buckets,
|
522 |
+
max_distance=rel_pos_max_distance,
|
523 |
+
num_heads=num_heads,
|
524 |
+
)
|
525 |
+
if out_features is None:
|
526 |
+
out_features = features
|
527 |
+
|
528 |
+
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
|
529 |
+
|
530 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
531 |
+
# Split heads
|
532 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
533 |
+
# Compute similarity matrix
|
534 |
+
sim = einsum("... n d, ... m d -> ... n m", q, k)
|
535 |
+
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
|
536 |
+
sim = sim * self.scale
|
537 |
+
# Get attention matrix with softmax
|
538 |
+
attn = sim.softmax(dim=-1)
|
539 |
+
# Compute values
|
540 |
+
out = einsum("... n m, ... m d -> ... n d", attn, v)
|
541 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
542 |
+
return self.to_out(out)
|
543 |
+
|
544 |
+
|
545 |
+
class Attention(nn.Module):
|
546 |
+
def __init__(
|
547 |
+
self,
|
548 |
+
features: int,
|
549 |
+
*,
|
550 |
+
head_features: int,
|
551 |
+
num_heads: int,
|
552 |
+
out_features: Optional[int] = None,
|
553 |
+
context_features: Optional[int] = None,
|
554 |
+
use_rel_pos: bool,
|
555 |
+
rel_pos_num_buckets: Optional[int] = None,
|
556 |
+
rel_pos_max_distance: Optional[int] = None,
|
557 |
+
):
|
558 |
+
super().__init__()
|
559 |
+
self.context_features = context_features
|
560 |
+
mid_features = head_features * num_heads
|
561 |
+
context_features = default(context_features, features)
|
562 |
+
|
563 |
+
self.norm = nn.LayerNorm(features)
|
564 |
+
self.norm_context = nn.LayerNorm(context_features)
|
565 |
+
self.to_q = nn.Linear(
|
566 |
+
in_features=features, out_features=mid_features, bias=False
|
567 |
+
)
|
568 |
+
self.to_kv = nn.Linear(
|
569 |
+
in_features=context_features, out_features=mid_features * 2, bias=False
|
570 |
+
)
|
571 |
+
|
572 |
+
self.attention = AttentionBase(
|
573 |
+
features,
|
574 |
+
out_features=out_features,
|
575 |
+
num_heads=num_heads,
|
576 |
+
head_features=head_features,
|
577 |
+
use_rel_pos=use_rel_pos,
|
578 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
579 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
580 |
+
)
|
581 |
+
|
582 |
+
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
583 |
+
assert_message = "You must provide a context when using context_features"
|
584 |
+
assert not self.context_features or exists(context), assert_message
|
585 |
+
# Use context if provided
|
586 |
+
context = default(context, x)
|
587 |
+
# Normalize then compute q from input and k,v from context
|
588 |
+
x, context = self.norm(x), self.norm_context(context)
|
589 |
+
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
590 |
+
# Compute and return attention
|
591 |
+
return self.attention(q, k, v)
|
592 |
+
|
593 |
+
|
594 |
+
"""
|
595 |
+
Transformer Blocks
|
596 |
+
"""
|
597 |
+
|
598 |
+
|
599 |
+
class TransformerBlock(nn.Module):
|
600 |
+
def __init__(
|
601 |
+
self,
|
602 |
+
features: int,
|
603 |
+
num_heads: int,
|
604 |
+
head_features: int,
|
605 |
+
multiplier: int,
|
606 |
+
use_rel_pos: bool,
|
607 |
+
rel_pos_num_buckets: Optional[int] = None,
|
608 |
+
rel_pos_max_distance: Optional[int] = None,
|
609 |
+
context_features: Optional[int] = None,
|
610 |
+
):
|
611 |
+
super().__init__()
|
612 |
+
|
613 |
+
self.use_cross_attention = exists(context_features) and context_features > 0
|
614 |
+
|
615 |
+
self.attention = Attention(
|
616 |
+
features=features,
|
617 |
+
num_heads=num_heads,
|
618 |
+
head_features=head_features,
|
619 |
+
use_rel_pos=use_rel_pos,
|
620 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
621 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
622 |
+
)
|
623 |
+
|
624 |
+
if self.use_cross_attention:
|
625 |
+
self.cross_attention = Attention(
|
626 |
+
features=features,
|
627 |
+
num_heads=num_heads,
|
628 |
+
head_features=head_features,
|
629 |
+
context_features=context_features,
|
630 |
+
use_rel_pos=use_rel_pos,
|
631 |
+
rel_pos_num_buckets=rel_pos_num_buckets,
|
632 |
+
rel_pos_max_distance=rel_pos_max_distance,
|
633 |
+
)
|
634 |
+
|
635 |
+
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
636 |
+
|
637 |
+
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
638 |
+
x = self.attention(x) + x
|
639 |
+
if self.use_cross_attention:
|
640 |
+
x = self.cross_attention(x, context=context) + x
|
641 |
+
x = self.feed_forward(x) + x
|
642 |
+
return x
|
643 |
+
|
644 |
+
|
645 |
+
"""
|
646 |
+
Time Embeddings
|
647 |
+
"""
|
648 |
+
|
649 |
+
|
650 |
+
class SinusoidalEmbedding(nn.Module):
|
651 |
+
def __init__(self, dim: int):
|
652 |
+
super().__init__()
|
653 |
+
self.dim = dim
|
654 |
+
|
655 |
+
def forward(self, x: Tensor) -> Tensor:
|
656 |
+
device, half_dim = x.device, self.dim // 2
|
657 |
+
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
|
658 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
659 |
+
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
|
660 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
661 |
+
|
662 |
+
|
663 |
+
class LearnedPositionalEmbedding(nn.Module):
|
664 |
+
"""Used for continuous time"""
|
665 |
+
|
666 |
+
def __init__(self, dim: int):
|
667 |
+
super().__init__()
|
668 |
+
assert (dim % 2) == 0
|
669 |
+
half_dim = dim // 2
|
670 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
671 |
+
|
672 |
+
def forward(self, x: Tensor) -> Tensor:
|
673 |
+
x = rearrange(x, "b -> b 1")
|
674 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
675 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
676 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
677 |
+
return fouriered
|
678 |
+
|
679 |
+
|
680 |
+
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
681 |
+
return nn.Sequential(
|
682 |
+
LearnedPositionalEmbedding(dim),
|
683 |
+
nn.Linear(in_features=dim + 1, out_features=out_features),
|
684 |
+
)
|
685 |
+
|
686 |
+
|
687 |
+
class FixedEmbedding(nn.Module):
|
688 |
+
def __init__(self, max_length: int, features: int):
|
689 |
+
super().__init__()
|
690 |
+
self.max_length = max_length
|
691 |
+
self.embedding = nn.Embedding(max_length, features)
|
692 |
+
|
693 |
+
def forward(self, x: Tensor) -> Tensor:
|
694 |
+
batch_size, length, device = *x.shape[0:2], x.device
|
695 |
+
assert_message = "Input sequence length must be <= max_length"
|
696 |
+
assert length <= self.max_length, assert_message
|
697 |
+
position = torch.arange(length, device=device)
|
698 |
+
fixed_embedding = self.embedding(position)
|
699 |
+
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
|
700 |
+
return fixed_embedding
|
Modules/diffusion/sampler.py
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import atan, cos, pi, sin, sqrt
|
2 |
+
from typing import Any, Callable, List, Optional, Tuple, Type
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange, reduce
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from .utils import *
|
11 |
+
|
12 |
+
"""
|
13 |
+
Diffusion Training
|
14 |
+
"""
|
15 |
+
|
16 |
+
""" Distributions """
|
17 |
+
|
18 |
+
|
19 |
+
class Distribution:
|
20 |
+
def __call__(self, num_samples: int, device: torch.device):
|
21 |
+
raise NotImplementedError()
|
22 |
+
|
23 |
+
|
24 |
+
class LogNormalDistribution(Distribution):
|
25 |
+
def __init__(self, mean: float, std: float):
|
26 |
+
self.mean = mean
|
27 |
+
self.std = std
|
28 |
+
|
29 |
+
def __call__(
|
30 |
+
self, num_samples: int, device: torch.device = torch.device("cpu")
|
31 |
+
) -> Tensor:
|
32 |
+
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
|
33 |
+
return normal.exp()
|
34 |
+
|
35 |
+
|
36 |
+
class UniformDistribution(Distribution):
|
37 |
+
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
38 |
+
return torch.rand(num_samples, device=device)
|
39 |
+
|
40 |
+
|
41 |
+
class VKDistribution(Distribution):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
min_value: float = 0.0,
|
45 |
+
max_value: float = float("inf"),
|
46 |
+
sigma_data: float = 1.0,
|
47 |
+
):
|
48 |
+
self.min_value = min_value
|
49 |
+
self.max_value = max_value
|
50 |
+
self.sigma_data = sigma_data
|
51 |
+
|
52 |
+
def __call__(
|
53 |
+
self, num_samples: int, device: torch.device = torch.device("cpu")
|
54 |
+
) -> Tensor:
|
55 |
+
sigma_data = self.sigma_data
|
56 |
+
min_cdf = atan(self.min_value / sigma_data) * 2 / pi
|
57 |
+
max_cdf = atan(self.max_value / sigma_data) * 2 / pi
|
58 |
+
u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
|
59 |
+
return torch.tan(u * pi / 2) * sigma_data
|
60 |
+
|
61 |
+
|
62 |
+
""" Diffusion Classes """
|
63 |
+
|
64 |
+
|
65 |
+
def pad_dims(x: Tensor, ndim: int) -> Tensor:
|
66 |
+
# Pads additional ndims to the right of the tensor
|
67 |
+
return x.view(*x.shape, *((1,) * ndim))
|
68 |
+
|
69 |
+
|
70 |
+
def clip(x: Tensor, dynamic_threshold: float = 0.0):
|
71 |
+
if dynamic_threshold == 0.0:
|
72 |
+
return x.clamp(-1.0, 1.0)
|
73 |
+
else:
|
74 |
+
# Dynamic thresholding
|
75 |
+
# Find dynamic threshold quantile for each batch
|
76 |
+
x_flat = rearrange(x, "b ... -> b (...)")
|
77 |
+
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
|
78 |
+
# Clamp to a min of 1.0
|
79 |
+
scale.clamp_(min=1.0)
|
80 |
+
# Clamp all values and scale
|
81 |
+
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
|
82 |
+
x = x.clamp(-scale, scale) / scale
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
def to_batch(
|
87 |
+
batch_size: int,
|
88 |
+
device: torch.device,
|
89 |
+
x: Optional[float] = None,
|
90 |
+
xs: Optional[Tensor] = None,
|
91 |
+
) -> Tensor:
|
92 |
+
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
93 |
+
# If x provided use the same for all batch items
|
94 |
+
if exists(x):
|
95 |
+
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
96 |
+
assert exists(xs)
|
97 |
+
return xs
|
98 |
+
|
99 |
+
|
100 |
+
class Diffusion(nn.Module):
|
101 |
+
alias: str = ""
|
102 |
+
|
103 |
+
"""Base diffusion class"""
|
104 |
+
|
105 |
+
def denoise_fn(
|
106 |
+
self,
|
107 |
+
x_noisy: Tensor,
|
108 |
+
sigmas: Optional[Tensor] = None,
|
109 |
+
sigma: Optional[float] = None,
|
110 |
+
**kwargs,
|
111 |
+
) -> Tensor:
|
112 |
+
raise NotImplementedError("Diffusion class missing denoise_fn")
|
113 |
+
|
114 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
115 |
+
raise NotImplementedError("Diffusion class missing forward function")
|
116 |
+
|
117 |
+
|
118 |
+
class VDiffusion(Diffusion):
|
119 |
+
alias = "v"
|
120 |
+
|
121 |
+
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
|
122 |
+
super().__init__()
|
123 |
+
self.net = net
|
124 |
+
self.sigma_distribution = sigma_distribution
|
125 |
+
|
126 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
127 |
+
angle = sigmas * pi / 2
|
128 |
+
alpha = torch.cos(angle)
|
129 |
+
beta = torch.sin(angle)
|
130 |
+
return alpha, beta
|
131 |
+
|
132 |
+
def denoise_fn(
|
133 |
+
self,
|
134 |
+
x_noisy: Tensor,
|
135 |
+
sigmas: Optional[Tensor] = None,
|
136 |
+
sigma: Optional[float] = None,
|
137 |
+
**kwargs,
|
138 |
+
) -> Tensor:
|
139 |
+
batch_size, device = x_noisy.shape[0], x_noisy.device
|
140 |
+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
141 |
+
return self.net(x_noisy, sigmas, **kwargs)
|
142 |
+
|
143 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
144 |
+
batch_size, device = x.shape[0], x.device
|
145 |
+
|
146 |
+
# Sample amount of noise to add for each batch element
|
147 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
148 |
+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
|
149 |
+
|
150 |
+
# Get noise
|
151 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
152 |
+
|
153 |
+
# Combine input and noise weighted by half-circle
|
154 |
+
alpha, beta = self.get_alpha_beta(sigmas_padded)
|
155 |
+
x_noisy = x * alpha + noise * beta
|
156 |
+
x_target = noise * alpha - x * beta
|
157 |
+
|
158 |
+
# Denoise and return loss
|
159 |
+
x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
|
160 |
+
return F.mse_loss(x_denoised, x_target)
|
161 |
+
|
162 |
+
|
163 |
+
class KDiffusion(Diffusion):
|
164 |
+
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
|
165 |
+
|
166 |
+
alias = "k"
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
net: nn.Module,
|
171 |
+
*,
|
172 |
+
sigma_distribution: Distribution,
|
173 |
+
sigma_data: float, # data distribution standard deviation
|
174 |
+
dynamic_threshold: float = 0.0,
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
self.net = net
|
178 |
+
self.sigma_data = sigma_data
|
179 |
+
self.sigma_distribution = sigma_distribution
|
180 |
+
self.dynamic_threshold = dynamic_threshold
|
181 |
+
|
182 |
+
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
|
183 |
+
sigma_data = self.sigma_data
|
184 |
+
c_noise = torch.log(sigmas) * 0.25
|
185 |
+
sigmas = rearrange(sigmas, "b -> b 1 1")
|
186 |
+
c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2)
|
187 |
+
c_out = sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5
|
188 |
+
c_in = (sigmas**2 + sigma_data**2) ** -0.5
|
189 |
+
return c_skip, c_out, c_in, c_noise
|
190 |
+
|
191 |
+
def denoise_fn(
|
192 |
+
self,
|
193 |
+
x_noisy: Tensor,
|
194 |
+
sigmas: Optional[Tensor] = None,
|
195 |
+
sigma: Optional[float] = None,
|
196 |
+
**kwargs,
|
197 |
+
) -> Tensor:
|
198 |
+
batch_size, device = x_noisy.shape[0], x_noisy.device
|
199 |
+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
200 |
+
|
201 |
+
# Predict network output and add skip connection
|
202 |
+
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
|
203 |
+
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
|
204 |
+
x_denoised = c_skip * x_noisy + c_out * x_pred
|
205 |
+
|
206 |
+
return x_denoised
|
207 |
+
|
208 |
+
def loss_weight(self, sigmas: Tensor) -> Tensor:
|
209 |
+
# Computes weight depending on data distribution
|
210 |
+
return (sigmas**2 + self.sigma_data**2) * (sigmas * self.sigma_data) ** -2
|
211 |
+
|
212 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
213 |
+
batch_size, device = x.shape[0], x.device
|
214 |
+
from einops import rearrange, reduce
|
215 |
+
|
216 |
+
# Sample amount of noise to add for each batch element
|
217 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
218 |
+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
|
219 |
+
|
220 |
+
# Add noise to input
|
221 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
222 |
+
x_noisy = x + sigmas_padded * noise
|
223 |
+
|
224 |
+
# Compute denoised values
|
225 |
+
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
|
226 |
+
|
227 |
+
# Compute weighted loss
|
228 |
+
losses = F.mse_loss(x_denoised, x, reduction="none")
|
229 |
+
losses = reduce(losses, "b ... -> b", "mean")
|
230 |
+
losses = losses * self.loss_weight(sigmas)
|
231 |
+
loss = losses.mean()
|
232 |
+
return loss
|
233 |
+
|
234 |
+
|
235 |
+
class VKDiffusion(Diffusion):
|
236 |
+
alias = "vk"
|
237 |
+
|
238 |
+
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
|
239 |
+
super().__init__()
|
240 |
+
self.net = net
|
241 |
+
self.sigma_distribution = sigma_distribution
|
242 |
+
|
243 |
+
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
|
244 |
+
sigma_data = 1.0
|
245 |
+
sigmas = rearrange(sigmas, "b -> b 1 1")
|
246 |
+
c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2)
|
247 |
+
c_out = -sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5
|
248 |
+
c_in = (sigmas**2 + sigma_data**2) ** -0.5
|
249 |
+
return c_skip, c_out, c_in
|
250 |
+
|
251 |
+
def sigma_to_t(self, sigmas: Tensor) -> Tensor:
|
252 |
+
return sigmas.atan() / pi * 2
|
253 |
+
|
254 |
+
def t_to_sigma(self, t: Tensor) -> Tensor:
|
255 |
+
return (t * pi / 2).tan()
|
256 |
+
|
257 |
+
def denoise_fn(
|
258 |
+
self,
|
259 |
+
x_noisy: Tensor,
|
260 |
+
sigmas: Optional[Tensor] = None,
|
261 |
+
sigma: Optional[float] = None,
|
262 |
+
**kwargs,
|
263 |
+
) -> Tensor:
|
264 |
+
batch_size, device = x_noisy.shape[0], x_noisy.device
|
265 |
+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
266 |
+
|
267 |
+
# Predict network output and add skip connection
|
268 |
+
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
|
269 |
+
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
|
270 |
+
x_denoised = c_skip * x_noisy + c_out * x_pred
|
271 |
+
return x_denoised
|
272 |
+
|
273 |
+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
|
274 |
+
batch_size, device = x.shape[0], x.device
|
275 |
+
|
276 |
+
# Sample amount of noise to add for each batch element
|
277 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
278 |
+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
|
279 |
+
|
280 |
+
# Add noise to input
|
281 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
282 |
+
x_noisy = x + sigmas_padded * noise
|
283 |
+
|
284 |
+
# Compute model output
|
285 |
+
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
|
286 |
+
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
|
287 |
+
|
288 |
+
# Compute v-objective target
|
289 |
+
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
|
290 |
+
|
291 |
+
# Compute loss
|
292 |
+
loss = F.mse_loss(x_pred, v_target)
|
293 |
+
return loss
|
294 |
+
|
295 |
+
|
296 |
+
"""
|
297 |
+
Diffusion Sampling
|
298 |
+
"""
|
299 |
+
|
300 |
+
""" Schedules """
|
301 |
+
|
302 |
+
|
303 |
+
class Schedule(nn.Module):
|
304 |
+
"""Interface used by different sampling schedules"""
|
305 |
+
|
306 |
+
def forward(self, num_steps: int, device: torch.device) -> Tensor:
|
307 |
+
raise NotImplementedError()
|
308 |
+
|
309 |
+
|
310 |
+
class LinearSchedule(Schedule):
|
311 |
+
def forward(self, num_steps: int, device: Any) -> Tensor:
|
312 |
+
sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
|
313 |
+
return sigmas
|
314 |
+
|
315 |
+
|
316 |
+
class KarrasSchedule(Schedule):
|
317 |
+
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
318 |
+
|
319 |
+
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
|
320 |
+
super().__init__()
|
321 |
+
self.sigma_min = sigma_min
|
322 |
+
self.sigma_max = sigma_max
|
323 |
+
self.rho = rho
|
324 |
+
|
325 |
+
def forward(self, num_steps: int, device: Any) -> Tensor:
|
326 |
+
rho_inv = 1.0 / self.rho
|
327 |
+
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
328 |
+
sigmas = (
|
329 |
+
self.sigma_max**rho_inv
|
330 |
+
+ (steps / (num_steps - 1))
|
331 |
+
* (self.sigma_min**rho_inv - self.sigma_max**rho_inv)
|
332 |
+
) ** self.rho
|
333 |
+
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
334 |
+
return sigmas
|
335 |
+
|
336 |
+
|
337 |
+
""" Samplers """
|
338 |
+
|
339 |
+
|
340 |
+
class Sampler(nn.Module):
|
341 |
+
diffusion_types: List[Type[Diffusion]] = []
|
342 |
+
|
343 |
+
def forward(
|
344 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
345 |
+
) -> Tensor:
|
346 |
+
raise NotImplementedError()
|
347 |
+
|
348 |
+
def inpaint(
|
349 |
+
self,
|
350 |
+
source: Tensor,
|
351 |
+
mask: Tensor,
|
352 |
+
fn: Callable,
|
353 |
+
sigmas: Tensor,
|
354 |
+
num_steps: int,
|
355 |
+
num_resamples: int,
|
356 |
+
) -> Tensor:
|
357 |
+
raise NotImplementedError("Inpainting not available with current sampler")
|
358 |
+
|
359 |
+
|
360 |
+
class VSampler(Sampler):
|
361 |
+
diffusion_types = [VDiffusion]
|
362 |
+
|
363 |
+
def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
|
364 |
+
angle = sigma * pi / 2
|
365 |
+
alpha = cos(angle)
|
366 |
+
beta = sin(angle)
|
367 |
+
return alpha, beta
|
368 |
+
|
369 |
+
def forward(
|
370 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
371 |
+
) -> Tensor:
|
372 |
+
x = sigmas[0] * noise
|
373 |
+
alpha, beta = self.get_alpha_beta(sigmas[0].item())
|
374 |
+
|
375 |
+
for i in range(num_steps - 1):
|
376 |
+
is_last = i == num_steps - 1
|
377 |
+
|
378 |
+
x_denoised = fn(x, sigma=sigmas[i])
|
379 |
+
x_pred = x * alpha - x_denoised * beta
|
380 |
+
x_eps = x * beta + x_denoised * alpha
|
381 |
+
|
382 |
+
if not is_last:
|
383 |
+
alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
|
384 |
+
x = x_pred * alpha + x_eps * beta
|
385 |
+
|
386 |
+
return x_pred
|
387 |
+
|
388 |
+
|
389 |
+
class KarrasSampler(Sampler):
|
390 |
+
"""https://arxiv.org/abs/2206.00364 algorithm 1"""
|
391 |
+
|
392 |
+
diffusion_types = [KDiffusion, VKDiffusion]
|
393 |
+
|
394 |
+
def __init__(
|
395 |
+
self,
|
396 |
+
s_tmin: float = 0,
|
397 |
+
s_tmax: float = float("inf"),
|
398 |
+
s_churn: float = 0.0,
|
399 |
+
s_noise: float = 1.0,
|
400 |
+
):
|
401 |
+
super().__init__()
|
402 |
+
self.s_tmin = s_tmin
|
403 |
+
self.s_tmax = s_tmax
|
404 |
+
self.s_noise = s_noise
|
405 |
+
self.s_churn = s_churn
|
406 |
+
|
407 |
+
def step(
|
408 |
+
self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
|
409 |
+
) -> Tensor:
|
410 |
+
"""Algorithm 2 (step)"""
|
411 |
+
# Select temporarily increased noise level
|
412 |
+
sigma_hat = sigma + gamma * sigma
|
413 |
+
# Add noise to move from sigma to sigma_hat
|
414 |
+
epsilon = self.s_noise * torch.randn_like(x)
|
415 |
+
x_hat = x + sqrt(sigma_hat**2 - sigma**2) * epsilon
|
416 |
+
# Evaluate ∂x/∂sigma at sigma_hat
|
417 |
+
d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
|
418 |
+
# Take euler step from sigma_hat to sigma_next
|
419 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d
|
420 |
+
# Second order correction
|
421 |
+
if sigma_next != 0:
|
422 |
+
model_out_next = fn(x_next, sigma=sigma_next)
|
423 |
+
d_prime = (x_next - model_out_next) / sigma_next
|
424 |
+
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
|
425 |
+
return x_next
|
426 |
+
|
427 |
+
def forward(
|
428 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
429 |
+
) -> Tensor:
|
430 |
+
x = sigmas[0] * noise
|
431 |
+
# Compute gammas
|
432 |
+
gammas = torch.where(
|
433 |
+
(sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
|
434 |
+
min(self.s_churn / num_steps, sqrt(2) - 1),
|
435 |
+
0.0,
|
436 |
+
)
|
437 |
+
# Denoise to sample
|
438 |
+
for i in range(num_steps - 1):
|
439 |
+
x = self.step(
|
440 |
+
x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
|
441 |
+
)
|
442 |
+
|
443 |
+
return x
|
444 |
+
|
445 |
+
|
446 |
+
class AEulerSampler(Sampler):
|
447 |
+
diffusion_types = [KDiffusion, VKDiffusion]
|
448 |
+
|
449 |
+
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
|
450 |
+
sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2)
|
451 |
+
sigma_down = sqrt(sigma_next**2 - sigma_up**2)
|
452 |
+
return sigma_up, sigma_down
|
453 |
+
|
454 |
+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
|
455 |
+
# Sigma steps
|
456 |
+
sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
|
457 |
+
# Derivative at sigma (∂x/∂sigma)
|
458 |
+
d = (x - fn(x, sigma=sigma)) / sigma
|
459 |
+
# Euler method
|
460 |
+
x_next = x + d * (sigma_down - sigma)
|
461 |
+
# Add randomness
|
462 |
+
x_next = x_next + torch.randn_like(x) * sigma_up
|
463 |
+
return x_next
|
464 |
+
|
465 |
+
def forward(
|
466 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
467 |
+
) -> Tensor:
|
468 |
+
x = sigmas[0] * noise
|
469 |
+
# Denoise to sample
|
470 |
+
for i in range(num_steps - 1):
|
471 |
+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
472 |
+
return x
|
473 |
+
|
474 |
+
|
475 |
+
class ADPM2Sampler(Sampler):
|
476 |
+
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
477 |
+
|
478 |
+
diffusion_types = [KDiffusion, VKDiffusion]
|
479 |
+
|
480 |
+
def __init__(self, rho: float = 1.0):
|
481 |
+
super().__init__()
|
482 |
+
self.rho = rho
|
483 |
+
|
484 |
+
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
|
485 |
+
r = self.rho
|
486 |
+
sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2)
|
487 |
+
sigma_down = sqrt(sigma_next**2 - sigma_up**2)
|
488 |
+
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
489 |
+
return sigma_up, sigma_down, sigma_mid
|
490 |
+
|
491 |
+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
|
492 |
+
# Sigma steps
|
493 |
+
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
494 |
+
# Derivative at sigma (∂x/∂sigma)
|
495 |
+
d = (x - fn(x, sigma=sigma)) / sigma
|
496 |
+
# Denoise to midpoint
|
497 |
+
x_mid = x + d * (sigma_mid - sigma)
|
498 |
+
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
|
499 |
+
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
|
500 |
+
# Denoise to next
|
501 |
+
x = x + d_mid * (sigma_down - sigma)
|
502 |
+
# Add randomness
|
503 |
+
x_next = x + torch.randn_like(x) * sigma_up
|
504 |
+
return x_next
|
505 |
+
|
506 |
+
def forward(
|
507 |
+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
508 |
+
) -> Tensor:
|
509 |
+
x = sigmas[0] * noise
|
510 |
+
# Denoise to sample
|
511 |
+
for i in range(num_steps - 1):
|
512 |
+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
513 |
+
return x
|
514 |
+
|
515 |
+
def inpaint(
|
516 |
+
self,
|
517 |
+
source: Tensor,
|
518 |
+
mask: Tensor,
|
519 |
+
fn: Callable,
|
520 |
+
sigmas: Tensor,
|
521 |
+
num_steps: int,
|
522 |
+
num_resamples: int,
|
523 |
+
) -> Tensor:
|
524 |
+
x = sigmas[0] * torch.randn_like(source)
|
525 |
+
|
526 |
+
for i in range(num_steps - 1):
|
527 |
+
# Noise source to current noise level
|
528 |
+
source_noisy = source + sigmas[i] * torch.randn_like(source)
|
529 |
+
for r in range(num_resamples):
|
530 |
+
# Merge noisy source and current then denoise
|
531 |
+
x = source_noisy * mask + x * ~mask
|
532 |
+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
533 |
+
# Renoise if not last resample step
|
534 |
+
if r < num_resamples - 1:
|
535 |
+
sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
|
536 |
+
x = x + sigma * torch.randn_like(x)
|
537 |
+
|
538 |
+
return source * mask + x * ~mask
|
539 |
+
|
540 |
+
|
541 |
+
""" Main Classes """
|
542 |
+
|
543 |
+
|
544 |
+
class DiffusionSampler(nn.Module):
|
545 |
+
def __init__(
|
546 |
+
self,
|
547 |
+
diffusion: Diffusion,
|
548 |
+
*,
|
549 |
+
sampler: Sampler,
|
550 |
+
sigma_schedule: Schedule,
|
551 |
+
num_steps: Optional[int] = None,
|
552 |
+
clamp: bool = True,
|
553 |
+
):
|
554 |
+
super().__init__()
|
555 |
+
self.denoise_fn = diffusion.denoise_fn
|
556 |
+
self.sampler = sampler
|
557 |
+
self.sigma_schedule = sigma_schedule
|
558 |
+
self.num_steps = num_steps
|
559 |
+
self.clamp = clamp
|
560 |
+
|
561 |
+
# Check sampler is compatible with diffusion type
|
562 |
+
sampler_class = sampler.__class__.__name__
|
563 |
+
diffusion_class = diffusion.__class__.__name__
|
564 |
+
message = f"{sampler_class} incompatible with {diffusion_class}"
|
565 |
+
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
|
566 |
+
|
567 |
+
def forward(
|
568 |
+
self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
|
569 |
+
) -> Tensor:
|
570 |
+
device = noise.device
|
571 |
+
num_steps = default(num_steps, self.num_steps) # type: ignore
|
572 |
+
assert exists(num_steps), "Parameter `num_steps` must be provided"
|
573 |
+
# Compute sigmas using schedule
|
574 |
+
sigmas = self.sigma_schedule(num_steps, device)
|
575 |
+
# Append additional kwargs to denoise function (used e.g. for conditional unet)
|
576 |
+
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
577 |
+
# Sample using sampler
|
578 |
+
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
579 |
+
x = x.clamp(-1.0, 1.0) if self.clamp else x
|
580 |
+
return x
|
581 |
+
|
582 |
+
|
583 |
+
class DiffusionInpainter(nn.Module):
|
584 |
+
def __init__(
|
585 |
+
self,
|
586 |
+
diffusion: Diffusion,
|
587 |
+
*,
|
588 |
+
num_steps: int,
|
589 |
+
num_resamples: int,
|
590 |
+
sampler: Sampler,
|
591 |
+
sigma_schedule: Schedule,
|
592 |
+
):
|
593 |
+
super().__init__()
|
594 |
+
self.denoise_fn = diffusion.denoise_fn
|
595 |
+
self.num_steps = num_steps
|
596 |
+
self.num_resamples = num_resamples
|
597 |
+
self.inpaint_fn = sampler.inpaint
|
598 |
+
self.sigma_schedule = sigma_schedule
|
599 |
+
|
600 |
+
@torch.no_grad()
|
601 |
+
def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
|
602 |
+
x = self.inpaint_fn(
|
603 |
+
source=inpaint,
|
604 |
+
mask=inpaint_mask,
|
605 |
+
fn=self.denoise_fn,
|
606 |
+
sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
|
607 |
+
num_steps=self.num_steps,
|
608 |
+
num_resamples=self.num_resamples,
|
609 |
+
)
|
610 |
+
return x
|
611 |
+
|
612 |
+
|
613 |
+
def sequential_mask(like: Tensor, start: int) -> Tensor:
|
614 |
+
length, device = like.shape[2], like.device
|
615 |
+
mask = torch.ones_like(like, dtype=torch.bool)
|
616 |
+
mask[:, :, start:] = torch.zeros((length - start,), device=device)
|
617 |
+
return mask
|
618 |
+
|
619 |
+
|
620 |
+
class SpanBySpanComposer(nn.Module):
|
621 |
+
def __init__(
|
622 |
+
self,
|
623 |
+
inpainter: DiffusionInpainter,
|
624 |
+
*,
|
625 |
+
num_spans: int,
|
626 |
+
):
|
627 |
+
super().__init__()
|
628 |
+
self.inpainter = inpainter
|
629 |
+
self.num_spans = num_spans
|
630 |
+
|
631 |
+
def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
|
632 |
+
half_length = start.shape[2] // 2
|
633 |
+
|
634 |
+
spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
|
635 |
+
# Inpaint second half from first half
|
636 |
+
inpaint = torch.zeros_like(start)
|
637 |
+
inpaint[:, :, :half_length] = start[:, :, half_length:]
|
638 |
+
inpaint_mask = sequential_mask(like=start, start=half_length)
|
639 |
+
|
640 |
+
for i in range(self.num_spans):
|
641 |
+
# Inpaint second half
|
642 |
+
span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
|
643 |
+
# Replace first half with generated second half
|
644 |
+
second_half = span[:, :, half_length:]
|
645 |
+
inpaint[:, :, :half_length] = second_half
|
646 |
+
# Save generated span
|
647 |
+
spans.append(second_half)
|
648 |
+
|
649 |
+
return torch.cat(spans, dim=2)
|
650 |
+
|
651 |
+
|
652 |
+
class XDiffusion(nn.Module):
|
653 |
+
def __init__(self, type: str, net: nn.Module, **kwargs):
|
654 |
+
super().__init__()
|
655 |
+
|
656 |
+
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
|
657 |
+
aliases = [t.alias for t in diffusion_classes] # type: ignore
|
658 |
+
message = f"type='{type}' must be one of {*aliases,}"
|
659 |
+
assert type in aliases, message
|
660 |
+
self.net = net
|
661 |
+
|
662 |
+
for XDiffusion in diffusion_classes:
|
663 |
+
if XDiffusion.alias == type: # type: ignore
|
664 |
+
self.diffusion = XDiffusion(net=net, **kwargs)
|
665 |
+
|
666 |
+
def forward(self, *args, **kwargs) -> Tensor:
|
667 |
+
return self.diffusion(*args, **kwargs)
|
668 |
+
|
669 |
+
def sample(
|
670 |
+
self,
|
671 |
+
noise: Tensor,
|
672 |
+
num_steps: int,
|
673 |
+
sigma_schedule: Schedule,
|
674 |
+
sampler: Sampler,
|
675 |
+
clamp: bool,
|
676 |
+
**kwargs,
|
677 |
+
) -> Tensor:
|
678 |
+
diffusion_sampler = DiffusionSampler(
|
679 |
+
diffusion=self.diffusion,
|
680 |
+
sampler=sampler,
|
681 |
+
sigma_schedule=sigma_schedule,
|
682 |
+
num_steps=num_steps,
|
683 |
+
clamp=clamp,
|
684 |
+
)
|
685 |
+
return diffusion_sampler(noise, **kwargs)
|
Modules/diffusion/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
from inspect import isfunction
|
3 |
+
from math import ceil, floor, log2, pi
|
4 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import Generator, Tensor
|
10 |
+
from typing_extensions import TypeGuard
|
11 |
+
|
12 |
+
T = TypeVar("T")
|
13 |
+
|
14 |
+
|
15 |
+
def exists(val: Optional[T]) -> TypeGuard[T]:
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
|
19 |
+
def iff(condition: bool, value: T) -> Optional[T]:
|
20 |
+
return value if condition else None
|
21 |
+
|
22 |
+
|
23 |
+
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
24 |
+
return isinstance(obj, list) or isinstance(obj, tuple)
|
25 |
+
|
26 |
+
|
27 |
+
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
28 |
+
if exists(val):
|
29 |
+
return val
|
30 |
+
return d() if isfunction(d) else d
|
31 |
+
|
32 |
+
|
33 |
+
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
34 |
+
if isinstance(val, tuple):
|
35 |
+
return list(val)
|
36 |
+
if isinstance(val, list):
|
37 |
+
return val
|
38 |
+
return [val] # type: ignore
|
39 |
+
|
40 |
+
|
41 |
+
def prod(vals: Sequence[int]) -> int:
|
42 |
+
return reduce(lambda x, y: x * y, vals)
|
43 |
+
|
44 |
+
|
45 |
+
def closest_power_2(x: float) -> int:
|
46 |
+
exponent = log2(x)
|
47 |
+
distance_fn = lambda z: abs(x - 2**z) # noqa
|
48 |
+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
49 |
+
return 2 ** int(exponent_closest)
|
50 |
+
|
51 |
+
|
52 |
+
def rand_bool(shape, proba, device=None):
|
53 |
+
if proba == 1:
|
54 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
55 |
+
elif proba == 0:
|
56 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
57 |
+
else:
|
58 |
+
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
59 |
+
|
60 |
+
|
61 |
+
"""
|
62 |
+
Kwargs Utils
|
63 |
+
"""
|
64 |
+
|
65 |
+
|
66 |
+
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
67 |
+
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
68 |
+
for key in d.keys():
|
69 |
+
no_prefix = int(not key.startswith(prefix))
|
70 |
+
return_dicts[no_prefix][key] = d[key]
|
71 |
+
return return_dicts
|
72 |
+
|
73 |
+
|
74 |
+
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
75 |
+
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
76 |
+
if keep_prefix:
|
77 |
+
return kwargs_with_prefix, kwargs
|
78 |
+
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
79 |
+
return kwargs_no_prefix, kwargs
|
80 |
+
|
81 |
+
|
82 |
+
def prefix_dict(prefix: str, d: Dict) -> Dict:
|
83 |
+
return {prefix + str(k): v for k, v in d.items()}
|
Modules/discriminators.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
6 |
+
|
7 |
+
from .utils import get_padding
|
8 |
+
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
|
12 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
13 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
14 |
+
Args:
|
15 |
+
x (Tensor): Input signal tensor (B, T).
|
16 |
+
fft_size (int): FFT size.
|
17 |
+
hop_size (int): Hop size.
|
18 |
+
win_length (int): Window length.
|
19 |
+
window (str): Window function type.
|
20 |
+
Returns:
|
21 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
22 |
+
"""
|
23 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
24 |
+
real = x_stft[..., 0]
|
25 |
+
imag = x_stft[..., 1]
|
26 |
+
|
27 |
+
return torch.abs(x_stft).transpose(2, 1)
|
28 |
+
|
29 |
+
|
30 |
+
class SpecDiscriminator(nn.Module):
|
31 |
+
"""docstring for Discriminator."""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
fft_size=1024,
|
36 |
+
shift_size=120,
|
37 |
+
win_length=600,
|
38 |
+
window="hann_window",
|
39 |
+
use_spectral_norm=False,
|
40 |
+
):
|
41 |
+
super(SpecDiscriminator, self).__init__()
|
42 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
43 |
+
self.fft_size = fft_size
|
44 |
+
self.shift_size = shift_size
|
45 |
+
self.win_length = win_length
|
46 |
+
self.window = getattr(torch, window)(win_length)
|
47 |
+
self.discriminators = nn.ModuleList(
|
48 |
+
[
|
49 |
+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
50 |
+
norm_f(
|
51 |
+
nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
|
52 |
+
),
|
53 |
+
norm_f(
|
54 |
+
nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
|
55 |
+
),
|
56 |
+
norm_f(
|
57 |
+
nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
|
58 |
+
),
|
59 |
+
norm_f(
|
60 |
+
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
61 |
+
),
|
62 |
+
]
|
63 |
+
)
|
64 |
+
|
65 |
+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
66 |
+
|
67 |
+
def forward(self, y):
|
68 |
+
fmap = []
|
69 |
+
y = y.squeeze(1)
|
70 |
+
y = stft(
|
71 |
+
y,
|
72 |
+
self.fft_size,
|
73 |
+
self.shift_size,
|
74 |
+
self.win_length,
|
75 |
+
self.window.to(y.get_device()),
|
76 |
+
)
|
77 |
+
y = y.unsqueeze(1)
|
78 |
+
for i, d in enumerate(self.discriminators):
|
79 |
+
y = d(y)
|
80 |
+
y = F.leaky_relu(y, LRELU_SLOPE)
|
81 |
+
fmap.append(y)
|
82 |
+
|
83 |
+
y = self.out(y)
|
84 |
+
fmap.append(y)
|
85 |
+
|
86 |
+
return torch.flatten(y, 1, -1), fmap
|
87 |
+
|
88 |
+
|
89 |
+
class MultiResSpecDiscriminator(torch.nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
fft_sizes=[1024, 2048, 512],
|
93 |
+
hop_sizes=[120, 240, 50],
|
94 |
+
win_lengths=[600, 1200, 240],
|
95 |
+
window="hann_window",
|
96 |
+
):
|
97 |
+
super(MultiResSpecDiscriminator, self).__init__()
|
98 |
+
self.discriminators = nn.ModuleList(
|
99 |
+
[
|
100 |
+
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
101 |
+
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
102 |
+
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, y, y_hat):
|
107 |
+
y_d_rs = []
|
108 |
+
y_d_gs = []
|
109 |
+
fmap_rs = []
|
110 |
+
fmap_gs = []
|
111 |
+
for i, d in enumerate(self.discriminators):
|
112 |
+
y_d_r, fmap_r = d(y)
|
113 |
+
y_d_g, fmap_g = d(y_hat)
|
114 |
+
y_d_rs.append(y_d_r)
|
115 |
+
fmap_rs.append(fmap_r)
|
116 |
+
y_d_gs.append(y_d_g)
|
117 |
+
fmap_gs.append(fmap_g)
|
118 |
+
|
119 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
120 |
+
|
121 |
+
|
122 |
+
class DiscriminatorP(torch.nn.Module):
|
123 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
124 |
+
super(DiscriminatorP, self).__init__()
|
125 |
+
self.period = period
|
126 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
127 |
+
self.convs = nn.ModuleList(
|
128 |
+
[
|
129 |
+
norm_f(
|
130 |
+
Conv2d(
|
131 |
+
1,
|
132 |
+
32,
|
133 |
+
(kernel_size, 1),
|
134 |
+
(stride, 1),
|
135 |
+
padding=(get_padding(5, 1), 0),
|
136 |
+
)
|
137 |
+
),
|
138 |
+
norm_f(
|
139 |
+
Conv2d(
|
140 |
+
32,
|
141 |
+
128,
|
142 |
+
(kernel_size, 1),
|
143 |
+
(stride, 1),
|
144 |
+
padding=(get_padding(5, 1), 0),
|
145 |
+
)
|
146 |
+
),
|
147 |
+
norm_f(
|
148 |
+
Conv2d(
|
149 |
+
128,
|
150 |
+
512,
|
151 |
+
(kernel_size, 1),
|
152 |
+
(stride, 1),
|
153 |
+
padding=(get_padding(5, 1), 0),
|
154 |
+
)
|
155 |
+
),
|
156 |
+
norm_f(
|
157 |
+
Conv2d(
|
158 |
+
512,
|
159 |
+
1024,
|
160 |
+
(kernel_size, 1),
|
161 |
+
(stride, 1),
|
162 |
+
padding=(get_padding(5, 1), 0),
|
163 |
+
)
|
164 |
+
),
|
165 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
166 |
+
]
|
167 |
+
)
|
168 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
fmap = []
|
172 |
+
|
173 |
+
# 1d to 2d
|
174 |
+
b, c, t = x.shape
|
175 |
+
if t % self.period != 0: # pad first
|
176 |
+
n_pad = self.period - (t % self.period)
|
177 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
178 |
+
t = t + n_pad
|
179 |
+
x = x.view(b, c, t // self.period, self.period)
|
180 |
+
|
181 |
+
for l in self.convs:
|
182 |
+
x = l(x)
|
183 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
184 |
+
fmap.append(x)
|
185 |
+
x = self.conv_post(x)
|
186 |
+
fmap.append(x)
|
187 |
+
x = torch.flatten(x, 1, -1)
|
188 |
+
|
189 |
+
return x, fmap
|
190 |
+
|
191 |
+
|
192 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
193 |
+
def __init__(self):
|
194 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
195 |
+
self.discriminators = nn.ModuleList(
|
196 |
+
[
|
197 |
+
DiscriminatorP(2),
|
198 |
+
DiscriminatorP(3),
|
199 |
+
DiscriminatorP(5),
|
200 |
+
DiscriminatorP(7),
|
201 |
+
DiscriminatorP(11),
|
202 |
+
]
|
203 |
+
)
|
204 |
+
|
205 |
+
def forward(self, y, y_hat):
|
206 |
+
y_d_rs = []
|
207 |
+
y_d_gs = []
|
208 |
+
fmap_rs = []
|
209 |
+
fmap_gs = []
|
210 |
+
for i, d in enumerate(self.discriminators):
|
211 |
+
y_d_r, fmap_r = d(y)
|
212 |
+
y_d_g, fmap_g = d(y_hat)
|
213 |
+
y_d_rs.append(y_d_r)
|
214 |
+
fmap_rs.append(fmap_r)
|
215 |
+
y_d_gs.append(y_d_g)
|
216 |
+
fmap_gs.append(fmap_g)
|
217 |
+
|
218 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
219 |
+
|
220 |
+
|
221 |
+
class WavLMDiscriminator(nn.Module):
|
222 |
+
"""docstring for Discriminator."""
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
226 |
+
):
|
227 |
+
super(WavLMDiscriminator, self).__init__()
|
228 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
229 |
+
self.pre = norm_f(
|
230 |
+
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
231 |
+
)
|
232 |
+
|
233 |
+
self.convs = nn.ModuleList(
|
234 |
+
[
|
235 |
+
norm_f(
|
236 |
+
nn.Conv1d(
|
237 |
+
initial_channel, initial_channel * 2, kernel_size=5, padding=2
|
238 |
+
)
|
239 |
+
),
|
240 |
+
norm_f(
|
241 |
+
nn.Conv1d(
|
242 |
+
initial_channel * 2,
|
243 |
+
initial_channel * 4,
|
244 |
+
kernel_size=5,
|
245 |
+
padding=2,
|
246 |
+
)
|
247 |
+
),
|
248 |
+
norm_f(
|
249 |
+
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
|
250 |
+
),
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
255 |
+
|
256 |
+
def forward(self, x):
|
257 |
+
x = self.pre(x)
|
258 |
+
|
259 |
+
fmap = []
|
260 |
+
for l in self.convs:
|
261 |
+
x = l(x)
|
262 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
263 |
+
fmap.append(x)
|
264 |
+
x = self.conv_post(x)
|
265 |
+
x = torch.flatten(x, 1, -1)
|
266 |
+
|
267 |
+
return x
|
Modules/hifigan.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
from .utils import init_weights, get_padding
|
7 |
+
|
8 |
+
import math
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
LRELU_SLOPE = 0.1
|
13 |
+
|
14 |
+
|
15 |
+
class AdaIN1d(nn.Module):
|
16 |
+
def __init__(self, style_dim, num_features):
|
17 |
+
super().__init__()
|
18 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
19 |
+
self.fc = nn.Linear(style_dim, num_features * 2)
|
20 |
+
|
21 |
+
def forward(self, x, s):
|
22 |
+
h = self.fc(s)
|
23 |
+
h = h.view(h.size(0), h.size(1), 1)
|
24 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
25 |
+
return (1 + gamma) * self.norm(x) + beta
|
26 |
+
|
27 |
+
|
28 |
+
class AdaINResBlock1(torch.nn.Module):
|
29 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
30 |
+
super(AdaINResBlock1, self).__init__()
|
31 |
+
self.convs1 = nn.ModuleList(
|
32 |
+
[
|
33 |
+
weight_norm(
|
34 |
+
Conv1d(
|
35 |
+
channels,
|
36 |
+
channels,
|
37 |
+
kernel_size,
|
38 |
+
1,
|
39 |
+
dilation=dilation[0],
|
40 |
+
padding=get_padding(kernel_size, dilation[0]),
|
41 |
+
)
|
42 |
+
),
|
43 |
+
weight_norm(
|
44 |
+
Conv1d(
|
45 |
+
channels,
|
46 |
+
channels,
|
47 |
+
kernel_size,
|
48 |
+
1,
|
49 |
+
dilation=dilation[1],
|
50 |
+
padding=get_padding(kernel_size, dilation[1]),
|
51 |
+
)
|
52 |
+
),
|
53 |
+
weight_norm(
|
54 |
+
Conv1d(
|
55 |
+
channels,
|
56 |
+
channels,
|
57 |
+
kernel_size,
|
58 |
+
1,
|
59 |
+
dilation=dilation[2],
|
60 |
+
padding=get_padding(kernel_size, dilation[2]),
|
61 |
+
)
|
62 |
+
),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
self.convs1.apply(init_weights)
|
66 |
+
|
67 |
+
self.convs2 = nn.ModuleList(
|
68 |
+
[
|
69 |
+
weight_norm(
|
70 |
+
Conv1d(
|
71 |
+
channels,
|
72 |
+
channels,
|
73 |
+
kernel_size,
|
74 |
+
1,
|
75 |
+
dilation=1,
|
76 |
+
padding=get_padding(kernel_size, 1),
|
77 |
+
)
|
78 |
+
),
|
79 |
+
weight_norm(
|
80 |
+
Conv1d(
|
81 |
+
channels,
|
82 |
+
channels,
|
83 |
+
kernel_size,
|
84 |
+
1,
|
85 |
+
dilation=1,
|
86 |
+
padding=get_padding(kernel_size, 1),
|
87 |
+
)
|
88 |
+
),
|
89 |
+
weight_norm(
|
90 |
+
Conv1d(
|
91 |
+
channels,
|
92 |
+
channels,
|
93 |
+
kernel_size,
|
94 |
+
1,
|
95 |
+
dilation=1,
|
96 |
+
padding=get_padding(kernel_size, 1),
|
97 |
+
)
|
98 |
+
),
|
99 |
+
]
|
100 |
+
)
|
101 |
+
self.convs2.apply(init_weights)
|
102 |
+
|
103 |
+
self.adain1 = nn.ModuleList(
|
104 |
+
[
|
105 |
+
AdaIN1d(style_dim, channels),
|
106 |
+
AdaIN1d(style_dim, channels),
|
107 |
+
AdaIN1d(style_dim, channels),
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
self.adain2 = nn.ModuleList(
|
112 |
+
[
|
113 |
+
AdaIN1d(style_dim, channels),
|
114 |
+
AdaIN1d(style_dim, channels),
|
115 |
+
AdaIN1d(style_dim, channels),
|
116 |
+
]
|
117 |
+
)
|
118 |
+
|
119 |
+
self.alpha1 = nn.ParameterList(
|
120 |
+
[nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]
|
121 |
+
)
|
122 |
+
self.alpha2 = nn.ParameterList(
|
123 |
+
[nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, x, s):
|
127 |
+
for c1, c2, n1, n2, a1, a2 in zip(
|
128 |
+
self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
|
129 |
+
):
|
130 |
+
xt = n1(x, s)
|
131 |
+
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
132 |
+
xt = c1(xt)
|
133 |
+
xt = n2(xt, s)
|
134 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
135 |
+
xt = c2(xt)
|
136 |
+
x = xt + x
|
137 |
+
return x
|
138 |
+
|
139 |
+
def remove_weight_norm(self):
|
140 |
+
for l in self.convs1:
|
141 |
+
remove_weight_norm(l)
|
142 |
+
for l in self.convs2:
|
143 |
+
remove_weight_norm(l)
|
144 |
+
|
145 |
+
|
146 |
+
class SineGen(torch.nn.Module):
|
147 |
+
"""Definition of sine generator
|
148 |
+
SineGen(samp_rate, harmonic_num = 0,
|
149 |
+
sine_amp = 0.1, noise_std = 0.003,
|
150 |
+
voiced_threshold = 0,
|
151 |
+
flag_for_pulse=False)
|
152 |
+
samp_rate: sampling rate in Hz
|
153 |
+
harmonic_num: number of harmonic overtones (default 0)
|
154 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
155 |
+
noise_std: std of Gaussian noise (default 0.003)
|
156 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
157 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
158 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
159 |
+
segment is always sin(np.pi) or cos(0)
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
samp_rate,
|
165 |
+
upsample_scale,
|
166 |
+
harmonic_num=0,
|
167 |
+
sine_amp=0.1,
|
168 |
+
noise_std=0.003,
|
169 |
+
voiced_threshold=0,
|
170 |
+
flag_for_pulse=False,
|
171 |
+
):
|
172 |
+
super(SineGen, self).__init__()
|
173 |
+
self.sine_amp = sine_amp
|
174 |
+
self.noise_std = noise_std
|
175 |
+
self.harmonic_num = harmonic_num
|
176 |
+
self.dim = self.harmonic_num + 1
|
177 |
+
self.sampling_rate = samp_rate
|
178 |
+
self.voiced_threshold = voiced_threshold
|
179 |
+
self.flag_for_pulse = flag_for_pulse
|
180 |
+
self.upsample_scale = upsample_scale
|
181 |
+
|
182 |
+
def _f02uv(self, f0):
|
183 |
+
# generate uv signal
|
184 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
185 |
+
return uv
|
186 |
+
|
187 |
+
def _f02sine(self, f0_values):
|
188 |
+
"""f0_values: (batchsize, length, dim)
|
189 |
+
where dim indicates fundamental tone and overtones
|
190 |
+
"""
|
191 |
+
# convert to F0 in rad. The interger part n can be ignored
|
192 |
+
# because 2 * np.pi * n doesn't affect phase
|
193 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
194 |
+
|
195 |
+
# initial phase noise (no noise for fundamental component)
|
196 |
+
rand_ini = torch.rand(
|
197 |
+
f0_values.shape[0], f0_values.shape[2], device=f0_values.device
|
198 |
+
)
|
199 |
+
rand_ini[:, 0] = 0
|
200 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
201 |
+
|
202 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
203 |
+
if not self.flag_for_pulse:
|
204 |
+
# # for normal case
|
205 |
+
|
206 |
+
# # To prevent torch.cumsum numerical overflow,
|
207 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
208 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
209 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
210 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
211 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
212 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
213 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
214 |
+
|
215 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
216 |
+
rad_values = torch.nn.functional.interpolate(
|
217 |
+
rad_values.transpose(1, 2),
|
218 |
+
scale_factor=1 / self.upsample_scale,
|
219 |
+
mode="linear",
|
220 |
+
).transpose(1, 2)
|
221 |
+
|
222 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
223 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
224 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
225 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
226 |
+
|
227 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
228 |
+
phase = torch.nn.functional.interpolate(
|
229 |
+
phase.transpose(1, 2) * self.upsample_scale,
|
230 |
+
scale_factor=self.upsample_scale,
|
231 |
+
mode="linear",
|
232 |
+
).transpose(1, 2)
|
233 |
+
sines = torch.sin(phase)
|
234 |
+
|
235 |
+
else:
|
236 |
+
# If necessary, make sure that the first time step of every
|
237 |
+
# voiced segments is sin(pi) or cos(0)
|
238 |
+
# This is used for pulse-train generation
|
239 |
+
|
240 |
+
# identify the last time step in unvoiced segments
|
241 |
+
uv = self._f02uv(f0_values)
|
242 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
243 |
+
uv_1[:, -1, :] = 1
|
244 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
245 |
+
|
246 |
+
# get the instantanouse phase
|
247 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
248 |
+
# different batch needs to be processed differently
|
249 |
+
for idx in range(f0_values.shape[0]):
|
250 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
251 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
252 |
+
# stores the accumulation of i.phase within
|
253 |
+
# each voiced segments
|
254 |
+
tmp_cumsum[idx, :, :] = 0
|
255 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
256 |
+
|
257 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
258 |
+
# within the previous voiced segment.
|
259 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
260 |
+
|
261 |
+
# get the sines
|
262 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
263 |
+
return sines
|
264 |
+
|
265 |
+
def forward(self, f0):
|
266 |
+
"""sine_tensor, uv = forward(f0)
|
267 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
268 |
+
f0 for unvoiced steps should be 0
|
269 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
270 |
+
output uv: tensor(batchsize=1, length, 1)
|
271 |
+
"""
|
272 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
273 |
+
# fundamental component
|
274 |
+
fn = torch.multiply(
|
275 |
+
f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
276 |
+
)
|
277 |
+
|
278 |
+
# generate sine waveforms
|
279 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
280 |
+
|
281 |
+
# generate uv signal
|
282 |
+
# uv = torch.ones(f0.shape)
|
283 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
284 |
+
uv = self._f02uv(f0)
|
285 |
+
|
286 |
+
# noise: for unvoiced should be similar to sine_amp
|
287 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
288 |
+
# . for voiced regions is self.noise_std
|
289 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
290 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
291 |
+
|
292 |
+
# first: set the unvoiced part to 0 by uv
|
293 |
+
# then: additive noise
|
294 |
+
sine_waves = sine_waves * uv + noise
|
295 |
+
return sine_waves, uv, noise
|
296 |
+
|
297 |
+
|
298 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
299 |
+
"""SourceModule for hn-nsf
|
300 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
301 |
+
add_noise_std=0.003, voiced_threshod=0)
|
302 |
+
sampling_rate: sampling_rate in Hz
|
303 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
304 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
305 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
306 |
+
note that amplitude of noise in unvoiced is decided
|
307 |
+
by sine_amp
|
308 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
309 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
310 |
+
F0_sampled (batchsize, length, 1)
|
311 |
+
Sine_source (batchsize, length, 1)
|
312 |
+
noise_source (batchsize, length 1)
|
313 |
+
uv (batchsize, length, 1)
|
314 |
+
"""
|
315 |
+
|
316 |
+
def __init__(
|
317 |
+
self,
|
318 |
+
sampling_rate,
|
319 |
+
upsample_scale,
|
320 |
+
harmonic_num=0,
|
321 |
+
sine_amp=0.1,
|
322 |
+
add_noise_std=0.003,
|
323 |
+
voiced_threshod=0,
|
324 |
+
):
|
325 |
+
super(SourceModuleHnNSF, self).__init__()
|
326 |
+
|
327 |
+
self.sine_amp = sine_amp
|
328 |
+
self.noise_std = add_noise_std
|
329 |
+
|
330 |
+
# to produce sine waveforms
|
331 |
+
self.l_sin_gen = SineGen(
|
332 |
+
sampling_rate,
|
333 |
+
upsample_scale,
|
334 |
+
harmonic_num,
|
335 |
+
sine_amp,
|
336 |
+
add_noise_std,
|
337 |
+
voiced_threshod,
|
338 |
+
)
|
339 |
+
|
340 |
+
# to merge source harmonics into a single excitation
|
341 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
342 |
+
self.l_tanh = torch.nn.Tanh()
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
"""
|
346 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
347 |
+
F0_sampled (batchsize, length, 1)
|
348 |
+
Sine_source (batchsize, length, 1)
|
349 |
+
noise_source (batchsize, length 1)
|
350 |
+
"""
|
351 |
+
# source for harmonic branch
|
352 |
+
with torch.no_grad():
|
353 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
354 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
355 |
+
|
356 |
+
# source for noise branch, in the same shape as uv
|
357 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
358 |
+
return sine_merge, noise, uv
|
359 |
+
|
360 |
+
|
361 |
+
def padDiff(x):
|
362 |
+
return F.pad(
|
363 |
+
F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
|
364 |
+
)
|
365 |
+
|
366 |
+
|
367 |
+
class Generator(torch.nn.Module):
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
style_dim,
|
371 |
+
resblock_kernel_sizes,
|
372 |
+
upsample_rates,
|
373 |
+
upsample_initial_channel,
|
374 |
+
resblock_dilation_sizes,
|
375 |
+
upsample_kernel_sizes,
|
376 |
+
):
|
377 |
+
super(Generator, self).__init__()
|
378 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
379 |
+
self.num_upsamples = len(upsample_rates)
|
380 |
+
resblock = AdaINResBlock1
|
381 |
+
|
382 |
+
self.m_source = SourceModuleHnNSF(
|
383 |
+
sampling_rate=24000,
|
384 |
+
upsample_scale=np.prod(upsample_rates),
|
385 |
+
harmonic_num=8,
|
386 |
+
voiced_threshod=10,
|
387 |
+
)
|
388 |
+
|
389 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
390 |
+
self.noise_convs = nn.ModuleList()
|
391 |
+
self.ups = nn.ModuleList()
|
392 |
+
self.noise_res = nn.ModuleList()
|
393 |
+
|
394 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
395 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
396 |
+
|
397 |
+
self.ups.append(
|
398 |
+
weight_norm(
|
399 |
+
ConvTranspose1d(
|
400 |
+
upsample_initial_channel // (2**i),
|
401 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
402 |
+
k,
|
403 |
+
u,
|
404 |
+
padding=(u // 2 + u % 2),
|
405 |
+
output_padding=u % 2,
|
406 |
+
)
|
407 |
+
)
|
408 |
+
)
|
409 |
+
|
410 |
+
if i + 1 < len(upsample_rates): #
|
411 |
+
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
412 |
+
self.noise_convs.append(
|
413 |
+
Conv1d(
|
414 |
+
1,
|
415 |
+
c_cur,
|
416 |
+
kernel_size=stride_f0 * 2,
|
417 |
+
stride=stride_f0,
|
418 |
+
padding=(stride_f0 + 1) // 2,
|
419 |
+
)
|
420 |
+
)
|
421 |
+
self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
|
422 |
+
else:
|
423 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
424 |
+
self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
|
425 |
+
|
426 |
+
self.resblocks = nn.ModuleList()
|
427 |
+
|
428 |
+
self.alphas = nn.ParameterList()
|
429 |
+
self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
|
430 |
+
|
431 |
+
for i in range(len(self.ups)):
|
432 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
433 |
+
self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
|
434 |
+
|
435 |
+
for j, (k, d) in enumerate(
|
436 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
437 |
+
):
|
438 |
+
self.resblocks.append(resblock(ch, k, d, style_dim))
|
439 |
+
|
440 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
441 |
+
self.ups.apply(init_weights)
|
442 |
+
self.conv_post.apply(init_weights)
|
443 |
+
|
444 |
+
def forward(self, x, s, f0):
|
445 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
446 |
+
|
447 |
+
har_source, noi_source, uv = self.m_source(f0)
|
448 |
+
har_source = har_source.transpose(1, 2)
|
449 |
+
|
450 |
+
for i in range(self.num_upsamples):
|
451 |
+
x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
|
452 |
+
x_source = self.noise_convs[i](har_source)
|
453 |
+
x_source = self.noise_res[i](x_source, s)
|
454 |
+
|
455 |
+
x = self.ups[i](x)
|
456 |
+
x = x + x_source
|
457 |
+
|
458 |
+
xs = None
|
459 |
+
for j in range(self.num_kernels):
|
460 |
+
if xs is None:
|
461 |
+
xs = self.resblocks[i * self.num_kernels + j](x, s)
|
462 |
+
else:
|
463 |
+
xs += self.resblocks[i * self.num_kernels + j](x, s)
|
464 |
+
x = xs / self.num_kernels
|
465 |
+
x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
|
466 |
+
x = self.conv_post(x)
|
467 |
+
x = torch.tanh(x)
|
468 |
+
|
469 |
+
return x
|
470 |
+
|
471 |
+
def remove_weight_norm(self):
|
472 |
+
print("Removing weight norm...")
|
473 |
+
for l in self.ups:
|
474 |
+
remove_weight_norm(l)
|
475 |
+
for l in self.resblocks:
|
476 |
+
l.remove_weight_norm()
|
477 |
+
remove_weight_norm(self.conv_pre)
|
478 |
+
remove_weight_norm(self.conv_post)
|
479 |
+
|
480 |
+
|
481 |
+
class AdainResBlk1d(nn.Module):
|
482 |
+
def __init__(
|
483 |
+
self,
|
484 |
+
dim_in,
|
485 |
+
dim_out,
|
486 |
+
style_dim=64,
|
487 |
+
actv=nn.LeakyReLU(0.2),
|
488 |
+
upsample="none",
|
489 |
+
dropout_p=0.0,
|
490 |
+
):
|
491 |
+
super().__init__()
|
492 |
+
self.actv = actv
|
493 |
+
self.upsample_type = upsample
|
494 |
+
self.upsample = UpSample1d(upsample)
|
495 |
+
self.learned_sc = dim_in != dim_out
|
496 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
497 |
+
self.dropout = nn.Dropout(dropout_p)
|
498 |
+
|
499 |
+
if upsample == "none":
|
500 |
+
self.pool = nn.Identity()
|
501 |
+
else:
|
502 |
+
self.pool = weight_norm(
|
503 |
+
nn.ConvTranspose1d(
|
504 |
+
dim_in,
|
505 |
+
dim_in,
|
506 |
+
kernel_size=3,
|
507 |
+
stride=2,
|
508 |
+
groups=dim_in,
|
509 |
+
padding=1,
|
510 |
+
output_padding=1,
|
511 |
+
)
|
512 |
+
)
|
513 |
+
|
514 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
515 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
516 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
517 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
518 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
519 |
+
if self.learned_sc:
|
520 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
521 |
+
|
522 |
+
def _shortcut(self, x):
|
523 |
+
x = self.upsample(x)
|
524 |
+
if self.learned_sc:
|
525 |
+
x = self.conv1x1(x)
|
526 |
+
return x
|
527 |
+
|
528 |
+
def _residual(self, x, s):
|
529 |
+
x = self.norm1(x, s)
|
530 |
+
x = self.actv(x)
|
531 |
+
x = self.pool(x)
|
532 |
+
x = self.conv1(self.dropout(x))
|
533 |
+
x = self.norm2(x, s)
|
534 |
+
x = self.actv(x)
|
535 |
+
x = self.conv2(self.dropout(x))
|
536 |
+
return x
|
537 |
+
|
538 |
+
def forward(self, x, s):
|
539 |
+
out = self._residual(x, s)
|
540 |
+
out = (out + self._shortcut(x)) / math.sqrt(2)
|
541 |
+
return out
|
542 |
+
|
543 |
+
|
544 |
+
class UpSample1d(nn.Module):
|
545 |
+
def __init__(self, layer_type):
|
546 |
+
super().__init__()
|
547 |
+
self.layer_type = layer_type
|
548 |
+
|
549 |
+
def forward(self, x):
|
550 |
+
if self.layer_type == "none":
|
551 |
+
return x
|
552 |
+
else:
|
553 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
554 |
+
|
555 |
+
|
556 |
+
class Decoder(nn.Module):
|
557 |
+
def __init__(
|
558 |
+
self,
|
559 |
+
dim_in=512,
|
560 |
+
F0_channel=512,
|
561 |
+
style_dim=64,
|
562 |
+
dim_out=80,
|
563 |
+
resblock_kernel_sizes=[3, 7, 11],
|
564 |
+
upsample_rates=[10, 5, 3, 2],
|
565 |
+
upsample_initial_channel=512,
|
566 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
567 |
+
upsample_kernel_sizes=[20, 10, 6, 4],
|
568 |
+
):
|
569 |
+
super().__init__()
|
570 |
+
|
571 |
+
self.decode = nn.ModuleList()
|
572 |
+
|
573 |
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
574 |
+
|
575 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
576 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
577 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
578 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
579 |
+
|
580 |
+
self.F0_conv = weight_norm(
|
581 |
+
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
582 |
+
)
|
583 |
+
|
584 |
+
self.N_conv = weight_norm(
|
585 |
+
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
586 |
+
)
|
587 |
+
|
588 |
+
self.asr_res = nn.Sequential(
|
589 |
+
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
590 |
+
)
|
591 |
+
|
592 |
+
self.generator = Generator(
|
593 |
+
style_dim,
|
594 |
+
resblock_kernel_sizes,
|
595 |
+
upsample_rates,
|
596 |
+
upsample_initial_channel,
|
597 |
+
resblock_dilation_sizes,
|
598 |
+
upsample_kernel_sizes,
|
599 |
+
)
|
600 |
+
|
601 |
+
def forward(self, asr, F0_curve, N, s):
|
602 |
+
if self.training:
|
603 |
+
downlist = [0, 3, 7]
|
604 |
+
F0_down = downlist[random.randint(0, 2)]
|
605 |
+
downlist = [0, 3, 7, 15]
|
606 |
+
N_down = downlist[random.randint(0, 3)]
|
607 |
+
if F0_down:
|
608 |
+
F0_curve = (
|
609 |
+
nn.functional.conv1d(
|
610 |
+
F0_curve.unsqueeze(1),
|
611 |
+
torch.ones(1, 1, F0_down).to("cuda"),
|
612 |
+
padding=F0_down // 2,
|
613 |
+
).squeeze(1)
|
614 |
+
/ F0_down
|
615 |
+
)
|
616 |
+
if N_down:
|
617 |
+
N = (
|
618 |
+
nn.functional.conv1d(
|
619 |
+
N.unsqueeze(1),
|
620 |
+
torch.ones(1, 1, N_down).to("cuda"),
|
621 |
+
padding=N_down // 2,
|
622 |
+
).squeeze(1)
|
623 |
+
/ N_down
|
624 |
+
)
|
625 |
+
|
626 |
+
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
627 |
+
N = self.N_conv(N.unsqueeze(1))
|
628 |
+
|
629 |
+
x = torch.cat([asr, F0, N], axis=1)
|
630 |
+
x = self.encode(x, s)
|
631 |
+
|
632 |
+
asr_res = self.asr_res(asr)
|
633 |
+
|
634 |
+
res = True
|
635 |
+
for block in self.decode:
|
636 |
+
if res:
|
637 |
+
x = torch.cat([x, asr_res, F0, N], axis=1)
|
638 |
+
x = block(x, s)
|
639 |
+
if block.upsample_type != "none":
|
640 |
+
res = False
|
641 |
+
|
642 |
+
x = self.generator(x, s, F0_curve)
|
643 |
+
return x
|
Modules/istftnet.py
ADDED
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
from .utils import init_weights, get_padding
|
7 |
+
|
8 |
+
import math
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
from scipy.signal import get_window
|
12 |
+
|
13 |
+
LRELU_SLOPE = 0.1
|
14 |
+
|
15 |
+
|
16 |
+
class AdaIN1d(nn.Module):
|
17 |
+
def __init__(self, style_dim, num_features):
|
18 |
+
super().__init__()
|
19 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
20 |
+
self.fc = nn.Linear(style_dim, num_features * 2)
|
21 |
+
|
22 |
+
def forward(self, x, s):
|
23 |
+
h = self.fc(s)
|
24 |
+
h = h.view(h.size(0), h.size(1), 1)
|
25 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
26 |
+
return (1 + gamma) * self.norm(x) + beta
|
27 |
+
|
28 |
+
|
29 |
+
class AdaINResBlock1(torch.nn.Module):
|
30 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
31 |
+
super(AdaINResBlock1, self).__init__()
|
32 |
+
self.convs1 = nn.ModuleList(
|
33 |
+
[
|
34 |
+
weight_norm(
|
35 |
+
Conv1d(
|
36 |
+
channels,
|
37 |
+
channels,
|
38 |
+
kernel_size,
|
39 |
+
1,
|
40 |
+
dilation=dilation[0],
|
41 |
+
padding=get_padding(kernel_size, dilation[0]),
|
42 |
+
)
|
43 |
+
),
|
44 |
+
weight_norm(
|
45 |
+
Conv1d(
|
46 |
+
channels,
|
47 |
+
channels,
|
48 |
+
kernel_size,
|
49 |
+
1,
|
50 |
+
dilation=dilation[1],
|
51 |
+
padding=get_padding(kernel_size, dilation[1]),
|
52 |
+
)
|
53 |
+
),
|
54 |
+
weight_norm(
|
55 |
+
Conv1d(
|
56 |
+
channels,
|
57 |
+
channels,
|
58 |
+
kernel_size,
|
59 |
+
1,
|
60 |
+
dilation=dilation[2],
|
61 |
+
padding=get_padding(kernel_size, dilation[2]),
|
62 |
+
)
|
63 |
+
),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
self.convs1.apply(init_weights)
|
67 |
+
|
68 |
+
self.convs2 = nn.ModuleList(
|
69 |
+
[
|
70 |
+
weight_norm(
|
71 |
+
Conv1d(
|
72 |
+
channels,
|
73 |
+
channels,
|
74 |
+
kernel_size,
|
75 |
+
1,
|
76 |
+
dilation=1,
|
77 |
+
padding=get_padding(kernel_size, 1),
|
78 |
+
)
|
79 |
+
),
|
80 |
+
weight_norm(
|
81 |
+
Conv1d(
|
82 |
+
channels,
|
83 |
+
channels,
|
84 |
+
kernel_size,
|
85 |
+
1,
|
86 |
+
dilation=1,
|
87 |
+
padding=get_padding(kernel_size, 1),
|
88 |
+
)
|
89 |
+
),
|
90 |
+
weight_norm(
|
91 |
+
Conv1d(
|
92 |
+
channels,
|
93 |
+
channels,
|
94 |
+
kernel_size,
|
95 |
+
1,
|
96 |
+
dilation=1,
|
97 |
+
padding=get_padding(kernel_size, 1),
|
98 |
+
)
|
99 |
+
),
|
100 |
+
]
|
101 |
+
)
|
102 |
+
self.convs2.apply(init_weights)
|
103 |
+
|
104 |
+
self.adain1 = nn.ModuleList(
|
105 |
+
[
|
106 |
+
AdaIN1d(style_dim, channels),
|
107 |
+
AdaIN1d(style_dim, channels),
|
108 |
+
AdaIN1d(style_dim, channels),
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
self.adain2 = nn.ModuleList(
|
113 |
+
[
|
114 |
+
AdaIN1d(style_dim, channels),
|
115 |
+
AdaIN1d(style_dim, channels),
|
116 |
+
AdaIN1d(style_dim, channels),
|
117 |
+
]
|
118 |
+
)
|
119 |
+
|
120 |
+
self.alpha1 = nn.ParameterList(
|
121 |
+
[nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]
|
122 |
+
)
|
123 |
+
self.alpha2 = nn.ParameterList(
|
124 |
+
[nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x, s):
|
128 |
+
for c1, c2, n1, n2, a1, a2 in zip(
|
129 |
+
self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
|
130 |
+
):
|
131 |
+
xt = n1(x, s)
|
132 |
+
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
133 |
+
xt = c1(xt)
|
134 |
+
xt = n2(xt, s)
|
135 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
136 |
+
xt = c2(xt)
|
137 |
+
x = xt + x
|
138 |
+
return x
|
139 |
+
|
140 |
+
def remove_weight_norm(self):
|
141 |
+
for l in self.convs1:
|
142 |
+
remove_weight_norm(l)
|
143 |
+
for l in self.convs2:
|
144 |
+
remove_weight_norm(l)
|
145 |
+
|
146 |
+
|
147 |
+
class TorchSTFT(torch.nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
150 |
+
):
|
151 |
+
super().__init__()
|
152 |
+
self.filter_length = filter_length
|
153 |
+
self.hop_length = hop_length
|
154 |
+
self.win_length = win_length
|
155 |
+
self.window = torch.from_numpy(
|
156 |
+
get_window(window, win_length, fftbins=True).astype(np.float32)
|
157 |
+
)
|
158 |
+
|
159 |
+
def transform(self, input_data):
|
160 |
+
forward_transform = torch.stft(
|
161 |
+
input_data,
|
162 |
+
self.filter_length,
|
163 |
+
self.hop_length,
|
164 |
+
self.win_length,
|
165 |
+
window=self.window.to(input_data.device),
|
166 |
+
return_complex=True,
|
167 |
+
)
|
168 |
+
|
169 |
+
return torch.abs(forward_transform), torch.angle(forward_transform)
|
170 |
+
|
171 |
+
def inverse(self, magnitude, phase):
|
172 |
+
inverse_transform = torch.istft(
|
173 |
+
magnitude * torch.exp(phase * 1j),
|
174 |
+
self.filter_length,
|
175 |
+
self.hop_length,
|
176 |
+
self.win_length,
|
177 |
+
window=self.window.to(magnitude.device),
|
178 |
+
)
|
179 |
+
|
180 |
+
return inverse_transform.unsqueeze(
|
181 |
+
-2
|
182 |
+
) # unsqueeze to stay consistent with conv_transpose1d implementation
|
183 |
+
|
184 |
+
def forward(self, input_data):
|
185 |
+
self.magnitude, self.phase = self.transform(input_data)
|
186 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
187 |
+
return reconstruction
|
188 |
+
|
189 |
+
|
190 |
+
class SineGen(torch.nn.Module):
|
191 |
+
"""Definition of sine generator
|
192 |
+
SineGen(samp_rate, harmonic_num = 0,
|
193 |
+
sine_amp = 0.1, noise_std = 0.003,
|
194 |
+
voiced_threshold = 0,
|
195 |
+
flag_for_pulse=False)
|
196 |
+
samp_rate: sampling rate in Hz
|
197 |
+
harmonic_num: number of harmonic overtones (default 0)
|
198 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
199 |
+
noise_std: std of Gaussian noise (default 0.003)
|
200 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
201 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
202 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
203 |
+
segment is always sin(np.pi) or cos(0)
|
204 |
+
"""
|
205 |
+
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
samp_rate,
|
209 |
+
upsample_scale,
|
210 |
+
harmonic_num=0,
|
211 |
+
sine_amp=0.1,
|
212 |
+
noise_std=0.003,
|
213 |
+
voiced_threshold=0,
|
214 |
+
flag_for_pulse=False,
|
215 |
+
):
|
216 |
+
super(SineGen, self).__init__()
|
217 |
+
self.sine_amp = sine_amp
|
218 |
+
self.noise_std = noise_std
|
219 |
+
self.harmonic_num = harmonic_num
|
220 |
+
self.dim = self.harmonic_num + 1
|
221 |
+
self.sampling_rate = samp_rate
|
222 |
+
self.voiced_threshold = voiced_threshold
|
223 |
+
self.flag_for_pulse = flag_for_pulse
|
224 |
+
self.upsample_scale = upsample_scale
|
225 |
+
|
226 |
+
def _f02uv(self, f0):
|
227 |
+
# generate uv signal
|
228 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
229 |
+
return uv
|
230 |
+
|
231 |
+
def _f02sine(self, f0_values):
|
232 |
+
"""f0_values: (batchsize, length, dim)
|
233 |
+
where dim indicates fundamental tone and overtones
|
234 |
+
"""
|
235 |
+
# convert to F0 in rad. The interger part n can be ignored
|
236 |
+
# because 2 * np.pi * n doesn't affect phase
|
237 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
238 |
+
|
239 |
+
# initial phase noise (no noise for fundamental component)
|
240 |
+
rand_ini = torch.rand(
|
241 |
+
f0_values.shape[0], f0_values.shape[2], device=f0_values.device
|
242 |
+
)
|
243 |
+
rand_ini[:, 0] = 0
|
244 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
245 |
+
|
246 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
247 |
+
if not self.flag_for_pulse:
|
248 |
+
# # for normal case
|
249 |
+
|
250 |
+
# # To prevent torch.cumsum numerical overflow,
|
251 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
252 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
253 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
254 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
255 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
256 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
257 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
258 |
+
|
259 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
260 |
+
rad_values = torch.nn.functional.interpolate(
|
261 |
+
rad_values.transpose(1, 2),
|
262 |
+
scale_factor=1 / self.upsample_scale,
|
263 |
+
mode="linear",
|
264 |
+
).transpose(1, 2)
|
265 |
+
|
266 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
267 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
268 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
269 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
270 |
+
|
271 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
272 |
+
phase = torch.nn.functional.interpolate(
|
273 |
+
phase.transpose(1, 2) * self.upsample_scale,
|
274 |
+
scale_factor=self.upsample_scale,
|
275 |
+
mode="linear",
|
276 |
+
).transpose(1, 2)
|
277 |
+
sines = torch.sin(phase)
|
278 |
+
|
279 |
+
else:
|
280 |
+
# If necessary, make sure that the first time step of every
|
281 |
+
# voiced segments is sin(pi) or cos(0)
|
282 |
+
# This is used for pulse-train generation
|
283 |
+
|
284 |
+
# identify the last time step in unvoiced segments
|
285 |
+
uv = self._f02uv(f0_values)
|
286 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
287 |
+
uv_1[:, -1, :] = 1
|
288 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
289 |
+
|
290 |
+
# get the instantanouse phase
|
291 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
292 |
+
# different batch needs to be processed differently
|
293 |
+
for idx in range(f0_values.shape[0]):
|
294 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
295 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
296 |
+
# stores the accumulation of i.phase within
|
297 |
+
# each voiced segments
|
298 |
+
tmp_cumsum[idx, :, :] = 0
|
299 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
300 |
+
|
301 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
302 |
+
# within the previous voiced segment.
|
303 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
304 |
+
|
305 |
+
# get the sines
|
306 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
307 |
+
return sines
|
308 |
+
|
309 |
+
def forward(self, f0):
|
310 |
+
"""sine_tensor, uv = forward(f0)
|
311 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
312 |
+
f0 for unvoiced steps should be 0
|
313 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
314 |
+
output uv: tensor(batchsize=1, length, 1)
|
315 |
+
"""
|
316 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
317 |
+
# fundamental component
|
318 |
+
fn = torch.multiply(
|
319 |
+
f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
320 |
+
)
|
321 |
+
|
322 |
+
# generate sine waveforms
|
323 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
324 |
+
|
325 |
+
# generate uv signal
|
326 |
+
# uv = torch.ones(f0.shape)
|
327 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
328 |
+
uv = self._f02uv(f0)
|
329 |
+
|
330 |
+
# noise: for unvoiced should be similar to sine_amp
|
331 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
332 |
+
# . for voiced regions is self.noise_std
|
333 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
334 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
335 |
+
|
336 |
+
# first: set the unvoiced part to 0 by uv
|
337 |
+
# then: additive noise
|
338 |
+
sine_waves = sine_waves * uv + noise
|
339 |
+
return sine_waves, uv, noise
|
340 |
+
|
341 |
+
|
342 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
343 |
+
"""SourceModule for hn-nsf
|
344 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
345 |
+
add_noise_std=0.003, voiced_threshod=0)
|
346 |
+
sampling_rate: sampling_rate in Hz
|
347 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
348 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
349 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
350 |
+
note that amplitude of noise in unvoiced is decided
|
351 |
+
by sine_amp
|
352 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
353 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
354 |
+
F0_sampled (batchsize, length, 1)
|
355 |
+
Sine_source (batchsize, length, 1)
|
356 |
+
noise_source (batchsize, length 1)
|
357 |
+
uv (batchsize, length, 1)
|
358 |
+
"""
|
359 |
+
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
sampling_rate,
|
363 |
+
upsample_scale,
|
364 |
+
harmonic_num=0,
|
365 |
+
sine_amp=0.1,
|
366 |
+
add_noise_std=0.003,
|
367 |
+
voiced_threshod=0,
|
368 |
+
):
|
369 |
+
super(SourceModuleHnNSF, self).__init__()
|
370 |
+
|
371 |
+
self.sine_amp = sine_amp
|
372 |
+
self.noise_std = add_noise_std
|
373 |
+
|
374 |
+
# to produce sine waveforms
|
375 |
+
self.l_sin_gen = SineGen(
|
376 |
+
sampling_rate,
|
377 |
+
upsample_scale,
|
378 |
+
harmonic_num,
|
379 |
+
sine_amp,
|
380 |
+
add_noise_std,
|
381 |
+
voiced_threshod,
|
382 |
+
)
|
383 |
+
|
384 |
+
# to merge source harmonics into a single excitation
|
385 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
386 |
+
self.l_tanh = torch.nn.Tanh()
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
"""
|
390 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
391 |
+
F0_sampled (batchsize, length, 1)
|
392 |
+
Sine_source (batchsize, length, 1)
|
393 |
+
noise_source (batchsize, length 1)
|
394 |
+
"""
|
395 |
+
# source for harmonic branch
|
396 |
+
with torch.no_grad():
|
397 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
398 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
399 |
+
|
400 |
+
# source for noise branch, in the same shape as uv
|
401 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
402 |
+
return sine_merge, noise, uv
|
403 |
+
|
404 |
+
|
405 |
+
def padDiff(x):
|
406 |
+
return F.pad(
|
407 |
+
F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
|
408 |
+
)
|
409 |
+
|
410 |
+
|
411 |
+
class Generator(torch.nn.Module):
|
412 |
+
def __init__(
|
413 |
+
self,
|
414 |
+
style_dim,
|
415 |
+
resblock_kernel_sizes,
|
416 |
+
upsample_rates,
|
417 |
+
upsample_initial_channel,
|
418 |
+
resblock_dilation_sizes,
|
419 |
+
upsample_kernel_sizes,
|
420 |
+
gen_istft_n_fft,
|
421 |
+
gen_istft_hop_size,
|
422 |
+
):
|
423 |
+
super(Generator, self).__init__()
|
424 |
+
|
425 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
426 |
+
self.num_upsamples = len(upsample_rates)
|
427 |
+
resblock = AdaINResBlock1
|
428 |
+
|
429 |
+
self.m_source = SourceModuleHnNSF(
|
430 |
+
sampling_rate=24000,
|
431 |
+
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
432 |
+
harmonic_num=8,
|
433 |
+
voiced_threshod=10,
|
434 |
+
)
|
435 |
+
self.f0_upsamp = torch.nn.Upsample(
|
436 |
+
scale_factor=np.prod(upsample_rates) * gen_istft_hop_size
|
437 |
+
)
|
438 |
+
self.noise_convs = nn.ModuleList()
|
439 |
+
self.noise_res = nn.ModuleList()
|
440 |
+
|
441 |
+
self.ups = nn.ModuleList()
|
442 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
443 |
+
self.ups.append(
|
444 |
+
weight_norm(
|
445 |
+
ConvTranspose1d(
|
446 |
+
upsample_initial_channel // (2**i),
|
447 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
448 |
+
k,
|
449 |
+
u,
|
450 |
+
padding=(k - u) // 2,
|
451 |
+
)
|
452 |
+
)
|
453 |
+
)
|
454 |
+
|
455 |
+
self.resblocks = nn.ModuleList()
|
456 |
+
for i in range(len(self.ups)):
|
457 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
458 |
+
for j, (k, d) in enumerate(
|
459 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
460 |
+
):
|
461 |
+
self.resblocks.append(resblock(ch, k, d, style_dim))
|
462 |
+
|
463 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
464 |
+
|
465 |
+
if i + 1 < len(upsample_rates): #
|
466 |
+
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
467 |
+
self.noise_convs.append(
|
468 |
+
Conv1d(
|
469 |
+
gen_istft_n_fft + 2,
|
470 |
+
c_cur,
|
471 |
+
kernel_size=stride_f0 * 2,
|
472 |
+
stride=stride_f0,
|
473 |
+
padding=(stride_f0 + 1) // 2,
|
474 |
+
)
|
475 |
+
)
|
476 |
+
self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
|
477 |
+
else:
|
478 |
+
self.noise_convs.append(
|
479 |
+
Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)
|
480 |
+
)
|
481 |
+
self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
|
482 |
+
|
483 |
+
self.post_n_fft = gen_istft_n_fft
|
484 |
+
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
485 |
+
self.ups.apply(init_weights)
|
486 |
+
self.conv_post.apply(init_weights)
|
487 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
488 |
+
self.stft = TorchSTFT(
|
489 |
+
filter_length=gen_istft_n_fft,
|
490 |
+
hop_length=gen_istft_hop_size,
|
491 |
+
win_length=gen_istft_n_fft,
|
492 |
+
)
|
493 |
+
|
494 |
+
def forward(self, x, s, f0):
|
495 |
+
with torch.no_grad():
|
496 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
497 |
+
|
498 |
+
har_source, noi_source, uv = self.m_source(f0)
|
499 |
+
har_source = har_source.transpose(1, 2).squeeze(1)
|
500 |
+
har_spec, har_phase = self.stft.transform(har_source)
|
501 |
+
har = torch.cat([har_spec, har_phase], dim=1)
|
502 |
+
|
503 |
+
for i in range(self.num_upsamples):
|
504 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
505 |
+
x_source = self.noise_convs[i](har)
|
506 |
+
x_source = self.noise_res[i](x_source, s)
|
507 |
+
|
508 |
+
x = self.ups[i](x)
|
509 |
+
if i == self.num_upsamples - 1:
|
510 |
+
x = self.reflection_pad(x)
|
511 |
+
|
512 |
+
x = x + x_source
|
513 |
+
xs = None
|
514 |
+
for j in range(self.num_kernels):
|
515 |
+
if xs is None:
|
516 |
+
xs = self.resblocks[i * self.num_kernels + j](x, s)
|
517 |
+
else:
|
518 |
+
xs += self.resblocks[i * self.num_kernels + j](x, s)
|
519 |
+
x = xs / self.num_kernels
|
520 |
+
x = F.leaky_relu(x)
|
521 |
+
x = self.conv_post(x)
|
522 |
+
spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
|
523 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
|
524 |
+
return self.stft.inverse(spec, phase)
|
525 |
+
|
526 |
+
def fw_phase(self, x, s):
|
527 |
+
for i in range(self.num_upsamples):
|
528 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
529 |
+
x = self.ups[i](x)
|
530 |
+
xs = None
|
531 |
+
for j in range(self.num_kernels):
|
532 |
+
if xs is None:
|
533 |
+
xs = self.resblocks[i * self.num_kernels + j](x, s)
|
534 |
+
else:
|
535 |
+
xs += self.resblocks[i * self.num_kernels + j](x, s)
|
536 |
+
x = xs / self.num_kernels
|
537 |
+
x = F.leaky_relu(x)
|
538 |
+
x = self.reflection_pad(x)
|
539 |
+
x = self.conv_post(x)
|
540 |
+
spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
|
541 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
|
542 |
+
return spec, phase
|
543 |
+
|
544 |
+
def remove_weight_norm(self):
|
545 |
+
print("Removing weight norm...")
|
546 |
+
for l in self.ups:
|
547 |
+
remove_weight_norm(l)
|
548 |
+
for l in self.resblocks:
|
549 |
+
l.remove_weight_norm()
|
550 |
+
remove_weight_norm(self.conv_pre)
|
551 |
+
remove_weight_norm(self.conv_post)
|
552 |
+
|
553 |
+
|
554 |
+
class AdainResBlk1d(nn.Module):
|
555 |
+
def __init__(
|
556 |
+
self,
|
557 |
+
dim_in,
|
558 |
+
dim_out,
|
559 |
+
style_dim=64,
|
560 |
+
actv=nn.LeakyReLU(0.2),
|
561 |
+
upsample="none",
|
562 |
+
dropout_p=0.0,
|
563 |
+
):
|
564 |
+
super().__init__()
|
565 |
+
self.actv = actv
|
566 |
+
self.upsample_type = upsample
|
567 |
+
self.upsample = UpSample1d(upsample)
|
568 |
+
self.learned_sc = dim_in != dim_out
|
569 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
570 |
+
self.dropout = nn.Dropout(dropout_p)
|
571 |
+
|
572 |
+
if upsample == "none":
|
573 |
+
self.pool = nn.Identity()
|
574 |
+
else:
|
575 |
+
self.pool = weight_norm(
|
576 |
+
nn.ConvTranspose1d(
|
577 |
+
dim_in,
|
578 |
+
dim_in,
|
579 |
+
kernel_size=3,
|
580 |
+
stride=2,
|
581 |
+
groups=dim_in,
|
582 |
+
padding=1,
|
583 |
+
output_padding=1,
|
584 |
+
)
|
585 |
+
)
|
586 |
+
|
587 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
588 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
589 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
590 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
591 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
592 |
+
if self.learned_sc:
|
593 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
594 |
+
|
595 |
+
def _shortcut(self, x):
|
596 |
+
x = self.upsample(x)
|
597 |
+
if self.learned_sc:
|
598 |
+
x = self.conv1x1(x)
|
599 |
+
return x
|
600 |
+
|
601 |
+
def _residual(self, x, s):
|
602 |
+
x = self.norm1(x, s)
|
603 |
+
x = self.actv(x)
|
604 |
+
x = self.pool(x)
|
605 |
+
x = self.conv1(self.dropout(x))
|
606 |
+
x = self.norm2(x, s)
|
607 |
+
x = self.actv(x)
|
608 |
+
x = self.conv2(self.dropout(x))
|
609 |
+
return x
|
610 |
+
|
611 |
+
def forward(self, x, s):
|
612 |
+
out = self._residual(x, s)
|
613 |
+
out = (out + self._shortcut(x)) / math.sqrt(2)
|
614 |
+
return out
|
615 |
+
|
616 |
+
|
617 |
+
class UpSample1d(nn.Module):
|
618 |
+
def __init__(self, layer_type):
|
619 |
+
super().__init__()
|
620 |
+
self.layer_type = layer_type
|
621 |
+
|
622 |
+
def forward(self, x):
|
623 |
+
if self.layer_type == "none":
|
624 |
+
return x
|
625 |
+
else:
|
626 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
627 |
+
|
628 |
+
|
629 |
+
class Decoder(nn.Module):
|
630 |
+
def __init__(
|
631 |
+
self,
|
632 |
+
dim_in=512,
|
633 |
+
F0_channel=512,
|
634 |
+
style_dim=64,
|
635 |
+
dim_out=80,
|
636 |
+
resblock_kernel_sizes=[3, 7, 11],
|
637 |
+
upsample_rates=[10, 6],
|
638 |
+
upsample_initial_channel=512,
|
639 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
640 |
+
upsample_kernel_sizes=[20, 12],
|
641 |
+
gen_istft_n_fft=20,
|
642 |
+
gen_istft_hop_size=5,
|
643 |
+
):
|
644 |
+
super().__init__()
|
645 |
+
|
646 |
+
self.decode = nn.ModuleList()
|
647 |
+
|
648 |
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
649 |
+
|
650 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
651 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
652 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
653 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
654 |
+
|
655 |
+
self.F0_conv = weight_norm(
|
656 |
+
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
657 |
+
)
|
658 |
+
|
659 |
+
self.N_conv = weight_norm(
|
660 |
+
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
661 |
+
)
|
662 |
+
|
663 |
+
self.asr_res = nn.Sequential(
|
664 |
+
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
665 |
+
)
|
666 |
+
|
667 |
+
self.generator = Generator(
|
668 |
+
style_dim,
|
669 |
+
resblock_kernel_sizes,
|
670 |
+
upsample_rates,
|
671 |
+
upsample_initial_channel,
|
672 |
+
resblock_dilation_sizes,
|
673 |
+
upsample_kernel_sizes,
|
674 |
+
gen_istft_n_fft,
|
675 |
+
gen_istft_hop_size,
|
676 |
+
)
|
677 |
+
|
678 |
+
def forward(self, asr, F0_curve, N, s):
|
679 |
+
if self.training:
|
680 |
+
downlist = [0, 3, 7]
|
681 |
+
F0_down = downlist[random.randint(0, 2)]
|
682 |
+
downlist = [0, 3, 7, 15]
|
683 |
+
N_down = downlist[random.randint(0, 3)]
|
684 |
+
if F0_down:
|
685 |
+
F0_curve = (
|
686 |
+
nn.functional.conv1d(
|
687 |
+
F0_curve.unsqueeze(1),
|
688 |
+
torch.ones(1, 1, F0_down).to("cuda"),
|
689 |
+
padding=F0_down // 2,
|
690 |
+
).squeeze(1)
|
691 |
+
/ F0_down
|
692 |
+
)
|
693 |
+
if N_down:
|
694 |
+
N = (
|
695 |
+
nn.functional.conv1d(
|
696 |
+
N.unsqueeze(1),
|
697 |
+
torch.ones(1, 1, N_down).to("cuda"),
|
698 |
+
padding=N_down // 2,
|
699 |
+
).squeeze(1)
|
700 |
+
/ N_down
|
701 |
+
)
|
702 |
+
|
703 |
+
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
704 |
+
N = self.N_conv(N.unsqueeze(1))
|
705 |
+
|
706 |
+
x = torch.cat([asr, F0, N], axis=1)
|
707 |
+
x = self.encode(x, s)
|
708 |
+
|
709 |
+
asr_res = self.asr_res(asr)
|
710 |
+
|
711 |
+
res = True
|
712 |
+
for block in self.decode:
|
713 |
+
if res:
|
714 |
+
x = torch.cat([x, asr_res, F0, N], axis=1)
|
715 |
+
x = block(x, s)
|
716 |
+
if block.upsample_type != "none":
|
717 |
+
res = False
|
718 |
+
|
719 |
+
x = self.generator(x, s, F0_curve)
|
720 |
+
return x
|
Modules/slmadv.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class SLMAdversarialLoss(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
model,
|
10 |
+
wl,
|
11 |
+
sampler,
|
12 |
+
min_len,
|
13 |
+
max_len,
|
14 |
+
batch_percentage=0.5,
|
15 |
+
skip_update=10,
|
16 |
+
sig=1.5,
|
17 |
+
):
|
18 |
+
super(SLMAdversarialLoss, self).__init__()
|
19 |
+
self.model = model
|
20 |
+
self.wl = wl
|
21 |
+
self.sampler = sampler
|
22 |
+
|
23 |
+
self.min_len = min_len
|
24 |
+
self.max_len = max_len
|
25 |
+
self.batch_percentage = batch_percentage
|
26 |
+
|
27 |
+
self.sig = sig
|
28 |
+
self.skip_update = skip_update
|
29 |
+
|
30 |
+
def forward(
|
31 |
+
self,
|
32 |
+
iters,
|
33 |
+
y_rec_gt,
|
34 |
+
y_rec_gt_pred,
|
35 |
+
waves,
|
36 |
+
mel_input_length,
|
37 |
+
ref_text,
|
38 |
+
ref_lengths,
|
39 |
+
use_ind,
|
40 |
+
s_trg,
|
41 |
+
ref_s=None,
|
42 |
+
):
|
43 |
+
text_mask = length_to_mask(ref_lengths).to(ref_text.device)
|
44 |
+
bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
|
45 |
+
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
|
46 |
+
|
47 |
+
if use_ind and np.random.rand() < 0.5:
|
48 |
+
s_preds = s_trg
|
49 |
+
else:
|
50 |
+
num_steps = np.random.randint(3, 5)
|
51 |
+
if ref_s is not None:
|
52 |
+
s_preds = self.sampler(
|
53 |
+
noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
|
54 |
+
embedding=bert_dur,
|
55 |
+
embedding_scale=1,
|
56 |
+
features=ref_s, # reference from the same speaker as the embedding
|
57 |
+
embedding_mask_proba=0.1,
|
58 |
+
num_steps=num_steps,
|
59 |
+
).squeeze(1)
|
60 |
+
else:
|
61 |
+
s_preds = self.sampler(
|
62 |
+
noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
|
63 |
+
embedding=bert_dur,
|
64 |
+
embedding_scale=1,
|
65 |
+
embedding_mask_proba=0.1,
|
66 |
+
num_steps=num_steps,
|
67 |
+
).squeeze(1)
|
68 |
+
|
69 |
+
s_dur = s_preds[:, 128:]
|
70 |
+
s = s_preds[:, :128]
|
71 |
+
|
72 |
+
d, _ = self.model.predictor(
|
73 |
+
d_en,
|
74 |
+
s_dur,
|
75 |
+
ref_lengths,
|
76 |
+
torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
|
77 |
+
text_mask,
|
78 |
+
)
|
79 |
+
|
80 |
+
bib = 0
|
81 |
+
|
82 |
+
output_lengths = []
|
83 |
+
attn_preds = []
|
84 |
+
|
85 |
+
# differentiable duration modeling
|
86 |
+
for _s2s_pred, _text_length in zip(d, ref_lengths):
|
87 |
+
_s2s_pred_org = _s2s_pred[:_text_length, :]
|
88 |
+
|
89 |
+
_s2s_pred = torch.sigmoid(_s2s_pred_org)
|
90 |
+
_dur_pred = _s2s_pred.sum(axis=-1)
|
91 |
+
|
92 |
+
l = int(torch.round(_s2s_pred.sum()).item())
|
93 |
+
t = torch.arange(0, l).expand(l)
|
94 |
+
|
95 |
+
t = (
|
96 |
+
torch.arange(0, l)
|
97 |
+
.unsqueeze(0)
|
98 |
+
.expand((len(_s2s_pred), l))
|
99 |
+
.to(ref_text.device)
|
100 |
+
)
|
101 |
+
loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
|
102 |
+
|
103 |
+
h = torch.exp(
|
104 |
+
-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig) ** 2
|
105 |
+
)
|
106 |
+
|
107 |
+
out = torch.nn.functional.conv1d(
|
108 |
+
_s2s_pred_org.unsqueeze(0),
|
109 |
+
h.unsqueeze(1),
|
110 |
+
padding=h.shape[-1] - 1,
|
111 |
+
groups=int(_text_length),
|
112 |
+
)[..., :l]
|
113 |
+
attn_preds.append(F.softmax(out.squeeze(), dim=0))
|
114 |
+
|
115 |
+
output_lengths.append(l)
|
116 |
+
|
117 |
+
max_len = max(output_lengths)
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
|
121 |
+
|
122 |
+
s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(
|
123 |
+
ref_text.device
|
124 |
+
)
|
125 |
+
for bib in range(len(output_lengths)):
|
126 |
+
s2s_attn[bib, : ref_lengths[bib], : output_lengths[bib]] = attn_preds[bib]
|
127 |
+
|
128 |
+
asr_pred = t_en @ s2s_attn
|
129 |
+
|
130 |
+
_, p_pred = self.model.predictor(d_en, s_dur, ref_lengths, s2s_attn, text_mask)
|
131 |
+
|
132 |
+
mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
|
133 |
+
mel_len = min(mel_len, self.max_len // 2)
|
134 |
+
|
135 |
+
# get clips
|
136 |
+
|
137 |
+
en = []
|
138 |
+
p_en = []
|
139 |
+
sp = []
|
140 |
+
|
141 |
+
F0_fakes = []
|
142 |
+
N_fakes = []
|
143 |
+
|
144 |
+
wav = []
|
145 |
+
|
146 |
+
for bib in range(len(output_lengths)):
|
147 |
+
mel_length_pred = output_lengths[bib]
|
148 |
+
mel_length_gt = int(mel_input_length[bib].item() / 2)
|
149 |
+
if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
|
150 |
+
continue
|
151 |
+
|
152 |
+
sp.append(s_preds[bib])
|
153 |
+
|
154 |
+
random_start = np.random.randint(0, mel_length_pred - mel_len)
|
155 |
+
en.append(asr_pred[bib, :, random_start : random_start + mel_len])
|
156 |
+
p_en.append(p_pred[bib, :, random_start : random_start + mel_len])
|
157 |
+
|
158 |
+
# get ground truth clips
|
159 |
+
random_start = np.random.randint(0, mel_length_gt - mel_len)
|
160 |
+
y = waves[bib][
|
161 |
+
(random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
|
162 |
+
]
|
163 |
+
wav.append(torch.from_numpy(y).to(ref_text.device))
|
164 |
+
|
165 |
+
if len(wav) >= self.batch_percentage * len(
|
166 |
+
waves
|
167 |
+
): # prevent OOM due to longer lengths
|
168 |
+
break
|
169 |
+
|
170 |
+
if len(sp) <= 1:
|
171 |
+
return None
|
172 |
+
|
173 |
+
sp = torch.stack(sp)
|
174 |
+
wav = torch.stack(wav).float()
|
175 |
+
en = torch.stack(en)
|
176 |
+
p_en = torch.stack(p_en)
|
177 |
+
|
178 |
+
F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
|
179 |
+
y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
|
180 |
+
|
181 |
+
# discriminator loss
|
182 |
+
if (iters + 1) % self.skip_update == 0:
|
183 |
+
if np.random.randint(0, 2) == 0:
|
184 |
+
wav = y_rec_gt_pred
|
185 |
+
use_rec = True
|
186 |
+
else:
|
187 |
+
use_rec = False
|
188 |
+
|
189 |
+
crop_size = min(wav.size(-1), y_pred.size(-1))
|
190 |
+
if (
|
191 |
+
use_rec
|
192 |
+
): # use reconstructed (shorter lengths), do length invariant regularization
|
193 |
+
if wav.size(-1) > y_pred.size(-1):
|
194 |
+
real_GP = wav[:, :, :crop_size]
|
195 |
+
out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
|
196 |
+
out_org = self.wl.discriminator_forward(wav.detach().squeeze())
|
197 |
+
loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)])
|
198 |
+
|
199 |
+
if np.random.randint(0, 2) == 0:
|
200 |
+
d_loss = self.wl.discriminator(
|
201 |
+
real_GP.detach().squeeze(), y_pred.detach().squeeze()
|
202 |
+
).mean()
|
203 |
+
else:
|
204 |
+
d_loss = self.wl.discriminator(
|
205 |
+
wav.detach().squeeze(), y_pred.detach().squeeze()
|
206 |
+
).mean()
|
207 |
+
else:
|
208 |
+
real_GP = y_pred[:, :, :crop_size]
|
209 |
+
out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
|
210 |
+
out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
|
211 |
+
loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)])
|
212 |
+
|
213 |
+
if np.random.randint(0, 2) == 0:
|
214 |
+
d_loss = self.wl.discriminator(
|
215 |
+
wav.detach().squeeze(), real_GP.detach().squeeze()
|
216 |
+
).mean()
|
217 |
+
else:
|
218 |
+
d_loss = self.wl.discriminator(
|
219 |
+
wav.detach().squeeze(), y_pred.detach().squeeze()
|
220 |
+
).mean()
|
221 |
+
|
222 |
+
# regularization (ignore length variation)
|
223 |
+
d_loss += loss_reg
|
224 |
+
|
225 |
+
out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
|
226 |
+
out_rec = self.wl.discriminator_forward(
|
227 |
+
y_rec_gt_pred.detach().squeeze()
|
228 |
+
)
|
229 |
+
|
230 |
+
# regularization (ignore reconstruction artifacts)
|
231 |
+
d_loss += F.l1_loss(out_gt, out_rec)
|
232 |
+
|
233 |
+
else:
|
234 |
+
d_loss = self.wl.discriminator(
|
235 |
+
wav.detach().squeeze(), y_pred.detach().squeeze()
|
236 |
+
).mean()
|
237 |
+
else:
|
238 |
+
d_loss = 0
|
239 |
+
|
240 |
+
# generator loss
|
241 |
+
gen_loss = self.wl.generator(y_pred.squeeze())
|
242 |
+
|
243 |
+
gen_loss = gen_loss.mean()
|
244 |
+
|
245 |
+
return d_loss, gen_loss, y_pred.detach().cpu().numpy()
|
246 |
+
|
247 |
+
|
248 |
+
def length_to_mask(lengths):
|
249 |
+
mask = (
|
250 |
+
torch.arange(lengths.max())
|
251 |
+
.unsqueeze(0)
|
252 |
+
.expand(lengths.shape[0], -1)
|
253 |
+
.type_as(lengths)
|
254 |
+
)
|
255 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
256 |
+
return mask
|
Modules/utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def init_weights(m, mean=0.0, std=0.01):
|
2 |
+
classname = m.__class__.__name__
|
3 |
+
if classname.find("Conv") != -1:
|
4 |
+
m.weight.data.normal_(mean, std)
|
5 |
+
|
6 |
+
|
7 |
+
def apply_weight_norm(m):
|
8 |
+
classname = m.__class__.__name__
|
9 |
+
if classname.find("Conv") != -1:
|
10 |
+
weight_norm(m)
|
11 |
+
|
12 |
+
|
13 |
+
def get_padding(kernel_size, dilation=1):
|
14 |
+
return int((kernel_size * dilation - dilation) / 2)
|
README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: StyleTTS 2
|
3 |
+
emoji: 🗣️
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.5.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: other
|
11 |
+
---
|
12 |
+
|
13 |
+
LICENSE FOR STYLETTS2: MIT LICENSE
|
14 |
+
|
15 |
+
LICENSE FOR STYLETTS2 DEMO PAGE: © 2023 MRFAKENAME. ALL RIGHTS RESERVED.
|
Utils/ASR/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Utils/ASR/config.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "logs/20201006"
|
2 |
+
save_freq: 5
|
3 |
+
device: "cuda"
|
4 |
+
epochs: 180
|
5 |
+
batch_size: 64
|
6 |
+
pretrained_model: ""
|
7 |
+
train_data: "ASRDataset/train_list.txt"
|
8 |
+
val_data: "ASRDataset/val_list.txt"
|
9 |
+
|
10 |
+
dataset_params:
|
11 |
+
data_augmentation: false
|
12 |
+
|
13 |
+
preprocess_parasm:
|
14 |
+
sr: 24000
|
15 |
+
spect_params:
|
16 |
+
n_fft: 2048
|
17 |
+
win_length: 1200
|
18 |
+
hop_length: 300
|
19 |
+
mel_params:
|
20 |
+
n_mels: 80
|
21 |
+
|
22 |
+
model_params:
|
23 |
+
input_dim: 80
|
24 |
+
hidden_dim: 256
|
25 |
+
n_token: 178
|
26 |
+
token_embedding_dim: 512
|
27 |
+
|
28 |
+
optimizer_params:
|
29 |
+
lr: 0.0005
|
Utils/ASR/epoch_00080.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
|
3 |
+
size 94552811
|
Utils/ASR/layers.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from typing import Optional, Any
|
5 |
+
from torch import Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
import torchaudio.functional as audio_F
|
9 |
+
|
10 |
+
import random
|
11 |
+
|
12 |
+
random.seed(0)
|
13 |
+
|
14 |
+
|
15 |
+
def _get_activation_fn(activ):
|
16 |
+
if activ == "relu":
|
17 |
+
return nn.ReLU()
|
18 |
+
elif activ == "lrelu":
|
19 |
+
return nn.LeakyReLU(0.2)
|
20 |
+
elif activ == "swish":
|
21 |
+
return lambda x: x * torch.sigmoid(x)
|
22 |
+
else:
|
23 |
+
raise RuntimeError(
|
24 |
+
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class LinearNorm(torch.nn.Module):
|
29 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
30 |
+
super(LinearNorm, self).__init__()
|
31 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
32 |
+
|
33 |
+
torch.nn.init.xavier_uniform_(
|
34 |
+
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.linear_layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConvNorm(torch.nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
in_channels,
|
45 |
+
out_channels,
|
46 |
+
kernel_size=1,
|
47 |
+
stride=1,
|
48 |
+
padding=None,
|
49 |
+
dilation=1,
|
50 |
+
bias=True,
|
51 |
+
w_init_gain="linear",
|
52 |
+
param=None,
|
53 |
+
):
|
54 |
+
super(ConvNorm, self).__init__()
|
55 |
+
if padding is None:
|
56 |
+
assert kernel_size % 2 == 1
|
57 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
58 |
+
|
59 |
+
self.conv = torch.nn.Conv1d(
|
60 |
+
in_channels,
|
61 |
+
out_channels,
|
62 |
+
kernel_size=kernel_size,
|
63 |
+
stride=stride,
|
64 |
+
padding=padding,
|
65 |
+
dilation=dilation,
|
66 |
+
bias=bias,
|
67 |
+
)
|
68 |
+
|
69 |
+
torch.nn.init.xavier_uniform_(
|
70 |
+
self.conv.weight,
|
71 |
+
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, signal):
|
75 |
+
conv_signal = self.conv(signal)
|
76 |
+
return conv_signal
|
77 |
+
|
78 |
+
|
79 |
+
class CausualConv(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size=1,
|
85 |
+
stride=1,
|
86 |
+
padding=1,
|
87 |
+
dilation=1,
|
88 |
+
bias=True,
|
89 |
+
w_init_gain="linear",
|
90 |
+
param=None,
|
91 |
+
):
|
92 |
+
super(CausualConv, self).__init__()
|
93 |
+
if padding is None:
|
94 |
+
assert kernel_size % 2 == 1
|
95 |
+
padding = int(dilation * (kernel_size - 1) / 2) * 2
|
96 |
+
else:
|
97 |
+
self.padding = padding * 2
|
98 |
+
self.conv = nn.Conv1d(
|
99 |
+
in_channels,
|
100 |
+
out_channels,
|
101 |
+
kernel_size=kernel_size,
|
102 |
+
stride=stride,
|
103 |
+
padding=self.padding,
|
104 |
+
dilation=dilation,
|
105 |
+
bias=bias,
|
106 |
+
)
|
107 |
+
|
108 |
+
torch.nn.init.xavier_uniform_(
|
109 |
+
self.conv.weight,
|
110 |
+
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
x = self.conv(x)
|
115 |
+
x = x[:, :, : -self.padding]
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class CausualBlock(nn.Module):
|
120 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
|
121 |
+
super(CausualBlock, self).__init__()
|
122 |
+
self.blocks = nn.ModuleList(
|
123 |
+
[
|
124 |
+
self._get_conv(
|
125 |
+
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
126 |
+
)
|
127 |
+
for i in range(n_conv)
|
128 |
+
]
|
129 |
+
)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
for block in self.blocks:
|
133 |
+
res = x
|
134 |
+
x = block(x)
|
135 |
+
x += res
|
136 |
+
return x
|
137 |
+
|
138 |
+
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
|
139 |
+
layers = [
|
140 |
+
CausualConv(
|
141 |
+
hidden_dim,
|
142 |
+
hidden_dim,
|
143 |
+
kernel_size=3,
|
144 |
+
padding=dilation,
|
145 |
+
dilation=dilation,
|
146 |
+
),
|
147 |
+
_get_activation_fn(activ),
|
148 |
+
nn.BatchNorm1d(hidden_dim),
|
149 |
+
nn.Dropout(p=dropout_p),
|
150 |
+
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
151 |
+
_get_activation_fn(activ),
|
152 |
+
nn.Dropout(p=dropout_p),
|
153 |
+
]
|
154 |
+
return nn.Sequential(*layers)
|
155 |
+
|
156 |
+
|
157 |
+
class ConvBlock(nn.Module):
|
158 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
|
159 |
+
super().__init__()
|
160 |
+
self._n_groups = 8
|
161 |
+
self.blocks = nn.ModuleList(
|
162 |
+
[
|
163 |
+
self._get_conv(
|
164 |
+
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
165 |
+
)
|
166 |
+
for i in range(n_conv)
|
167 |
+
]
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
for block in self.blocks:
|
172 |
+
res = x
|
173 |
+
x = block(x)
|
174 |
+
x += res
|
175 |
+
return x
|
176 |
+
|
177 |
+
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
|
178 |
+
layers = [
|
179 |
+
ConvNorm(
|
180 |
+
hidden_dim,
|
181 |
+
hidden_dim,
|
182 |
+
kernel_size=3,
|
183 |
+
padding=dilation,
|
184 |
+
dilation=dilation,
|
185 |
+
),
|
186 |
+
_get_activation_fn(activ),
|
187 |
+
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
188 |
+
nn.Dropout(p=dropout_p),
|
189 |
+
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
190 |
+
_get_activation_fn(activ),
|
191 |
+
nn.Dropout(p=dropout_p),
|
192 |
+
]
|
193 |
+
return nn.Sequential(*layers)
|
194 |
+
|
195 |
+
|
196 |
+
class LocationLayer(nn.Module):
|
197 |
+
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
|
198 |
+
super(LocationLayer, self).__init__()
|
199 |
+
padding = int((attention_kernel_size - 1) / 2)
|
200 |
+
self.location_conv = ConvNorm(
|
201 |
+
2,
|
202 |
+
attention_n_filters,
|
203 |
+
kernel_size=attention_kernel_size,
|
204 |
+
padding=padding,
|
205 |
+
bias=False,
|
206 |
+
stride=1,
|
207 |
+
dilation=1,
|
208 |
+
)
|
209 |
+
self.location_dense = LinearNorm(
|
210 |
+
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
|
211 |
+
)
|
212 |
+
|
213 |
+
def forward(self, attention_weights_cat):
|
214 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
215 |
+
processed_attention = processed_attention.transpose(1, 2)
|
216 |
+
processed_attention = self.location_dense(processed_attention)
|
217 |
+
return processed_attention
|
218 |
+
|
219 |
+
|
220 |
+
class Attention(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
attention_rnn_dim,
|
224 |
+
embedding_dim,
|
225 |
+
attention_dim,
|
226 |
+
attention_location_n_filters,
|
227 |
+
attention_location_kernel_size,
|
228 |
+
):
|
229 |
+
super(Attention, self).__init__()
|
230 |
+
self.query_layer = LinearNorm(
|
231 |
+
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
232 |
+
)
|
233 |
+
self.memory_layer = LinearNorm(
|
234 |
+
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
235 |
+
)
|
236 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
237 |
+
self.location_layer = LocationLayer(
|
238 |
+
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
239 |
+
)
|
240 |
+
self.score_mask_value = -float("inf")
|
241 |
+
|
242 |
+
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
243 |
+
"""
|
244 |
+
PARAMS
|
245 |
+
------
|
246 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
247 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
248 |
+
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
249 |
+
RETURNS
|
250 |
+
-------
|
251 |
+
alignment (batch, max_time)
|
252 |
+
"""
|
253 |
+
|
254 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
255 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
256 |
+
energies = self.v(
|
257 |
+
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
258 |
+
)
|
259 |
+
|
260 |
+
energies = energies.squeeze(-1)
|
261 |
+
return energies
|
262 |
+
|
263 |
+
def forward(
|
264 |
+
self,
|
265 |
+
attention_hidden_state,
|
266 |
+
memory,
|
267 |
+
processed_memory,
|
268 |
+
attention_weights_cat,
|
269 |
+
mask,
|
270 |
+
):
|
271 |
+
"""
|
272 |
+
PARAMS
|
273 |
+
------
|
274 |
+
attention_hidden_state: attention rnn last output
|
275 |
+
memory: encoder outputs
|
276 |
+
processed_memory: processed encoder outputs
|
277 |
+
attention_weights_cat: previous and cummulative attention weights
|
278 |
+
mask: binary mask for padded data
|
279 |
+
"""
|
280 |
+
alignment = self.get_alignment_energies(
|
281 |
+
attention_hidden_state, processed_memory, attention_weights_cat
|
282 |
+
)
|
283 |
+
|
284 |
+
if mask is not None:
|
285 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
286 |
+
|
287 |
+
attention_weights = F.softmax(alignment, dim=1)
|
288 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
289 |
+
attention_context = attention_context.squeeze(1)
|
290 |
+
|
291 |
+
return attention_context, attention_weights
|
292 |
+
|
293 |
+
|
294 |
+
class ForwardAttentionV2(nn.Module):
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
attention_rnn_dim,
|
298 |
+
embedding_dim,
|
299 |
+
attention_dim,
|
300 |
+
attention_location_n_filters,
|
301 |
+
attention_location_kernel_size,
|
302 |
+
):
|
303 |
+
super(ForwardAttentionV2, self).__init__()
|
304 |
+
self.query_layer = LinearNorm(
|
305 |
+
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
306 |
+
)
|
307 |
+
self.memory_layer = LinearNorm(
|
308 |
+
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
309 |
+
)
|
310 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
311 |
+
self.location_layer = LocationLayer(
|
312 |
+
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
313 |
+
)
|
314 |
+
self.score_mask_value = -float(1e20)
|
315 |
+
|
316 |
+
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
317 |
+
"""
|
318 |
+
PARAMS
|
319 |
+
------
|
320 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
321 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
322 |
+
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
|
323 |
+
RETURNS
|
324 |
+
-------
|
325 |
+
alignment (batch, max_time)
|
326 |
+
"""
|
327 |
+
|
328 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
329 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
330 |
+
energies = self.v(
|
331 |
+
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
332 |
+
)
|
333 |
+
|
334 |
+
energies = energies.squeeze(-1)
|
335 |
+
return energies
|
336 |
+
|
337 |
+
def forward(
|
338 |
+
self,
|
339 |
+
attention_hidden_state,
|
340 |
+
memory,
|
341 |
+
processed_memory,
|
342 |
+
attention_weights_cat,
|
343 |
+
mask,
|
344 |
+
log_alpha,
|
345 |
+
):
|
346 |
+
"""
|
347 |
+
PARAMS
|
348 |
+
------
|
349 |
+
attention_hidden_state: attention rnn last output
|
350 |
+
memory: encoder outputs
|
351 |
+
processed_memory: processed encoder outputs
|
352 |
+
attention_weights_cat: previous and cummulative attention weights
|
353 |
+
mask: binary mask for padded data
|
354 |
+
"""
|
355 |
+
log_energy = self.get_alignment_energies(
|
356 |
+
attention_hidden_state, processed_memory, attention_weights_cat
|
357 |
+
)
|
358 |
+
|
359 |
+
# log_energy =
|
360 |
+
|
361 |
+
if mask is not None:
|
362 |
+
log_energy.data.masked_fill_(mask, self.score_mask_value)
|
363 |
+
|
364 |
+
# attention_weights = F.softmax(alignment, dim=1)
|
365 |
+
|
366 |
+
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
|
367 |
+
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
|
368 |
+
|
369 |
+
# log_total_score = log_alpha + content_score
|
370 |
+
|
371 |
+
# previous_attention_weights = attention_weights_cat[:,0,:]
|
372 |
+
|
373 |
+
log_alpha_shift_padded = []
|
374 |
+
max_time = log_energy.size(1)
|
375 |
+
for sft in range(2):
|
376 |
+
shifted = log_alpha[:, : max_time - sft]
|
377 |
+
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
|
378 |
+
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
|
379 |
+
|
380 |
+
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
|
381 |
+
|
382 |
+
log_alpha_new = biased + log_energy
|
383 |
+
|
384 |
+
attention_weights = F.softmax(log_alpha_new, dim=1)
|
385 |
+
|
386 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
387 |
+
attention_context = attention_context.squeeze(1)
|
388 |
+
|
389 |
+
return attention_context, attention_weights, log_alpha_new
|
390 |
+
|
391 |
+
|
392 |
+
class PhaseShuffle2d(nn.Module):
|
393 |
+
def __init__(self, n=2):
|
394 |
+
super(PhaseShuffle2d, self).__init__()
|
395 |
+
self.n = n
|
396 |
+
self.random = random.Random(1)
|
397 |
+
|
398 |
+
def forward(self, x, move=None):
|
399 |
+
# x.size = (B, C, M, L)
|
400 |
+
if move is None:
|
401 |
+
move = self.random.randint(-self.n, self.n)
|
402 |
+
|
403 |
+
if move == 0:
|
404 |
+
return x
|
405 |
+
else:
|
406 |
+
left = x[:, :, :, :move]
|
407 |
+
right = x[:, :, :, move:]
|
408 |
+
shuffled = torch.cat([right, left], dim=3)
|
409 |
+
return shuffled
|
410 |
+
|
411 |
+
|
412 |
+
class PhaseShuffle1d(nn.Module):
|
413 |
+
def __init__(self, n=2):
|
414 |
+
super(PhaseShuffle1d, self).__init__()
|
415 |
+
self.n = n
|
416 |
+
self.random = random.Random(1)
|
417 |
+
|
418 |
+
def forward(self, x, move=None):
|
419 |
+
# x.size = (B, C, M, L)
|
420 |
+
if move is None:
|
421 |
+
move = self.random.randint(-self.n, self.n)
|
422 |
+
|
423 |
+
if move == 0:
|
424 |
+
return x
|
425 |
+
else:
|
426 |
+
left = x[:, :, :move]
|
427 |
+
right = x[:, :, move:]
|
428 |
+
shuffled = torch.cat([right, left], dim=2)
|
429 |
+
|
430 |
+
return shuffled
|
431 |
+
|
432 |
+
|
433 |
+
class MFCC(nn.Module):
|
434 |
+
def __init__(self, n_mfcc=40, n_mels=80):
|
435 |
+
super(MFCC, self).__init__()
|
436 |
+
self.n_mfcc = n_mfcc
|
437 |
+
self.n_mels = n_mels
|
438 |
+
self.norm = "ortho"
|
439 |
+
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
440 |
+
self.register_buffer("dct_mat", dct_mat)
|
441 |
+
|
442 |
+
def forward(self, mel_specgram):
|
443 |
+
if len(mel_specgram.shape) == 2:
|
444 |
+
mel_specgram = mel_specgram.unsqueeze(0)
|
445 |
+
unsqueezed = True
|
446 |
+
else:
|
447 |
+
unsqueezed = False
|
448 |
+
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
449 |
+
# -> (channel, time, n_mfcc).tranpose(...)
|
450 |
+
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
451 |
+
|
452 |
+
# unpack batch
|
453 |
+
if unsqueezed:
|
454 |
+
mfcc = mfcc.squeeze(0)
|
455 |
+
return mfcc
|
Utils/ASR/models.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import TransformerEncoder
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
|
7 |
+
|
8 |
+
|
9 |
+
class ASRCNN(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
input_dim=80,
|
13 |
+
hidden_dim=256,
|
14 |
+
n_token=35,
|
15 |
+
n_layers=6,
|
16 |
+
token_embedding_dim=256,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.n_token = n_token
|
20 |
+
self.n_down = 1
|
21 |
+
self.to_mfcc = MFCC()
|
22 |
+
self.init_cnn = ConvNorm(
|
23 |
+
input_dim // 2, hidden_dim, kernel_size=7, padding=3, stride=2
|
24 |
+
)
|
25 |
+
self.cnns = nn.Sequential(
|
26 |
+
*[
|
27 |
+
nn.Sequential(
|
28 |
+
ConvBlock(hidden_dim),
|
29 |
+
nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
|
30 |
+
)
|
31 |
+
for n in range(n_layers)
|
32 |
+
]
|
33 |
+
)
|
34 |
+
self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
|
35 |
+
self.ctc_linear = nn.Sequential(
|
36 |
+
LinearNorm(hidden_dim // 2, hidden_dim),
|
37 |
+
nn.ReLU(),
|
38 |
+
LinearNorm(hidden_dim, n_token),
|
39 |
+
)
|
40 |
+
self.asr_s2s = ASRS2S(
|
41 |
+
embedding_dim=token_embedding_dim,
|
42 |
+
hidden_dim=hidden_dim // 2,
|
43 |
+
n_token=n_token,
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x, src_key_padding_mask=None, text_input=None):
|
47 |
+
x = self.to_mfcc(x)
|
48 |
+
x = self.init_cnn(x)
|
49 |
+
x = self.cnns(x)
|
50 |
+
x = self.projection(x)
|
51 |
+
x = x.transpose(1, 2)
|
52 |
+
ctc_logit = self.ctc_linear(x)
|
53 |
+
if text_input is not None:
|
54 |
+
_, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
|
55 |
+
return ctc_logit, s2s_logit, s2s_attn
|
56 |
+
else:
|
57 |
+
return ctc_logit
|
58 |
+
|
59 |
+
def get_feature(self, x):
|
60 |
+
x = self.to_mfcc(x.squeeze(1))
|
61 |
+
x = self.init_cnn(x)
|
62 |
+
x = self.cnns(x)
|
63 |
+
x = self.projection(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
def length_to_mask(self, lengths):
|
67 |
+
mask = (
|
68 |
+
torch.arange(lengths.max())
|
69 |
+
.unsqueeze(0)
|
70 |
+
.expand(lengths.shape[0], -1)
|
71 |
+
.type_as(lengths)
|
72 |
+
)
|
73 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1)).to(lengths.device)
|
74 |
+
return mask
|
75 |
+
|
76 |
+
def get_future_mask(self, out_length, unmask_future_steps=0):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
out_length (int): returned mask shape is (out_length, out_length).
|
80 |
+
unmask_futre_steps (int): unmasking future step size.
|
81 |
+
Return:
|
82 |
+
mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
|
83 |
+
"""
|
84 |
+
index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
|
85 |
+
mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
|
86 |
+
return mask
|
87 |
+
|
88 |
+
|
89 |
+
class ASRS2S(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
embedding_dim=256,
|
93 |
+
hidden_dim=512,
|
94 |
+
n_location_filters=32,
|
95 |
+
location_kernel_size=63,
|
96 |
+
n_token=40,
|
97 |
+
):
|
98 |
+
super(ASRS2S, self).__init__()
|
99 |
+
self.embedding = nn.Embedding(n_token, embedding_dim)
|
100 |
+
val_range = math.sqrt(6 / hidden_dim)
|
101 |
+
self.embedding.weight.data.uniform_(-val_range, val_range)
|
102 |
+
|
103 |
+
self.decoder_rnn_dim = hidden_dim
|
104 |
+
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
|
105 |
+
self.attention_layer = Attention(
|
106 |
+
self.decoder_rnn_dim,
|
107 |
+
hidden_dim,
|
108 |
+
hidden_dim,
|
109 |
+
n_location_filters,
|
110 |
+
location_kernel_size,
|
111 |
+
)
|
112 |
+
self.decoder_rnn = nn.LSTMCell(
|
113 |
+
self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim
|
114 |
+
)
|
115 |
+
self.project_to_hidden = nn.Sequential(
|
116 |
+
LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), nn.Tanh()
|
117 |
+
)
|
118 |
+
self.sos = 1
|
119 |
+
self.eos = 2
|
120 |
+
|
121 |
+
def initialize_decoder_states(self, memory, mask):
|
122 |
+
"""
|
123 |
+
moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
|
124 |
+
"""
|
125 |
+
B, L, H = memory.shape
|
126 |
+
self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
|
127 |
+
self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
|
128 |
+
self.attention_weights = torch.zeros((B, L)).type_as(memory)
|
129 |
+
self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
|
130 |
+
self.attention_context = torch.zeros((B, H)).type_as(memory)
|
131 |
+
self.memory = memory
|
132 |
+
self.processed_memory = self.attention_layer.memory_layer(memory)
|
133 |
+
self.mask = mask
|
134 |
+
self.unk_index = 3
|
135 |
+
self.random_mask = 0.1
|
136 |
+
|
137 |
+
def forward(self, memory, memory_mask, text_input):
|
138 |
+
"""
|
139 |
+
moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
|
140 |
+
moemory_mask.shape = (B, L, )
|
141 |
+
texts_input.shape = (B, T)
|
142 |
+
"""
|
143 |
+
self.initialize_decoder_states(memory, memory_mask)
|
144 |
+
# text random mask
|
145 |
+
random_mask = (torch.rand(text_input.shape) < self.random_mask).to(
|
146 |
+
text_input.device
|
147 |
+
)
|
148 |
+
_text_input = text_input.clone()
|
149 |
+
_text_input.masked_fill_(random_mask, self.unk_index)
|
150 |
+
decoder_inputs = self.embedding(_text_input).transpose(
|
151 |
+
0, 1
|
152 |
+
) # -> [T, B, channel]
|
153 |
+
start_embedding = self.embedding(
|
154 |
+
torch.LongTensor([self.sos] * decoder_inputs.size(1)).to(
|
155 |
+
decoder_inputs.device
|
156 |
+
)
|
157 |
+
)
|
158 |
+
decoder_inputs = torch.cat(
|
159 |
+
(start_embedding.unsqueeze(0), decoder_inputs), dim=0
|
160 |
+
)
|
161 |
+
|
162 |
+
hidden_outputs, logit_outputs, alignments = [], [], []
|
163 |
+
while len(hidden_outputs) < decoder_inputs.size(0):
|
164 |
+
decoder_input = decoder_inputs[len(hidden_outputs)]
|
165 |
+
hidden, logit, attention_weights = self.decode(decoder_input)
|
166 |
+
hidden_outputs += [hidden]
|
167 |
+
logit_outputs += [logit]
|
168 |
+
alignments += [attention_weights]
|
169 |
+
|
170 |
+
hidden_outputs, logit_outputs, alignments = self.parse_decoder_outputs(
|
171 |
+
hidden_outputs, logit_outputs, alignments
|
172 |
+
)
|
173 |
+
|
174 |
+
return hidden_outputs, logit_outputs, alignments
|
175 |
+
|
176 |
+
def decode(self, decoder_input):
|
177 |
+
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
178 |
+
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
179 |
+
cell_input, (self.decoder_hidden, self.decoder_cell)
|
180 |
+
)
|
181 |
+
|
182 |
+
attention_weights_cat = torch.cat(
|
183 |
+
(
|
184 |
+
self.attention_weights.unsqueeze(1),
|
185 |
+
self.attention_weights_cum.unsqueeze(1),
|
186 |
+
),
|
187 |
+
dim=1,
|
188 |
+
)
|
189 |
+
|
190 |
+
self.attention_context, self.attention_weights = self.attention_layer(
|
191 |
+
self.decoder_hidden,
|
192 |
+
self.memory,
|
193 |
+
self.processed_memory,
|
194 |
+
attention_weights_cat,
|
195 |
+
self.mask,
|
196 |
+
)
|
197 |
+
|
198 |
+
self.attention_weights_cum += self.attention_weights
|
199 |
+
|
200 |
+
hidden_and_context = torch.cat(
|
201 |
+
(self.decoder_hidden, self.attention_context), -1
|
202 |
+
)
|
203 |
+
hidden = self.project_to_hidden(hidden_and_context)
|
204 |
+
|
205 |
+
# dropout to increasing g
|
206 |
+
logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
|
207 |
+
|
208 |
+
return hidden, logit, self.attention_weights
|
209 |
+
|
210 |
+
def parse_decoder_outputs(self, hidden, logit, alignments):
|
211 |
+
# -> [B, T_out + 1, max_time]
|
212 |
+
alignments = torch.stack(alignments).transpose(0, 1)
|
213 |
+
# [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
|
214 |
+
logit = torch.stack(logit).transpose(0, 1).contiguous()
|
215 |
+
hidden = torch.stack(hidden).transpose(0, 1).contiguous()
|
216 |
+
|
217 |
+
return hidden, logit, alignments
|
Utils/JDC/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Utils/JDC/bst.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
|
3 |
+
size 21029926
|
Utils/JDC/model.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of model from:
|
3 |
+
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
|
4 |
+
Convolutional Recurrent Neural Networks" (2019)
|
5 |
+
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class JDCNet(nn.Module):
|
12 |
+
"""
|
13 |
+
Joint Detection and Classification Network model for singing voice melody.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
|
17 |
+
super().__init__()
|
18 |
+
self.num_class = num_class
|
19 |
+
|
20 |
+
# input = (b, 1, 31, 513), b = batch size
|
21 |
+
self.conv_block = nn.Sequential(
|
22 |
+
nn.Conv2d(
|
23 |
+
in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
|
24 |
+
), # out: (b, 64, 31, 513)
|
25 |
+
nn.BatchNorm2d(num_features=64),
|
26 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
27 |
+
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
|
28 |
+
)
|
29 |
+
|
30 |
+
# res blocks
|
31 |
+
self.res_block1 = ResBlock(
|
32 |
+
in_channels=64, out_channels=128
|
33 |
+
) # (b, 128, 31, 128)
|
34 |
+
self.res_block2 = ResBlock(
|
35 |
+
in_channels=128, out_channels=192
|
36 |
+
) # (b, 192, 31, 32)
|
37 |
+
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
|
38 |
+
|
39 |
+
# pool block
|
40 |
+
self.pool_block = nn.Sequential(
|
41 |
+
nn.BatchNorm2d(num_features=256),
|
42 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
43 |
+
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
|
44 |
+
nn.Dropout(p=0.2),
|
45 |
+
)
|
46 |
+
|
47 |
+
# maxpool layers (for auxiliary network inputs)
|
48 |
+
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
|
49 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
|
50 |
+
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
|
51 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
|
52 |
+
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
|
53 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
|
54 |
+
|
55 |
+
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
|
56 |
+
self.detector_conv = nn.Sequential(
|
57 |
+
nn.Conv2d(640, 256, 1, bias=False),
|
58 |
+
nn.BatchNorm2d(256),
|
59 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
60 |
+
nn.Dropout(p=0.2),
|
61 |
+
)
|
62 |
+
|
63 |
+
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
64 |
+
self.bilstm_classifier = nn.LSTM(
|
65 |
+
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
|
66 |
+
) # (b, 31, 512)
|
67 |
+
|
68 |
+
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
69 |
+
self.bilstm_detector = nn.LSTM(
|
70 |
+
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
|
71 |
+
) # (b, 31, 512)
|
72 |
+
|
73 |
+
# input: (b * 31, 512)
|
74 |
+
self.classifier = nn.Linear(
|
75 |
+
in_features=512, out_features=self.num_class
|
76 |
+
) # (b * 31, num_class)
|
77 |
+
|
78 |
+
# input: (b * 31, 512)
|
79 |
+
self.detector = nn.Linear(
|
80 |
+
in_features=512, out_features=2
|
81 |
+
) # (b * 31, 2) - binary classifier
|
82 |
+
|
83 |
+
# initialize weights
|
84 |
+
self.apply(self.init_weights)
|
85 |
+
|
86 |
+
def get_feature_GAN(self, x):
|
87 |
+
seq_len = x.shape[-2]
|
88 |
+
x = x.float().transpose(-1, -2)
|
89 |
+
|
90 |
+
convblock_out = self.conv_block(x)
|
91 |
+
|
92 |
+
resblock1_out = self.res_block1(convblock_out)
|
93 |
+
resblock2_out = self.res_block2(resblock1_out)
|
94 |
+
resblock3_out = self.res_block3(resblock2_out)
|
95 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
96 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
97 |
+
|
98 |
+
return poolblock_out.transpose(-1, -2)
|
99 |
+
|
100 |
+
def get_feature(self, x):
|
101 |
+
seq_len = x.shape[-2]
|
102 |
+
x = x.float().transpose(-1, -2)
|
103 |
+
|
104 |
+
convblock_out = self.conv_block(x)
|
105 |
+
|
106 |
+
resblock1_out = self.res_block1(convblock_out)
|
107 |
+
resblock2_out = self.res_block2(resblock1_out)
|
108 |
+
resblock3_out = self.res_block3(resblock2_out)
|
109 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
110 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
111 |
+
|
112 |
+
return self.pool_block[2](poolblock_out)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
"""
|
116 |
+
Returns:
|
117 |
+
classification_prediction, detection_prediction
|
118 |
+
sizes: (b, 31, 722), (b, 31, 2)
|
119 |
+
"""
|
120 |
+
###############################
|
121 |
+
# forward pass for classifier #
|
122 |
+
###############################
|
123 |
+
seq_len = x.shape[-1]
|
124 |
+
x = x.float().transpose(-1, -2)
|
125 |
+
|
126 |
+
convblock_out = self.conv_block(x)
|
127 |
+
|
128 |
+
resblock1_out = self.res_block1(convblock_out)
|
129 |
+
resblock2_out = self.res_block2(resblock1_out)
|
130 |
+
resblock3_out = self.res_block3(resblock2_out)
|
131 |
+
|
132 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
133 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
134 |
+
GAN_feature = poolblock_out.transpose(-1, -2)
|
135 |
+
poolblock_out = self.pool_block[2](poolblock_out)
|
136 |
+
|
137 |
+
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
|
138 |
+
classifier_out = (
|
139 |
+
poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
|
140 |
+
)
|
141 |
+
classifier_out, _ = self.bilstm_classifier(
|
142 |
+
classifier_out
|
143 |
+
) # ignore the hidden states
|
144 |
+
|
145 |
+
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
|
146 |
+
classifier_out = self.classifier(classifier_out)
|
147 |
+
classifier_out = classifier_out.view(
|
148 |
+
(-1, seq_len, self.num_class)
|
149 |
+
) # (b, 31, num_class)
|
150 |
+
|
151 |
+
# sizes: (b, 31, 722), (b, 31, 2)
|
152 |
+
# classifier output consists of predicted pitch classes per frame
|
153 |
+
# detector output consists of: (isvoice, notvoice) estimates per frame
|
154 |
+
return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def init_weights(m):
|
158 |
+
if isinstance(m, nn.Linear):
|
159 |
+
nn.init.kaiming_uniform_(m.weight)
|
160 |
+
if m.bias is not None:
|
161 |
+
nn.init.constant_(m.bias, 0)
|
162 |
+
elif isinstance(m, nn.Conv2d):
|
163 |
+
nn.init.xavier_normal_(m.weight)
|
164 |
+
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
|
165 |
+
for p in m.parameters():
|
166 |
+
if p.data is None:
|
167 |
+
continue
|
168 |
+
|
169 |
+
if len(p.shape) >= 2:
|
170 |
+
nn.init.orthogonal_(p.data)
|
171 |
+
else:
|
172 |
+
nn.init.normal_(p.data)
|
173 |
+
|
174 |
+
|
175 |
+
class ResBlock(nn.Module):
|
176 |
+
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
|
177 |
+
super().__init__()
|
178 |
+
self.downsample = in_channels != out_channels
|
179 |
+
|
180 |
+
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
|
181 |
+
self.pre_conv = nn.Sequential(
|
182 |
+
nn.BatchNorm2d(num_features=in_channels),
|
183 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
184 |
+
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
|
185 |
+
)
|
186 |
+
|
187 |
+
# conv layers
|
188 |
+
self.conv = nn.Sequential(
|
189 |
+
nn.Conv2d(
|
190 |
+
in_channels=in_channels,
|
191 |
+
out_channels=out_channels,
|
192 |
+
kernel_size=3,
|
193 |
+
padding=1,
|
194 |
+
bias=False,
|
195 |
+
),
|
196 |
+
nn.BatchNorm2d(out_channels),
|
197 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
198 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
199 |
+
)
|
200 |
+
|
201 |
+
# 1 x 1 convolution layer to match the feature dimensions
|
202 |
+
self.conv1by1 = None
|
203 |
+
if self.downsample:
|
204 |
+
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
x = self.pre_conv(x)
|
208 |
+
if self.downsample:
|
209 |
+
x = self.conv(x) + self.conv1by1(x)
|
210 |
+
else:
|
211 |
+
x = self.conv(x) + x
|
212 |
+
return x
|
Utils/PLBERT/config.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "Checkpoint"
|
2 |
+
mixed_precision: "fp16"
|
3 |
+
data_folder: "wikipedia_20220301.en.processed"
|
4 |
+
batch_size: 192
|
5 |
+
save_interval: 5000
|
6 |
+
log_interval: 10
|
7 |
+
num_process: 1 # number of GPUs
|
8 |
+
num_steps: 1000000
|
9 |
+
|
10 |
+
dataset_params:
|
11 |
+
tokenizer: "transfo-xl-wt103"
|
12 |
+
token_separator: " " # token used for phoneme separator (space)
|
13 |
+
token_mask: "M" # token used for phoneme mask (M)
|
14 |
+
word_separator: 3039 # token used for word separator (<formula>)
|
15 |
+
token_maps: "token_maps.pkl" # token map path
|
16 |
+
|
17 |
+
max_mel_length: 512 # max phoneme length
|
18 |
+
|
19 |
+
word_mask_prob: 0.15 # probability to mask the entire word
|
20 |
+
phoneme_mask_prob: 0.1 # probability to mask each phoneme
|
21 |
+
replace_prob: 0.2 # probablity to replace phonemes
|
22 |
+
|
23 |
+
model_params:
|
24 |
+
vocab_size: 178
|
25 |
+
hidden_size: 768
|
26 |
+
num_attention_heads: 12
|
27 |
+
intermediate_size: 2048
|
28 |
+
max_position_embeddings: 512
|
29 |
+
num_hidden_layers: 12
|
30 |
+
dropout: 0.1
|
Utils/PLBERT/step_1000000.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
|
3 |
+
size 25185187
|
Utils/PLBERT/util.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
from transformers import AlbertConfig, AlbertModel
|
5 |
+
|
6 |
+
|
7 |
+
class CustomAlbert(AlbertModel):
|
8 |
+
def forward(self, *args, **kwargs):
|
9 |
+
# Call the original forward method
|
10 |
+
outputs = super().forward(*args, **kwargs)
|
11 |
+
|
12 |
+
# Only return the last_hidden_state
|
13 |
+
return outputs.last_hidden_state
|
14 |
+
|
15 |
+
|
16 |
+
def load_plbert(log_dir):
|
17 |
+
config_path = os.path.join(log_dir, "config.yml")
|
18 |
+
plbert_config = yaml.safe_load(open(config_path))
|
19 |
+
|
20 |
+
albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
|
21 |
+
bert = CustomAlbert(albert_base_configuration)
|
22 |
+
|
23 |
+
files = os.listdir(log_dir)
|
24 |
+
ckpts = []
|
25 |
+
for f in os.listdir(log_dir):
|
26 |
+
if f.startswith("step_"):
|
27 |
+
ckpts.append(f)
|
28 |
+
|
29 |
+
iters = [
|
30 |
+
int(f.split("_")[-1].split(".")[0])
|
31 |
+
for f in ckpts
|
32 |
+
if os.path.isfile(os.path.join(log_dir, f))
|
33 |
+
]
|
34 |
+
iters = sorted(iters)[-1]
|
35 |
+
|
36 |
+
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu")
|
37 |
+
state_dict = checkpoint["net"]
|
38 |
+
from collections import OrderedDict
|
39 |
+
|
40 |
+
new_state_dict = OrderedDict()
|
41 |
+
for k, v in state_dict.items():
|
42 |
+
name = k[7:] # remove `module.`
|
43 |
+
if name.startswith("encoder."):
|
44 |
+
name = name[8:] # remove `encoder.`
|
45 |
+
new_state_dict[name] = v
|
46 |
+
del new_state_dict["embeddings.position_ids"]
|
47 |
+
bert.load_state_dict(new_state_dict, strict=False)
|
48 |
+
|
49 |
+
return bert
|
Utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
_run.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cached_path import cached_path
|
2 |
+
|
3 |
+
from dp.phonemizer import Phonemizer
|
4 |
+
print("NLTK")
|
5 |
+
import nltk
|
6 |
+
nltk.download('punkt')
|
7 |
+
print("SCIPY")
|
8 |
+
from scipy.io.wavfile import write
|
9 |
+
print("TORCH STUFF")
|
10 |
+
import torch
|
11 |
+
print("START")
|
12 |
+
torch.manual_seed(0)
|
13 |
+
torch.backends.cudnn.benchmark = False
|
14 |
+
torch.backends.cudnn.deterministic = True
|
15 |
+
|
16 |
+
import random
|
17 |
+
random.seed(0)
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
np.random.seed(0)
|
21 |
+
|
22 |
+
# load packages
|
23 |
+
import time
|
24 |
+
import random
|
25 |
+
import yaml
|
26 |
+
from munch import Munch
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from torch import nn
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import torchaudio
|
32 |
+
import librosa
|
33 |
+
from nltk.tokenize import word_tokenize
|
34 |
+
|
35 |
+
from models import *
|
36 |
+
from utils import *
|
37 |
+
from text_utils import TextCleaner
|
38 |
+
textclenaer = TextCleaner()
|
39 |
+
|
40 |
+
|
41 |
+
to_mel = torchaudio.transforms.MelSpectrogram(
|
42 |
+
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
43 |
+
mean, std = -4, 4
|
44 |
+
|
45 |
+
def length_to_mask(lengths):
|
46 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
47 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
48 |
+
return mask
|
49 |
+
|
50 |
+
def preprocess(wave):
|
51 |
+
wave_tensor = torch.from_numpy(wave).float()
|
52 |
+
mel_tensor = to_mel(wave_tensor)
|
53 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
54 |
+
return mel_tensor
|
55 |
+
|
56 |
+
def compute_style(path):
|
57 |
+
wave, sr = librosa.load(path, sr=24000)
|
58 |
+
audio, index = librosa.effects.trim(wave, top_db=30)
|
59 |
+
if sr != 24000:
|
60 |
+
audio = librosa.resample(audio, sr, 24000)
|
61 |
+
mel_tensor = preprocess(audio).to(device)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
|
65 |
+
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
|
66 |
+
|
67 |
+
return torch.cat([ref_s, ref_p], dim=1)
|
68 |
+
|
69 |
+
device = 'cpu'
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
device = 'cuda'
|
72 |
+
elif torch.backends.mps.is_available():
|
73 |
+
print("MPS would be available but cannot be used rn")
|
74 |
+
# device = 'mps'
|
75 |
+
|
76 |
+
|
77 |
+
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
78 |
+
phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
79 |
+
|
80 |
+
|
81 |
+
config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
|
82 |
+
|
83 |
+
# load pretrained ASR model
|
84 |
+
ASR_config = config.get('ASR_config', False)
|
85 |
+
ASR_path = config.get('ASR_path', False)
|
86 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
87 |
+
|
88 |
+
# load pretrained F0 model
|
89 |
+
F0_path = config.get('F0_path', False)
|
90 |
+
pitch_extractor = load_F0_models(F0_path)
|
91 |
+
|
92 |
+
# load BERT model
|
93 |
+
from Utils.PLBERT.util import load_plbert
|
94 |
+
BERT_path = config.get('PLBERT_dir', False)
|
95 |
+
plbert = load_plbert(BERT_path)
|
96 |
+
|
97 |
+
model_params = recursive_munch(config['model_params'])
|
98 |
+
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
99 |
+
_ = [model[key].eval() for key in model]
|
100 |
+
_ = [model[key].to(device) for key in model]
|
101 |
+
|
102 |
+
params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
|
103 |
+
params = params_whole['net']
|
104 |
+
|
105 |
+
for key in model:
|
106 |
+
if key in params:
|
107 |
+
print('%s loaded' % key)
|
108 |
+
try:
|
109 |
+
model[key].load_state_dict(params[key])
|
110 |
+
except:
|
111 |
+
from collections import OrderedDict
|
112 |
+
state_dict = params[key]
|
113 |
+
new_state_dict = OrderedDict()
|
114 |
+
for k, v in state_dict.items():
|
115 |
+
name = k[7:] # remove `module.`
|
116 |
+
new_state_dict[name] = v
|
117 |
+
# load params
|
118 |
+
model[key].load_state_dict(new_state_dict, strict=False)
|
119 |
+
# except:
|
120 |
+
# _load(params[key], model[key])
|
121 |
+
_ = [model[key].eval() for key in model]
|
122 |
+
|
123 |
+
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
124 |
+
|
125 |
+
sampler = DiffusionSampler(
|
126 |
+
model.diffusion.diffusion,
|
127 |
+
sampler=ADPM2Sampler(),
|
128 |
+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
|
129 |
+
clamp=False
|
130 |
+
)
|
131 |
+
|
132 |
+
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
133 |
+
text = text.strip()
|
134 |
+
ps = phonemizer([text], lang='en_us')
|
135 |
+
ps = word_tokenize(ps[0])
|
136 |
+
ps = ' '.join(ps)
|
137 |
+
tokens = textclenaer(ps)
|
138 |
+
tokens.insert(0, 0)
|
139 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
143 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
144 |
+
|
145 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
146 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
147 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
148 |
+
|
149 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
150 |
+
embedding=bert_dur,
|
151 |
+
embedding_scale=embedding_scale,
|
152 |
+
features=ref_s, # reference from the same speaker as the embedding
|
153 |
+
num_steps=diffusion_steps).squeeze(1)
|
154 |
+
|
155 |
+
|
156 |
+
s = s_pred[:, 128:]
|
157 |
+
ref = s_pred[:, :128]
|
158 |
+
|
159 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
160 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
161 |
+
|
162 |
+
d = model.predictor.text_encoder(d_en,
|
163 |
+
s, input_lengths, text_mask)
|
164 |
+
|
165 |
+
x, _ = model.predictor.lstm(d)
|
166 |
+
duration = model.predictor.duration_proj(x)
|
167 |
+
|
168 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
169 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
170 |
+
|
171 |
+
|
172 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
173 |
+
c_frame = 0
|
174 |
+
for i in range(pred_aln_trg.size(0)):
|
175 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
176 |
+
c_frame += int(pred_dur[i].data)
|
177 |
+
|
178 |
+
# encode prosody
|
179 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
180 |
+
if model_params.decoder.type == "hifigan":
|
181 |
+
asr_new = torch.zeros_like(en)
|
182 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
183 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
184 |
+
en = asr_new
|
185 |
+
|
186 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
187 |
+
|
188 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
189 |
+
if model_params.decoder.type == "hifigan":
|
190 |
+
asr_new = torch.zeros_like(asr)
|
191 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
192 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
193 |
+
asr = asr_new
|
194 |
+
|
195 |
+
out = model.decoder(asr,
|
196 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
197 |
+
|
198 |
+
|
199 |
+
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
200 |
+
|
201 |
+
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
|
202 |
+
text = text.strip()
|
203 |
+
ps = phonemizer([text], lang='en_us')
|
204 |
+
ps = word_tokenize(ps[0])
|
205 |
+
ps = ' '.join(ps)
|
206 |
+
ps = ps.replace('``', '"')
|
207 |
+
ps = ps.replace("''", '"')
|
208 |
+
|
209 |
+
tokens = textclenaer(ps)
|
210 |
+
tokens.insert(0, 0)
|
211 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
215 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
216 |
+
|
217 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
218 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
219 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
220 |
+
|
221 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
222 |
+
embedding=bert_dur,
|
223 |
+
embedding_scale=embedding_scale,
|
224 |
+
features=ref_s, # reference from the same speaker as the embedding
|
225 |
+
num_steps=diffusion_steps).squeeze(1)
|
226 |
+
|
227 |
+
if s_prev is not None:
|
228 |
+
# convex combination of previous and current style
|
229 |
+
s_pred = t * s_prev + (1 - t) * s_pred
|
230 |
+
|
231 |
+
s = s_pred[:, 128:]
|
232 |
+
ref = s_pred[:, :128]
|
233 |
+
|
234 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
235 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
236 |
+
|
237 |
+
s_pred = torch.cat([ref, s], dim=-1)
|
238 |
+
|
239 |
+
d = model.predictor.text_encoder(d_en,
|
240 |
+
s, input_lengths, text_mask)
|
241 |
+
|
242 |
+
x, _ = model.predictor.lstm(d)
|
243 |
+
duration = model.predictor.duration_proj(x)
|
244 |
+
|
245 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
246 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
247 |
+
|
248 |
+
|
249 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
250 |
+
c_frame = 0
|
251 |
+
for i in range(pred_aln_trg.size(0)):
|
252 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
253 |
+
c_frame += int(pred_dur[i].data)
|
254 |
+
|
255 |
+
# encode prosody
|
256 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
257 |
+
if model_params.decoder.type == "hifigan":
|
258 |
+
asr_new = torch.zeros_like(en)
|
259 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
260 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
261 |
+
en = asr_new
|
262 |
+
|
263 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
264 |
+
|
265 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
266 |
+
if model_params.decoder.type == "hifigan":
|
267 |
+
asr_new = torch.zeros_like(asr)
|
268 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
269 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
270 |
+
asr = asr_new
|
271 |
+
|
272 |
+
out = model.decoder(asr,
|
273 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
274 |
+
|
275 |
+
|
276 |
+
return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
|
277 |
+
|
278 |
+
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
279 |
+
text = text.strip()
|
280 |
+
ps = phonemizer([text], lang='en_us')
|
281 |
+
ps = word_tokenize(ps[0])
|
282 |
+
ps = ' '.join(ps)
|
283 |
+
|
284 |
+
tokens = textclenaer(ps)
|
285 |
+
tokens.insert(0, 0)
|
286 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
287 |
+
|
288 |
+
ref_text = ref_text.strip()
|
289 |
+
ps = phonemizer([ref_text], lang='en_us')
|
290 |
+
ps = word_tokenize(ps[0])
|
291 |
+
ps = ' '.join(ps)
|
292 |
+
|
293 |
+
ref_tokens = textclenaer(ps)
|
294 |
+
ref_tokens.insert(0, 0)
|
295 |
+
ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
|
296 |
+
|
297 |
+
|
298 |
+
with torch.no_grad():
|
299 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
300 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
301 |
+
|
302 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
303 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
304 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
305 |
+
|
306 |
+
ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
|
307 |
+
ref_text_mask = length_to_mask(ref_input_lengths).to(device)
|
308 |
+
ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
|
309 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
310 |
+
embedding=bert_dur,
|
311 |
+
embedding_scale=embedding_scale,
|
312 |
+
features=ref_s, # reference from the same speaker as the embedding
|
313 |
+
num_steps=diffusion_steps).squeeze(1)
|
314 |
+
|
315 |
+
|
316 |
+
s = s_pred[:, 128:]
|
317 |
+
ref = s_pred[:, :128]
|
318 |
+
|
319 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
320 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
321 |
+
|
322 |
+
d = model.predictor.text_encoder(d_en,
|
323 |
+
s, input_lengths, text_mask)
|
324 |
+
|
325 |
+
x, _ = model.predictor.lstm(d)
|
326 |
+
duration = model.predictor.duration_proj(x)
|
327 |
+
|
328 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
329 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
330 |
+
|
331 |
+
|
332 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
333 |
+
c_frame = 0
|
334 |
+
for i in range(pred_aln_trg.size(0)):
|
335 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
336 |
+
c_frame += int(pred_dur[i].data)
|
337 |
+
|
338 |
+
# encode prosody
|
339 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
340 |
+
if model_params.decoder.type == "hifigan":
|
341 |
+
asr_new = torch.zeros_like(en)
|
342 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
343 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
344 |
+
en = asr_new
|
345 |
+
|
346 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
347 |
+
|
348 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
349 |
+
if model_params.decoder.type == "hifigan":
|
350 |
+
asr_new = torch.zeros_like(asr)
|
351 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
352 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
353 |
+
asr = asr_new
|
354 |
+
|
355 |
+
out = model.decoder(asr,
|
356 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
357 |
+
|
358 |
+
|
359 |
+
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
360 |
+
print("Time to synthesize!")
|
361 |
+
ref_s = compute_style('./voice/voice.wav')
|
362 |
+
while True:
|
363 |
+
text = input("What to say? > ")
|
364 |
+
start = time.time()
|
365 |
+
wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=15, embedding_scale=1)
|
366 |
+
rtf = (time.time() - start) / (len(wav) / 24000)
|
367 |
+
print(f"RTF = {rtf:5f}")
|
368 |
+
print(k + ' Synthesized:')
|
369 |
+
# display(ipd.Audio(wav, rate=24000, normalize=False))
|
370 |
+
write('result.wav', 24000, wav)
|
371 |
+
print("Saved to result.wav")
|
app.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import styletts2importable
|
3 |
+
theme = gr.themes.Base(
|
4 |
+
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
|
5 |
+
)
|
6 |
+
voices = {
|
7 |
+
'angie': styletts2importable.compute_style('voices/angie.wav'),
|
8 |
+
'daniel': styletts2importable.compute_style('voices/daniel.wav'),
|
9 |
+
'dotrice': styletts2importable.compute_style('voices/dotrice.wav'),
|
10 |
+
'lj': styletts2importable.compute_style('voices/lj.wav'),
|
11 |
+
'mouse': styletts2importable.compute_style('voices/mouse.wav'),
|
12 |
+
'pat': styletts2importable.compute_style('voices/pat.wav'),
|
13 |
+
'tom': styletts2importable.compute_style('voices/tom.wav'),
|
14 |
+
'william': styletts2importable.compute_style('voices/william.wav'),
|
15 |
+
}
|
16 |
+
def synthesize(text, voice):
|
17 |
+
if text.strip() == "":
|
18 |
+
raise gr.Error("You must enter some text")
|
19 |
+
v = voice.lower()
|
20 |
+
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=15, embedding_scale=1))
|
21 |
+
|
22 |
+
with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
|
23 |
+
gr.Markdown("""# StyleTTS 2
|
24 |
+
|
25 |
+
[Paper](https://arxiv.org/abs/2306.07691) - [Samples](https://styletts2.github.io/) - [Code](https://github.com/yl4579/StyleTTS2)
|
26 |
+
|
27 |
+
A free demo of StyleTTS 2. Not affiliated with the StyleTTS 2 Authors.
|
28 |
+
|
29 |
+
**Before using this demo, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.**
|
30 |
+
|
31 |
+
This space does NOT allow voice cloning. We use some default voice from Tortoise TTS instead.
|
32 |
+
|
33 |
+
Is there a long queue on this space? Duplicate it and add a GPU to skip the wait!""")
|
34 |
+
gr.DuplicateButton("Duplicate Space")
|
35 |
+
with gr.Row():
|
36 |
+
with gr.Column(scale=1):
|
37 |
+
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
38 |
+
voice = gr.Dropdown(['Angie', 'Daniel', 'Tom', 'LJ', 'Pat', 'Tom', 'Dotrice', 'Mouse', 'William'], label="Voice", info="Select a voice. We use some voices from Tortoise TTS.", value='Tom', interactive=True)
|
39 |
+
with gr.Column(scale=1):
|
40 |
+
btn = gr.Button("Synthesize")
|
41 |
+
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
42 |
+
btn.click(synthesize, inputs=[inp, voice], outputs=[audio])
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
demo.launch(show_api=False)
|
46 |
+
|
losses.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
from transformers import AutoModel
|
6 |
+
|
7 |
+
|
8 |
+
class SpectralConvergengeLoss(torch.nn.Module):
|
9 |
+
"""Spectral convergence loss module."""
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
"""Initilize spectral convergence loss module."""
|
13 |
+
super(SpectralConvergengeLoss, self).__init__()
|
14 |
+
|
15 |
+
def forward(self, x_mag, y_mag):
|
16 |
+
"""Calculate forward propagation.
|
17 |
+
Args:
|
18 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
19 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
20 |
+
Returns:
|
21 |
+
Tensor: Spectral convergence loss value.
|
22 |
+
"""
|
23 |
+
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
|
24 |
+
|
25 |
+
|
26 |
+
class STFTLoss(torch.nn.Module):
|
27 |
+
"""STFT loss module."""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
|
31 |
+
):
|
32 |
+
"""Initialize STFT loss module."""
|
33 |
+
super(STFTLoss, self).__init__()
|
34 |
+
self.fft_size = fft_size
|
35 |
+
self.shift_size = shift_size
|
36 |
+
self.win_length = win_length
|
37 |
+
self.to_mel = torchaudio.transforms.MelSpectrogram(
|
38 |
+
sample_rate=24000,
|
39 |
+
n_fft=fft_size,
|
40 |
+
win_length=win_length,
|
41 |
+
hop_length=shift_size,
|
42 |
+
window_fn=window,
|
43 |
+
)
|
44 |
+
|
45 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
46 |
+
|
47 |
+
def forward(self, x, y):
|
48 |
+
"""Calculate forward propagation.
|
49 |
+
Args:
|
50 |
+
x (Tensor): Predicted signal (B, T).
|
51 |
+
y (Tensor): Groundtruth signal (B, T).
|
52 |
+
Returns:
|
53 |
+
Tensor: Spectral convergence loss value.
|
54 |
+
Tensor: Log STFT magnitude loss value.
|
55 |
+
"""
|
56 |
+
x_mag = self.to_mel(x)
|
57 |
+
mean, std = -4, 4
|
58 |
+
x_mag = (torch.log(1e-5 + x_mag) - mean) / std
|
59 |
+
|
60 |
+
y_mag = self.to_mel(y)
|
61 |
+
mean, std = -4, 4
|
62 |
+
y_mag = (torch.log(1e-5 + y_mag) - mean) / std
|
63 |
+
|
64 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
65 |
+
return sc_loss
|
66 |
+
|
67 |
+
|
68 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
69 |
+
"""Multi resolution STFT loss module."""
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
fft_sizes=[1024, 2048, 512],
|
74 |
+
hop_sizes=[120, 240, 50],
|
75 |
+
win_lengths=[600, 1200, 240],
|
76 |
+
window=torch.hann_window,
|
77 |
+
):
|
78 |
+
"""Initialize Multi resolution STFT loss module.
|
79 |
+
Args:
|
80 |
+
fft_sizes (list): List of FFT sizes.
|
81 |
+
hop_sizes (list): List of hop sizes.
|
82 |
+
win_lengths (list): List of window lengths.
|
83 |
+
window (str): Window function type.
|
84 |
+
"""
|
85 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
86 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
87 |
+
self.stft_losses = torch.nn.ModuleList()
|
88 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
89 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
90 |
+
|
91 |
+
def forward(self, x, y):
|
92 |
+
"""Calculate forward propagation.
|
93 |
+
Args:
|
94 |
+
x (Tensor): Predicted signal (B, T).
|
95 |
+
y (Tensor): Groundtruth signal (B, T).
|
96 |
+
Returns:
|
97 |
+
Tensor: Multi resolution spectral convergence loss value.
|
98 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
99 |
+
"""
|
100 |
+
sc_loss = 0.0
|
101 |
+
for f in self.stft_losses:
|
102 |
+
sc_l = f(x, y)
|
103 |
+
sc_loss += sc_l
|
104 |
+
sc_loss /= len(self.stft_losses)
|
105 |
+
|
106 |
+
return sc_loss
|
107 |
+
|
108 |
+
|
109 |
+
def feature_loss(fmap_r, fmap_g):
|
110 |
+
loss = 0
|
111 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
112 |
+
for rl, gl in zip(dr, dg):
|
113 |
+
loss += torch.mean(torch.abs(rl - gl))
|
114 |
+
|
115 |
+
return loss * 2
|
116 |
+
|
117 |
+
|
118 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
119 |
+
loss = 0
|
120 |
+
r_losses = []
|
121 |
+
g_losses = []
|
122 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
123 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
124 |
+
g_loss = torch.mean(dg**2)
|
125 |
+
loss += r_loss + g_loss
|
126 |
+
r_losses.append(r_loss.item())
|
127 |
+
g_losses.append(g_loss.item())
|
128 |
+
|
129 |
+
return loss, r_losses, g_losses
|
130 |
+
|
131 |
+
|
132 |
+
def generator_loss(disc_outputs):
|
133 |
+
loss = 0
|
134 |
+
gen_losses = []
|
135 |
+
for dg in disc_outputs:
|
136 |
+
l = torch.mean((1 - dg) ** 2)
|
137 |
+
gen_losses.append(l)
|
138 |
+
loss += l
|
139 |
+
|
140 |
+
return loss, gen_losses
|
141 |
+
|
142 |
+
|
143 |
+
""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
|
144 |
+
|
145 |
+
|
146 |
+
def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
147 |
+
loss = 0
|
148 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
149 |
+
tau = 0.04
|
150 |
+
m_DG = torch.median((dr - dg))
|
151 |
+
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
152 |
+
loss += tau - F.relu(tau - L_rel)
|
153 |
+
return loss
|
154 |
+
|
155 |
+
|
156 |
+
def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
157 |
+
loss = 0
|
158 |
+
for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
|
159 |
+
tau = 0.04
|
160 |
+
m_DG = torch.median((dr - dg))
|
161 |
+
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
162 |
+
loss += tau - F.relu(tau - L_rel)
|
163 |
+
return loss
|
164 |
+
|
165 |
+
|
166 |
+
class GeneratorLoss(torch.nn.Module):
|
167 |
+
def __init__(self, mpd, msd):
|
168 |
+
super(GeneratorLoss, self).__init__()
|
169 |
+
self.mpd = mpd
|
170 |
+
self.msd = msd
|
171 |
+
|
172 |
+
def forward(self, y, y_hat):
|
173 |
+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
|
174 |
+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
|
175 |
+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
176 |
+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
177 |
+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
178 |
+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
179 |
+
|
180 |
+
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(
|
181 |
+
y_ds_hat_r, y_ds_hat_g
|
182 |
+
)
|
183 |
+
|
184 |
+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
|
185 |
+
|
186 |
+
return loss_gen_all.mean()
|
187 |
+
|
188 |
+
|
189 |
+
class DiscriminatorLoss(torch.nn.Module):
|
190 |
+
def __init__(self, mpd, msd):
|
191 |
+
super(DiscriminatorLoss, self).__init__()
|
192 |
+
self.mpd = mpd
|
193 |
+
self.msd = msd
|
194 |
+
|
195 |
+
def forward(self, y, y_hat):
|
196 |
+
# MPD
|
197 |
+
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
|
198 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
199 |
+
y_df_hat_r, y_df_hat_g
|
200 |
+
)
|
201 |
+
# MSD
|
202 |
+
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
|
203 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
204 |
+
y_ds_hat_r, y_ds_hat_g
|
205 |
+
)
|
206 |
+
|
207 |
+
loss_rel = discriminator_TPRLS_loss(
|
208 |
+
y_df_hat_r, y_df_hat_g
|
209 |
+
) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
210 |
+
|
211 |
+
d_loss = loss_disc_s + loss_disc_f + loss_rel
|
212 |
+
|
213 |
+
return d_loss.mean()
|
214 |
+
|
215 |
+
|
216 |
+
class WavLMLoss(torch.nn.Module):
|
217 |
+
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
218 |
+
super(WavLMLoss, self).__init__()
|
219 |
+
self.wavlm = AutoModel.from_pretrained(model)
|
220 |
+
self.wd = wd
|
221 |
+
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
222 |
+
|
223 |
+
def forward(self, wav, y_rec):
|
224 |
+
with torch.no_grad():
|
225 |
+
wav_16 = self.resample(wav)
|
226 |
+
wav_embeddings = self.wavlm(
|
227 |
+
input_values=wav_16, output_hidden_states=True
|
228 |
+
).hidden_states
|
229 |
+
y_rec_16 = self.resample(y_rec)
|
230 |
+
y_rec_embeddings = self.wavlm(
|
231 |
+
input_values=y_rec_16.squeeze(), output_hidden_states=True
|
232 |
+
).hidden_states
|
233 |
+
|
234 |
+
floss = 0
|
235 |
+
for er, eg in zip(wav_embeddings, y_rec_embeddings):
|
236 |
+
floss += torch.mean(torch.abs(er - eg))
|
237 |
+
|
238 |
+
return floss.mean()
|
239 |
+
|
240 |
+
def generator(self, y_rec):
|
241 |
+
y_rec_16 = self.resample(y_rec)
|
242 |
+
y_rec_embeddings = self.wavlm(
|
243 |
+
input_values=y_rec_16, output_hidden_states=True
|
244 |
+
).hidden_states
|
245 |
+
y_rec_embeddings = (
|
246 |
+
torch.stack(y_rec_embeddings, dim=1)
|
247 |
+
.transpose(-1, -2)
|
248 |
+
.flatten(start_dim=1, end_dim=2)
|
249 |
+
)
|
250 |
+
y_df_hat_g = self.wd(y_rec_embeddings)
|
251 |
+
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
|
252 |
+
|
253 |
+
return loss_gen
|
254 |
+
|
255 |
+
def discriminator(self, wav, y_rec):
|
256 |
+
with torch.no_grad():
|
257 |
+
wav_16 = self.resample(wav)
|
258 |
+
wav_embeddings = self.wavlm(
|
259 |
+
input_values=wav_16, output_hidden_states=True
|
260 |
+
).hidden_states
|
261 |
+
y_rec_16 = self.resample(y_rec)
|
262 |
+
y_rec_embeddings = self.wavlm(
|
263 |
+
input_values=y_rec_16, output_hidden_states=True
|
264 |
+
).hidden_states
|
265 |
+
|
266 |
+
y_embeddings = (
|
267 |
+
torch.stack(wav_embeddings, dim=1)
|
268 |
+
.transpose(-1, -2)
|
269 |
+
.flatten(start_dim=1, end_dim=2)
|
270 |
+
)
|
271 |
+
y_rec_embeddings = (
|
272 |
+
torch.stack(y_rec_embeddings, dim=1)
|
273 |
+
.transpose(-1, -2)
|
274 |
+
.flatten(start_dim=1, end_dim=2)
|
275 |
+
)
|
276 |
+
|
277 |
+
y_d_rs = self.wd(y_embeddings)
|
278 |
+
y_d_gs = self.wd(y_rec_embeddings)
|
279 |
+
|
280 |
+
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
281 |
+
|
282 |
+
r_loss = torch.mean((1 - y_df_hat_r) ** 2)
|
283 |
+
g_loss = torch.mean((y_df_hat_g) ** 2)
|
284 |
+
|
285 |
+
loss_disc_f = r_loss + g_loss
|
286 |
+
|
287 |
+
return loss_disc_f.mean()
|
288 |
+
|
289 |
+
def discriminator_forward(self, wav):
|
290 |
+
with torch.no_grad():
|
291 |
+
wav_16 = self.resample(wav)
|
292 |
+
wav_embeddings = self.wavlm(
|
293 |
+
input_values=wav_16, output_hidden_states=True
|
294 |
+
).hidden_states
|
295 |
+
y_embeddings = (
|
296 |
+
torch.stack(wav_embeddings, dim=1)
|
297 |
+
.transpose(-1, -2)
|
298 |
+
.flatten(start_dim=1, end_dim=2)
|
299 |
+
)
|
300 |
+
|
301 |
+
y_d_rs = self.wd(y_embeddings)
|
302 |
+
|
303 |
+
return y_d_rs
|
meldataset.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
import soundfile as sf
|
9 |
+
import librosa
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torchaudio
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
logger.setLevel(logging.DEBUG)
|
21 |
+
|
22 |
+
import pandas as pd
|
23 |
+
|
24 |
+
_pad = "$"
|
25 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
26 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
27 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
28 |
+
|
29 |
+
# Export all symbols:
|
30 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
31 |
+
|
32 |
+
dicts = {}
|
33 |
+
for i in range(len((symbols))):
|
34 |
+
dicts[symbols[i]] = i
|
35 |
+
|
36 |
+
|
37 |
+
class TextCleaner:
|
38 |
+
def __init__(self, dummy=None):
|
39 |
+
self.word_index_dictionary = dicts
|
40 |
+
|
41 |
+
def __call__(self, text):
|
42 |
+
indexes = []
|
43 |
+
for char in text:
|
44 |
+
try:
|
45 |
+
indexes.append(self.word_index_dictionary[char])
|
46 |
+
except KeyError:
|
47 |
+
print(text)
|
48 |
+
return indexes
|
49 |
+
|
50 |
+
|
51 |
+
np.random.seed(1)
|
52 |
+
random.seed(1)
|
53 |
+
SPECT_PARAMS = {"n_fft": 2048, "win_length": 1200, "hop_length": 300}
|
54 |
+
MEL_PARAMS = {
|
55 |
+
"n_mels": 80,
|
56 |
+
}
|
57 |
+
|
58 |
+
to_mel = torchaudio.transforms.MelSpectrogram(
|
59 |
+
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
|
60 |
+
)
|
61 |
+
mean, std = -4, 4
|
62 |
+
|
63 |
+
|
64 |
+
def preprocess(wave):
|
65 |
+
wave_tensor = torch.from_numpy(wave).float()
|
66 |
+
mel_tensor = to_mel(wave_tensor)
|
67 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
68 |
+
return mel_tensor
|
69 |
+
|
70 |
+
|
71 |
+
class FilePathDataset(torch.utils.data.Dataset):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
data_list,
|
75 |
+
root_path,
|
76 |
+
sr=24000,
|
77 |
+
data_augmentation=False,
|
78 |
+
validation=False,
|
79 |
+
OOD_data="Data/OOD_texts.txt",
|
80 |
+
min_length=50,
|
81 |
+
):
|
82 |
+
spect_params = SPECT_PARAMS
|
83 |
+
mel_params = MEL_PARAMS
|
84 |
+
|
85 |
+
_data_list = [l[:-1].split("|") for l in data_list]
|
86 |
+
self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
|
87 |
+
self.text_cleaner = TextCleaner()
|
88 |
+
self.sr = sr
|
89 |
+
|
90 |
+
self.df = pd.DataFrame(self.data_list)
|
91 |
+
|
92 |
+
self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
|
93 |
+
|
94 |
+
self.mean, self.std = -4, 4
|
95 |
+
self.data_augmentation = data_augmentation and (not validation)
|
96 |
+
self.max_mel_length = 192
|
97 |
+
|
98 |
+
self.min_length = min_length
|
99 |
+
with open(OOD_data, "r") as f:
|
100 |
+
tl = f.readlines()
|
101 |
+
idx = 1 if ".wav" in tl[0].split("|")[0] else 0
|
102 |
+
self.ptexts = [t.split("|")[idx] for t in tl]
|
103 |
+
|
104 |
+
self.root_path = root_path
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.data_list)
|
108 |
+
|
109 |
+
def __getitem__(self, idx):
|
110 |
+
data = self.data_list[idx]
|
111 |
+
path = data[0]
|
112 |
+
|
113 |
+
wave, text_tensor, speaker_id = self._load_tensor(data)
|
114 |
+
|
115 |
+
mel_tensor = preprocess(wave).squeeze()
|
116 |
+
|
117 |
+
acoustic_feature = mel_tensor.squeeze()
|
118 |
+
length_feature = acoustic_feature.size(1)
|
119 |
+
acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)]
|
120 |
+
|
121 |
+
# get reference sample
|
122 |
+
ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
|
123 |
+
ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
|
124 |
+
|
125 |
+
# get OOD text
|
126 |
+
|
127 |
+
ps = ""
|
128 |
+
|
129 |
+
while len(ps) < self.min_length:
|
130 |
+
rand_idx = np.random.randint(0, len(self.ptexts) - 1)
|
131 |
+
ps = self.ptexts[rand_idx]
|
132 |
+
|
133 |
+
text = self.text_cleaner(ps)
|
134 |
+
text.insert(0, 0)
|
135 |
+
text.append(0)
|
136 |
+
|
137 |
+
ref_text = torch.LongTensor(text)
|
138 |
+
|
139 |
+
return (
|
140 |
+
speaker_id,
|
141 |
+
acoustic_feature,
|
142 |
+
text_tensor,
|
143 |
+
ref_text,
|
144 |
+
ref_mel_tensor,
|
145 |
+
ref_label,
|
146 |
+
path,
|
147 |
+
wave,
|
148 |
+
)
|
149 |
+
|
150 |
+
def _load_tensor(self, data):
|
151 |
+
wave_path, text, speaker_id = data
|
152 |
+
speaker_id = int(speaker_id)
|
153 |
+
wave, sr = sf.read(osp.join(self.root_path, wave_path))
|
154 |
+
if wave.shape[-1] == 2:
|
155 |
+
wave = wave[:, 0].squeeze()
|
156 |
+
if sr != 24000:
|
157 |
+
wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
|
158 |
+
print(wave_path, sr)
|
159 |
+
|
160 |
+
wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
|
161 |
+
|
162 |
+
text = self.text_cleaner(text)
|
163 |
+
|
164 |
+
text.insert(0, 0)
|
165 |
+
text.append(0)
|
166 |
+
|
167 |
+
text = torch.LongTensor(text)
|
168 |
+
|
169 |
+
return wave, text, speaker_id
|
170 |
+
|
171 |
+
def _load_data(self, data):
|
172 |
+
wave, text_tensor, speaker_id = self._load_tensor(data)
|
173 |
+
mel_tensor = preprocess(wave).squeeze()
|
174 |
+
|
175 |
+
mel_length = mel_tensor.size(1)
|
176 |
+
if mel_length > self.max_mel_length:
|
177 |
+
random_start = np.random.randint(0, mel_length - self.max_mel_length)
|
178 |
+
mel_tensor = mel_tensor[
|
179 |
+
:, random_start : random_start + self.max_mel_length
|
180 |
+
]
|
181 |
+
|
182 |
+
return mel_tensor, speaker_id
|
183 |
+
|
184 |
+
|
185 |
+
class Collater(object):
|
186 |
+
"""
|
187 |
+
Args:
|
188 |
+
adaptive_batch_size (bool): if true, decrease batch size when long data comes.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(self, return_wave=False):
|
192 |
+
self.text_pad_index = 0
|
193 |
+
self.min_mel_length = 192
|
194 |
+
self.max_mel_length = 192
|
195 |
+
self.return_wave = return_wave
|
196 |
+
|
197 |
+
def __call__(self, batch):
|
198 |
+
# batch[0] = wave, mel, text, f0, speakerid
|
199 |
+
batch_size = len(batch)
|
200 |
+
|
201 |
+
# sort by mel length
|
202 |
+
lengths = [b[1].shape[1] for b in batch]
|
203 |
+
batch_indexes = np.argsort(lengths)[::-1]
|
204 |
+
batch = [batch[bid] for bid in batch_indexes]
|
205 |
+
|
206 |
+
nmels = batch[0][1].size(0)
|
207 |
+
max_mel_length = max([b[1].shape[1] for b in batch])
|
208 |
+
max_text_length = max([b[2].shape[0] for b in batch])
|
209 |
+
max_rtext_length = max([b[3].shape[0] for b in batch])
|
210 |
+
|
211 |
+
labels = torch.zeros((batch_size)).long()
|
212 |
+
mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
|
213 |
+
texts = torch.zeros((batch_size, max_text_length)).long()
|
214 |
+
ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
|
215 |
+
|
216 |
+
input_lengths = torch.zeros(batch_size).long()
|
217 |
+
ref_lengths = torch.zeros(batch_size).long()
|
218 |
+
output_lengths = torch.zeros(batch_size).long()
|
219 |
+
ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
|
220 |
+
ref_labels = torch.zeros((batch_size)).long()
|
221 |
+
paths = ["" for _ in range(batch_size)]
|
222 |
+
waves = [None for _ in range(batch_size)]
|
223 |
+
|
224 |
+
for bid, (
|
225 |
+
label,
|
226 |
+
mel,
|
227 |
+
text,
|
228 |
+
ref_text,
|
229 |
+
ref_mel,
|
230 |
+
ref_label,
|
231 |
+
path,
|
232 |
+
wave,
|
233 |
+
) in enumerate(batch):
|
234 |
+
mel_size = mel.size(1)
|
235 |
+
text_size = text.size(0)
|
236 |
+
rtext_size = ref_text.size(0)
|
237 |
+
labels[bid] = label
|
238 |
+
mels[bid, :, :mel_size] = mel
|
239 |
+
texts[bid, :text_size] = text
|
240 |
+
ref_texts[bid, :rtext_size] = ref_text
|
241 |
+
input_lengths[bid] = text_size
|
242 |
+
ref_lengths[bid] = rtext_size
|
243 |
+
output_lengths[bid] = mel_size
|
244 |
+
paths[bid] = path
|
245 |
+
ref_mel_size = ref_mel.size(1)
|
246 |
+
ref_mels[bid, :, :ref_mel_size] = ref_mel
|
247 |
+
|
248 |
+
ref_labels[bid] = ref_label
|
249 |
+
waves[bid] = wave
|
250 |
+
|
251 |
+
return (
|
252 |
+
waves,
|
253 |
+
texts,
|
254 |
+
input_lengths,
|
255 |
+
ref_texts,
|
256 |
+
ref_lengths,
|
257 |
+
mels,
|
258 |
+
output_lengths,
|
259 |
+
ref_mels,
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
def build_dataloader(
|
264 |
+
path_list,
|
265 |
+
root_path,
|
266 |
+
validation=False,
|
267 |
+
OOD_data="Data/OOD_texts.txt",
|
268 |
+
min_length=50,
|
269 |
+
batch_size=4,
|
270 |
+
num_workers=1,
|
271 |
+
device="cpu",
|
272 |
+
collate_config={},
|
273 |
+
dataset_config={},
|
274 |
+
):
|
275 |
+
dataset = FilePathDataset(
|
276 |
+
path_list,
|
277 |
+
root_path,
|
278 |
+
OOD_data=OOD_data,
|
279 |
+
min_length=min_length,
|
280 |
+
validation=validation,
|
281 |
+
**dataset_config
|
282 |
+
)
|
283 |
+
collate_fn = Collater(**collate_config)
|
284 |
+
data_loader = DataLoader(
|
285 |
+
dataset,
|
286 |
+
batch_size=batch_size,
|
287 |
+
shuffle=(not validation),
|
288 |
+
num_workers=num_workers,
|
289 |
+
drop_last=(not validation),
|
290 |
+
collate_fn=collate_fn,
|
291 |
+
pin_memory=(device != "cpu"),
|
292 |
+
)
|
293 |
+
|
294 |
+
return data_loader
|
models.py
ADDED
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding:utf-8
|
2 |
+
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
14 |
+
|
15 |
+
from Utils.ASR.models import ASRCNN
|
16 |
+
from Utils.JDC.model import JDCNet
|
17 |
+
|
18 |
+
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
19 |
+
from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
|
20 |
+
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
21 |
+
|
22 |
+
from Modules.discriminators import (
|
23 |
+
MultiPeriodDiscriminator,
|
24 |
+
MultiResSpecDiscriminator,
|
25 |
+
WavLMDiscriminator,
|
26 |
+
)
|
27 |
+
|
28 |
+
from munch import Munch
|
29 |
+
import yaml
|
30 |
+
|
31 |
+
|
32 |
+
class LearnedDownSample(nn.Module):
|
33 |
+
def __init__(self, layer_type, dim_in):
|
34 |
+
super().__init__()
|
35 |
+
self.layer_type = layer_type
|
36 |
+
|
37 |
+
if self.layer_type == "none":
|
38 |
+
self.conv = nn.Identity()
|
39 |
+
elif self.layer_type == "timepreserve":
|
40 |
+
self.conv = spectral_norm(
|
41 |
+
nn.Conv2d(
|
42 |
+
dim_in,
|
43 |
+
dim_in,
|
44 |
+
kernel_size=(3, 1),
|
45 |
+
stride=(2, 1),
|
46 |
+
groups=dim_in,
|
47 |
+
padding=(1, 0),
|
48 |
+
)
|
49 |
+
)
|
50 |
+
elif self.layer_type == "half":
|
51 |
+
self.conv = spectral_norm(
|
52 |
+
nn.Conv2d(
|
53 |
+
dim_in,
|
54 |
+
dim_in,
|
55 |
+
kernel_size=(3, 3),
|
56 |
+
stride=(2, 2),
|
57 |
+
groups=dim_in,
|
58 |
+
padding=1,
|
59 |
+
)
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
raise RuntimeError(
|
63 |
+
"Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
|
64 |
+
% self.layer_type
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.conv(x)
|
69 |
+
|
70 |
+
|
71 |
+
class LearnedUpSample(nn.Module):
|
72 |
+
def __init__(self, layer_type, dim_in):
|
73 |
+
super().__init__()
|
74 |
+
self.layer_type = layer_type
|
75 |
+
|
76 |
+
if self.layer_type == "none":
|
77 |
+
self.conv = nn.Identity()
|
78 |
+
elif self.layer_type == "timepreserve":
|
79 |
+
self.conv = nn.ConvTranspose2d(
|
80 |
+
dim_in,
|
81 |
+
dim_in,
|
82 |
+
kernel_size=(3, 1),
|
83 |
+
stride=(2, 1),
|
84 |
+
groups=dim_in,
|
85 |
+
output_padding=(1, 0),
|
86 |
+
padding=(1, 0),
|
87 |
+
)
|
88 |
+
elif self.layer_type == "half":
|
89 |
+
self.conv = nn.ConvTranspose2d(
|
90 |
+
dim_in,
|
91 |
+
dim_in,
|
92 |
+
kernel_size=(3, 3),
|
93 |
+
stride=(2, 2),
|
94 |
+
groups=dim_in,
|
95 |
+
output_padding=1,
|
96 |
+
padding=1,
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
raise RuntimeError(
|
100 |
+
"Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
|
101 |
+
% self.layer_type
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
return self.conv(x)
|
106 |
+
|
107 |
+
|
108 |
+
class DownSample(nn.Module):
|
109 |
+
def __init__(self, layer_type):
|
110 |
+
super().__init__()
|
111 |
+
self.layer_type = layer_type
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
if self.layer_type == "none":
|
115 |
+
return x
|
116 |
+
elif self.layer_type == "timepreserve":
|
117 |
+
return F.avg_pool2d(x, (2, 1))
|
118 |
+
elif self.layer_type == "half":
|
119 |
+
if x.shape[-1] % 2 != 0:
|
120 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
121 |
+
return F.avg_pool2d(x, 2)
|
122 |
+
else:
|
123 |
+
raise RuntimeError(
|
124 |
+
"Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
|
125 |
+
% self.layer_type
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class UpSample(nn.Module):
|
130 |
+
def __init__(self, layer_type):
|
131 |
+
super().__init__()
|
132 |
+
self.layer_type = layer_type
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
if self.layer_type == "none":
|
136 |
+
return x
|
137 |
+
elif self.layer_type == "timepreserve":
|
138 |
+
return F.interpolate(x, scale_factor=(2, 1), mode="nearest")
|
139 |
+
elif self.layer_type == "half":
|
140 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
141 |
+
else:
|
142 |
+
raise RuntimeError(
|
143 |
+
"Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
|
144 |
+
% self.layer_type
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
class ResBlk(nn.Module):
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
dim_in,
|
152 |
+
dim_out,
|
153 |
+
actv=nn.LeakyReLU(0.2),
|
154 |
+
normalize=False,
|
155 |
+
downsample="none",
|
156 |
+
):
|
157 |
+
super().__init__()
|
158 |
+
self.actv = actv
|
159 |
+
self.normalize = normalize
|
160 |
+
self.downsample = DownSample(downsample)
|
161 |
+
self.downsample_res = LearnedDownSample(downsample, dim_in)
|
162 |
+
self.learned_sc = dim_in != dim_out
|
163 |
+
self._build_weights(dim_in, dim_out)
|
164 |
+
|
165 |
+
def _build_weights(self, dim_in, dim_out):
|
166 |
+
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
167 |
+
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
168 |
+
if self.normalize:
|
169 |
+
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
170 |
+
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
171 |
+
if self.learned_sc:
|
172 |
+
self.conv1x1 = spectral_norm(
|
173 |
+
nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
174 |
+
)
|
175 |
+
|
176 |
+
def _shortcut(self, x):
|
177 |
+
if self.learned_sc:
|
178 |
+
x = self.conv1x1(x)
|
179 |
+
if self.downsample:
|
180 |
+
x = self.downsample(x)
|
181 |
+
return x
|
182 |
+
|
183 |
+
def _residual(self, x):
|
184 |
+
if self.normalize:
|
185 |
+
x = self.norm1(x)
|
186 |
+
x = self.actv(x)
|
187 |
+
x = self.conv1(x)
|
188 |
+
x = self.downsample_res(x)
|
189 |
+
if self.normalize:
|
190 |
+
x = self.norm2(x)
|
191 |
+
x = self.actv(x)
|
192 |
+
x = self.conv2(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
x = self._shortcut(x) + self._residual(x)
|
197 |
+
return x / math.sqrt(2) # unit variance
|
198 |
+
|
199 |
+
|
200 |
+
class StyleEncoder(nn.Module):
|
201 |
+
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
|
202 |
+
super().__init__()
|
203 |
+
blocks = []
|
204 |
+
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
|
205 |
+
|
206 |
+
repeat_num = 4
|
207 |
+
for _ in range(repeat_num):
|
208 |
+
dim_out = min(dim_in * 2, max_conv_dim)
|
209 |
+
blocks += [ResBlk(dim_in, dim_out, downsample="half")]
|
210 |
+
dim_in = dim_out
|
211 |
+
|
212 |
+
blocks += [nn.LeakyReLU(0.2)]
|
213 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
214 |
+
blocks += [nn.AdaptiveAvgPool2d(1)]
|
215 |
+
blocks += [nn.LeakyReLU(0.2)]
|
216 |
+
self.shared = nn.Sequential(*blocks)
|
217 |
+
|
218 |
+
self.unshared = nn.Linear(dim_out, style_dim)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
h = self.shared(x)
|
222 |
+
h = h.view(h.size(0), -1)
|
223 |
+
s = self.unshared(h)
|
224 |
+
|
225 |
+
return s
|
226 |
+
|
227 |
+
|
228 |
+
class LinearNorm(torch.nn.Module):
|
229 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
230 |
+
super(LinearNorm, self).__init__()
|
231 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
232 |
+
|
233 |
+
torch.nn.init.xavier_uniform_(
|
234 |
+
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
return self.linear_layer(x)
|
239 |
+
|
240 |
+
|
241 |
+
class Discriminator2d(nn.Module):
|
242 |
+
def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
|
243 |
+
super().__init__()
|
244 |
+
blocks = []
|
245 |
+
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
|
246 |
+
|
247 |
+
for lid in range(repeat_num):
|
248 |
+
dim_out = min(dim_in * 2, max_conv_dim)
|
249 |
+
blocks += [ResBlk(dim_in, dim_out, downsample="half")]
|
250 |
+
dim_in = dim_out
|
251 |
+
|
252 |
+
blocks += [nn.LeakyReLU(0.2)]
|
253 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
254 |
+
blocks += [nn.LeakyReLU(0.2)]
|
255 |
+
blocks += [nn.AdaptiveAvgPool2d(1)]
|
256 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
|
257 |
+
self.main = nn.Sequential(*blocks)
|
258 |
+
|
259 |
+
def get_feature(self, x):
|
260 |
+
features = []
|
261 |
+
for l in self.main:
|
262 |
+
x = l(x)
|
263 |
+
features.append(x)
|
264 |
+
out = features[-1]
|
265 |
+
out = out.view(out.size(0), -1) # (batch, num_domains)
|
266 |
+
return out, features
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
out, features = self.get_feature(x)
|
270 |
+
out = out.squeeze() # (batch)
|
271 |
+
return out, features
|
272 |
+
|
273 |
+
|
274 |
+
class ResBlk1d(nn.Module):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
dim_in,
|
278 |
+
dim_out,
|
279 |
+
actv=nn.LeakyReLU(0.2),
|
280 |
+
normalize=False,
|
281 |
+
downsample="none",
|
282 |
+
dropout_p=0.2,
|
283 |
+
):
|
284 |
+
super().__init__()
|
285 |
+
self.actv = actv
|
286 |
+
self.normalize = normalize
|
287 |
+
self.downsample_type = downsample
|
288 |
+
self.learned_sc = dim_in != dim_out
|
289 |
+
self._build_weights(dim_in, dim_out)
|
290 |
+
self.dropout_p = dropout_p
|
291 |
+
|
292 |
+
if self.downsample_type == "none":
|
293 |
+
self.pool = nn.Identity()
|
294 |
+
else:
|
295 |
+
self.pool = weight_norm(
|
296 |
+
nn.Conv1d(
|
297 |
+
dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1
|
298 |
+
)
|
299 |
+
)
|
300 |
+
|
301 |
+
def _build_weights(self, dim_in, dim_out):
|
302 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
|
303 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
304 |
+
if self.normalize:
|
305 |
+
self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
|
306 |
+
self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
|
307 |
+
if self.learned_sc:
|
308 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
309 |
+
|
310 |
+
def downsample(self, x):
|
311 |
+
if self.downsample_type == "none":
|
312 |
+
return x
|
313 |
+
else:
|
314 |
+
if x.shape[-1] % 2 != 0:
|
315 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
316 |
+
return F.avg_pool1d(x, 2)
|
317 |
+
|
318 |
+
def _shortcut(self, x):
|
319 |
+
if self.learned_sc:
|
320 |
+
x = self.conv1x1(x)
|
321 |
+
x = self.downsample(x)
|
322 |
+
return x
|
323 |
+
|
324 |
+
def _residual(self, x):
|
325 |
+
if self.normalize:
|
326 |
+
x = self.norm1(x)
|
327 |
+
x = self.actv(x)
|
328 |
+
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
329 |
+
|
330 |
+
x = self.conv1(x)
|
331 |
+
x = self.pool(x)
|
332 |
+
if self.normalize:
|
333 |
+
x = self.norm2(x)
|
334 |
+
|
335 |
+
x = self.actv(x)
|
336 |
+
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
337 |
+
|
338 |
+
x = self.conv2(x)
|
339 |
+
return x
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
x = self._shortcut(x) + self._residual(x)
|
343 |
+
return x / math.sqrt(2) # unit variance
|
344 |
+
|
345 |
+
|
346 |
+
class LayerNorm(nn.Module):
|
347 |
+
def __init__(self, channels, eps=1e-5):
|
348 |
+
super().__init__()
|
349 |
+
self.channels = channels
|
350 |
+
self.eps = eps
|
351 |
+
|
352 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
353 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
354 |
+
|
355 |
+
def forward(self, x):
|
356 |
+
x = x.transpose(1, -1)
|
357 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
358 |
+
return x.transpose(1, -1)
|
359 |
+
|
360 |
+
|
361 |
+
class TextEncoder(nn.Module):
|
362 |
+
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
363 |
+
super().__init__()
|
364 |
+
self.embedding = nn.Embedding(n_symbols, channels)
|
365 |
+
|
366 |
+
padding = (kernel_size - 1) // 2
|
367 |
+
self.cnn = nn.ModuleList()
|
368 |
+
for _ in range(depth):
|
369 |
+
self.cnn.append(
|
370 |
+
nn.Sequential(
|
371 |
+
weight_norm(
|
372 |
+
nn.Conv1d(
|
373 |
+
channels, channels, kernel_size=kernel_size, padding=padding
|
374 |
+
)
|
375 |
+
),
|
376 |
+
LayerNorm(channels),
|
377 |
+
actv,
|
378 |
+
nn.Dropout(0.2),
|
379 |
+
)
|
380 |
+
)
|
381 |
+
# self.cnn = nn.Sequential(*self.cnn)
|
382 |
+
|
383 |
+
self.lstm = nn.LSTM(
|
384 |
+
channels, channels // 2, 1, batch_first=True, bidirectional=True
|
385 |
+
)
|
386 |
+
|
387 |
+
def forward(self, x, input_lengths, m):
|
388 |
+
x = self.embedding(x) # [B, T, emb]
|
389 |
+
x = x.transpose(1, 2) # [B, emb, T]
|
390 |
+
m = m.to(input_lengths.device).unsqueeze(1)
|
391 |
+
x.masked_fill_(m, 0.0)
|
392 |
+
|
393 |
+
for c in self.cnn:
|
394 |
+
x = c(x)
|
395 |
+
x.masked_fill_(m, 0.0)
|
396 |
+
|
397 |
+
x = x.transpose(1, 2) # [B, T, chn]
|
398 |
+
|
399 |
+
input_lengths = input_lengths.cpu().numpy()
|
400 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
401 |
+
x, input_lengths, batch_first=True, enforce_sorted=False
|
402 |
+
)
|
403 |
+
|
404 |
+
self.lstm.flatten_parameters()
|
405 |
+
x, _ = self.lstm(x)
|
406 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
407 |
+
|
408 |
+
x = x.transpose(-1, -2)
|
409 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
410 |
+
|
411 |
+
x_pad[:, :, : x.shape[-1]] = x
|
412 |
+
x = x_pad.to(x.device)
|
413 |
+
|
414 |
+
x.masked_fill_(m, 0.0)
|
415 |
+
|
416 |
+
return x
|
417 |
+
|
418 |
+
def inference(self, x):
|
419 |
+
x = self.embedding(x)
|
420 |
+
x = x.transpose(1, 2)
|
421 |
+
x = self.cnn(x)
|
422 |
+
x = x.transpose(1, 2)
|
423 |
+
self.lstm.flatten_parameters()
|
424 |
+
x, _ = self.lstm(x)
|
425 |
+
return x
|
426 |
+
|
427 |
+
def length_to_mask(self, lengths):
|
428 |
+
mask = (
|
429 |
+
torch.arange(lengths.max())
|
430 |
+
.unsqueeze(0)
|
431 |
+
.expand(lengths.shape[0], -1)
|
432 |
+
.type_as(lengths)
|
433 |
+
)
|
434 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
435 |
+
return mask
|
436 |
+
|
437 |
+
|
438 |
+
class AdaIN1d(nn.Module):
|
439 |
+
def __init__(self, style_dim, num_features):
|
440 |
+
super().__init__()
|
441 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
442 |
+
self.fc = nn.Linear(style_dim, num_features * 2)
|
443 |
+
|
444 |
+
def forward(self, x, s):
|
445 |
+
h = self.fc(s)
|
446 |
+
h = h.view(h.size(0), h.size(1), 1)
|
447 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
448 |
+
return (1 + gamma) * self.norm(x) + beta
|
449 |
+
|
450 |
+
|
451 |
+
class UpSample1d(nn.Module):
|
452 |
+
def __init__(self, layer_type):
|
453 |
+
super().__init__()
|
454 |
+
self.layer_type = layer_type
|
455 |
+
|
456 |
+
def forward(self, x):
|
457 |
+
if self.layer_type == "none":
|
458 |
+
return x
|
459 |
+
else:
|
460 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
461 |
+
|
462 |
+
|
463 |
+
class AdainResBlk1d(nn.Module):
|
464 |
+
def __init__(
|
465 |
+
self,
|
466 |
+
dim_in,
|
467 |
+
dim_out,
|
468 |
+
style_dim=64,
|
469 |
+
actv=nn.LeakyReLU(0.2),
|
470 |
+
upsample="none",
|
471 |
+
dropout_p=0.0,
|
472 |
+
):
|
473 |
+
super().__init__()
|
474 |
+
self.actv = actv
|
475 |
+
self.upsample_type = upsample
|
476 |
+
self.upsample = UpSample1d(upsample)
|
477 |
+
self.learned_sc = dim_in != dim_out
|
478 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
479 |
+
self.dropout = nn.Dropout(dropout_p)
|
480 |
+
|
481 |
+
if upsample == "none":
|
482 |
+
self.pool = nn.Identity()
|
483 |
+
else:
|
484 |
+
self.pool = weight_norm(
|
485 |
+
nn.ConvTranspose1d(
|
486 |
+
dim_in,
|
487 |
+
dim_in,
|
488 |
+
kernel_size=3,
|
489 |
+
stride=2,
|
490 |
+
groups=dim_in,
|
491 |
+
padding=1,
|
492 |
+
output_padding=1,
|
493 |
+
)
|
494 |
+
)
|
495 |
+
|
496 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
497 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
498 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
499 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
500 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
501 |
+
if self.learned_sc:
|
502 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
503 |
+
|
504 |
+
def _shortcut(self, x):
|
505 |
+
x = self.upsample(x)
|
506 |
+
if self.learned_sc:
|
507 |
+
x = self.conv1x1(x)
|
508 |
+
return x
|
509 |
+
|
510 |
+
def _residual(self, x, s):
|
511 |
+
x = self.norm1(x, s)
|
512 |
+
x = self.actv(x)
|
513 |
+
x = self.pool(x)
|
514 |
+
x = self.conv1(self.dropout(x))
|
515 |
+
x = self.norm2(x, s)
|
516 |
+
x = self.actv(x)
|
517 |
+
x = self.conv2(self.dropout(x))
|
518 |
+
return x
|
519 |
+
|
520 |
+
def forward(self, x, s):
|
521 |
+
out = self._residual(x, s)
|
522 |
+
out = (out + self._shortcut(x)) / math.sqrt(2)
|
523 |
+
return out
|
524 |
+
|
525 |
+
|
526 |
+
class AdaLayerNorm(nn.Module):
|
527 |
+
def __init__(self, style_dim, channels, eps=1e-5):
|
528 |
+
super().__init__()
|
529 |
+
self.channels = channels
|
530 |
+
self.eps = eps
|
531 |
+
|
532 |
+
self.fc = nn.Linear(style_dim, channels * 2)
|
533 |
+
|
534 |
+
def forward(self, x, s):
|
535 |
+
x = x.transpose(-1, -2)
|
536 |
+
x = x.transpose(1, -1)
|
537 |
+
|
538 |
+
h = self.fc(s)
|
539 |
+
h = h.view(h.size(0), h.size(1), 1)
|
540 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
541 |
+
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
542 |
+
|
543 |
+
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
544 |
+
x = (1 + gamma) * x + beta
|
545 |
+
return x.transpose(1, -1).transpose(-1, -2)
|
546 |
+
|
547 |
+
|
548 |
+
class ProsodyPredictor(nn.Module):
|
549 |
+
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
550 |
+
super().__init__()
|
551 |
+
|
552 |
+
self.text_encoder = DurationEncoder(
|
553 |
+
sty_dim=style_dim, d_model=d_hid, nlayers=nlayers, dropout=dropout
|
554 |
+
)
|
555 |
+
|
556 |
+
self.lstm = nn.LSTM(
|
557 |
+
d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
|
558 |
+
)
|
559 |
+
self.duration_proj = LinearNorm(d_hid, max_dur)
|
560 |
+
|
561 |
+
self.shared = nn.LSTM(
|
562 |
+
d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
|
563 |
+
)
|
564 |
+
self.F0 = nn.ModuleList()
|
565 |
+
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
566 |
+
self.F0.append(
|
567 |
+
AdainResBlk1d(
|
568 |
+
d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
|
569 |
+
)
|
570 |
+
)
|
571 |
+
self.F0.append(
|
572 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
|
573 |
+
)
|
574 |
+
|
575 |
+
self.N = nn.ModuleList()
|
576 |
+
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
577 |
+
self.N.append(
|
578 |
+
AdainResBlk1d(
|
579 |
+
d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
|
580 |
+
)
|
581 |
+
)
|
582 |
+
self.N.append(
|
583 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
|
584 |
+
)
|
585 |
+
|
586 |
+
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
587 |
+
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
588 |
+
|
589 |
+
def forward(self, texts, style, text_lengths, alignment, m):
|
590 |
+
d = self.text_encoder(texts, style, text_lengths, m)
|
591 |
+
|
592 |
+
batch_size = d.shape[0]
|
593 |
+
text_size = d.shape[1]
|
594 |
+
|
595 |
+
# predict duration
|
596 |
+
input_lengths = text_lengths.cpu().numpy()
|
597 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
598 |
+
d, input_lengths, batch_first=True, enforce_sorted=False
|
599 |
+
)
|
600 |
+
|
601 |
+
m = m.to(text_lengths.device).unsqueeze(1)
|
602 |
+
|
603 |
+
self.lstm.flatten_parameters()
|
604 |
+
x, _ = self.lstm(x)
|
605 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
606 |
+
|
607 |
+
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
608 |
+
|
609 |
+
x_pad[:, : x.shape[1], :] = x
|
610 |
+
x = x_pad.to(x.device)
|
611 |
+
|
612 |
+
duration = self.duration_proj(
|
613 |
+
nn.functional.dropout(x, 0.5, training=self.training)
|
614 |
+
)
|
615 |
+
|
616 |
+
en = d.transpose(-1, -2) @ alignment
|
617 |
+
|
618 |
+
return duration.squeeze(-1), en
|
619 |
+
|
620 |
+
def F0Ntrain(self, x, s):
|
621 |
+
x, _ = self.shared(x.transpose(-1, -2))
|
622 |
+
|
623 |
+
F0 = x.transpose(-1, -2)
|
624 |
+
for block in self.F0:
|
625 |
+
F0 = block(F0, s)
|
626 |
+
F0 = self.F0_proj(F0)
|
627 |
+
|
628 |
+
N = x.transpose(-1, -2)
|
629 |
+
for block in self.N:
|
630 |
+
N = block(N, s)
|
631 |
+
N = self.N_proj(N)
|
632 |
+
|
633 |
+
return F0.squeeze(1), N.squeeze(1)
|
634 |
+
|
635 |
+
def length_to_mask(self, lengths):
|
636 |
+
mask = (
|
637 |
+
torch.arange(lengths.max())
|
638 |
+
.unsqueeze(0)
|
639 |
+
.expand(lengths.shape[0], -1)
|
640 |
+
.type_as(lengths)
|
641 |
+
)
|
642 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
643 |
+
return mask
|
644 |
+
|
645 |
+
|
646 |
+
class DurationEncoder(nn.Module):
|
647 |
+
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
648 |
+
super().__init__()
|
649 |
+
self.lstms = nn.ModuleList()
|
650 |
+
for _ in range(nlayers):
|
651 |
+
self.lstms.append(
|
652 |
+
nn.LSTM(
|
653 |
+
d_model + sty_dim,
|
654 |
+
d_model // 2,
|
655 |
+
num_layers=1,
|
656 |
+
batch_first=True,
|
657 |
+
bidirectional=True,
|
658 |
+
dropout=dropout,
|
659 |
+
)
|
660 |
+
)
|
661 |
+
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
662 |
+
|
663 |
+
self.dropout = dropout
|
664 |
+
self.d_model = d_model
|
665 |
+
self.sty_dim = sty_dim
|
666 |
+
|
667 |
+
def forward(self, x, style, text_lengths, m):
|
668 |
+
masks = m.to(text_lengths.device)
|
669 |
+
|
670 |
+
x = x.permute(2, 0, 1)
|
671 |
+
s = style.expand(x.shape[0], x.shape[1], -1)
|
672 |
+
x = torch.cat([x, s], axis=-1)
|
673 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
674 |
+
|
675 |
+
x = x.transpose(0, 1)
|
676 |
+
input_lengths = text_lengths.cpu().numpy()
|
677 |
+
x = x.transpose(-1, -2)
|
678 |
+
|
679 |
+
for block in self.lstms:
|
680 |
+
if isinstance(block, AdaLayerNorm):
|
681 |
+
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
682 |
+
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
683 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
684 |
+
else:
|
685 |
+
x = x.transpose(-1, -2)
|
686 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
687 |
+
x, input_lengths, batch_first=True, enforce_sorted=False
|
688 |
+
)
|
689 |
+
block.flatten_parameters()
|
690 |
+
x, _ = block(x)
|
691 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
692 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
693 |
+
x = x.transpose(-1, -2)
|
694 |
+
|
695 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
696 |
+
|
697 |
+
x_pad[:, :, : x.shape[-1]] = x
|
698 |
+
x = x_pad.to(x.device)
|
699 |
+
|
700 |
+
return x.transpose(-1, -2)
|
701 |
+
|
702 |
+
def inference(self, x, style):
|
703 |
+
x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
|
704 |
+
style = style.expand(x.shape[0], x.shape[1], -1)
|
705 |
+
x = torch.cat([x, style], axis=-1)
|
706 |
+
src = self.pos_encoder(x)
|
707 |
+
output = self.transformer_encoder(src).transpose(0, 1)
|
708 |
+
return output
|
709 |
+
|
710 |
+
def length_to_mask(self, lengths):
|
711 |
+
mask = (
|
712 |
+
torch.arange(lengths.max())
|
713 |
+
.unsqueeze(0)
|
714 |
+
.expand(lengths.shape[0], -1)
|
715 |
+
.type_as(lengths)
|
716 |
+
)
|
717 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
718 |
+
return mask
|
719 |
+
|
720 |
+
|
721 |
+
def load_F0_models(path):
|
722 |
+
# load F0 model
|
723 |
+
|
724 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
725 |
+
params = torch.load(path, map_location="cpu")["net"]
|
726 |
+
F0_model.load_state_dict(params)
|
727 |
+
_ = F0_model.train()
|
728 |
+
|
729 |
+
return F0_model
|
730 |
+
|
731 |
+
|
732 |
+
def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
733 |
+
# load ASR model
|
734 |
+
def _load_config(path):
|
735 |
+
with open(path) as f:
|
736 |
+
config = yaml.safe_load(f)
|
737 |
+
model_config = config["model_params"]
|
738 |
+
return model_config
|
739 |
+
|
740 |
+
def _load_model(model_config, model_path):
|
741 |
+
model = ASRCNN(**model_config)
|
742 |
+
params = torch.load(model_path, map_location="cpu")["model"]
|
743 |
+
model.load_state_dict(params)
|
744 |
+
return model
|
745 |
+
|
746 |
+
asr_model_config = _load_config(ASR_MODEL_CONFIG)
|
747 |
+
asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
|
748 |
+
_ = asr_model.train()
|
749 |
+
|
750 |
+
return asr_model
|
751 |
+
|
752 |
+
|
753 |
+
def build_model(args, text_aligner, pitch_extractor, bert):
|
754 |
+
assert args.decoder.type in ["istftnet", "hifigan"], "Decoder type unknown"
|
755 |
+
|
756 |
+
if args.decoder.type == "istftnet":
|
757 |
+
from Modules.istftnet import Decoder
|
758 |
+
|
759 |
+
decoder = Decoder(
|
760 |
+
dim_in=args.hidden_dim,
|
761 |
+
style_dim=args.style_dim,
|
762 |
+
dim_out=args.n_mels,
|
763 |
+
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
|
764 |
+
upsample_rates=args.decoder.upsample_rates,
|
765 |
+
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
766 |
+
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
767 |
+
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
768 |
+
gen_istft_n_fft=args.decoder.gen_istft_n_fft,
|
769 |
+
gen_istft_hop_size=args.decoder.gen_istft_hop_size,
|
770 |
+
)
|
771 |
+
else:
|
772 |
+
from Modules.hifigan import Decoder
|
773 |
+
|
774 |
+
decoder = Decoder(
|
775 |
+
dim_in=args.hidden_dim,
|
776 |
+
style_dim=args.style_dim,
|
777 |
+
dim_out=args.n_mels,
|
778 |
+
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
|
779 |
+
upsample_rates=args.decoder.upsample_rates,
|
780 |
+
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
781 |
+
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
782 |
+
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
783 |
+
)
|
784 |
+
|
785 |
+
text_encoder = TextEncoder(
|
786 |
+
channels=args.hidden_dim,
|
787 |
+
kernel_size=5,
|
788 |
+
depth=args.n_layer,
|
789 |
+
n_symbols=args.n_token,
|
790 |
+
)
|
791 |
+
|
792 |
+
predictor = ProsodyPredictor(
|
793 |
+
style_dim=args.style_dim,
|
794 |
+
d_hid=args.hidden_dim,
|
795 |
+
nlayers=args.n_layer,
|
796 |
+
max_dur=args.max_dur,
|
797 |
+
dropout=args.dropout,
|
798 |
+
)
|
799 |
+
|
800 |
+
style_encoder = StyleEncoder(
|
801 |
+
dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
|
802 |
+
) # acoustic style encoder
|
803 |
+
predictor_encoder = StyleEncoder(
|
804 |
+
dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
|
805 |
+
) # prosodic style encoder
|
806 |
+
|
807 |
+
# define diffusion model
|
808 |
+
if args.multispeaker:
|
809 |
+
transformer = StyleTransformer1d(
|
810 |
+
channels=args.style_dim * 2,
|
811 |
+
context_embedding_features=bert.config.hidden_size,
|
812 |
+
context_features=args.style_dim * 2,
|
813 |
+
**args.diffusion.transformer
|
814 |
+
)
|
815 |
+
else:
|
816 |
+
transformer = Transformer1d(
|
817 |
+
channels=args.style_dim * 2,
|
818 |
+
context_embedding_features=bert.config.hidden_size,
|
819 |
+
**args.diffusion.transformer
|
820 |
+
)
|
821 |
+
|
822 |
+
diffusion = AudioDiffusionConditional(
|
823 |
+
in_channels=1,
|
824 |
+
embedding_max_length=bert.config.max_position_embeddings,
|
825 |
+
embedding_features=bert.config.hidden_size,
|
826 |
+
embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
|
827 |
+
channels=args.style_dim * 2,
|
828 |
+
context_features=args.style_dim * 2,
|
829 |
+
)
|
830 |
+
|
831 |
+
diffusion.diffusion = KDiffusion(
|
832 |
+
net=diffusion.unet,
|
833 |
+
sigma_distribution=LogNormalDistribution(
|
834 |
+
mean=args.diffusion.dist.mean, std=args.diffusion.dist.std
|
835 |
+
),
|
836 |
+
sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
|
837 |
+
dynamic_threshold=0.0,
|
838 |
+
)
|
839 |
+
diffusion.diffusion.net = transformer
|
840 |
+
diffusion.unet = transformer
|
841 |
+
|
842 |
+
nets = Munch(
|
843 |
+
bert=bert,
|
844 |
+
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
|
845 |
+
predictor=predictor,
|
846 |
+
decoder=decoder,
|
847 |
+
text_encoder=text_encoder,
|
848 |
+
predictor_encoder=predictor_encoder,
|
849 |
+
style_encoder=style_encoder,
|
850 |
+
diffusion=diffusion,
|
851 |
+
text_aligner=text_aligner,
|
852 |
+
pitch_extractor=pitch_extractor,
|
853 |
+
mpd=MultiPeriodDiscriminator(),
|
854 |
+
msd=MultiResSpecDiscriminator(),
|
855 |
+
# slm discriminator head
|
856 |
+
wd=WavLMDiscriminator(
|
857 |
+
args.slm.hidden, args.slm.nlayers, args.slm.initial_channel
|
858 |
+
),
|
859 |
+
)
|
860 |
+
|
861 |
+
return nets
|
862 |
+
|
863 |
+
|
864 |
+
def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
|
865 |
+
state = torch.load(path, map_location="cpu")
|
866 |
+
params = state["net"]
|
867 |
+
for key in model:
|
868 |
+
if key in params and key not in ignore_modules:
|
869 |
+
print("%s loaded" % key)
|
870 |
+
model[key].load_state_dict(params[key], strict=False)
|
871 |
+
_ = [model[key].eval() for key in model]
|
872 |
+
|
873 |
+
if not load_only_params:
|
874 |
+
epoch = state["epoch"]
|
875 |
+
iters = state["iters"]
|
876 |
+
optimizer.load_state_dict(state["optimizer"])
|
877 |
+
else:
|
878 |
+
epoch = 0
|
879 |
+
iters = 0
|
880 |
+
|
881 |
+
return model, optimizer, epoch, iters
|
optimizers.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding:utf-8
|
2 |
+
import os, sys
|
3 |
+
import os.path as osp
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
from functools import reduce
|
9 |
+
from torch.optim import AdamW
|
10 |
+
|
11 |
+
|
12 |
+
class MultiOptimizer:
|
13 |
+
def __init__(self, optimizers={}, schedulers={}):
|
14 |
+
self.optimizers = optimizers
|
15 |
+
self.schedulers = schedulers
|
16 |
+
self.keys = list(optimizers.keys())
|
17 |
+
self.param_groups = reduce(
|
18 |
+
lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
|
19 |
+
)
|
20 |
+
|
21 |
+
def state_dict(self):
|
22 |
+
state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
|
23 |
+
return state_dicts
|
24 |
+
|
25 |
+
def load_state_dict(self, state_dict):
|
26 |
+
for key, val in state_dict:
|
27 |
+
try:
|
28 |
+
self.optimizers[key].load_state_dict(val)
|
29 |
+
except:
|
30 |
+
print("Unloaded %s" % key)
|
31 |
+
|
32 |
+
def step(self, key=None, scaler=None):
|
33 |
+
keys = [key] if key is not None else self.keys
|
34 |
+
_ = [self._step(key, scaler) for key in keys]
|
35 |
+
|
36 |
+
def _step(self, key, scaler=None):
|
37 |
+
if scaler is not None:
|
38 |
+
scaler.step(self.optimizers[key])
|
39 |
+
scaler.update()
|
40 |
+
else:
|
41 |
+
self.optimizers[key].step()
|
42 |
+
|
43 |
+
def zero_grad(self, key=None):
|
44 |
+
if key is not None:
|
45 |
+
self.optimizers[key].zero_grad()
|
46 |
+
else:
|
47 |
+
_ = [self.optimizers[key].zero_grad() for key in self.keys]
|
48 |
+
|
49 |
+
def scheduler(self, *args, key=None):
|
50 |
+
if key is not None:
|
51 |
+
self.schedulers[key].step(*args)
|
52 |
+
else:
|
53 |
+
_ = [self.schedulers[key].step(*args) for key in self.keys]
|
54 |
+
|
55 |
+
|
56 |
+
def define_scheduler(optimizer, params):
|
57 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
58 |
+
optimizer,
|
59 |
+
max_lr=params.get("max_lr", 2e-4),
|
60 |
+
epochs=params.get("epochs", 200),
|
61 |
+
steps_per_epoch=params.get("steps_per_epoch", 1000),
|
62 |
+
pct_start=params.get("pct_start", 0.0),
|
63 |
+
div_factor=1,
|
64 |
+
final_div_factor=1,
|
65 |
+
)
|
66 |
+
|
67 |
+
return scheduler
|
68 |
+
|
69 |
+
|
70 |
+
def build_optimizer(parameters_dict, scheduler_params_dict, lr):
|
71 |
+
optim = dict(
|
72 |
+
[
|
73 |
+
(key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
|
74 |
+
for key, params in parameters_dict.items()
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
schedulers = dict(
|
79 |
+
[
|
80 |
+
(key, define_scheduler(opt, scheduler_params_dict[key]))
|
81 |
+
for key, opt in optim.items()
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
multi_optim = MultiOptimizer(optim, schedulers)
|
86 |
+
return multi_optim
|
reference_audio.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d25b4950ec39cec5a00f5061491ad0b3606edc6618a54adc59663bfd6e6ab55e
|
3 |
+
size 2917622
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SoundFile
|
2 |
+
torchaudio
|
3 |
+
munch
|
4 |
+
torch
|
5 |
+
pydub
|
6 |
+
pyyaml
|
7 |
+
librosa
|
8 |
+
nltk
|
9 |
+
matplotlib
|
10 |
+
accelerate
|
11 |
+
transformers
|
12 |
+
einops
|
13 |
+
einops-exts
|
14 |
+
tqdm
|
15 |
+
typing
|
16 |
+
typing-extensions
|
17 |
+
git+https://github.com/resemble-ai/monotonic_align.git
|
18 |
+
scipy
|
19 |
+
deep-phonemizer
|
20 |
+
cached-path
|
21 |
+
gradio
|
styletts2importable.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cached_path import cached_path
|
2 |
+
|
3 |
+
from dp.phonemizer import Phonemizer
|
4 |
+
print("NLTK")
|
5 |
+
import nltk
|
6 |
+
nltk.download('punkt')
|
7 |
+
print("SCIPY")
|
8 |
+
from scipy.io.wavfile import write
|
9 |
+
print("TORCH STUFF")
|
10 |
+
import torch
|
11 |
+
print("START")
|
12 |
+
torch.manual_seed(0)
|
13 |
+
torch.backends.cudnn.benchmark = False
|
14 |
+
torch.backends.cudnn.deterministic = True
|
15 |
+
|
16 |
+
import random
|
17 |
+
random.seed(0)
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
np.random.seed(0)
|
21 |
+
|
22 |
+
# load packages
|
23 |
+
import time
|
24 |
+
import random
|
25 |
+
import yaml
|
26 |
+
from munch import Munch
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from torch import nn
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import torchaudio
|
32 |
+
import librosa
|
33 |
+
from nltk.tokenize import word_tokenize
|
34 |
+
|
35 |
+
from models import *
|
36 |
+
from utils import *
|
37 |
+
from text_utils import TextCleaner
|
38 |
+
textclenaer = TextCleaner()
|
39 |
+
|
40 |
+
|
41 |
+
to_mel = torchaudio.transforms.MelSpectrogram(
|
42 |
+
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
43 |
+
mean, std = -4, 4
|
44 |
+
|
45 |
+
def length_to_mask(lengths):
|
46 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
47 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
48 |
+
return mask
|
49 |
+
|
50 |
+
def preprocess(wave):
|
51 |
+
wave_tensor = torch.from_numpy(wave).float()
|
52 |
+
mel_tensor = to_mel(wave_tensor)
|
53 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
54 |
+
return mel_tensor
|
55 |
+
|
56 |
+
def compute_style(path):
|
57 |
+
wave, sr = librosa.load(path, sr=24000)
|
58 |
+
audio, index = librosa.effects.trim(wave, top_db=30)
|
59 |
+
if sr != 24000:
|
60 |
+
audio = librosa.resample(audio, sr, 24000)
|
61 |
+
mel_tensor = preprocess(audio).to(device)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
|
65 |
+
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
|
66 |
+
|
67 |
+
return torch.cat([ref_s, ref_p], dim=1)
|
68 |
+
|
69 |
+
device = 'cpu'
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
device = 'cuda'
|
72 |
+
elif torch.backends.mps.is_available():
|
73 |
+
print("MPS would be available but cannot be used rn")
|
74 |
+
# device = 'mps'
|
75 |
+
|
76 |
+
|
77 |
+
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
78 |
+
phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
79 |
+
|
80 |
+
|
81 |
+
# config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
|
82 |
+
config = yaml.safe_load(open(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml"))))
|
83 |
+
|
84 |
+
# load pretrained ASR model
|
85 |
+
ASR_config = config.get('ASR_config', False)
|
86 |
+
ASR_path = config.get('ASR_path', False)
|
87 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
88 |
+
|
89 |
+
# load pretrained F0 model
|
90 |
+
F0_path = config.get('F0_path', False)
|
91 |
+
pitch_extractor = load_F0_models(F0_path)
|
92 |
+
|
93 |
+
# load BERT model
|
94 |
+
from Utils.PLBERT.util import load_plbert
|
95 |
+
BERT_path = config.get('PLBERT_dir', False)
|
96 |
+
plbert = load_plbert(BERT_path)
|
97 |
+
|
98 |
+
model_params = recursive_munch(config['model_params'])
|
99 |
+
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
100 |
+
_ = [model[key].eval() for key in model]
|
101 |
+
_ = [model[key].to(device) for key in model]
|
102 |
+
|
103 |
+
# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
|
104 |
+
params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
|
105 |
+
params = params_whole['net']
|
106 |
+
|
107 |
+
for key in model:
|
108 |
+
if key in params:
|
109 |
+
print('%s loaded' % key)
|
110 |
+
try:
|
111 |
+
model[key].load_state_dict(params[key])
|
112 |
+
except:
|
113 |
+
from collections import OrderedDict
|
114 |
+
state_dict = params[key]
|
115 |
+
new_state_dict = OrderedDict()
|
116 |
+
for k, v in state_dict.items():
|
117 |
+
name = k[7:] # remove `module.`
|
118 |
+
new_state_dict[name] = v
|
119 |
+
# load params
|
120 |
+
model[key].load_state_dict(new_state_dict, strict=False)
|
121 |
+
# except:
|
122 |
+
# _load(params[key], model[key])
|
123 |
+
_ = [model[key].eval() for key in model]
|
124 |
+
|
125 |
+
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
126 |
+
|
127 |
+
sampler = DiffusionSampler(
|
128 |
+
model.diffusion.diffusion,
|
129 |
+
sampler=ADPM2Sampler(),
|
130 |
+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
|
131 |
+
clamp=False
|
132 |
+
)
|
133 |
+
|
134 |
+
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
135 |
+
text = text.strip()
|
136 |
+
ps = phonemizer([text], lang='en_us')
|
137 |
+
ps = word_tokenize(ps[0])
|
138 |
+
ps = ' '.join(ps)
|
139 |
+
tokens = textclenaer(ps)
|
140 |
+
tokens.insert(0, 0)
|
141 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
142 |
+
|
143 |
+
with torch.no_grad():
|
144 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
145 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
146 |
+
|
147 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
148 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
149 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
150 |
+
|
151 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
152 |
+
embedding=bert_dur,
|
153 |
+
embedding_scale=embedding_scale,
|
154 |
+
features=ref_s, # reference from the same speaker as the embedding
|
155 |
+
num_steps=diffusion_steps).squeeze(1)
|
156 |
+
|
157 |
+
|
158 |
+
s = s_pred[:, 128:]
|
159 |
+
ref = s_pred[:, :128]
|
160 |
+
|
161 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
162 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
163 |
+
|
164 |
+
d = model.predictor.text_encoder(d_en,
|
165 |
+
s, input_lengths, text_mask)
|
166 |
+
|
167 |
+
x, _ = model.predictor.lstm(d)
|
168 |
+
duration = model.predictor.duration_proj(x)
|
169 |
+
|
170 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
171 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
172 |
+
|
173 |
+
|
174 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
175 |
+
c_frame = 0
|
176 |
+
for i in range(pred_aln_trg.size(0)):
|
177 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
178 |
+
c_frame += int(pred_dur[i].data)
|
179 |
+
|
180 |
+
# encode prosody
|
181 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
182 |
+
if model_params.decoder.type == "hifigan":
|
183 |
+
asr_new = torch.zeros_like(en)
|
184 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
185 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
186 |
+
en = asr_new
|
187 |
+
|
188 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
189 |
+
|
190 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
191 |
+
if model_params.decoder.type == "hifigan":
|
192 |
+
asr_new = torch.zeros_like(asr)
|
193 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
194 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
195 |
+
asr = asr_new
|
196 |
+
|
197 |
+
out = model.decoder(asr,
|
198 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
199 |
+
|
200 |
+
|
201 |
+
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
202 |
+
|
203 |
+
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
|
204 |
+
text = text.strip()
|
205 |
+
ps = phonemizer([text], lang='en_us')
|
206 |
+
ps = word_tokenize(ps[0])
|
207 |
+
ps = ' '.join(ps)
|
208 |
+
ps = ps.replace('``', '"')
|
209 |
+
ps = ps.replace("''", '"')
|
210 |
+
|
211 |
+
tokens = textclenaer(ps)
|
212 |
+
tokens.insert(0, 0)
|
213 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
214 |
+
|
215 |
+
with torch.no_grad():
|
216 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
217 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
218 |
+
|
219 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
220 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
221 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
222 |
+
|
223 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
224 |
+
embedding=bert_dur,
|
225 |
+
embedding_scale=embedding_scale,
|
226 |
+
features=ref_s, # reference from the same speaker as the embedding
|
227 |
+
num_steps=diffusion_steps).squeeze(1)
|
228 |
+
|
229 |
+
if s_prev is not None:
|
230 |
+
# convex combination of previous and current style
|
231 |
+
s_pred = t * s_prev + (1 - t) * s_pred
|
232 |
+
|
233 |
+
s = s_pred[:, 128:]
|
234 |
+
ref = s_pred[:, :128]
|
235 |
+
|
236 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
237 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
238 |
+
|
239 |
+
s_pred = torch.cat([ref, s], dim=-1)
|
240 |
+
|
241 |
+
d = model.predictor.text_encoder(d_en,
|
242 |
+
s, input_lengths, text_mask)
|
243 |
+
|
244 |
+
x, _ = model.predictor.lstm(d)
|
245 |
+
duration = model.predictor.duration_proj(x)
|
246 |
+
|
247 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
248 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
249 |
+
|
250 |
+
|
251 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
252 |
+
c_frame = 0
|
253 |
+
for i in range(pred_aln_trg.size(0)):
|
254 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
255 |
+
c_frame += int(pred_dur[i].data)
|
256 |
+
|
257 |
+
# encode prosody
|
258 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
259 |
+
if model_params.decoder.type == "hifigan":
|
260 |
+
asr_new = torch.zeros_like(en)
|
261 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
262 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
263 |
+
en = asr_new
|
264 |
+
|
265 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
266 |
+
|
267 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
268 |
+
if model_params.decoder.type == "hifigan":
|
269 |
+
asr_new = torch.zeros_like(asr)
|
270 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
271 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
272 |
+
asr = asr_new
|
273 |
+
|
274 |
+
out = model.decoder(asr,
|
275 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
276 |
+
|
277 |
+
|
278 |
+
return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
|
279 |
+
|
280 |
+
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
281 |
+
text = text.strip()
|
282 |
+
ps = phonemizer([text], lang='en_us')
|
283 |
+
ps = word_tokenize(ps[0])
|
284 |
+
ps = ' '.join(ps)
|
285 |
+
|
286 |
+
tokens = textclenaer(ps)
|
287 |
+
tokens.insert(0, 0)
|
288 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
289 |
+
|
290 |
+
ref_text = ref_text.strip()
|
291 |
+
ps = phonemizer([ref_text], lang='en_us')
|
292 |
+
ps = word_tokenize(ps[0])
|
293 |
+
ps = ' '.join(ps)
|
294 |
+
|
295 |
+
ref_tokens = textclenaer(ps)
|
296 |
+
ref_tokens.insert(0, 0)
|
297 |
+
ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
|
298 |
+
|
299 |
+
|
300 |
+
with torch.no_grad():
|
301 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
302 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
303 |
+
|
304 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
305 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
306 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
307 |
+
|
308 |
+
ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
|
309 |
+
ref_text_mask = length_to_mask(ref_input_lengths).to(device)
|
310 |
+
ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
|
311 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
312 |
+
embedding=bert_dur,
|
313 |
+
embedding_scale=embedding_scale,
|
314 |
+
features=ref_s, # reference from the same speaker as the embedding
|
315 |
+
num_steps=diffusion_steps).squeeze(1)
|
316 |
+
|
317 |
+
|
318 |
+
s = s_pred[:, 128:]
|
319 |
+
ref = s_pred[:, :128]
|
320 |
+
|
321 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
322 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
323 |
+
|
324 |
+
d = model.predictor.text_encoder(d_en,
|
325 |
+
s, input_lengths, text_mask)
|
326 |
+
|
327 |
+
x, _ = model.predictor.lstm(d)
|
328 |
+
duration = model.predictor.duration_proj(x)
|
329 |
+
|
330 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
331 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
332 |
+
|
333 |
+
|
334 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
335 |
+
c_frame = 0
|
336 |
+
for i in range(pred_aln_trg.size(0)):
|
337 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
338 |
+
c_frame += int(pred_dur[i].data)
|
339 |
+
|
340 |
+
# encode prosody
|
341 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
342 |
+
if model_params.decoder.type == "hifigan":
|
343 |
+
asr_new = torch.zeros_like(en)
|
344 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
345 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
346 |
+
en = asr_new
|
347 |
+
|
348 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
349 |
+
|
350 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
351 |
+
if model_params.decoder.type == "hifigan":
|
352 |
+
asr_new = torch.zeros_like(asr)
|
353 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
354 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
355 |
+
asr = asr_new
|
356 |
+
|
357 |
+
out = model.decoder(asr,
|
358 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
359 |
+
|
360 |
+
|
361 |
+
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
text_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IPA Phonemizer: https://github.com/bootphon/phonemizer
|
2 |
+
|
3 |
+
_pad = "$"
|
4 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
5 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
6 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
7 |
+
|
8 |
+
# Export all symbols:
|
9 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
10 |
+
|
11 |
+
dicts = {}
|
12 |
+
for i in range(len((symbols))):
|
13 |
+
dicts[symbols[i]] = i
|
14 |
+
|
15 |
+
|
16 |
+
class TextCleaner:
|
17 |
+
def __init__(self, dummy=None):
|
18 |
+
self.word_index_dictionary = dicts
|
19 |
+
print(len(dicts))
|
20 |
+
|
21 |
+
def __call__(self, text):
|
22 |
+
indexes = []
|
23 |
+
for char in text:
|
24 |
+
try:
|
25 |
+
indexes.append(self.word_index_dictionary[char])
|
26 |
+
except KeyError:
|
27 |
+
print(text)
|
28 |
+
return indexes
|
train_finetune.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# load packages
|
2 |
+
import random
|
3 |
+
import yaml
|
4 |
+
import time
|
5 |
+
from munch import Munch
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchaudio
|
11 |
+
import librosa
|
12 |
+
import click
|
13 |
+
import shutil
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
warnings.simplefilter("ignore")
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
from meldataset import build_dataloader
|
20 |
+
|
21 |
+
from Utils.ASR.models import ASRCNN
|
22 |
+
from Utils.JDC.model import JDCNet
|
23 |
+
from Utils.PLBERT.util import load_plbert
|
24 |
+
|
25 |
+
from models import *
|
26 |
+
from losses import *
|
27 |
+
from utils import *
|
28 |
+
|
29 |
+
from Modules.slmadv import SLMAdversarialLoss
|
30 |
+
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
31 |
+
|
32 |
+
from optimizers import build_optimizer
|
33 |
+
|
34 |
+
|
35 |
+
# simple fix for dataparallel that allows access to class attributes
|
36 |
+
class MyDataParallel(torch.nn.DataParallel):
|
37 |
+
def __getattr__(self, name):
|
38 |
+
try:
|
39 |
+
return super().__getattr__(name)
|
40 |
+
except AttributeError:
|
41 |
+
return getattr(self.module, name)
|
42 |
+
|
43 |
+
|
44 |
+
import logging
|
45 |
+
from logging import StreamHandler
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
logger.setLevel(logging.DEBUG)
|
49 |
+
handler = StreamHandler()
|
50 |
+
handler.setLevel(logging.DEBUG)
|
51 |
+
logger.addHandler(handler)
|
52 |
+
|
53 |
+
|
54 |
+
@click.command()
|
55 |
+
@click.option("-p", "--config_path", default="Configs/config_ft.yml", type=str)
|
56 |
+
def main(config_path):
|
57 |
+
config = yaml.safe_load(open(config_path))
|
58 |
+
|
59 |
+
log_dir = config["log_dir"]
|
60 |
+
if not osp.exists(log_dir):
|
61 |
+
os.makedirs(log_dir, exist_ok=True)
|
62 |
+
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
|
63 |
+
writer = SummaryWriter(log_dir + "/tensorboard")
|
64 |
+
|
65 |
+
# write logs
|
66 |
+
file_handler = logging.FileHandler(osp.join(log_dir, "train.log"))
|
67 |
+
file_handler.setLevel(logging.DEBUG)
|
68 |
+
file_handler.setFormatter(
|
69 |
+
logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
|
70 |
+
)
|
71 |
+
logger.addHandler(file_handler)
|
72 |
+
|
73 |
+
batch_size = config.get("batch_size", 10)
|
74 |
+
|
75 |
+
epochs = config.get("epochs", 200)
|
76 |
+
save_freq = config.get("save_freq", 2)
|
77 |
+
log_interval = config.get("log_interval", 10)
|
78 |
+
saving_epoch = config.get("save_freq", 2)
|
79 |
+
|
80 |
+
data_params = config.get("data_params", None)
|
81 |
+
sr = config["preprocess_params"].get("sr", 24000)
|
82 |
+
train_path = data_params["train_data"]
|
83 |
+
val_path = data_params["val_data"]
|
84 |
+
root_path = data_params["root_path"]
|
85 |
+
min_length = data_params["min_length"]
|
86 |
+
OOD_data = data_params["OOD_data"]
|
87 |
+
|
88 |
+
max_len = config.get("max_len", 200)
|
89 |
+
|
90 |
+
loss_params = Munch(config["loss_params"])
|
91 |
+
diff_epoch = loss_params.diff_epoch
|
92 |
+
joint_epoch = loss_params.joint_epoch
|
93 |
+
|
94 |
+
optimizer_params = Munch(config["optimizer_params"])
|
95 |
+
|
96 |
+
train_list, val_list = get_data_path_list(train_path, val_path)
|
97 |
+
device = "cuda"
|
98 |
+
|
99 |
+
train_dataloader = build_dataloader(
|
100 |
+
train_list,
|
101 |
+
root_path,
|
102 |
+
OOD_data=OOD_data,
|
103 |
+
min_length=min_length,
|
104 |
+
batch_size=batch_size,
|
105 |
+
num_workers=2,
|
106 |
+
dataset_config={},
|
107 |
+
device=device,
|
108 |
+
)
|
109 |
+
|
110 |
+
val_dataloader = build_dataloader(
|
111 |
+
val_list,
|
112 |
+
root_path,
|
113 |
+
OOD_data=OOD_data,
|
114 |
+
min_length=min_length,
|
115 |
+
batch_size=batch_size,
|
116 |
+
validation=True,
|
117 |
+
num_workers=0,
|
118 |
+
device=device,
|
119 |
+
dataset_config={},
|
120 |
+
)
|
121 |
+
|
122 |
+
# load pretrained ASR model
|
123 |
+
ASR_config = config.get("ASR_config", False)
|
124 |
+
ASR_path = config.get("ASR_path", False)
|
125 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
126 |
+
|
127 |
+
# load pretrained F0 model
|
128 |
+
F0_path = config.get("F0_path", False)
|
129 |
+
pitch_extractor = load_F0_models(F0_path)
|
130 |
+
|
131 |
+
# load PL-BERT model
|
132 |
+
BERT_path = config.get("PLBERT_dir", False)
|
133 |
+
plbert = load_plbert(BERT_path)
|
134 |
+
|
135 |
+
# build model
|
136 |
+
model_params = recursive_munch(config["model_params"])
|
137 |
+
multispeaker = model_params.multispeaker
|
138 |
+
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
139 |
+
_ = [model[key].to(device) for key in model]
|
140 |
+
|
141 |
+
# DP
|
142 |
+
for key in model:
|
143 |
+
if key != "mpd" and key != "msd" and key != "wd":
|
144 |
+
model[key] = MyDataParallel(model[key])
|
145 |
+
|
146 |
+
start_epoch = 0
|
147 |
+
iters = 0
|
148 |
+
|
149 |
+
load_pretrained = config.get("pretrained_model", "") != "" and config.get(
|
150 |
+
"second_stage_load_pretrained", False
|
151 |
+
)
|
152 |
+
|
153 |
+
if not load_pretrained:
|
154 |
+
if config.get("first_stage_path", "") != "":
|
155 |
+
first_stage_path = osp.join(
|
156 |
+
log_dir, config.get("first_stage_path", "first_stage.pth")
|
157 |
+
)
|
158 |
+
print("Loading the first stage model at %s ..." % first_stage_path)
|
159 |
+
model, _, start_epoch, iters = load_checkpoint(
|
160 |
+
model,
|
161 |
+
None,
|
162 |
+
first_stage_path,
|
163 |
+
load_only_params=True,
|
164 |
+
ignore_modules=[
|
165 |
+
"bert",
|
166 |
+
"bert_encoder",
|
167 |
+
"predictor",
|
168 |
+
"predictor_encoder",
|
169 |
+
"msd",
|
170 |
+
"mpd",
|
171 |
+
"wd",
|
172 |
+
"diffusion",
|
173 |
+
],
|
174 |
+
) # keep starting epoch for tensorboard log
|
175 |
+
|
176 |
+
# these epochs should be counted from the start epoch
|
177 |
+
diff_epoch += start_epoch
|
178 |
+
joint_epoch += start_epoch
|
179 |
+
epochs += start_epoch
|
180 |
+
|
181 |
+
model.predictor_encoder = copy.deepcopy(model.style_encoder)
|
182 |
+
else:
|
183 |
+
raise ValueError("You need to specify the path to the first stage model.")
|
184 |
+
|
185 |
+
gl = GeneratorLoss(model.mpd, model.msd).to(device)
|
186 |
+
dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
|
187 |
+
wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device)
|
188 |
+
|
189 |
+
gl = MyDataParallel(gl)
|
190 |
+
dl = MyDataParallel(dl)
|
191 |
+
wl = MyDataParallel(wl)
|
192 |
+
|
193 |
+
sampler = DiffusionSampler(
|
194 |
+
model.diffusion.diffusion,
|
195 |
+
sampler=ADPM2Sampler(),
|
196 |
+
sigma_schedule=KarrasSchedule(
|
197 |
+
sigma_min=0.0001, sigma_max=3.0, rho=9.0
|
198 |
+
), # empirical parameters
|
199 |
+
clamp=False,
|
200 |
+
)
|
201 |
+
|
202 |
+
scheduler_params = {
|
203 |
+
"max_lr": optimizer_params.lr,
|
204 |
+
"pct_start": float(0),
|
205 |
+
"epochs": epochs,
|
206 |
+
"steps_per_epoch": len(train_dataloader),
|
207 |
+
}
|
208 |
+
scheduler_params_dict = {key: scheduler_params.copy() for key in model}
|
209 |
+
scheduler_params_dict["bert"]["max_lr"] = optimizer_params.bert_lr * 2
|
210 |
+
scheduler_params_dict["decoder"]["max_lr"] = optimizer_params.ft_lr * 2
|
211 |
+
scheduler_params_dict["style_encoder"]["max_lr"] = optimizer_params.ft_lr * 2
|
212 |
+
|
213 |
+
optimizer = build_optimizer(
|
214 |
+
{key: model[key].parameters() for key in model},
|
215 |
+
scheduler_params_dict=scheduler_params_dict,
|
216 |
+
lr=optimizer_params.lr,
|
217 |
+
)
|
218 |
+
|
219 |
+
# adjust BERT learning rate
|
220 |
+
for g in optimizer.optimizers["bert"].param_groups:
|
221 |
+
g["betas"] = (0.9, 0.99)
|
222 |
+
g["lr"] = optimizer_params.bert_lr
|
223 |
+
g["initial_lr"] = optimizer_params.bert_lr
|
224 |
+
g["min_lr"] = 0
|
225 |
+
g["weight_decay"] = 0.01
|
226 |
+
|
227 |
+
# adjust acoustic module learning rate
|
228 |
+
for module in ["decoder", "style_encoder"]:
|
229 |
+
for g in optimizer.optimizers[module].param_groups:
|
230 |
+
g["betas"] = (0.0, 0.99)
|
231 |
+
g["lr"] = optimizer_params.ft_lr
|
232 |
+
g["initial_lr"] = optimizer_params.ft_lr
|
233 |
+
g["min_lr"] = 0
|
234 |
+
g["weight_decay"] = 1e-4
|
235 |
+
|
236 |
+
# load models if there is a model
|
237 |
+
if load_pretrained:
|
238 |
+
model, optimizer, start_epoch, iters = load_checkpoint(
|
239 |
+
model,
|
240 |
+
optimizer,
|
241 |
+
config["pretrained_model"],
|
242 |
+
load_only_params=config.get("load_only_params", True),
|
243 |
+
)
|
244 |
+
|
245 |
+
n_down = model.text_aligner.n_down
|
246 |
+
|
247 |
+
best_loss = float("inf") # best test loss
|
248 |
+
loss_train_record = list([])
|
249 |
+
loss_test_record = list([])
|
250 |
+
iters = 0
|
251 |
+
|
252 |
+
criterion = nn.L1Loss() # F0 loss (regression)
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
|
255 |
+
stft_loss = MultiResolutionSTFTLoss().to(device)
|
256 |
+
|
257 |
+
print("BERT", optimizer.optimizers["bert"])
|
258 |
+
print("decoder", optimizer.optimizers["decoder"])
|
259 |
+
|
260 |
+
start_ds = False
|
261 |
+
|
262 |
+
running_std = []
|
263 |
+
|
264 |
+
slmadv_params = Munch(config["slmadv_params"])
|
265 |
+
slmadv = SLMAdversarialLoss(
|
266 |
+
model,
|
267 |
+
wl,
|
268 |
+
sampler,
|
269 |
+
slmadv_params.min_len,
|
270 |
+
slmadv_params.max_len,
|
271 |
+
batch_percentage=slmadv_params.batch_percentage,
|
272 |
+
skip_update=slmadv_params.iter,
|
273 |
+
sig=slmadv_params.sig,
|
274 |
+
)
|
275 |
+
|
276 |
+
for epoch in range(start_epoch, epochs):
|
277 |
+
running_loss = 0
|
278 |
+
start_time = time.time()
|
279 |
+
|
280 |
+
_ = [model[key].eval() for key in model]
|
281 |
+
|
282 |
+
model.text_aligner.train()
|
283 |
+
model.text_encoder.train()
|
284 |
+
|
285 |
+
model.predictor.train()
|
286 |
+
model.bert_encoder.train()
|
287 |
+
model.bert.train()
|
288 |
+
model.msd.train()
|
289 |
+
model.mpd.train()
|
290 |
+
|
291 |
+
for i, batch in enumerate(train_dataloader):
|
292 |
+
waves = batch[0]
|
293 |
+
batch = [b.to(device) for b in batch[1:]]
|
294 |
+
(
|
295 |
+
texts,
|
296 |
+
input_lengths,
|
297 |
+
ref_texts,
|
298 |
+
ref_lengths,
|
299 |
+
mels,
|
300 |
+
mel_input_length,
|
301 |
+
ref_mels,
|
302 |
+
) = batch
|
303 |
+
with torch.no_grad():
|
304 |
+
mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
|
305 |
+
mel_mask = length_to_mask(mel_input_length).to(device)
|
306 |
+
text_mask = length_to_mask(input_lengths).to(texts.device)
|
307 |
+
|
308 |
+
# compute reference styles
|
309 |
+
if multispeaker and epoch >= diff_epoch:
|
310 |
+
ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
|
311 |
+
ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
|
312 |
+
ref = torch.cat([ref_ss, ref_sp], dim=1)
|
313 |
+
|
314 |
+
try:
|
315 |
+
ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
|
316 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
317 |
+
s2s_attn = s2s_attn[..., 1:]
|
318 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
319 |
+
except:
|
320 |
+
continue
|
321 |
+
|
322 |
+
mask_ST = mask_from_lens(
|
323 |
+
s2s_attn, input_lengths, mel_input_length // (2**n_down)
|
324 |
+
)
|
325 |
+
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
|
326 |
+
|
327 |
+
# encode
|
328 |
+
t_en = model.text_encoder(texts, input_lengths, text_mask)
|
329 |
+
|
330 |
+
# 50% of chance of using monotonic version
|
331 |
+
if bool(random.getrandbits(1)):
|
332 |
+
asr = t_en @ s2s_attn
|
333 |
+
else:
|
334 |
+
asr = t_en @ s2s_attn_mono
|
335 |
+
|
336 |
+
d_gt = s2s_attn_mono.sum(axis=-1).detach()
|
337 |
+
|
338 |
+
# compute the style of the entire utterance
|
339 |
+
# this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
|
340 |
+
ss = []
|
341 |
+
gs = []
|
342 |
+
for bib in range(len(mel_input_length)):
|
343 |
+
mel_length = int(mel_input_length[bib].item())
|
344 |
+
mel = mels[bib, :, : mel_input_length[bib]]
|
345 |
+
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
|
346 |
+
ss.append(s)
|
347 |
+
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
|
348 |
+
gs.append(s)
|
349 |
+
|
350 |
+
s_dur = torch.stack(ss).squeeze() # global prosodic styles
|
351 |
+
gs = torch.stack(gs).squeeze() # global acoustic styles
|
352 |
+
s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
|
353 |
+
|
354 |
+
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
|
355 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
356 |
+
|
357 |
+
# denoiser training
|
358 |
+
if epoch >= diff_epoch:
|
359 |
+
num_steps = np.random.randint(3, 5)
|
360 |
+
|
361 |
+
if model_params.diffusion.dist.estimate_sigma_data:
|
362 |
+
model.diffusion.module.diffusion.sigma_data = (
|
363 |
+
s_trg.std(axis=-1).mean().item()
|
364 |
+
) # batch-wise std estimation
|
365 |
+
running_std.append(model.diffusion.module.diffusion.sigma_data)
|
366 |
+
|
367 |
+
if multispeaker:
|
368 |
+
s_preds = sampler(
|
369 |
+
noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
|
370 |
+
embedding=bert_dur,
|
371 |
+
embedding_scale=1,
|
372 |
+
features=ref, # reference from the same speaker as the embedding
|
373 |
+
embedding_mask_proba=0.1,
|
374 |
+
num_steps=num_steps,
|
375 |
+
).squeeze(1)
|
376 |
+
loss_diff = model.diffusion(
|
377 |
+
s_trg.unsqueeze(1), embedding=bert_dur, features=ref
|
378 |
+
).mean() # EDM loss
|
379 |
+
loss_sty = F.l1_loss(
|
380 |
+
s_preds, s_trg.detach()
|
381 |
+
) # style reconstruction loss
|
382 |
+
else:
|
383 |
+
s_preds = sampler(
|
384 |
+
noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
|
385 |
+
embedding=bert_dur,
|
386 |
+
embedding_scale=1,
|
387 |
+
embedding_mask_proba=0.1,
|
388 |
+
num_steps=num_steps,
|
389 |
+
).squeeze(1)
|
390 |
+
loss_diff = model.diffusion.module.diffusion(
|
391 |
+
s_trg.unsqueeze(1), embedding=bert_dur
|
392 |
+
).mean() # EDM loss
|
393 |
+
loss_sty = F.l1_loss(
|
394 |
+
s_preds, s_trg.detach()
|
395 |
+
) # style reconstruction loss
|
396 |
+
else:
|
397 |
+
loss_sty = 0
|
398 |
+
loss_diff = 0
|
399 |
+
|
400 |
+
s_loss = 0
|
401 |
+
|
402 |
+
d, p = model.predictor(d_en, s_dur, input_lengths, s2s_attn_mono, text_mask)
|
403 |
+
|
404 |
+
mel_len_st = int(mel_input_length.min().item() / 2 - 1)
|
405 |
+
mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
|
406 |
+
en = []
|
407 |
+
gt = []
|
408 |
+
p_en = []
|
409 |
+
wav = []
|
410 |
+
st = []
|
411 |
+
|
412 |
+
for bib in range(len(mel_input_length)):
|
413 |
+
mel_length = int(mel_input_length[bib].item() / 2)
|
414 |
+
|
415 |
+
random_start = np.random.randint(0, mel_length - mel_len)
|
416 |
+
en.append(asr[bib, :, random_start : random_start + mel_len])
|
417 |
+
p_en.append(p[bib, :, random_start : random_start + mel_len])
|
418 |
+
gt.append(
|
419 |
+
mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)]
|
420 |
+
)
|
421 |
+
|
422 |
+
y = waves[bib][
|
423 |
+
(random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
|
424 |
+
]
|
425 |
+
wav.append(torch.from_numpy(y).to(device))
|
426 |
+
|
427 |
+
# style reference (better to be different from the GT)
|
428 |
+
random_start = np.random.randint(0, mel_length - mel_len_st)
|
429 |
+
st.append(
|
430 |
+
mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)]
|
431 |
+
)
|
432 |
+
|
433 |
+
wav = torch.stack(wav).float().detach()
|
434 |
+
|
435 |
+
en = torch.stack(en)
|
436 |
+
p_en = torch.stack(p_en)
|
437 |
+
gt = torch.stack(gt).detach()
|
438 |
+
st = torch.stack(st).detach()
|
439 |
+
|
440 |
+
if gt.size(-1) < 80:
|
441 |
+
continue
|
442 |
+
|
443 |
+
s = model.style_encoder(gt.unsqueeze(1))
|
444 |
+
s_dur = model.predictor_encoder(gt.unsqueeze(1))
|
445 |
+
|
446 |
+
with torch.no_grad():
|
447 |
+
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
448 |
+
F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
|
449 |
+
|
450 |
+
N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
|
451 |
+
|
452 |
+
y_rec_gt = wav.unsqueeze(1)
|
453 |
+
y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
|
454 |
+
|
455 |
+
wav = y_rec_gt
|
456 |
+
|
457 |
+
F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
|
458 |
+
|
459 |
+
y_rec = model.decoder(en, F0_fake, N_fake, s)
|
460 |
+
|
461 |
+
loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
|
462 |
+
loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
|
463 |
+
|
464 |
+
optimizer.zero_grad()
|
465 |
+
d_loss = dl(wav.detach(), y_rec.detach()).mean()
|
466 |
+
d_loss.backward()
|
467 |
+
optimizer.step("msd")
|
468 |
+
optimizer.step("mpd")
|
469 |
+
|
470 |
+
# generator loss
|
471 |
+
optimizer.zero_grad()
|
472 |
+
|
473 |
+
loss_mel = stft_loss(y_rec, wav)
|
474 |
+
loss_gen_all = gl(wav, y_rec).mean()
|
475 |
+
loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
|
476 |
+
|
477 |
+
loss_ce = 0
|
478 |
+
loss_dur = 0
|
479 |
+
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
|
480 |
+
_s2s_pred = _s2s_pred[:_text_length, :]
|
481 |
+
_text_input = _text_input[:_text_length].long()
|
482 |
+
_s2s_trg = torch.zeros_like(_s2s_pred)
|
483 |
+
for p in range(_s2s_trg.shape[0]):
|
484 |
+
_s2s_trg[p, : _text_input[p]] = 1
|
485 |
+
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
|
486 |
+
|
487 |
+
loss_dur += F.l1_loss(
|
488 |
+
_dur_pred[1 : _text_length - 1], _text_input[1 : _text_length - 1]
|
489 |
+
)
|
490 |
+
loss_ce += F.binary_cross_entropy_with_logits(
|
491 |
+
_s2s_pred.flatten(), _s2s_trg.flatten()
|
492 |
+
)
|
493 |
+
|
494 |
+
loss_ce /= texts.size(0)
|
495 |
+
loss_dur /= texts.size(0)
|
496 |
+
|
497 |
+
loss_s2s = 0
|
498 |
+
for _s2s_pred, _text_input, _text_length in zip(
|
499 |
+
s2s_pred, texts, input_lengths
|
500 |
+
):
|
501 |
+
loss_s2s += F.cross_entropy(
|
502 |
+
_s2s_pred[:_text_length], _text_input[:_text_length]
|
503 |
+
)
|
504 |
+
loss_s2s /= texts.size(0)
|
505 |
+
|
506 |
+
loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
|
507 |
+
|
508 |
+
g_loss = (
|
509 |
+
loss_params.lambda_mel * loss_mel
|
510 |
+
+ loss_params.lambda_F0 * loss_F0_rec
|
511 |
+
+ loss_params.lambda_ce * loss_ce
|
512 |
+
+ loss_params.lambda_norm * loss_norm_rec
|
513 |
+
+ loss_params.lambda_dur * loss_dur
|
514 |
+
+ loss_params.lambda_gen * loss_gen_all
|
515 |
+
+ loss_params.lambda_slm * loss_lm
|
516 |
+
+ loss_params.lambda_sty * loss_sty
|
517 |
+
+ loss_params.lambda_diff * loss_diff
|
518 |
+
+ loss_params.lambda_mono * loss_mono
|
519 |
+
+ loss_params.lambda_s2s * loss_s2s
|
520 |
+
)
|
521 |
+
|
522 |
+
running_loss += loss_mel.item()
|
523 |
+
g_loss.backward()
|
524 |
+
if torch.isnan(g_loss):
|
525 |
+
from IPython.core.debugger import set_trace
|
526 |
+
|
527 |
+
set_trace()
|
528 |
+
|
529 |
+
optimizer.step("bert_encoder")
|
530 |
+
optimizer.step("bert")
|
531 |
+
optimizer.step("predictor")
|
532 |
+
optimizer.step("predictor_encoder")
|
533 |
+
optimizer.step("style_encoder")
|
534 |
+
optimizer.step("decoder")
|
535 |
+
|
536 |
+
optimizer.step("text_encoder")
|
537 |
+
optimizer.step("text_aligner")
|
538 |
+
|
539 |
+
if epoch >= diff_epoch:
|
540 |
+
optimizer.step("diffusion")
|
541 |
+
|
542 |
+
if epoch >= joint_epoch:
|
543 |
+
# randomly pick whether to use in-distribution text
|
544 |
+
if np.random.rand() < 0.5:
|
545 |
+
use_ind = True
|
546 |
+
else:
|
547 |
+
use_ind = False
|
548 |
+
|
549 |
+
if use_ind:
|
550 |
+
ref_lengths = input_lengths
|
551 |
+
ref_texts = texts
|
552 |
+
|
553 |
+
slm_out = slmadv(
|
554 |
+
i,
|
555 |
+
y_rec_gt,
|
556 |
+
y_rec_gt_pred,
|
557 |
+
waves,
|
558 |
+
mel_input_length,
|
559 |
+
ref_texts,
|
560 |
+
ref_lengths,
|
561 |
+
use_ind,
|
562 |
+
s_trg.detach(),
|
563 |
+
ref if multispeaker else None,
|
564 |
+
)
|
565 |
+
|
566 |
+
if slm_out is None:
|
567 |
+
continue
|
568 |
+
|
569 |
+
d_loss_slm, loss_gen_lm, y_pred = slm_out
|
570 |
+
|
571 |
+
# SLM discriminator loss
|
572 |
+
if d_loss_slm != 0:
|
573 |
+
optimizer.zero_grad()
|
574 |
+
d_loss_slm.backward()
|
575 |
+
optimizer.step("wd")
|
576 |
+
|
577 |
+
# SLM generator loss
|
578 |
+
optimizer.zero_grad()
|
579 |
+
loss_gen_lm.backward()
|
580 |
+
|
581 |
+
# compute the gradient norm
|
582 |
+
total_norm = {}
|
583 |
+
for key in model.keys():
|
584 |
+
total_norm[key] = 0
|
585 |
+
parameters = [
|
586 |
+
p
|
587 |
+
for p in model[key].parameters()
|
588 |
+
if p.grad is not None and p.requires_grad
|
589 |
+
]
|
590 |
+
for p in parameters:
|
591 |
+
param_norm = p.grad.detach().data.norm(2)
|
592 |
+
total_norm[key] += param_norm.item() ** 2
|
593 |
+
total_norm[key] = total_norm[key] ** 0.5
|
594 |
+
|
595 |
+
# gradient scaling
|
596 |
+
if total_norm["predictor"] > slmadv_params.thresh:
|
597 |
+
for key in model.keys():
|
598 |
+
for p in model[key].parameters():
|
599 |
+
if p.grad is not None:
|
600 |
+
p.grad *= 1 / total_norm["predictor"]
|
601 |
+
|
602 |
+
for p in model.predictor.duration_proj.parameters():
|
603 |
+
if p.grad is not None:
|
604 |
+
p.grad *= slmadv_params.scale
|
605 |
+
|
606 |
+
for p in model.predictor.lstm.parameters():
|
607 |
+
if p.grad is not None:
|
608 |
+
p.grad *= slmadv_params.scale
|
609 |
+
|
610 |
+
for p in model.diffusion.parameters():
|
611 |
+
if p.grad is not None:
|
612 |
+
p.grad *= slmadv_params.scale
|
613 |
+
|
614 |
+
optimizer.step("bert_encoder")
|
615 |
+
optimizer.step("bert")
|
616 |
+
optimizer.step("predictor")
|
617 |
+
optimizer.step("diffusion")
|
618 |
+
|
619 |
+
else:
|
620 |
+
d_loss_slm, loss_gen_lm = 0, 0
|
621 |
+
|
622 |
+
iters = iters + 1
|
623 |
+
|
624 |
+
if (i + 1) % log_interval == 0:
|
625 |
+
logger.info(
|
626 |
+
"Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f"
|
627 |
+
% (
|
628 |
+
epoch + 1,
|
629 |
+
epochs,
|
630 |
+
i + 1,
|
631 |
+
len(train_list) // batch_size,
|
632 |
+
running_loss / log_interval,
|
633 |
+
d_loss,
|
634 |
+
loss_dur,
|
635 |
+
loss_ce,
|
636 |
+
loss_norm_rec,
|
637 |
+
loss_F0_rec,
|
638 |
+
loss_lm,
|
639 |
+
loss_gen_all,
|
640 |
+
loss_sty,
|
641 |
+
loss_diff,
|
642 |
+
d_loss_slm,
|
643 |
+
loss_gen_lm,
|
644 |
+
s_loss,
|
645 |
+
loss_s2s,
|
646 |
+
loss_mono,
|
647 |
+
)
|
648 |
+
)
|
649 |
+
|
650 |
+
writer.add_scalar("train/mel_loss", running_loss / log_interval, iters)
|
651 |
+
writer.add_scalar("train/gen_loss", loss_gen_all, iters)
|
652 |
+
writer.add_scalar("train/d_loss", d_loss, iters)
|
653 |
+
writer.add_scalar("train/ce_loss", loss_ce, iters)
|
654 |
+
writer.add_scalar("train/dur_loss", loss_dur, iters)
|
655 |
+
writer.add_scalar("train/slm_loss", loss_lm, iters)
|
656 |
+
writer.add_scalar("train/norm_loss", loss_norm_rec, iters)
|
657 |
+
writer.add_scalar("train/F0_loss", loss_F0_rec, iters)
|
658 |
+
writer.add_scalar("train/sty_loss", loss_sty, iters)
|
659 |
+
writer.add_scalar("train/diff_loss", loss_diff, iters)
|
660 |
+
writer.add_scalar("train/d_loss_slm", d_loss_slm, iters)
|
661 |
+
writer.add_scalar("train/gen_loss_slm", loss_gen_lm, iters)
|
662 |
+
|
663 |
+
running_loss = 0
|
664 |
+
|
665 |
+
print("Time elasped:", time.time() - start_time)
|
666 |
+
|
667 |
+
loss_test = 0
|
668 |
+
loss_align = 0
|
669 |
+
loss_f = 0
|
670 |
+
_ = [model[key].eval() for key in model]
|
671 |
+
|
672 |
+
with torch.no_grad():
|
673 |
+
iters_test = 0
|
674 |
+
for batch_idx, batch in enumerate(val_dataloader):
|
675 |
+
optimizer.zero_grad()
|
676 |
+
|
677 |
+
try:
|
678 |
+
waves = batch[0]
|
679 |
+
batch = [b.to(device) for b in batch[1:]]
|
680 |
+
(
|
681 |
+
texts,
|
682 |
+
input_lengths,
|
683 |
+
ref_texts,
|
684 |
+
ref_lengths,
|
685 |
+
mels,
|
686 |
+
mel_input_length,
|
687 |
+
ref_mels,
|
688 |
+
) = batch
|
689 |
+
with torch.no_grad():
|
690 |
+
mask = length_to_mask(mel_input_length // (2**n_down)).to(
|
691 |
+
"cuda"
|
692 |
+
)
|
693 |
+
text_mask = length_to_mask(input_lengths).to(texts.device)
|
694 |
+
|
695 |
+
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
|
696 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
697 |
+
s2s_attn = s2s_attn[..., 1:]
|
698 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
699 |
+
|
700 |
+
mask_ST = mask_from_lens(
|
701 |
+
s2s_attn, input_lengths, mel_input_length // (2**n_down)
|
702 |
+
)
|
703 |
+
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
|
704 |
+
|
705 |
+
# encode
|
706 |
+
t_en = model.text_encoder(texts, input_lengths, text_mask)
|
707 |
+
asr = t_en @ s2s_attn_mono
|
708 |
+
|
709 |
+
d_gt = s2s_attn_mono.sum(axis=-1).detach()
|
710 |
+
|
711 |
+
ss = []
|
712 |
+
gs = []
|
713 |
+
|
714 |
+
for bib in range(len(mel_input_length)):
|
715 |
+
mel_length = int(mel_input_length[bib].item())
|
716 |
+
mel = mels[bib, :, : mel_input_length[bib]]
|
717 |
+
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
|
718 |
+
ss.append(s)
|
719 |
+
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
|
720 |
+
gs.append(s)
|
721 |
+
|
722 |
+
s = torch.stack(ss).squeeze()
|
723 |
+
gs = torch.stack(gs).squeeze()
|
724 |
+
s_trg = torch.cat([s, gs], dim=-1).detach()
|
725 |
+
|
726 |
+
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
|
727 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
728 |
+
d, p = model.predictor(
|
729 |
+
d_en, s, input_lengths, s2s_attn_mono, text_mask
|
730 |
+
)
|
731 |
+
# get clips
|
732 |
+
mel_len = int(mel_input_length.min().item() / 2 - 1)
|
733 |
+
en = []
|
734 |
+
gt = []
|
735 |
+
|
736 |
+
p_en = []
|
737 |
+
wav = []
|
738 |
+
|
739 |
+
for bib in range(len(mel_input_length)):
|
740 |
+
mel_length = int(mel_input_length[bib].item() / 2)
|
741 |
+
|
742 |
+
random_start = np.random.randint(0, mel_length - mel_len)
|
743 |
+
en.append(asr[bib, :, random_start : random_start + mel_len])
|
744 |
+
p_en.append(p[bib, :, random_start : random_start + mel_len])
|
745 |
+
|
746 |
+
gt.append(
|
747 |
+
mels[
|
748 |
+
bib,
|
749 |
+
:,
|
750 |
+
(random_start * 2) : ((random_start + mel_len) * 2),
|
751 |
+
]
|
752 |
+
)
|
753 |
+
y = waves[bib][
|
754 |
+
(random_start * 2)
|
755 |
+
* 300 : ((random_start + mel_len) * 2)
|
756 |
+
* 300
|
757 |
+
]
|
758 |
+
wav.append(torch.from_numpy(y).to(device))
|
759 |
+
|
760 |
+
wav = torch.stack(wav).float().detach()
|
761 |
+
|
762 |
+
en = torch.stack(en)
|
763 |
+
p_en = torch.stack(p_en)
|
764 |
+
gt = torch.stack(gt).detach()
|
765 |
+
s = model.predictor_encoder(gt.unsqueeze(1))
|
766 |
+
|
767 |
+
F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
|
768 |
+
|
769 |
+
loss_dur = 0
|
770 |
+
for _s2s_pred, _text_input, _text_length in zip(
|
771 |
+
d, (d_gt), input_lengths
|
772 |
+
):
|
773 |
+
_s2s_pred = _s2s_pred[:_text_length, :]
|
774 |
+
_text_input = _text_input[:_text_length].long()
|
775 |
+
_s2s_trg = torch.zeros_like(_s2s_pred)
|
776 |
+
for bib in range(_s2s_trg.shape[0]):
|
777 |
+
_s2s_trg[bib, : _text_input[bib]] = 1
|
778 |
+
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
|
779 |
+
loss_dur += F.l1_loss(
|
780 |
+
_dur_pred[1 : _text_length - 1],
|
781 |
+
_text_input[1 : _text_length - 1],
|
782 |
+
)
|
783 |
+
|
784 |
+
loss_dur /= texts.size(0)
|
785 |
+
|
786 |
+
s = model.style_encoder(gt.unsqueeze(1))
|
787 |
+
|
788 |
+
y_rec = model.decoder(en, F0_fake, N_fake, s)
|
789 |
+
loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
|
790 |
+
|
791 |
+
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
792 |
+
|
793 |
+
loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
|
794 |
+
|
795 |
+
loss_test += (loss_mel).mean()
|
796 |
+
loss_align += (loss_dur).mean()
|
797 |
+
loss_f += (loss_F0).mean()
|
798 |
+
|
799 |
+
iters_test += 1
|
800 |
+
except:
|
801 |
+
continue
|
802 |
+
|
803 |
+
print("Epochs:", epoch + 1)
|
804 |
+
logger.info(
|
805 |
+
"Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f"
|
806 |
+
% (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test)
|
807 |
+
+ "\n\n\n"
|
808 |
+
)
|
809 |
+
print("\n\n\n")
|
810 |
+
writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1)
|
811 |
+
writer.add_scalar("eval/dur_loss", loss_test / iters_test, epoch + 1)
|
812 |
+
writer.add_scalar("eval/F0_loss", loss_f / iters_test, epoch + 1)
|
813 |
+
|
814 |
+
if (epoch + 1) % save_freq == 0:
|
815 |
+
if (loss_test / iters_test) < best_loss:
|
816 |
+
best_loss = loss_test / iters_test
|
817 |
+
print("Saving..")
|
818 |
+
state = {
|
819 |
+
"net": {key: model[key].state_dict() for key in model},
|
820 |
+
"optimizer": optimizer.state_dict(),
|
821 |
+
"iters": iters,
|
822 |
+
"val_loss": loss_test / iters_test,
|
823 |
+
"epoch": epoch,
|
824 |
+
}
|
825 |
+
save_path = osp.join(log_dir, "epoch_2nd_%05d.pth" % epoch)
|
826 |
+
torch.save(state, save_path)
|
827 |
+
|
828 |
+
# if estimate sigma, save the estimated simga
|
829 |
+
if model_params.diffusion.dist.estimate_sigma_data:
|
830 |
+
config["model_params"]["diffusion"]["dist"]["sigma_data"] = float(
|
831 |
+
np.mean(running_std)
|
832 |
+
)
|
833 |
+
|
834 |
+
with open(osp.join(log_dir, osp.basename(config_path)), "w") as outfile:
|
835 |
+
yaml.dump(config, outfile, default_flow_style=True)
|
836 |
+
|
837 |
+
|
838 |
+
if __name__ == "__main__":
|
839 |
+
main()
|
train_first.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import yaml
|
6 |
+
import shutil
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import click
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
warnings.simplefilter("ignore")
|
13 |
+
|
14 |
+
# load packages
|
15 |
+
import random
|
16 |
+
import yaml
|
17 |
+
from munch import Munch
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torchaudio
|
23 |
+
import librosa
|
24 |
+
|
25 |
+
from models import *
|
26 |
+
from meldataset import build_dataloader
|
27 |
+
from utils import *
|
28 |
+
from losses import *
|
29 |
+
from optimizers import build_optimizer
|
30 |
+
import time
|
31 |
+
|
32 |
+
from accelerate import Accelerator
|
33 |
+
from accelerate.utils import LoggerType
|
34 |
+
from accelerate import DistributedDataParallelKwargs
|
35 |
+
|
36 |
+
from torch.utils.tensorboard import SummaryWriter
|
37 |
+
|
38 |
+
import logging
|
39 |
+
from accelerate.logging import get_logger
|
40 |
+
|
41 |
+
logger = get_logger(__name__, log_level="DEBUG")
|
42 |
+
|
43 |
+
|
44 |
+
@click.command()
|
45 |
+
@click.option("-p", "--config_path", default="Configs/config.yml", type=str)
|
46 |
+
def main(config_path):
|
47 |
+
config = yaml.safe_load(open(config_path))
|
48 |
+
|
49 |
+
log_dir = config["log_dir"]
|
50 |
+
if not osp.exists(log_dir):
|
51 |
+
os.makedirs(log_dir, exist_ok=True)
|
52 |
+
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
|
53 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
54 |
+
accelerator = Accelerator(
|
55 |
+
project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs]
|
56 |
+
)
|
57 |
+
if accelerator.is_main_process:
|
58 |
+
writer = SummaryWriter(log_dir + "/tensorboard")
|
59 |
+
|
60 |
+
# write logs
|
61 |
+
file_handler = logging.FileHandler(osp.join(log_dir, "train.log"))
|
62 |
+
file_handler.setLevel(logging.DEBUG)
|
63 |
+
file_handler.setFormatter(
|
64 |
+
logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
|
65 |
+
)
|
66 |
+
logger.logger.addHandler(file_handler)
|
67 |
+
|
68 |
+
batch_size = config.get("batch_size", 10)
|
69 |
+
device = accelerator.device
|
70 |
+
|
71 |
+
epochs = config.get("epochs_1st", 200)
|
72 |
+
save_freq = config.get("save_freq", 2)
|
73 |
+
log_interval = config.get("log_interval", 10)
|
74 |
+
saving_epoch = config.get("save_freq", 2)
|
75 |
+
|
76 |
+
data_params = config.get("data_params", None)
|
77 |
+
sr = config["preprocess_params"].get("sr", 24000)
|
78 |
+
train_path = data_params["train_data"]
|
79 |
+
val_path = data_params["val_data"]
|
80 |
+
root_path = data_params["root_path"]
|
81 |
+
min_length = data_params["min_length"]
|
82 |
+
OOD_data = data_params["OOD_data"]
|
83 |
+
|
84 |
+
max_len = config.get("max_len", 200)
|
85 |
+
|
86 |
+
# load data
|
87 |
+
train_list, val_list = get_data_path_list(train_path, val_path)
|
88 |
+
|
89 |
+
train_dataloader = build_dataloader(
|
90 |
+
train_list,
|
91 |
+
root_path,
|
92 |
+
OOD_data=OOD_data,
|
93 |
+
min_length=min_length,
|
94 |
+
batch_size=batch_size,
|
95 |
+
num_workers=2,
|
96 |
+
dataset_config={},
|
97 |
+
device=device,
|
98 |
+
)
|
99 |
+
|
100 |
+
val_dataloader = build_dataloader(
|
101 |
+
val_list,
|
102 |
+
root_path,
|
103 |
+
OOD_data=OOD_data,
|
104 |
+
min_length=min_length,
|
105 |
+
batch_size=batch_size,
|
106 |
+
validation=True,
|
107 |
+
num_workers=0,
|
108 |
+
device=device,
|
109 |
+
dataset_config={},
|
110 |
+
)
|
111 |
+
|
112 |
+
with accelerator.main_process_first():
|
113 |
+
# load pretrained ASR model
|
114 |
+
ASR_config = config.get("ASR_config", False)
|
115 |
+
ASR_path = config.get("ASR_path", False)
|
116 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
117 |
+
|
118 |
+
# load pretrained F0 model
|
119 |
+
F0_path = config.get("F0_path", False)
|
120 |
+
pitch_extractor = load_F0_models(F0_path)
|
121 |
+
|
122 |
+
# load BERT model
|
123 |
+
from Utils.PLBERT.util import load_plbert
|
124 |
+
|
125 |
+
BERT_path = config.get("PLBERT_dir", False)
|
126 |
+
plbert = load_plbert(BERT_path)
|
127 |
+
|
128 |
+
scheduler_params = {
|
129 |
+
"max_lr": float(config["optimizer_params"].get("lr", 1e-4)),
|
130 |
+
"pct_start": float(config["optimizer_params"].get("pct_start", 0.0)),
|
131 |
+
"epochs": epochs,
|
132 |
+
"steps_per_epoch": len(train_dataloader),
|
133 |
+
}
|
134 |
+
|
135 |
+
model_params = recursive_munch(config["model_params"])
|
136 |
+
multispeaker = model_params.multispeaker
|
137 |
+
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
138 |
+
|
139 |
+
best_loss = float("inf") # best test loss
|
140 |
+
loss_train_record = list([])
|
141 |
+
loss_test_record = list([])
|
142 |
+
|
143 |
+
loss_params = Munch(config["loss_params"])
|
144 |
+
TMA_epoch = loss_params.TMA_epoch
|
145 |
+
|
146 |
+
for k in model:
|
147 |
+
model[k] = accelerator.prepare(model[k])
|
148 |
+
|
149 |
+
train_dataloader, val_dataloader = accelerator.prepare(
|
150 |
+
train_dataloader, val_dataloader
|
151 |
+
)
|
152 |
+
|
153 |
+
_ = [model[key].to(device) for key in model]
|
154 |
+
|
155 |
+
# initialize optimizers after preparing models for compatibility with FSDP
|
156 |
+
optimizer = build_optimizer(
|
157 |
+
{key: model[key].parameters() for key in model},
|
158 |
+
scheduler_params_dict={key: scheduler_params.copy() for key in model},
|
159 |
+
lr=float(config["optimizer_params"].get("lr", 1e-4)),
|
160 |
+
)
|
161 |
+
|
162 |
+
for k, v in optimizer.optimizers.items():
|
163 |
+
optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
|
164 |
+
optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
|
165 |
+
|
166 |
+
with accelerator.main_process_first():
|
167 |
+
if config.get("pretrained_model", "") != "":
|
168 |
+
model, optimizer, start_epoch, iters = load_checkpoint(
|
169 |
+
model,
|
170 |
+
optimizer,
|
171 |
+
config["pretrained_model"],
|
172 |
+
load_only_params=config.get("load_only_params", True),
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
start_epoch = 0
|
176 |
+
iters = 0
|
177 |
+
|
178 |
+
# in case not distributed
|
179 |
+
try:
|
180 |
+
n_down = model.text_aligner.module.n_down
|
181 |
+
except:
|
182 |
+
n_down = model.text_aligner.n_down
|
183 |
+
|
184 |
+
# wrapped losses for compatibility with mixed precision
|
185 |
+
stft_loss = MultiResolutionSTFTLoss().to(device)
|
186 |
+
gl = GeneratorLoss(model.mpd, model.msd).to(device)
|
187 |
+
dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
|
188 |
+
wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device)
|
189 |
+
|
190 |
+
for epoch in range(start_epoch, epochs):
|
191 |
+
running_loss = 0
|
192 |
+
start_time = time.time()
|
193 |
+
|
194 |
+
_ = [model[key].train() for key in model]
|
195 |
+
|
196 |
+
for i, batch in enumerate(train_dataloader):
|
197 |
+
waves = batch[0]
|
198 |
+
batch = [b.to(device) for b in batch[1:]]
|
199 |
+
texts, input_lengths, _, _, mels, mel_input_length, _ = batch
|
200 |
+
|
201 |
+
with torch.no_grad():
|
202 |
+
mask = length_to_mask(mel_input_length // (2**n_down)).to("cuda")
|
203 |
+
text_mask = length_to_mask(input_lengths).to(texts.device)
|
204 |
+
|
205 |
+
ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
|
206 |
+
|
207 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
208 |
+
s2s_attn = s2s_attn[..., 1:]
|
209 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
210 |
+
|
211 |
+
with torch.no_grad():
|
212 |
+
attn_mask = (
|
213 |
+
(~mask)
|
214 |
+
.unsqueeze(-1)
|
215 |
+
.expand(mask.shape[0], mask.shape[1], text_mask.shape[-1])
|
216 |
+
.float()
|
217 |
+
.transpose(-1, -2)
|
218 |
+
)
|
219 |
+
attn_mask = (
|
220 |
+
attn_mask.float()
|
221 |
+
* (~text_mask)
|
222 |
+
.unsqueeze(-1)
|
223 |
+
.expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1])
|
224 |
+
.float()
|
225 |
+
)
|
226 |
+
attn_mask = attn_mask < 1
|
227 |
+
|
228 |
+
s2s_attn.masked_fill_(attn_mask, 0.0)
|
229 |
+
|
230 |
+
with torch.no_grad():
|
231 |
+
mask_ST = mask_from_lens(
|
232 |
+
s2s_attn, input_lengths, mel_input_length // (2**n_down)
|
233 |
+
)
|
234 |
+
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
|
235 |
+
|
236 |
+
# encode
|
237 |
+
t_en = model.text_encoder(texts, input_lengths, text_mask)
|
238 |
+
|
239 |
+
# 50% of chance of using monotonic version
|
240 |
+
if bool(random.getrandbits(1)):
|
241 |
+
asr = t_en @ s2s_attn
|
242 |
+
else:
|
243 |
+
asr = t_en @ s2s_attn_mono
|
244 |
+
|
245 |
+
# get clips
|
246 |
+
mel_input_length_all = accelerator.gather(
|
247 |
+
mel_input_length
|
248 |
+
) # for balanced load
|
249 |
+
mel_len = min(
|
250 |
+
[int(mel_input_length_all.min().item() / 2 - 1), max_len // 2]
|
251 |
+
)
|
252 |
+
mel_len_st = int(mel_input_length.min().item() / 2 - 1)
|
253 |
+
|
254 |
+
en = []
|
255 |
+
gt = []
|
256 |
+
wav = []
|
257 |
+
st = []
|
258 |
+
|
259 |
+
for bib in range(len(mel_input_length)):
|
260 |
+
mel_length = int(mel_input_length[bib].item() / 2)
|
261 |
+
|
262 |
+
random_start = np.random.randint(0, mel_length - mel_len)
|
263 |
+
en.append(asr[bib, :, random_start : random_start + mel_len])
|
264 |
+
gt.append(
|
265 |
+
mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)]
|
266 |
+
)
|
267 |
+
|
268 |
+
y = waves[bib][
|
269 |
+
(random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
|
270 |
+
]
|
271 |
+
wav.append(torch.from_numpy(y).to(device))
|
272 |
+
|
273 |
+
# style reference (better to be different from the GT)
|
274 |
+
random_start = np.random.randint(0, mel_length - mel_len_st)
|
275 |
+
st.append(
|
276 |
+
mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)]
|
277 |
+
)
|
278 |
+
|
279 |
+
en = torch.stack(en)
|
280 |
+
gt = torch.stack(gt).detach()
|
281 |
+
st = torch.stack(st).detach()
|
282 |
+
|
283 |
+
wav = torch.stack(wav).float().detach()
|
284 |
+
|
285 |
+
# clip too short to be used by the style encoder
|
286 |
+
if gt.shape[-1] < 80:
|
287 |
+
continue
|
288 |
+
|
289 |
+
with torch.no_grad():
|
290 |
+
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
|
291 |
+
F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
|
292 |
+
|
293 |
+
s = model.style_encoder(
|
294 |
+
st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)
|
295 |
+
)
|
296 |
+
|
297 |
+
y_rec = model.decoder(en, F0_real, real_norm, s)
|
298 |
+
|
299 |
+
# discriminator loss
|
300 |
+
|
301 |
+
if epoch >= TMA_epoch:
|
302 |
+
optimizer.zero_grad()
|
303 |
+
d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
|
304 |
+
accelerator.backward(d_loss)
|
305 |
+
optimizer.step("msd")
|
306 |
+
optimizer.step("mpd")
|
307 |
+
else:
|
308 |
+
d_loss = 0
|
309 |
+
|
310 |
+
# generator loss
|
311 |
+
optimizer.zero_grad()
|
312 |
+
loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
|
313 |
+
|
314 |
+
if epoch >= TMA_epoch: # start TMA training
|
315 |
+
loss_s2s = 0
|
316 |
+
for _s2s_pred, _text_input, _text_length in zip(
|
317 |
+
s2s_pred, texts, input_lengths
|
318 |
+
):
|
319 |
+
loss_s2s += F.cross_entropy(
|
320 |
+
_s2s_pred[:_text_length], _text_input[:_text_length]
|
321 |
+
)
|
322 |
+
loss_s2s /= texts.size(0)
|
323 |
+
|
324 |
+
loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
|
325 |
+
|
326 |
+
loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
|
327 |
+
loss_slm = wl(wav.detach(), y_rec).mean()
|
328 |
+
|
329 |
+
g_loss = (
|
330 |
+
loss_params.lambda_mel * loss_mel
|
331 |
+
+ loss_params.lambda_mono * loss_mono
|
332 |
+
+ loss_params.lambda_s2s * loss_s2s
|
333 |
+
+ loss_params.lambda_gen * loss_gen_all
|
334 |
+
+ loss_params.lambda_slm * loss_slm
|
335 |
+
)
|
336 |
+
|
337 |
+
else:
|
338 |
+
loss_s2s = 0
|
339 |
+
loss_mono = 0
|
340 |
+
loss_gen_all = 0
|
341 |
+
loss_slm = 0
|
342 |
+
g_loss = loss_mel
|
343 |
+
|
344 |
+
running_loss += accelerator.gather(loss_mel).mean().item()
|
345 |
+
|
346 |
+
accelerator.backward(g_loss)
|
347 |
+
|
348 |
+
optimizer.step("text_encoder")
|
349 |
+
optimizer.step("style_encoder")
|
350 |
+
optimizer.step("decoder")
|
351 |
+
|
352 |
+
if epoch >= TMA_epoch:
|
353 |
+
optimizer.step("text_aligner")
|
354 |
+
optimizer.step("pitch_extractor")
|
355 |
+
|
356 |
+
iters = iters + 1
|
357 |
+
|
358 |
+
if (i + 1) % log_interval == 0 and accelerator.is_main_process:
|
359 |
+
log_print(
|
360 |
+
"Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f"
|
361 |
+
% (
|
362 |
+
epoch + 1,
|
363 |
+
epochs,
|
364 |
+
i + 1,
|
365 |
+
len(train_list) // batch_size,
|
366 |
+
running_loss / log_interval,
|
367 |
+
loss_gen_all,
|
368 |
+
d_loss,
|
369 |
+
loss_mono,
|
370 |
+
loss_s2s,
|
371 |
+
loss_slm,
|
372 |
+
),
|
373 |
+
logger,
|
374 |
+
)
|
375 |
+
|
376 |
+
writer.add_scalar("train/mel_loss", running_loss / log_interval, iters)
|
377 |
+
writer.add_scalar("train/gen_loss", loss_gen_all, iters)
|
378 |
+
writer.add_scalar("train/d_loss", d_loss, iters)
|
379 |
+
writer.add_scalar("train/mono_loss", loss_mono, iters)
|
380 |
+
writer.add_scalar("train/s2s_loss", loss_s2s, iters)
|
381 |
+
writer.add_scalar("train/slm_loss", loss_slm, iters)
|
382 |
+
|
383 |
+
running_loss = 0
|
384 |
+
|
385 |
+
print("Time elasped:", time.time() - start_time)
|
386 |
+
|
387 |
+
loss_test = 0
|
388 |
+
|
389 |
+
_ = [model[key].eval() for key in model]
|
390 |
+
|
391 |
+
with torch.no_grad():
|
392 |
+
iters_test = 0
|
393 |
+
for batch_idx, batch in enumerate(val_dataloader):
|
394 |
+
optimizer.zero_grad()
|
395 |
+
|
396 |
+
waves = batch[0]
|
397 |
+
batch = [b.to(device) for b in batch[1:]]
|
398 |
+
texts, input_lengths, _, _, mels, mel_input_length, _ = batch
|
399 |
+
|
400 |
+
with torch.no_grad():
|
401 |
+
mask = length_to_mask(mel_input_length // (2**n_down)).to("cuda")
|
402 |
+
ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
|
403 |
+
|
404 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
405 |
+
s2s_attn = s2s_attn[..., 1:]
|
406 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
407 |
+
|
408 |
+
text_mask = length_to_mask(input_lengths).to(texts.device)
|
409 |
+
attn_mask = (
|
410 |
+
(~mask)
|
411 |
+
.unsqueeze(-1)
|
412 |
+
.expand(mask.shape[0], mask.shape[1], text_mask.shape[-1])
|
413 |
+
.float()
|
414 |
+
.transpose(-1, -2)
|
415 |
+
)
|
416 |
+
attn_mask = (
|
417 |
+
attn_mask.float()
|
418 |
+
* (~text_mask)
|
419 |
+
.unsqueeze(-1)
|
420 |
+
.expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1])
|
421 |
+
.float()
|
422 |
+
)
|
423 |
+
attn_mask = attn_mask < 1
|
424 |
+
s2s_attn.masked_fill_(attn_mask, 0.0)
|
425 |
+
|
426 |
+
# encode
|
427 |
+
t_en = model.text_encoder(texts, input_lengths, text_mask)
|
428 |
+
|
429 |
+
asr = t_en @ s2s_attn
|
430 |
+
|
431 |
+
# get clips
|
432 |
+
mel_input_length_all = accelerator.gather(
|
433 |
+
mel_input_length
|
434 |
+
) # for balanced load
|
435 |
+
mel_len = min(
|
436 |
+
[int(mel_input_length.min().item() / 2 - 1), max_len // 2]
|
437 |
+
)
|
438 |
+
|
439 |
+
en = []
|
440 |
+
gt = []
|
441 |
+
wav = []
|
442 |
+
for bib in range(len(mel_input_length)):
|
443 |
+
mel_length = int(mel_input_length[bib].item() / 2)
|
444 |
+
|
445 |
+
random_start = np.random.randint(0, mel_length - mel_len)
|
446 |
+
en.append(asr[bib, :, random_start : random_start + mel_len])
|
447 |
+
gt.append(
|
448 |
+
mels[
|
449 |
+
bib, :, (random_start * 2) : ((random_start + mel_len) * 2)
|
450 |
+
]
|
451 |
+
)
|
452 |
+
y = waves[bib][
|
453 |
+
(random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
|
454 |
+
]
|
455 |
+
wav.append(torch.from_numpy(y).to("cuda"))
|
456 |
+
|
457 |
+
wav = torch.stack(wav).float().detach()
|
458 |
+
|
459 |
+
en = torch.stack(en)
|
460 |
+
gt = torch.stack(gt).detach()
|
461 |
+
|
462 |
+
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
463 |
+
s = model.style_encoder(gt.unsqueeze(1))
|
464 |
+
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
|
465 |
+
y_rec = model.decoder(en, F0_real, real_norm, s)
|
466 |
+
|
467 |
+
loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
|
468 |
+
|
469 |
+
loss_test += accelerator.gather(loss_mel).mean().item()
|
470 |
+
iters_test += 1
|
471 |
+
|
472 |
+
if accelerator.is_main_process:
|
473 |
+
print("Epochs:", epoch + 1)
|
474 |
+
log_print(
|
475 |
+
"Validation loss: %.3f" % (loss_test / iters_test) + "\n\n\n\n", logger
|
476 |
+
)
|
477 |
+
print("\n\n\n")
|
478 |
+
writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1)
|
479 |
+
attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze())
|
480 |
+
writer.add_figure("eval/attn", attn_image, epoch)
|
481 |
+
|
482 |
+
with torch.no_grad():
|
483 |
+
for bib in range(len(asr)):
|
484 |
+
mel_length = int(mel_input_length[bib].item())
|
485 |
+
gt = mels[bib, :, :mel_length].unsqueeze(0)
|
486 |
+
en = asr[bib, :, : mel_length // 2].unsqueeze(0)
|
487 |
+
|
488 |
+
F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
|
489 |
+
F0_real = F0_real.unsqueeze(0)
|
490 |
+
s = model.style_encoder(gt.unsqueeze(1))
|
491 |
+
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
|
492 |
+
|
493 |
+
y_rec = model.decoder(en, F0_real, real_norm, s)
|
494 |
+
|
495 |
+
writer.add_audio(
|
496 |
+
"eval/y" + str(bib),
|
497 |
+
y_rec.cpu().numpy().squeeze(),
|
498 |
+
epoch,
|
499 |
+
sample_rate=sr,
|
500 |
+
)
|
501 |
+
if epoch == 0:
|
502 |
+
writer.add_audio(
|
503 |
+
"gt/y" + str(bib),
|
504 |
+
waves[bib].squeeze(),
|
505 |
+
epoch,
|
506 |
+
sample_rate=sr,
|
507 |
+
)
|
508 |
+
|
509 |
+
if bib >= 6:
|
510 |
+
break
|
511 |
+
|
512 |
+
if epoch % saving_epoch == 0:
|
513 |
+
if (loss_test / iters_test) < best_loss:
|
514 |
+
best_loss = loss_test / iters_test
|
515 |
+
print("Saving..")
|
516 |
+
state = {
|
517 |
+
"net": {key: model[key].state_dict() for key in model},
|
518 |
+
"optimizer": optimizer.state_dict(),
|
519 |
+
"iters": iters,
|
520 |
+
"val_loss": loss_test / iters_test,
|
521 |
+
"epoch": epoch,
|
522 |
+
}
|
523 |
+
save_path = osp.join(log_dir, "epoch_1st_%05d.pth" % epoch)
|
524 |
+
torch.save(state, save_path)
|
525 |
+
|
526 |
+
if accelerator.is_main_process:
|
527 |
+
print("Saving..")
|
528 |
+
state = {
|
529 |
+
"net": {key: model[key].state_dict() for key in model},
|
530 |
+
"optimizer": optimizer.state_dict(),
|
531 |
+
"iters": iters,
|
532 |
+
"val_loss": loss_test / iters_test,
|
533 |
+
"epoch": epoch,
|
534 |
+
}
|
535 |
+
save_path = osp.join(log_dir, config.get("first_stage_path", "first_stage.pth"))
|
536 |
+
torch.save(state, save_path)
|
537 |
+
|
538 |
+
|
539 |
+
if __name__ == "__main__":
|
540 |
+
main()
|
train_second.py
ADDED
@@ -0,0 +1,958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# load packages
|
2 |
+
import random
|
3 |
+
import yaml
|
4 |
+
import time
|
5 |
+
from munch import Munch
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchaudio
|
11 |
+
import librosa
|
12 |
+
import click
|
13 |
+
import shutil
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
warnings.simplefilter("ignore")
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
from meldataset import build_dataloader
|
20 |
+
|
21 |
+
from Utils.ASR.models import ASRCNN
|
22 |
+
from Utils.JDC.model import JDCNet
|
23 |
+
from Utils.PLBERT.util import load_plbert
|
24 |
+
|
25 |
+
from models import *
|
26 |
+
from losses import *
|
27 |
+
from utils import *
|
28 |
+
|
29 |
+
from Modules.slmadv import SLMAdversarialLoss
|
30 |
+
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
31 |
+
|
32 |
+
from optimizers import build_optimizer
|
33 |
+
|
34 |
+
|
35 |
+
# simple fix for dataparallel that allows access to class attributes
|
36 |
+
class MyDataParallel(torch.nn.DataParallel):
|
37 |
+
def __getattr__(self, name):
|
38 |
+
try:
|
39 |
+
return super().__getattr__(name)
|
40 |
+
except AttributeError:
|
41 |
+
return getattr(self.module, name)
|
42 |
+
|
43 |
+
|
44 |
+
import logging
|
45 |
+
from logging import StreamHandler
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
logger.setLevel(logging.DEBUG)
|
49 |
+
handler = StreamHandler()
|
50 |
+
handler.setLevel(logging.DEBUG)
|
51 |
+
logger.addHandler(handler)
|
52 |
+
|
53 |
+
|
54 |
+
@click.command()
|
55 |
+
@click.option("-p", "--config_path", default="Configs/config.yml", type=str)
|
56 |
+
def main(config_path):
|
57 |
+
config = yaml.safe_load(open(config_path))
|
58 |
+
|
59 |
+
log_dir = config["log_dir"]
|
60 |
+
if not osp.exists(log_dir):
|
61 |
+
os.makedirs(log_dir, exist_ok=True)
|
62 |
+
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
|
63 |
+
writer = SummaryWriter(log_dir + "/tensorboard")
|
64 |
+
|
65 |
+
# write logs
|
66 |
+
file_handler = logging.FileHandler(osp.join(log_dir, "train.log"))
|
67 |
+
file_handler.setLevel(logging.DEBUG)
|
68 |
+
file_handler.setFormatter(
|
69 |
+
logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
|
70 |
+
)
|
71 |
+
logger.addHandler(file_handler)
|
72 |
+
|
73 |
+
batch_size = config.get("batch_size", 10)
|
74 |
+
|
75 |
+
epochs = config.get("epochs_2nd", 200)
|
76 |
+
save_freq = config.get("save_freq", 2)
|
77 |
+
log_interval = config.get("log_interval", 10)
|
78 |
+
saving_epoch = config.get("save_freq", 2)
|
79 |
+
|
80 |
+
data_params = config.get("data_params", None)
|
81 |
+
sr = config["preprocess_params"].get("sr", 24000)
|
82 |
+
train_path = data_params["train_data"]
|
83 |
+
val_path = data_params["val_data"]
|
84 |
+
root_path = data_params["root_path"]
|
85 |
+
min_length = data_params["min_length"]
|
86 |
+
OOD_data = data_params["OOD_data"]
|
87 |
+
|
88 |
+
max_len = config.get("max_len", 200)
|
89 |
+
|
90 |
+
loss_params = Munch(config["loss_params"])
|
91 |
+
diff_epoch = loss_params.diff_epoch
|
92 |
+
joint_epoch = loss_params.joint_epoch
|
93 |
+
|
94 |
+
optimizer_params = Munch(config["optimizer_params"])
|
95 |
+
|
96 |
+
train_list, val_list = get_data_path_list(train_path, val_path)
|
97 |
+
device = "cuda"
|
98 |
+
|
99 |
+
train_dataloader = build_dataloader(
|
100 |
+
train_list,
|
101 |
+
root_path,
|
102 |
+
OOD_data=OOD_data,
|
103 |
+
min_length=min_length,
|
104 |
+
batch_size=batch_size,
|
105 |
+
num_workers=2,
|
106 |
+
dataset_config={},
|
107 |
+
device=device,
|
108 |
+
)
|
109 |
+
|
110 |
+
val_dataloader = build_dataloader(
|
111 |
+
val_list,
|
112 |
+
root_path,
|
113 |
+
OOD_data=OOD_data,
|
114 |
+
min_length=min_length,
|
115 |
+
batch_size=batch_size,
|
116 |
+
validation=True,
|
117 |
+
num_workers=0,
|
118 |
+
device=device,
|
119 |
+
dataset_config={},
|
120 |
+
)
|
121 |
+
|
122 |
+
# load pretrained ASR model
|
123 |
+
ASR_config = config.get("ASR_config", False)
|
124 |
+
ASR_path = config.get("ASR_path", False)
|
125 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
126 |
+
|
127 |
+
# load pretrained F0 model
|
128 |
+
F0_path = config.get("F0_path", False)
|
129 |
+
pitch_extractor = load_F0_models(F0_path)
|
130 |
+
|
131 |
+
# load PL-BERT model
|
132 |
+
BERT_path = config.get("PLBERT_dir", False)
|
133 |
+
plbert = load_plbert(BERT_path)
|
134 |
+
|
135 |
+
# build model
|
136 |
+
model_params = recursive_munch(config["model_params"])
|
137 |
+
multispeaker = model_params.multispeaker
|
138 |
+
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
139 |
+
_ = [model[key].to(device) for key in model]
|
140 |
+
|
141 |
+
# DP
|
142 |
+
for key in model:
|
143 |
+
if key != "mpd" and key != "msd" and key != "wd":
|
144 |
+
model[key] = MyDataParallel(model[key])
|
145 |
+
|
146 |
+
start_epoch = 0
|
147 |
+
iters = 0
|
148 |
+
|
149 |
+
load_pretrained = config.get("pretrained_model", "") != "" and config.get(
|
150 |
+
"second_stage_load_pretrained", False
|
151 |
+
)
|
152 |
+
|
153 |
+
if not load_pretrained:
|
154 |
+
if config.get("first_stage_path", "") != "":
|
155 |
+
first_stage_path = osp.join(
|
156 |
+
log_dir, config.get("first_stage_path", "first_stage.pth")
|
157 |
+
)
|
158 |
+
print("Loading the first stage model at %s ..." % first_stage_path)
|
159 |
+
model, _, start_epoch, iters = load_checkpoint(
|
160 |
+
model,
|
161 |
+
None,
|
162 |
+
first_stage_path,
|
163 |
+
load_only_params=True,
|
164 |
+
ignore_modules=[
|
165 |
+
"bert",
|
166 |
+
"bert_encoder",
|
167 |
+
"predictor",
|
168 |
+
"predictor_encoder",
|
169 |
+
"msd",
|
170 |
+
"mpd",
|
171 |
+
"wd",
|
172 |
+
"diffusion",
|
173 |
+
],
|
174 |
+
) # keep starting epoch for tensorboard log
|
175 |
+
|
176 |
+
# these epochs should be counted from the start epoch
|
177 |
+
diff_epoch += start_epoch
|
178 |
+
joint_epoch += start_epoch
|
179 |
+
epochs += start_epoch
|
180 |
+
|
181 |
+
model.predictor_encoder = copy.deepcopy(model.style_encoder)
|
182 |
+
else:
|
183 |
+
raise ValueError("You need to specify the path to the first stage model.")
|
184 |
+
|
185 |
+
gl = GeneratorLoss(model.mpd, model.msd).to(device)
|
186 |
+
dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
|
187 |
+
wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device)
|
188 |
+
|
189 |
+
gl = MyDataParallel(gl)
|
190 |
+
dl = MyDataParallel(dl)
|
191 |
+
wl = MyDataParallel(wl)
|
192 |
+
|
193 |
+
sampler = DiffusionSampler(
|
194 |
+
model.diffusion.diffusion,
|
195 |
+
sampler=ADPM2Sampler(),
|
196 |
+
sigma_schedule=KarrasSchedule(
|
197 |
+
sigma_min=0.0001, sigma_max=3.0, rho=9.0
|
198 |
+
), # empirical parameters
|
199 |
+
clamp=False,
|
200 |
+
)
|
201 |
+
|
202 |
+
scheduler_params = {
|
203 |
+
"max_lr": optimizer_params.lr,
|
204 |
+
"pct_start": float(0),
|
205 |
+
"epochs": epochs,
|
206 |
+
"steps_per_epoch": len(train_dataloader),
|
207 |
+
}
|
208 |
+
scheduler_params_dict = {key: scheduler_params.copy() for key in model}
|
209 |
+
scheduler_params_dict["bert"]["max_lr"] = optimizer_params.bert_lr * 2
|
210 |
+
scheduler_params_dict["decoder"]["max_lr"] = optimizer_params.ft_lr * 2
|
211 |
+
scheduler_params_dict["style_encoder"]["max_lr"] = optimizer_params.ft_lr * 2
|
212 |
+
|
213 |
+
optimizer = build_optimizer(
|
214 |
+
{key: model[key].parameters() for key in model},
|
215 |
+
scheduler_params_dict=scheduler_params_dict,
|
216 |
+
lr=optimizer_params.lr,
|
217 |
+
)
|
218 |
+
|
219 |
+
# adjust BERT learning rate
|
220 |
+
for g in optimizer.optimizers["bert"].param_groups:
|
221 |
+
g["betas"] = (0.9, 0.99)
|
222 |
+
g["lr"] = optimizer_params.bert_lr
|
223 |
+
g["initial_lr"] = optimizer_params.bert_lr
|
224 |
+
g["min_lr"] = 0
|
225 |
+
g["weight_decay"] = 0.01
|
226 |
+
|
227 |
+
# adjust acoustic module learning rate
|
228 |
+
for module in ["decoder", "style_encoder"]:
|
229 |
+
for g in optimizer.optimizers[module].param_groups:
|
230 |
+
g["betas"] = (0.0, 0.99)
|
231 |
+
g["lr"] = optimizer_params.ft_lr
|
232 |
+
g["initial_lr"] = optimizer_params.ft_lr
|
233 |
+
g["min_lr"] = 0
|
234 |
+
g["weight_decay"] = 1e-4
|
235 |
+
|
236 |
+
# load models if there is a model
|
237 |
+
if load_pretrained:
|
238 |
+
model, optimizer, start_epoch, iters = load_checkpoint(
|
239 |
+
model,
|
240 |
+
optimizer,
|
241 |
+
config["pretrained_model"],
|
242 |
+
load_only_params=config.get("load_only_params", True),
|
243 |
+
)
|
244 |
+
|
245 |
+
n_down = model.text_aligner.n_down
|
246 |
+
|
247 |
+
best_loss = float("inf") # best test loss
|
248 |
+
loss_train_record = list([])
|
249 |
+
loss_test_record = list([])
|
250 |
+
iters = 0
|
251 |
+
|
252 |
+
criterion = nn.L1Loss() # F0 loss (regression)
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
|
255 |
+
stft_loss = MultiResolutionSTFTLoss().to(device)
|
256 |
+
|
257 |
+
print("BERT", optimizer.optimizers["bert"])
|
258 |
+
print("decoder", optimizer.optimizers["decoder"])
|
259 |
+
|
260 |
+
start_ds = False
|
261 |
+
|
262 |
+
running_std = []
|
263 |
+
|
264 |
+
slmadv_params = Munch(config["slmadv_params"])
|
265 |
+
slmadv = SLMAdversarialLoss(
|
266 |
+
model,
|
267 |
+
wl,
|
268 |
+
sampler,
|
269 |
+
slmadv_params.min_len,
|
270 |
+
slmadv_params.max_len,
|
271 |
+
batch_percentage=slmadv_params.batch_percentage,
|
272 |
+
skip_update=slmadv_params.iter,
|
273 |
+
sig=slmadv_params.sig,
|
274 |
+
)
|
275 |
+
|
276 |
+
for epoch in range(start_epoch, epochs):
|
277 |
+
running_loss = 0
|
278 |
+
start_time = time.time()
|
279 |
+
|
280 |
+
_ = [model[key].eval() for key in model]
|
281 |
+
|
282 |
+
model.predictor.train()
|
283 |
+
model.bert_encoder.train()
|
284 |
+
model.bert.train()
|
285 |
+
model.msd.train()
|
286 |
+
model.mpd.train()
|
287 |
+
|
288 |
+
if epoch >= diff_epoch:
|
289 |
+
start_ds = True
|
290 |
+
|
291 |
+
for i, batch in enumerate(train_dataloader):
|
292 |
+
waves = batch[0]
|
293 |
+
batch = [b.to(device) for b in batch[1:]]
|
294 |
+
(
|
295 |
+
texts,
|
296 |
+
input_lengths,
|
297 |
+
ref_texts,
|
298 |
+
ref_lengths,
|
299 |
+
mels,
|
300 |
+
mel_input_length,
|
301 |
+
ref_mels,
|
302 |
+
) = batch
|
303 |
+
|
304 |
+
with torch.no_grad():
|
305 |
+
mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
|
306 |
+
mel_mask = length_to_mask(mel_input_length).to(device)
|
307 |
+
text_mask = length_to_mask(input_lengths).to(texts.device)
|
308 |
+
|
309 |
+
try:
|
310 |
+
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
|
311 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
312 |
+
s2s_attn = s2s_attn[..., 1:]
|
313 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
314 |
+
except:
|
315 |
+
continue
|
316 |
+
|
317 |
+
mask_ST = mask_from_lens(
|
318 |
+
s2s_attn, input_lengths, mel_input_length // (2**n_down)
|
319 |
+
)
|
320 |
+
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
|
321 |
+
|
322 |
+
# encode
|
323 |
+
t_en = model.text_encoder(texts, input_lengths, text_mask)
|
324 |
+
asr = t_en @ s2s_attn_mono
|
325 |
+
|
326 |
+
d_gt = s2s_attn_mono.sum(axis=-1).detach()
|
327 |
+
|
328 |
+
# compute reference styles
|
329 |
+
if multispeaker and epoch >= diff_epoch:
|
330 |
+
ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
|
331 |
+
ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
|
332 |
+
ref = torch.cat([ref_ss, ref_sp], dim=1)
|
333 |
+
|
334 |
+
# compute the style of the entire utterance
|
335 |
+
# this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
|
336 |
+
ss = []
|
337 |
+
gs = []
|
338 |
+
for bib in range(len(mel_input_length)):
|
339 |
+
mel_length = int(mel_input_length[bib].item())
|
340 |
+
mel = mels[bib, :, : mel_input_length[bib]]
|
341 |
+
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
|
342 |
+
ss.append(s)
|
343 |
+
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
|
344 |
+
gs.append(s)
|
345 |
+
|
346 |
+
s_dur = torch.stack(ss).squeeze() # global prosodic styles
|
347 |
+
gs = torch.stack(gs).squeeze() # global acoustic styles
|
348 |
+
s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
|
349 |
+
|
350 |
+
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
|
351 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
352 |
+
|
353 |
+
# denoiser training
|
354 |
+
if epoch >= diff_epoch:
|
355 |
+
num_steps = np.random.randint(3, 5)
|
356 |
+
|
357 |
+
if model_params.diffusion.dist.estimate_sigma_data:
|
358 |
+
model.diffusion.module.diffusion.sigma_data = (
|
359 |
+
s_trg.std(axis=-1).mean().item()
|
360 |
+
) # batch-wise std estimation
|
361 |
+
running_std.append(model.diffusion.module.diffusion.sigma_data)
|
362 |
+
|
363 |
+
if multispeaker:
|
364 |
+
s_preds = sampler(
|
365 |
+
noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
|
366 |
+
embedding=bert_dur,
|
367 |
+
embedding_scale=1,
|
368 |
+
features=ref, # reference from the same speaker as the embedding
|
369 |
+
embedding_mask_proba=0.1,
|
370 |
+
num_steps=num_steps,
|
371 |
+
).squeeze(1)
|
372 |
+
loss_diff = model.diffusion(
|
373 |
+
s_trg.unsqueeze(1), embedding=bert_dur, features=ref
|
374 |
+
).mean() # EDM loss
|
375 |
+
loss_sty = F.l1_loss(
|
376 |
+
s_preds, s_trg.detach()
|
377 |
+
) # style reconstruction loss
|
378 |
+
else:
|
379 |
+
s_preds = sampler(
|
380 |
+
noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
|
381 |
+
embedding=bert_dur,
|
382 |
+
embedding_scale=1,
|
383 |
+
embedding_mask_proba=0.1,
|
384 |
+
num_steps=num_steps,
|
385 |
+
).squeeze(1)
|
386 |
+
loss_diff = model.diffusion.module.diffusion(
|
387 |
+
s_trg.unsqueeze(1), embedding=bert_dur
|
388 |
+
).mean() # EDM loss
|
389 |
+
loss_sty = F.l1_loss(
|
390 |
+
s_preds, s_trg.detach()
|
391 |
+
) # style reconstruction loss
|
392 |
+
else:
|
393 |
+
loss_sty = 0
|
394 |
+
loss_diff = 0
|
395 |
+
|
396 |
+
d, p = model.predictor(d_en, s_dur, input_lengths, s2s_attn_mono, text_mask)
|
397 |
+
|
398 |
+
mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
|
399 |
+
mel_len_st = int(mel_input_length.min().item() / 2 - 1)
|
400 |
+
en = []
|
401 |
+
gt = []
|
402 |
+
st = []
|
403 |
+
p_en = []
|
404 |
+
wav = []
|
405 |
+
|
406 |
+
for bib in range(len(mel_input_length)):
|
407 |
+
mel_length = int(mel_input_length[bib].item() / 2)
|
408 |
+
|
409 |
+
random_start = np.random.randint(0, mel_length - mel_len)
|
410 |
+
en.append(asr[bib, :, random_start : random_start + mel_len])
|
411 |
+
p_en.append(p[bib, :, random_start : random_start + mel_len])
|
412 |
+
gt.append(
|
413 |
+
mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)]
|
414 |
+
)
|
415 |
+
|
416 |
+
y = waves[bib][
|
417 |
+
(random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
|
418 |
+
]
|
419 |
+
wav.append(torch.from_numpy(y).to(device))
|
420 |
+
|
421 |
+
# style reference (better to be different from the GT)
|
422 |
+
random_start = np.random.randint(0, mel_length - mel_len_st)
|
423 |
+
st.append(
|
424 |
+
mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)]
|
425 |
+
)
|
426 |
+
|
427 |
+
wav = torch.stack(wav).float().detach()
|
428 |
+
|
429 |
+
en = torch.stack(en)
|
430 |
+
p_en = torch.stack(p_en)
|
431 |
+
gt = torch.stack(gt).detach()
|
432 |
+
st = torch.stack(st).detach()
|
433 |
+
|
434 |
+
if gt.size(-1) < 80:
|
435 |
+
continue
|
436 |
+
|
437 |
+
s_dur = model.predictor_encoder(
|
438 |
+
st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)
|
439 |
+
)
|
440 |
+
s = model.style_encoder(
|
441 |
+
st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)
|
442 |
+
)
|
443 |
+
|
444 |
+
with torch.no_grad():
|
445 |
+
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
446 |
+
F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
|
447 |
+
|
448 |
+
asr_real = model.text_aligner.get_feature(gt)
|
449 |
+
|
450 |
+
N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
|
451 |
+
|
452 |
+
y_rec_gt = wav.unsqueeze(1)
|
453 |
+
y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
|
454 |
+
|
455 |
+
if epoch >= joint_epoch:
|
456 |
+
# ground truth from recording
|
457 |
+
wav = y_rec_gt # use recording since decoder is tuned
|
458 |
+
else:
|
459 |
+
# ground truth from reconstruction
|
460 |
+
wav = y_rec_gt_pred # use reconstruction since decoder is fixed
|
461 |
+
|
462 |
+
F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
|
463 |
+
|
464 |
+
y_rec = model.decoder(en, F0_fake, N_fake, s)
|
465 |
+
|
466 |
+
loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
|
467 |
+
loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
|
468 |
+
|
469 |
+
if start_ds:
|
470 |
+
optimizer.zero_grad()
|
471 |
+
d_loss = dl(wav.detach(), y_rec.detach()).mean()
|
472 |
+
d_loss.backward()
|
473 |
+
optimizer.step("msd")
|
474 |
+
optimizer.step("mpd")
|
475 |
+
else:
|
476 |
+
d_loss = 0
|
477 |
+
|
478 |
+
# generator loss
|
479 |
+
optimizer.zero_grad()
|
480 |
+
|
481 |
+
loss_mel = stft_loss(y_rec, wav)
|
482 |
+
if start_ds:
|
483 |
+
loss_gen_all = gl(wav, y_rec).mean()
|
484 |
+
else:
|
485 |
+
loss_gen_all = 0
|
486 |
+
loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
|
487 |
+
|
488 |
+
loss_ce = 0
|
489 |
+
loss_dur = 0
|
490 |
+
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
|
491 |
+
_s2s_pred = _s2s_pred[:_text_length, :]
|
492 |
+
_text_input = _text_input[:_text_length].long()
|
493 |
+
_s2s_trg = torch.zeros_like(_s2s_pred)
|
494 |
+
for p in range(_s2s_trg.shape[0]):
|
495 |
+
_s2s_trg[p, : _text_input[p]] = 1
|
496 |
+
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
|
497 |
+
|
498 |
+
loss_dur += F.l1_loss(
|
499 |
+
_dur_pred[1 : _text_length - 1], _text_input[1 : _text_length - 1]
|
500 |
+
)
|
501 |
+
loss_ce += F.binary_cross_entropy_with_logits(
|
502 |
+
_s2s_pred.flatten(), _s2s_trg.flatten()
|
503 |
+
)
|
504 |
+
|
505 |
+
loss_ce /= texts.size(0)
|
506 |
+
loss_dur /= texts.size(0)
|
507 |
+
|
508 |
+
g_loss = (
|
509 |
+
loss_params.lambda_mel * loss_mel
|
510 |
+
+ loss_params.lambda_F0 * loss_F0_rec
|
511 |
+
+ loss_params.lambda_ce * loss_ce
|
512 |
+
+ loss_params.lambda_norm * loss_norm_rec
|
513 |
+
+ loss_params.lambda_dur * loss_dur
|
514 |
+
+ loss_params.lambda_gen * loss_gen_all
|
515 |
+
+ loss_params.lambda_slm * loss_lm
|
516 |
+
+ loss_params.lambda_sty * loss_sty
|
517 |
+
+ loss_params.lambda_diff * loss_diff
|
518 |
+
)
|
519 |
+
|
520 |
+
running_loss += loss_mel.item()
|
521 |
+
g_loss.backward()
|
522 |
+
if torch.isnan(g_loss):
|
523 |
+
from IPython.core.debugger import set_trace
|
524 |
+
|
525 |
+
set_trace()
|
526 |
+
|
527 |
+
optimizer.step("bert_encoder")
|
528 |
+
optimizer.step("bert")
|
529 |
+
optimizer.step("predictor")
|
530 |
+
optimizer.step("predictor_encoder")
|
531 |
+
|
532 |
+
if epoch >= diff_epoch:
|
533 |
+
optimizer.step("diffusion")
|
534 |
+
|
535 |
+
if epoch >= joint_epoch:
|
536 |
+
optimizer.step("style_encoder")
|
537 |
+
optimizer.step("decoder")
|
538 |
+
|
539 |
+
# randomly pick whether to use in-distribution text
|
540 |
+
if np.random.rand() < 0.5:
|
541 |
+
use_ind = True
|
542 |
+
else:
|
543 |
+
use_ind = False
|
544 |
+
|
545 |
+
if use_ind:
|
546 |
+
ref_lengths = input_lengths
|
547 |
+
ref_texts = texts
|
548 |
+
|
549 |
+
slm_out = slmadv(
|
550 |
+
i,
|
551 |
+
y_rec_gt,
|
552 |
+
y_rec_gt_pred,
|
553 |
+
waves,
|
554 |
+
mel_input_length,
|
555 |
+
ref_texts,
|
556 |
+
ref_lengths,
|
557 |
+
use_ind,
|
558 |
+
s_trg.detach(),
|
559 |
+
ref if multispeaker else None,
|
560 |
+
)
|
561 |
+
|
562 |
+
if slm_out is None:
|
563 |
+
continue
|
564 |
+
|
565 |
+
d_loss_slm, loss_gen_lm, y_pred = slm_out
|
566 |
+
|
567 |
+
# SLM generator loss
|
568 |
+
optimizer.zero_grad()
|
569 |
+
loss_gen_lm.backward()
|
570 |
+
|
571 |
+
# SLM discriminator loss
|
572 |
+
if d_loss_slm != 0:
|
573 |
+
optimizer.zero_grad()
|
574 |
+
d_loss_slm.backward(retain_graph=True)
|
575 |
+
optimizer.step("wd")
|
576 |
+
|
577 |
+
# compute the gradient norm
|
578 |
+
total_norm = {}
|
579 |
+
for key in model.keys():
|
580 |
+
total_norm[key] = 0
|
581 |
+
parameters = [
|
582 |
+
p
|
583 |
+
for p in model[key].parameters()
|
584 |
+
if p.grad is not None and p.requires_grad
|
585 |
+
]
|
586 |
+
for p in parameters:
|
587 |
+
param_norm = p.grad.detach().data.norm(2)
|
588 |
+
total_norm[key] += param_norm.item() ** 2
|
589 |
+
total_norm[key] = total_norm[key] ** 0.5
|
590 |
+
|
591 |
+
# gradient scaling
|
592 |
+
if total_norm["predictor"] > slmadv_params.thresh:
|
593 |
+
for key in model.keys():
|
594 |
+
for p in model[key].parameters():
|
595 |
+
if p.grad is not None:
|
596 |
+
p.grad *= 1 / total_norm["predictor"]
|
597 |
+
|
598 |
+
for p in model.predictor.duration_proj.parameters():
|
599 |
+
if p.grad is not None:
|
600 |
+
p.grad *= slmadv_params.scale
|
601 |
+
|
602 |
+
for p in model.predictor.lstm.parameters():
|
603 |
+
if p.grad is not None:
|
604 |
+
p.grad *= slmadv_params.scale
|
605 |
+
|
606 |
+
for p in model.diffusion.parameters():
|
607 |
+
if p.grad is not None:
|
608 |
+
p.grad *= slmadv_params.scale
|
609 |
+
|
610 |
+
optimizer.step("bert_encoder")
|
611 |
+
optimizer.step("bert")
|
612 |
+
optimizer.step("predictor")
|
613 |
+
optimizer.step("diffusion")
|
614 |
+
else:
|
615 |
+
d_loss_slm, loss_gen_lm = 0, 0
|
616 |
+
|
617 |
+
iters = iters + 1
|
618 |
+
|
619 |
+
if (i + 1) % log_interval == 0:
|
620 |
+
logger.info(
|
621 |
+
"Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f"
|
622 |
+
% (
|
623 |
+
epoch + 1,
|
624 |
+
epochs,
|
625 |
+
i + 1,
|
626 |
+
len(train_list) // batch_size,
|
627 |
+
running_loss / log_interval,
|
628 |
+
d_loss,
|
629 |
+
loss_dur,
|
630 |
+
loss_ce,
|
631 |
+
loss_norm_rec,
|
632 |
+
loss_F0_rec,
|
633 |
+
loss_lm,
|
634 |
+
loss_gen_all,
|
635 |
+
loss_sty,
|
636 |
+
loss_diff,
|
637 |
+
d_loss_slm,
|
638 |
+
loss_gen_lm,
|
639 |
+
)
|
640 |
+
)
|
641 |
+
|
642 |
+
writer.add_scalar("train/mel_loss", running_loss / log_interval, iters)
|
643 |
+
writer.add_scalar("train/gen_loss", loss_gen_all, iters)
|
644 |
+
writer.add_scalar("train/d_loss", d_loss, iters)
|
645 |
+
writer.add_scalar("train/ce_loss", loss_ce, iters)
|
646 |
+
writer.add_scalar("train/dur_loss", loss_dur, iters)
|
647 |
+
writer.add_scalar("train/slm_loss", loss_lm, iters)
|
648 |
+
writer.add_scalar("train/norm_loss", loss_norm_rec, iters)
|
649 |
+
writer.add_scalar("train/F0_loss", loss_F0_rec, iters)
|
650 |
+
writer.add_scalar("train/sty_loss", loss_sty, iters)
|
651 |
+
writer.add_scalar("train/diff_loss", loss_diff, iters)
|
652 |
+
writer.add_scalar("train/d_loss_slm", d_loss_slm, iters)
|
653 |
+
writer.add_scalar("train/gen_loss_slm", loss_gen_lm, iters)
|
654 |
+
|
655 |
+
running_loss = 0
|
656 |
+
|
657 |
+
print("Time elasped:", time.time() - start_time)
|
658 |
+
|
659 |
+
loss_test = 0
|
660 |
+
loss_align = 0
|
661 |
+
loss_f = 0
|
662 |
+
_ = [model[key].eval() for key in model]
|
663 |
+
|
664 |
+
with torch.no_grad():
|
665 |
+
iters_test = 0
|
666 |
+
for batch_idx, batch in enumerate(val_dataloader):
|
667 |
+
optimizer.zero_grad()
|
668 |
+
|
669 |
+
try:
|
670 |
+
waves = batch[0]
|
671 |
+
batch = [b.to(device) for b in batch[1:]]
|
672 |
+
(
|
673 |
+
texts,
|
674 |
+
input_lengths,
|
675 |
+
ref_texts,
|
676 |
+
ref_lengths,
|
677 |
+
mels,
|
678 |
+
mel_input_length,
|
679 |
+
ref_mels,
|
680 |
+
) = batch
|
681 |
+
with torch.no_grad():
|
682 |
+
mask = length_to_mask(mel_input_length // (2**n_down)).to(
|
683 |
+
"cuda"
|
684 |
+
)
|
685 |
+
text_mask = length_to_mask(input_lengths).to(texts.device)
|
686 |
+
|
687 |
+
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
|
688 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
689 |
+
s2s_attn = s2s_attn[..., 1:]
|
690 |
+
s2s_attn = s2s_attn.transpose(-1, -2)
|
691 |
+
|
692 |
+
mask_ST = mask_from_lens(
|
693 |
+
s2s_attn, input_lengths, mel_input_length // (2**n_down)
|
694 |
+
)
|
695 |
+
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
|
696 |
+
|
697 |
+
# encode
|
698 |
+
t_en = model.text_encoder(texts, input_lengths, text_mask)
|
699 |
+
asr = t_en @ s2s_attn_mono
|
700 |
+
|
701 |
+
d_gt = s2s_attn_mono.sum(axis=-1).detach()
|
702 |
+
|
703 |
+
ss = []
|
704 |
+
gs = []
|
705 |
+
|
706 |
+
for bib in range(len(mel_input_length)):
|
707 |
+
mel_length = int(mel_input_length[bib].item())
|
708 |
+
mel = mels[bib, :, : mel_input_length[bib]]
|
709 |
+
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
|
710 |
+
ss.append(s)
|
711 |
+
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
|
712 |
+
gs.append(s)
|
713 |
+
|
714 |
+
s = torch.stack(ss).squeeze()
|
715 |
+
gs = torch.stack(gs).squeeze()
|
716 |
+
s_trg = torch.cat([s, gs], dim=-1).detach()
|
717 |
+
|
718 |
+
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
|
719 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
720 |
+
d, p = model.predictor(
|
721 |
+
d_en, s, input_lengths, s2s_attn_mono, text_mask
|
722 |
+
)
|
723 |
+
# get clips
|
724 |
+
mel_len = int(mel_input_length.min().item() / 2 - 1)
|
725 |
+
en = []
|
726 |
+
gt = []
|
727 |
+
p_en = []
|
728 |
+
wav = []
|
729 |
+
|
730 |
+
for bib in range(len(mel_input_length)):
|
731 |
+
mel_length = int(mel_input_length[bib].item() / 2)
|
732 |
+
|
733 |
+
random_start = np.random.randint(0, mel_length - mel_len)
|
734 |
+
en.append(asr[bib, :, random_start : random_start + mel_len])
|
735 |
+
p_en.append(p[bib, :, random_start : random_start + mel_len])
|
736 |
+
|
737 |
+
gt.append(
|
738 |
+
mels[
|
739 |
+
bib,
|
740 |
+
:,
|
741 |
+
(random_start * 2) : ((random_start + mel_len) * 2),
|
742 |
+
]
|
743 |
+
)
|
744 |
+
|
745 |
+
y = waves[bib][
|
746 |
+
(random_start * 2)
|
747 |
+
* 300 : ((random_start + mel_len) * 2)
|
748 |
+
* 300
|
749 |
+
]
|
750 |
+
wav.append(torch.from_numpy(y).to(device))
|
751 |
+
|
752 |
+
wav = torch.stack(wav).float().detach()
|
753 |
+
|
754 |
+
en = torch.stack(en)
|
755 |
+
p_en = torch.stack(p_en)
|
756 |
+
gt = torch.stack(gt).detach()
|
757 |
+
|
758 |
+
s = model.predictor_encoder(gt.unsqueeze(1))
|
759 |
+
|
760 |
+
F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
|
761 |
+
|
762 |
+
loss_dur = 0
|
763 |
+
for _s2s_pred, _text_input, _text_length in zip(
|
764 |
+
d, (d_gt), input_lengths
|
765 |
+
):
|
766 |
+
_s2s_pred = _s2s_pred[:_text_length, :]
|
767 |
+
_text_input = _text_input[:_text_length].long()
|
768 |
+
_s2s_trg = torch.zeros_like(_s2s_pred)
|
769 |
+
for bib in range(_s2s_trg.shape[0]):
|
770 |
+
_s2s_trg[bib, : _text_input[bib]] = 1
|
771 |
+
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
|
772 |
+
loss_dur += F.l1_loss(
|
773 |
+
_dur_pred[1 : _text_length - 1],
|
774 |
+
_text_input[1 : _text_length - 1],
|
775 |
+
)
|
776 |
+
|
777 |
+
loss_dur /= texts.size(0)
|
778 |
+
|
779 |
+
s = model.style_encoder(gt.unsqueeze(1))
|
780 |
+
|
781 |
+
y_rec = model.decoder(en, F0_fake, N_fake, s)
|
782 |
+
loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
|
783 |
+
|
784 |
+
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
785 |
+
|
786 |
+
loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
|
787 |
+
|
788 |
+
loss_test += (loss_mel).mean()
|
789 |
+
loss_align += (loss_dur).mean()
|
790 |
+
loss_f += (loss_F0).mean()
|
791 |
+
|
792 |
+
iters_test += 1
|
793 |
+
except:
|
794 |
+
continue
|
795 |
+
|
796 |
+
print("Epochs:", epoch + 1)
|
797 |
+
logger.info(
|
798 |
+
"Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f"
|
799 |
+
% (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test)
|
800 |
+
+ "\n\n\n"
|
801 |
+
)
|
802 |
+
print("\n\n\n")
|
803 |
+
writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1)
|
804 |
+
writer.add_scalar("eval/dur_loss", loss_test / iters_test, epoch + 1)
|
805 |
+
writer.add_scalar("eval/F0_loss", loss_f / iters_test, epoch + 1)
|
806 |
+
|
807 |
+
if epoch < joint_epoch:
|
808 |
+
# generating reconstruction examples with GT duration
|
809 |
+
|
810 |
+
with torch.no_grad():
|
811 |
+
for bib in range(len(asr)):
|
812 |
+
mel_length = int(mel_input_length[bib].item())
|
813 |
+
gt = mels[bib, :, :mel_length].unsqueeze(0)
|
814 |
+
en = asr[bib, :, : mel_length // 2].unsqueeze(0)
|
815 |
+
|
816 |
+
F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
|
817 |
+
F0_real = F0_real.unsqueeze(0)
|
818 |
+
s = model.style_encoder(gt.unsqueeze(1))
|
819 |
+
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
|
820 |
+
|
821 |
+
y_rec = model.decoder(en, F0_real, real_norm, s)
|
822 |
+
|
823 |
+
writer.add_audio(
|
824 |
+
"eval/y" + str(bib),
|
825 |
+
y_rec.cpu().numpy().squeeze(),
|
826 |
+
epoch,
|
827 |
+
sample_rate=sr,
|
828 |
+
)
|
829 |
+
|
830 |
+
s_dur = model.predictor_encoder(gt.unsqueeze(1))
|
831 |
+
p_en = p[bib, :, : mel_length // 2].unsqueeze(0)
|
832 |
+
|
833 |
+
F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
|
834 |
+
|
835 |
+
y_pred = model.decoder(en, F0_fake, N_fake, s)
|
836 |
+
|
837 |
+
writer.add_audio(
|
838 |
+
"pred/y" + str(bib),
|
839 |
+
y_pred.cpu().numpy().squeeze(),
|
840 |
+
epoch,
|
841 |
+
sample_rate=sr,
|
842 |
+
)
|
843 |
+
|
844 |
+
if epoch == 0:
|
845 |
+
writer.add_audio(
|
846 |
+
"gt/y" + str(bib),
|
847 |
+
waves[bib].squeeze(),
|
848 |
+
epoch,
|
849 |
+
sample_rate=sr,
|
850 |
+
)
|
851 |
+
|
852 |
+
if bib >= 5:
|
853 |
+
break
|
854 |
+
else:
|
855 |
+
# generating sampled speech from text directly
|
856 |
+
with torch.no_grad():
|
857 |
+
# compute reference styles
|
858 |
+
if multispeaker and epoch >= diff_epoch:
|
859 |
+
ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
|
860 |
+
ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
|
861 |
+
ref_s = torch.cat([ref_ss, ref_sp], dim=1)
|
862 |
+
|
863 |
+
for bib in range(len(d_en)):
|
864 |
+
if multispeaker:
|
865 |
+
s_pred = sampler(
|
866 |
+
noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
|
867 |
+
embedding=bert_dur[bib].unsqueeze(0),
|
868 |
+
embedding_scale=1,
|
869 |
+
features=ref_s[bib].unsqueeze(
|
870 |
+
0
|
871 |
+
), # reference from the same speaker as the embedding
|
872 |
+
num_steps=5,
|
873 |
+
).squeeze(1)
|
874 |
+
else:
|
875 |
+
s_pred = sampler(
|
876 |
+
noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
|
877 |
+
embedding=bert_dur[bib].unsqueeze(0),
|
878 |
+
embedding_scale=1,
|
879 |
+
num_steps=5,
|
880 |
+
).squeeze(1)
|
881 |
+
|
882 |
+
s = s_pred[:, 128:]
|
883 |
+
ref = s_pred[:, :128]
|
884 |
+
|
885 |
+
d = model.predictor.text_encoder(
|
886 |
+
d_en[bib, :, : input_lengths[bib]].unsqueeze(0),
|
887 |
+
s,
|
888 |
+
input_lengths[bib, ...].unsqueeze(0),
|
889 |
+
text_mask[bib, : input_lengths[bib]].unsqueeze(0),
|
890 |
+
)
|
891 |
+
|
892 |
+
x, _ = model.predictor.lstm(d)
|
893 |
+
duration = model.predictor.duration_proj(x)
|
894 |
+
|
895 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
896 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
897 |
+
|
898 |
+
pred_dur[-1] += 5
|
899 |
+
|
900 |
+
pred_aln_trg = torch.zeros(
|
901 |
+
input_lengths[bib], int(pred_dur.sum().data)
|
902 |
+
)
|
903 |
+
c_frame = 0
|
904 |
+
for i in range(pred_aln_trg.size(0)):
|
905 |
+
pred_aln_trg[i, c_frame : c_frame + int(pred_dur[i].data)] = 1
|
906 |
+
c_frame += int(pred_dur[i].data)
|
907 |
+
|
908 |
+
# encode prosody
|
909 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(
|
910 |
+
texts.device
|
911 |
+
)
|
912 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
913 |
+
out = model.decoder(
|
914 |
+
(
|
915 |
+
t_en[bib, :, : input_lengths[bib]].unsqueeze(0)
|
916 |
+
@ pred_aln_trg.unsqueeze(0).to(texts.device)
|
917 |
+
),
|
918 |
+
F0_pred,
|
919 |
+
N_pred,
|
920 |
+
ref.squeeze().unsqueeze(0),
|
921 |
+
)
|
922 |
+
|
923 |
+
writer.add_audio(
|
924 |
+
"pred/y" + str(bib),
|
925 |
+
out.cpu().numpy().squeeze(),
|
926 |
+
epoch,
|
927 |
+
sample_rate=sr,
|
928 |
+
)
|
929 |
+
|
930 |
+
if bib >= 5:
|
931 |
+
break
|
932 |
+
|
933 |
+
if epoch % saving_epoch == 0:
|
934 |
+
if (loss_test / iters_test) < best_loss:
|
935 |
+
best_loss = loss_test / iters_test
|
936 |
+
print("Saving..")
|
937 |
+
state = {
|
938 |
+
"net": {key: model[key].state_dict() for key in model},
|
939 |
+
"optimizer": optimizer.state_dict(),
|
940 |
+
"iters": iters,
|
941 |
+
"val_loss": loss_test / iters_test,
|
942 |
+
"epoch": epoch,
|
943 |
+
}
|
944 |
+
save_path = osp.join(log_dir, "epoch_2nd_%05d.pth" % epoch)
|
945 |
+
torch.save(state, save_path)
|
946 |
+
|
947 |
+
# if estimate sigma, save the estimated simga
|
948 |
+
if model_params.diffusion.dist.estimate_sigma_data:
|
949 |
+
config["model_params"]["diffusion"]["dist"]["sigma_data"] = float(
|
950 |
+
np.mean(running_std)
|
951 |
+
)
|
952 |
+
|
953 |
+
with open(osp.join(log_dir, osp.basename(config_path)), "w") as outfile:
|
954 |
+
yaml.dump(config, outfile, default_flow_style=True)
|
955 |
+
|
956 |
+
|
957 |
+
if __name__ == "__main__":
|
958 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from monotonic_align import maximum_path
|
2 |
+
from monotonic_align import mask_from_lens
|
3 |
+
from monotonic_align.core import maximum_path_c
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import copy
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchaudio
|
10 |
+
import librosa
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from munch import Munch
|
13 |
+
|
14 |
+
|
15 |
+
def maximum_path(neg_cent, mask):
|
16 |
+
"""Cython optimized version.
|
17 |
+
neg_cent: [b, t_t, t_s]
|
18 |
+
mask: [b, t_t, t_s]
|
19 |
+
"""
|
20 |
+
device = neg_cent.device
|
21 |
+
dtype = neg_cent.dtype
|
22 |
+
neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
|
23 |
+
path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
|
24 |
+
|
25 |
+
t_t_max = np.ascontiguousarray(
|
26 |
+
mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
|
27 |
+
)
|
28 |
+
t_s_max = np.ascontiguousarray(
|
29 |
+
mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
|
30 |
+
)
|
31 |
+
maximum_path_c(path, neg_cent, t_t_max, t_s_max)
|
32 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
33 |
+
|
34 |
+
|
35 |
+
def get_data_path_list(train_path=None, val_path=None):
|
36 |
+
if train_path is None:
|
37 |
+
train_path = "Data/train_list.txt"
|
38 |
+
if val_path is None:
|
39 |
+
val_path = "Data/val_list.txt"
|
40 |
+
|
41 |
+
with open(train_path, "r", encoding="utf-8", errors="ignore") as f:
|
42 |
+
train_list = f.readlines()
|
43 |
+
with open(val_path, "r", encoding="utf-8", errors="ignore") as f:
|
44 |
+
val_list = f.readlines()
|
45 |
+
|
46 |
+
return train_list, val_list
|
47 |
+
|
48 |
+
|
49 |
+
def length_to_mask(lengths):
|
50 |
+
mask = (
|
51 |
+
torch.arange(lengths.max())
|
52 |
+
.unsqueeze(0)
|
53 |
+
.expand(lengths.shape[0], -1)
|
54 |
+
.type_as(lengths)
|
55 |
+
)
|
56 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
57 |
+
return mask
|
58 |
+
|
59 |
+
|
60 |
+
# for norm consistency loss
|
61 |
+
def log_norm(x, mean=-4, std=4, dim=2):
|
62 |
+
"""
|
63 |
+
normalized log mel -> mel -> norm -> log(norm)
|
64 |
+
"""
|
65 |
+
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
def get_image(arrs):
|
70 |
+
plt.switch_backend("agg")
|
71 |
+
fig = plt.figure()
|
72 |
+
ax = plt.gca()
|
73 |
+
ax.imshow(arrs)
|
74 |
+
|
75 |
+
return fig
|
76 |
+
|
77 |
+
|
78 |
+
def recursive_munch(d):
|
79 |
+
if isinstance(d, dict):
|
80 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
81 |
+
elif isinstance(d, list):
|
82 |
+
return [recursive_munch(v) for v in d]
|
83 |
+
else:
|
84 |
+
return d
|
85 |
+
|
86 |
+
|
87 |
+
def log_print(message, logger):
|
88 |
+
logger.info(message)
|
89 |
+
print(message)
|