Spaces:
Running
Running
Mahiruoshi
commited on
Commit
•
5220ea7
1
Parent(s):
f4cadb2
Upload 86 files
Browse files- bert_gen.py +24 -17
- clap_gen.py +1 -1
- config.yml +17 -17
- data_utils.py +7 -23
- export_onnx.py +6 -4
- infer.py +53 -100
- losses.py +95 -0
- models.py +66 -65
- onnx_infer.py +68 -0
- re_matching.py +0 -1
- requirements.txt +2 -3
- resample.py +10 -6
- resample_legacy.py +71 -0
- server.py +214 -47
- server_fastapi.py +39 -1
- slm/wavlm-base-plus/.gitattributes +27 -0
- slm/wavlm-base-plus/README.md +65 -0
- slm/wavlm-base-plus/config.json +99 -0
- slm/wavlm-base-plus/preprocessor_config.json +9 -0
- slm/wavlm-base-plus/pytorch_model.bin +3 -0
- text/__init__.py +4 -2
- text/__pycache__/__init__.cpython-311.pyc +0 -0
- text/__pycache__/bert_utils.cpython-311.pyc +0 -0
- text/__pycache__/chinese.cpython-311.pyc +0 -0
- text/__pycache__/chinese_bert.cpython-311.pyc +0 -0
- text/__pycache__/cleaner.cpython-311.pyc +0 -0
- text/__pycache__/english.cpython-311.pyc +0 -0
- text/__pycache__/english_bert_mock.cpython-311.pyc +0 -0
- text/__pycache__/japanese.cpython-311.pyc +0 -0
- text/__pycache__/japanese_bert.cpython-311.pyc +0 -0
- text/__pycache__/symbols.cpython-311.pyc +0 -0
- text/__pycache__/tone_sandhi.cpython-311.pyc +0 -0
- text/chinese_bert.py +21 -3
- text/cleaner.py +2 -2
- text/english.py +71 -29
- text/english_bert_mock.py +21 -2
- text/japanese_bert.py +23 -2
- text/tone_sandhi.py +7 -3
- tools/__pycache__/__init__.cpython-311.pyc +0 -0
- tools/__pycache__/classify_language.cpython-311.pyc +0 -0
- tools/__pycache__/log.cpython-311.pyc +0 -0
- tools/__pycache__/sentence.cpython-311.pyc +0 -0
- tools/__pycache__/translate.cpython-311.pyc +0 -0
- train_ms.py +172 -58
- utils.py +5 -1
- webui.py +194 -174
- webui_preprocess.py +10 -21
bert_gen.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1 |
-
import argparse
|
2 |
-
from multiprocessing import Pool, cpu_count
|
3 |
-
|
4 |
import torch
|
5 |
-
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
import commons
|
9 |
import utils
|
|
|
|
|
|
|
|
|
10 |
from config import config
|
11 |
-
from text import cleaned_text_to_sequence, get_bert
|
12 |
|
13 |
|
14 |
-
def process_line(
|
|
|
15 |
device = config.bert_gen_config.device
|
16 |
if config.bert_gen_config.use_multi_device:
|
17 |
rank = mp.current_process()._identity
|
@@ -28,12 +27,13 @@ def process_line(line):
|
|
28 |
word2ph = [i for i in word2ph]
|
29 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
37 |
|
38 |
bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
|
39 |
|
@@ -59,16 +59,23 @@ if __name__ == "__main__":
|
|
59 |
args, _ = parser.parse_known_args()
|
60 |
config_path = args.config
|
61 |
hps = utils.get_hparams_from_file(config_path)
|
|
|
62 |
lines = []
|
63 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
64 |
lines.extend(f.readlines())
|
65 |
|
66 |
with open(hps.data.validation_files, encoding="utf-8") as f:
|
67 |
lines.extend(f.readlines())
|
|
|
|
|
68 |
if len(lines) != 0:
|
69 |
-
num_processes =
|
70 |
with Pool(processes=num_processes) as pool:
|
71 |
-
for _ in tqdm(
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
|
74 |
print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from multiprocessing import Pool
|
|
|
|
|
3 |
import commons
|
4 |
import utils
|
5 |
+
from tqdm import tqdm
|
6 |
+
from text import check_bert_models, cleaned_text_to_sequence, get_bert
|
7 |
+
import argparse
|
8 |
+
import torch.multiprocessing as mp
|
9 |
from config import config
|
|
|
10 |
|
11 |
|
12 |
+
def process_line(x):
|
13 |
+
line, add_blank = x
|
14 |
device = config.bert_gen_config.device
|
15 |
if config.bert_gen_config.use_multi_device:
|
16 |
rank = mp.current_process()._identity
|
|
|
27 |
word2ph = [i for i in word2ph]
|
28 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
29 |
|
30 |
+
if add_blank:
|
31 |
+
phone = commons.intersperse(phone, 0)
|
32 |
+
tone = commons.intersperse(tone, 0)
|
33 |
+
language = commons.intersperse(language, 0)
|
34 |
+
for i in range(len(word2ph)):
|
35 |
+
word2ph[i] = word2ph[i] * 2
|
36 |
+
word2ph[0] += 1
|
37 |
|
38 |
bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
|
39 |
|
|
|
59 |
args, _ = parser.parse_known_args()
|
60 |
config_path = args.config
|
61 |
hps = utils.get_hparams_from_file(config_path)
|
62 |
+
check_bert_models()
|
63 |
lines = []
|
64 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
65 |
lines.extend(f.readlines())
|
66 |
|
67 |
with open(hps.data.validation_files, encoding="utf-8") as f:
|
68 |
lines.extend(f.readlines())
|
69 |
+
add_blank = [hps.data.add_blank] * len(lines)
|
70 |
+
|
71 |
if len(lines) != 0:
|
72 |
+
num_processes = args.num_processes
|
73 |
with Pool(processes=num_processes) as pool:
|
74 |
+
for _ in tqdm(
|
75 |
+
pool.imap_unordered(process_line, zip(lines, add_blank)),
|
76 |
+
total=len(lines),
|
77 |
+
):
|
78 |
+
# 这里是缩进的代码块,表示循环体
|
79 |
+
pass # 使用pass语句作为占位符
|
80 |
|
81 |
print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
|
clap_gen.py
CHANGED
@@ -27,7 +27,7 @@ def process_line(line):
|
|
27 |
device = torch.device("cpu")
|
28 |
wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
|
29 |
|
30 |
-
clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.
|
31 |
if os.path.isfile(clap_path):
|
32 |
return
|
33 |
|
|
|
27 |
device = torch.device("cpu")
|
28 |
wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
|
29 |
|
30 |
+
clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.pt")
|
31 |
if os.path.isfile(clap_path):
|
32 |
return
|
33 |
|
config.yml
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
|
5 |
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
|
6 |
# 不填或者填空则路径为相对于项目根目录的路径
|
7 |
-
dataset_path: "Data/"
|
8 |
|
9 |
# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
|
10 |
mirror: ""
|
@@ -17,16 +17,16 @@ resample:
|
|
17 |
sampling_rate: 44100
|
18 |
# 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
|
19 |
# 请填入相对于datasetPath的相对路径
|
20 |
-
in_dir: "
|
21 |
# 音频文件重采样后输出路径
|
22 |
-
out_dir: "
|
23 |
|
24 |
|
25 |
# preprocess_text 数据集预处理相关配置
|
26 |
# 注意, “:” 后需要加空格
|
27 |
preprocess_text:
|
28 |
# 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
|
29 |
-
transcription_path: "filelists
|
30 |
# 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
|
31 |
cleaned_path: ""
|
32 |
# 训练集路径
|
@@ -34,11 +34,11 @@ preprocess_text:
|
|
34 |
# 验证集路径
|
35 |
val_path: "filelists/val.list"
|
36 |
# 配置文件路径
|
37 |
-
config_path: "config.json"
|
38 |
# 每个语言的验证集条数
|
39 |
val_per_lang: 4
|
40 |
# 验证集最大条数,多于的会被截断并放到训练集中
|
41 |
-
max_val_total:
|
42 |
# 是否进行数据清洗
|
43 |
clean: true
|
44 |
|
@@ -47,7 +47,7 @@ preprocess_text:
|
|
47 |
# 注意, “:” 后需要加空格
|
48 |
bert_gen:
|
49 |
# 训练数据集配置文件路径
|
50 |
-
config_path: "config.json"
|
51 |
# 并行数
|
52 |
num_processes: 4
|
53 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
@@ -60,9 +60,9 @@ bert_gen:
|
|
60 |
# 注意, “:” 后需要加空格
|
61 |
emo_gen:
|
62 |
# 训练数据集配置文件路径
|
63 |
-
config_path: "config.json"
|
64 |
# 并行数
|
65 |
-
num_processes:
|
66 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
67 |
device: "cuda"
|
68 |
# 使用多卡推理
|
@@ -81,15 +81,15 @@ train_ms:
|
|
81 |
# THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
|
82 |
# 底模设置
|
83 |
base:
|
84 |
-
use_base_model:
|
85 |
repo_id: "Stardust_minus/Bert-VITS2"
|
86 |
-
model_image: "Bert-VITS2_2.
|
87 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
88 |
model: "models"
|
89 |
# 配置文件路径
|
90 |
-
config_path: "config.json"
|
91 |
# 训练使用的worker,不建议超过CPU核心数
|
92 |
-
num_workers:
|
93 |
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
94 |
spec_cache: True
|
95 |
# 保存的检查点数量,多于此数目的权重会被删除来节省空间。
|
@@ -102,9 +102,9 @@ webui:
|
|
102 |
# 推理设备
|
103 |
device: "cuda"
|
104 |
# 模型路径
|
105 |
-
model: "models/
|
106 |
# 配置文件路径
|
107 |
-
config_path: "config.json"
|
108 |
# 端口号
|
109 |
port: 7860
|
110 |
# 是否公开部署,对外网开放
|
@@ -172,6 +172,6 @@ server:
|
|
172 |
# 请不要在github等网站公开分享你的app id 与 key
|
173 |
translate:
|
174 |
# 你的APPID
|
175 |
-
"app_key": ""
|
176 |
# 你的密钥
|
177 |
-
"secret_key": ""
|
|
|
4 |
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
|
5 |
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
|
6 |
# 不填或者填空则路径为相对于项目根目录的路径
|
7 |
+
dataset_path: "Data/V23"
|
8 |
|
9 |
# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
|
10 |
mirror: ""
|
|
|
17 |
sampling_rate: 44100
|
18 |
# 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
|
19 |
# 请填入相对于datasetPath的相对路径
|
20 |
+
in_dir: "" # 相对于根目录的路径为 /datasetPath/in_dir
|
21 |
# 音频文件重采样后输出路径
|
22 |
+
out_dir: ""
|
23 |
|
24 |
|
25 |
# preprocess_text 数据集预处理相关配置
|
26 |
# 注意, “:” 后需要加空格
|
27 |
preprocess_text:
|
28 |
# 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
|
29 |
+
transcription_path: "filelists/whole.list"
|
30 |
# 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
|
31 |
cleaned_path: ""
|
32 |
# 训练集路径
|
|
|
34 |
# 验证集路径
|
35 |
val_path: "filelists/val.list"
|
36 |
# 配置文件路径
|
37 |
+
config_path: "configs/config.json"
|
38 |
# 每个语言的验证集条数
|
39 |
val_per_lang: 4
|
40 |
# 验证集最大条数,多于的会被截断并放到训练集中
|
41 |
+
max_val_total: 800
|
42 |
# 是否进行数据清洗
|
43 |
clean: true
|
44 |
|
|
|
47 |
# 注意, “:” 后需要加空格
|
48 |
bert_gen:
|
49 |
# 训练数据集配置文件路径
|
50 |
+
config_path: "configs/config.json"
|
51 |
# 并行数
|
52 |
num_processes: 4
|
53 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
|
|
60 |
# 注意, “:” 后需要加空格
|
61 |
emo_gen:
|
62 |
# 训练数据集配置文件路径
|
63 |
+
config_path: "configs/config.json"
|
64 |
# 并行数
|
65 |
+
num_processes: 16
|
66 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
67 |
device: "cuda"
|
68 |
# 使用多卡推理
|
|
|
81 |
# THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
|
82 |
# 底模设置
|
83 |
base:
|
84 |
+
use_base_model: True
|
85 |
repo_id: "Stardust_minus/Bert-VITS2"
|
86 |
+
model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
|
87 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
88 |
model: "models"
|
89 |
# 配置文件路径
|
90 |
+
config_path: "configs/config.json"
|
91 |
# 训练使用的worker,不建议超过CPU核心数
|
92 |
+
num_workers: 22
|
93 |
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
94 |
spec_cache: True
|
95 |
# 保存的检查点数量,多于此数目的权重会被删除来节省空间。
|
|
|
102 |
# 推理设备
|
103 |
device: "cuda"
|
104 |
# 模型路径
|
105 |
+
model: "models/G_408000.pth"
|
106 |
# 配置文件路径
|
107 |
+
config_path: "configs/config.json"
|
108 |
# 端口号
|
109 |
port: 7860
|
110 |
# 是否公开部署,对外网开放
|
|
|
172 |
# 请不要在github等网站公开分享你的app id 与 key
|
173 |
translate:
|
174 |
# 你的APPID
|
175 |
+
"app_key": "20231117001883321"
|
176 |
# 你的密钥
|
177 |
+
"secret_key": "lMQbvZHeJveDceLof2wf"
|
data_utils.py
CHANGED
@@ -44,10 +44,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
44 |
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
45 |
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
46 |
|
47 |
-
self.empty_emo = torch.squeeze(
|
48 |
-
torch.load("empty_emo.npy", map_location="cpu"), dim=1
|
49 |
-
)
|
50 |
-
|
51 |
random.seed(1234)
|
52 |
random.shuffle(self.audiopaths_sid_text)
|
53 |
self._filter()
|
@@ -98,14 +94,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
98 |
spec, wav = self.get_audio(audiopath)
|
99 |
sid = torch.LongTensor([int(self.spk_map[sid])])
|
100 |
|
101 |
-
|
102 |
-
emo = torch.squeeze(
|
103 |
-
torch.load(audiopath.replace(".wav", ".emo.npy"), map_location="cpu"),
|
104 |
-
dim=1,
|
105 |
-
)
|
106 |
-
else:
|
107 |
-
emo = self.empty_emo
|
108 |
-
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
|
109 |
|
110 |
def get_audio(self, filename):
|
111 |
audio, sampling_rate = load_wav_to_torch(filename)
|
@@ -168,15 +157,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
168 |
|
169 |
if language_str == "ZH":
|
170 |
bert = bert_ori
|
171 |
-
ja_bert = torch.
|
172 |
-
en_bert = torch.
|
173 |
elif language_str == "JP":
|
174 |
-
bert = torch.
|
175 |
ja_bert = bert_ori
|
176 |
-
en_bert = torch.
|
177 |
elif language_str == "EN":
|
178 |
-
bert = torch.
|
179 |
-
ja_bert = torch.
|
180 |
en_bert = bert_ori
|
181 |
phone = torch.LongTensor(phone)
|
182 |
tone = torch.LongTensor(tone)
|
@@ -226,7 +215,6 @@ class TextAudioSpeakerCollate:
|
|
226 |
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
227 |
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
228 |
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
229 |
-
emo = torch.FloatTensor(len(batch), 512)
|
230 |
|
231 |
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
232 |
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
@@ -238,7 +226,6 @@ class TextAudioSpeakerCollate:
|
|
238 |
bert_padded.zero_()
|
239 |
ja_bert_padded.zero_()
|
240 |
en_bert_padded.zero_()
|
241 |
-
emo.zero_()
|
242 |
|
243 |
for i in range(len(ids_sorted_decreasing)):
|
244 |
row = batch[ids_sorted_decreasing[i]]
|
@@ -272,8 +259,6 @@ class TextAudioSpeakerCollate:
|
|
272 |
en_bert = row[8]
|
273 |
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
274 |
|
275 |
-
emo[i, :] = row[9]
|
276 |
-
|
277 |
return (
|
278 |
text_padded,
|
279 |
text_lengths,
|
@@ -287,7 +272,6 @@ class TextAudioSpeakerCollate:
|
|
287 |
bert_padded,
|
288 |
ja_bert_padded,
|
289 |
en_bert_padded,
|
290 |
-
emo,
|
291 |
)
|
292 |
|
293 |
|
|
|
44 |
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
45 |
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
46 |
|
|
|
|
|
|
|
|
|
47 |
random.seed(1234)
|
48 |
random.shuffle(self.audiopaths_sid_text)
|
49 |
self._filter()
|
|
|
94 |
spec, wav = self.get_audio(audiopath)
|
95 |
sid = torch.LongTensor([int(self.spk_map[sid])])
|
96 |
|
97 |
+
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def get_audio(self, filename):
|
100 |
audio, sampling_rate = load_wav_to_torch(filename)
|
|
|
157 |
|
158 |
if language_str == "ZH":
|
159 |
bert = bert_ori
|
160 |
+
ja_bert = torch.randn(1024, len(phone))
|
161 |
+
en_bert = torch.randn(1024, len(phone))
|
162 |
elif language_str == "JP":
|
163 |
+
bert = torch.randn(1024, len(phone))
|
164 |
ja_bert = bert_ori
|
165 |
+
en_bert = torch.randn(1024, len(phone))
|
166 |
elif language_str == "EN":
|
167 |
+
bert = torch.randn(1024, len(phone))
|
168 |
+
ja_bert = torch.randn(1024, len(phone))
|
169 |
en_bert = bert_ori
|
170 |
phone = torch.LongTensor(phone)
|
171 |
tone = torch.LongTensor(tone)
|
|
|
215 |
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
216 |
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
217 |
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
|
|
218 |
|
219 |
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
220 |
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
|
|
226 |
bert_padded.zero_()
|
227 |
ja_bert_padded.zero_()
|
228 |
en_bert_padded.zero_()
|
|
|
229 |
|
230 |
for i in range(len(ids_sorted_decreasing)):
|
231 |
row = batch[ids_sorted_decreasing[i]]
|
|
|
259 |
en_bert = row[8]
|
260 |
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
261 |
|
|
|
|
|
262 |
return (
|
263 |
text_padded,
|
264 |
text_lengths,
|
|
|
272 |
bert_padded,
|
273 |
ja_bert_padded,
|
274 |
en_bert_padded,
|
|
|
275 |
)
|
276 |
|
277 |
|
export_onnx.py
CHANGED
@@ -2,11 +2,13 @@ from onnx_modules import export_onnx
|
|
2 |
import os
|
3 |
|
4 |
if __name__ == "__main__":
|
5 |
-
export_path = "
|
6 |
-
model_path = "
|
7 |
-
config_path = "
|
|
|
|
|
8 |
if not os.path.exists("onnx"):
|
9 |
os.makedirs("onnx")
|
10 |
if not os.path.exists(f"onnx/{export_path}"):
|
11 |
os.makedirs(f"onnx/{export_path}")
|
12 |
-
export_onnx(export_path, model_path, config_path)
|
|
|
2 |
import os
|
3 |
|
4 |
if __name__ == "__main__":
|
5 |
+
export_path = "BangDreamApi"
|
6 |
+
model_path = "Data/V23/models/G_621000.pth"
|
7 |
+
config_path = "Data/V23/configs/config.json"
|
8 |
+
novq = False
|
9 |
+
dev = False
|
10 |
if not os.path.exists("onnx"):
|
11 |
os.makedirs("onnx")
|
12 |
if not os.path.exists(f"onnx/{export_path}"):
|
13 |
os.makedirs(f"onnx/{export_path}")
|
14 |
+
export_onnx(export_path, model_path, config_path, novq, dev)
|
infer.py
CHANGED
@@ -10,7 +10,8 @@
|
|
10 |
import torch
|
11 |
import commons
|
12 |
from text import cleaned_text_to_sequence, get_bert
|
13 |
-
|
|
|
14 |
from text.cleaner import clean_text
|
15 |
import utils
|
16 |
import numpy as np
|
@@ -32,7 +33,7 @@ from oldVersion.V101.text import symbols as V101symbols
|
|
32 |
from oldVersion import V111, V110, V101, V200, V210
|
33 |
|
34 |
# 当前版本信息
|
35 |
-
latest_version = "2.
|
36 |
|
37 |
# 版本兼容
|
38 |
SynthesizerTrnMap = {
|
@@ -98,7 +99,8 @@ def get_net_g(model_path: str, version: str, device: str, hps):
|
|
98 |
return net_g
|
99 |
|
100 |
|
101 |
-
def get_text(text, language_str, hps, device):
|
|
|
102 |
# 在此处实现当前版本的get_text
|
103 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
104 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
@@ -110,21 +112,23 @@ def get_text(text, language_str, hps, device):
|
|
110 |
for i in range(len(word2ph)):
|
111 |
word2ph[i] = word2ph[i] * 2
|
112 |
word2ph[0] += 1
|
113 |
-
bert_ori = get_bert(
|
|
|
|
|
114 |
del word2ph
|
115 |
assert bert_ori.shape[-1] == len(phone), phone
|
116 |
|
117 |
if language_str == "ZH":
|
118 |
bert = bert_ori
|
119 |
-
ja_bert = torch.
|
120 |
-
en_bert = torch.
|
121 |
elif language_str == "JP":
|
122 |
-
bert = torch.
|
123 |
ja_bert = bert_ori
|
124 |
-
en_bert = torch.
|
125 |
elif language_str == "EN":
|
126 |
-
bert = torch.
|
127 |
-
ja_bert = torch.
|
128 |
en_bert = bert_ori
|
129 |
else:
|
130 |
raise ValueError("language_str should be ZH, JP or EN")
|
@@ -154,84 +158,17 @@ def infer(
|
|
154 |
reference_audio=None,
|
155 |
skip_start=False,
|
156 |
skip_end=False,
|
|
|
|
|
157 |
):
|
158 |
-
# 2.2版本参数位置变了
|
159 |
-
# 2.1 参数新增 emotion reference_audio skip_start skip_end
|
160 |
-
inferMap_V3 = {
|
161 |
-
"2.1": V210.infer,
|
162 |
-
}
|
163 |
-
# 支持中日英三语版本
|
164 |
-
inferMap_V2 = {
|
165 |
-
"2.0.2-fix": V200.infer,
|
166 |
-
"2.0.1": V200.infer,
|
167 |
-
"2.0": V200.infer,
|
168 |
-
"1.1.1-fix": V111.infer_fix,
|
169 |
-
"1.1.1": V111.infer,
|
170 |
-
"1.1": V110.infer,
|
171 |
-
"1.1.0": V110.infer,
|
172 |
-
}
|
173 |
-
# 仅支持中文版本
|
174 |
-
# 在测试中,并未发现两个版本的模型不能互相通用
|
175 |
-
inferMap_V1 = {
|
176 |
-
"1.0.1": V101.infer,
|
177 |
-
"1.0": V101.infer,
|
178 |
-
"1.0.0": V101.infer,
|
179 |
-
}
|
180 |
-
version = hps.version if hasattr(hps, "version") else latest_version
|
181 |
-
# 非当前版本,根据版本号选择合适的infer
|
182 |
-
if version != latest_version:
|
183 |
-
if version in inferMap_V3.keys():
|
184 |
-
return inferMap_V3[version](
|
185 |
-
text,
|
186 |
-
sdp_ratio,
|
187 |
-
noise_scale,
|
188 |
-
noise_scale_w,
|
189 |
-
length_scale,
|
190 |
-
sid,
|
191 |
-
language,
|
192 |
-
hps,
|
193 |
-
net_g,
|
194 |
-
device,
|
195 |
-
reference_audio,
|
196 |
-
emotion,
|
197 |
-
skip_start,
|
198 |
-
skip_end,
|
199 |
-
)
|
200 |
-
if version in inferMap_V2.keys():
|
201 |
-
return inferMap_V2[version](
|
202 |
-
text,
|
203 |
-
sdp_ratio,
|
204 |
-
noise_scale,
|
205 |
-
noise_scale_w,
|
206 |
-
length_scale,
|
207 |
-
sid,
|
208 |
-
language,
|
209 |
-
hps,
|
210 |
-
net_g,
|
211 |
-
device,
|
212 |
-
)
|
213 |
-
if version in inferMap_V1.keys():
|
214 |
-
return inferMap_V1[version](
|
215 |
-
text,
|
216 |
-
sdp_ratio,
|
217 |
-
noise_scale,
|
218 |
-
noise_scale_w,
|
219 |
-
length_scale,
|
220 |
-
sid,
|
221 |
-
hps,
|
222 |
-
net_g,
|
223 |
-
device,
|
224 |
-
)
|
225 |
-
# 在此处实现当前版本的推理
|
226 |
-
# emo = get_emo_(reference_audio, emotion, sid)
|
227 |
-
if isinstance(reference_audio, np.ndarray):
|
228 |
-
emo = get_clap_audio_feature(reference_audio, device)
|
229 |
-
else:
|
230 |
-
emo = get_clap_text_feature(emotion, device)
|
231 |
-
emo = torch.squeeze(emo, dim=1)
|
232 |
|
233 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
234 |
-
text,
|
|
|
|
|
|
|
|
|
|
|
235 |
)
|
236 |
if skip_start:
|
237 |
phones = phones[3:]
|
@@ -255,7 +192,7 @@ def infer(
|
|
255 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
256 |
en_bert = en_bert.to(device).unsqueeze(0)
|
257 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
258 |
-
emo = emo.to(device).unsqueeze(0)
|
259 |
del phones
|
260 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
261 |
audio = (
|
@@ -268,7 +205,6 @@ def infer(
|
|
268 |
bert,
|
269 |
ja_bert,
|
270 |
en_bert,
|
271 |
-
emo,
|
272 |
sdp_ratio=sdp_ratio,
|
273 |
noise_scale=noise_scale,
|
274 |
noise_scale_w=noise_scale_w,
|
@@ -278,7 +214,16 @@ def infer(
|
|
278 |
.float()
|
279 |
.numpy()
|
280 |
)
|
281 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
if torch.cuda.is_available():
|
283 |
torch.cuda.empty_cache()
|
284 |
return audio
|
@@ -302,14 +247,14 @@ def infer_multilang(
|
|
302 |
):
|
303 |
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
304 |
# emo = get_emo_(reference_audio, emotion, sid)
|
305 |
-
if isinstance(reference_audio, np.ndarray):
|
306 |
-
|
307 |
-
else:
|
308 |
-
|
309 |
-
emo = torch.squeeze(emo, dim=1)
|
310 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
311 |
-
|
312 |
-
|
313 |
(
|
314 |
temp_bert,
|
315 |
temp_ja_bert,
|
@@ -318,14 +263,14 @@ def infer_multilang(
|
|
318 |
temp_tones,
|
319 |
temp_lang_ids,
|
320 |
) = get_text(txt, lang, hps, device)
|
321 |
-
if
|
322 |
temp_bert = temp_bert[:, 3:]
|
323 |
temp_ja_bert = temp_ja_bert[:, 3:]
|
324 |
temp_en_bert = temp_en_bert[:, 3:]
|
325 |
temp_phones = temp_phones[3:]
|
326 |
temp_tones = temp_tones[3:]
|
327 |
temp_lang_ids = temp_lang_ids[3:]
|
328 |
-
if
|
329 |
temp_bert = temp_bert[:, :-2]
|
330 |
temp_ja_bert = temp_ja_bert[:, :-2]
|
331 |
temp_en_bert = temp_en_bert[:, :-2]
|
@@ -351,7 +296,7 @@ def infer_multilang(
|
|
351 |
bert = bert.to(device).unsqueeze(0)
|
352 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
353 |
en_bert = en_bert.to(device).unsqueeze(0)
|
354 |
-
emo = emo.to(device).unsqueeze(0)
|
355 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
356 |
del phones
|
357 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
@@ -365,7 +310,6 @@ def infer_multilang(
|
|
365 |
bert,
|
366 |
ja_bert,
|
367 |
en_bert,
|
368 |
-
emo,
|
369 |
sdp_ratio=sdp_ratio,
|
370 |
noise_scale=noise_scale,
|
371 |
noise_scale_w=noise_scale_w,
|
@@ -375,7 +319,16 @@ def infer_multilang(
|
|
375 |
.float()
|
376 |
.numpy()
|
377 |
)
|
378 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
if torch.cuda.is_available():
|
380 |
torch.cuda.empty_cache()
|
381 |
return audio
|
|
|
10 |
import torch
|
11 |
import commons
|
12 |
from text import cleaned_text_to_sequence, get_bert
|
13 |
+
|
14 |
+
# from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
15 |
from text.cleaner import clean_text
|
16 |
import utils
|
17 |
import numpy as np
|
|
|
33 |
from oldVersion import V111, V110, V101, V200, V210
|
34 |
|
35 |
# 当前版本信息
|
36 |
+
latest_version = "2.3"
|
37 |
|
38 |
# 版本兼容
|
39 |
SynthesizerTrnMap = {
|
|
|
99 |
return net_g
|
100 |
|
101 |
|
102 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
103 |
+
style_text = None if style_text == "" else style_text
|
104 |
# 在此处实现当前版本的get_text
|
105 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
106 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
|
112 |
for i in range(len(word2ph)):
|
113 |
word2ph[i] = word2ph[i] * 2
|
114 |
word2ph[0] += 1
|
115 |
+
bert_ori = get_bert(
|
116 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
117 |
+
)
|
118 |
del word2ph
|
119 |
assert bert_ori.shape[-1] == len(phone), phone
|
120 |
|
121 |
if language_str == "ZH":
|
122 |
bert = bert_ori
|
123 |
+
ja_bert = torch.randn(1024, len(phone))
|
124 |
+
en_bert = torch.randn(1024, len(phone))
|
125 |
elif language_str == "JP":
|
126 |
+
bert = torch.randn(1024, len(phone))
|
127 |
ja_bert = bert_ori
|
128 |
+
en_bert = torch.randn(1024, len(phone))
|
129 |
elif language_str == "EN":
|
130 |
+
bert = torch.randn(1024, len(phone))
|
131 |
+
ja_bert = torch.randn(1024, len(phone))
|
132 |
en_bert = bert_ori
|
133 |
else:
|
134 |
raise ValueError("language_str should be ZH, JP or EN")
|
|
|
158 |
reference_audio=None,
|
159 |
skip_start=False,
|
160 |
skip_end=False,
|
161 |
+
style_text=None,
|
162 |
+
style_weight=0.7,
|
163 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
166 |
+
text,
|
167 |
+
language,
|
168 |
+
hps,
|
169 |
+
device,
|
170 |
+
style_text=style_text,
|
171 |
+
style_weight=style_weight,
|
172 |
)
|
173 |
if skip_start:
|
174 |
phones = phones[3:]
|
|
|
192 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
193 |
en_bert = en_bert.to(device).unsqueeze(0)
|
194 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
195 |
+
# emo = emo.to(device).unsqueeze(0)
|
196 |
del phones
|
197 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
198 |
audio = (
|
|
|
205 |
bert,
|
206 |
ja_bert,
|
207 |
en_bert,
|
|
|
208 |
sdp_ratio=sdp_ratio,
|
209 |
noise_scale=noise_scale,
|
210 |
noise_scale_w=noise_scale_w,
|
|
|
214 |
.float()
|
215 |
.numpy()
|
216 |
)
|
217 |
+
del (
|
218 |
+
x_tst,
|
219 |
+
tones,
|
220 |
+
lang_ids,
|
221 |
+
bert,
|
222 |
+
x_tst_lengths,
|
223 |
+
speakers,
|
224 |
+
ja_bert,
|
225 |
+
en_bert,
|
226 |
+
) # , emo
|
227 |
if torch.cuda.is_available():
|
228 |
torch.cuda.empty_cache()
|
229 |
return audio
|
|
|
247 |
):
|
248 |
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
249 |
# emo = get_emo_(reference_audio, emotion, sid)
|
250 |
+
# if isinstance(reference_audio, np.ndarray):
|
251 |
+
# emo = get_clap_audio_feature(reference_audio, device)
|
252 |
+
# else:
|
253 |
+
# emo = get_clap_text_feature(emotion, device)
|
254 |
+
# emo = torch.squeeze(emo, dim=1)
|
255 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
256 |
+
_skip_start = (idx != 0) or (skip_start and idx == 0)
|
257 |
+
_skip_end = (idx != len(language) - 1) or skip_end
|
258 |
(
|
259 |
temp_bert,
|
260 |
temp_ja_bert,
|
|
|
263 |
temp_tones,
|
264 |
temp_lang_ids,
|
265 |
) = get_text(txt, lang, hps, device)
|
266 |
+
if _skip_start:
|
267 |
temp_bert = temp_bert[:, 3:]
|
268 |
temp_ja_bert = temp_ja_bert[:, 3:]
|
269 |
temp_en_bert = temp_en_bert[:, 3:]
|
270 |
temp_phones = temp_phones[3:]
|
271 |
temp_tones = temp_tones[3:]
|
272 |
temp_lang_ids = temp_lang_ids[3:]
|
273 |
+
if _skip_end:
|
274 |
temp_bert = temp_bert[:, :-2]
|
275 |
temp_ja_bert = temp_ja_bert[:, :-2]
|
276 |
temp_en_bert = temp_en_bert[:, :-2]
|
|
|
296 |
bert = bert.to(device).unsqueeze(0)
|
297 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
298 |
en_bert = en_bert.to(device).unsqueeze(0)
|
299 |
+
# emo = emo.to(device).unsqueeze(0)
|
300 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
301 |
del phones
|
302 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
|
|
310 |
bert,
|
311 |
ja_bert,
|
312 |
en_bert,
|
|
|
313 |
sdp_ratio=sdp_ratio,
|
314 |
noise_scale=noise_scale,
|
315 |
noise_scale_w=noise_scale_w,
|
|
|
319 |
.float()
|
320 |
.numpy()
|
321 |
)
|
322 |
+
del (
|
323 |
+
x_tst,
|
324 |
+
tones,
|
325 |
+
lang_ids,
|
326 |
+
bert,
|
327 |
+
x_tst_lengths,
|
328 |
+
speakers,
|
329 |
+
ja_bert,
|
330 |
+
en_bert,
|
331 |
+
) # , emo
|
332 |
if torch.cuda.is_available():
|
333 |
torch.cuda.empty_cache()
|
334 |
return audio
|
losses.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import torch
|
|
|
|
|
2 |
|
3 |
|
4 |
def feature_loss(fmap_r, fmap_g):
|
@@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
|
56 |
kl = torch.sum(kl * z_mask)
|
57 |
l = kl / torch.sum(z_mask)
|
58 |
return l
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torchaudio
|
3 |
+
from transformers import AutoModel
|
4 |
|
5 |
|
6 |
def feature_loss(fmap_r, fmap_g):
|
|
|
58 |
kl = torch.sum(kl * z_mask)
|
59 |
l = kl / torch.sum(z_mask)
|
60 |
return l
|
61 |
+
|
62 |
+
|
63 |
+
class WavLMLoss(torch.nn.Module):
|
64 |
+
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
65 |
+
super(WavLMLoss, self).__init__()
|
66 |
+
self.wavlm = AutoModel.from_pretrained(model)
|
67 |
+
self.wd = wd
|
68 |
+
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
69 |
+
self.wavlm.eval()
|
70 |
+
for param in self.wavlm.parameters():
|
71 |
+
param.requires_grad = False
|
72 |
+
|
73 |
+
def forward(self, wav, y_rec):
|
74 |
+
with torch.no_grad():
|
75 |
+
wav_16 = self.resample(wav)
|
76 |
+
wav_embeddings = self.wavlm(
|
77 |
+
input_values=wav_16, output_hidden_states=True
|
78 |
+
).hidden_states
|
79 |
+
y_rec_16 = self.resample(y_rec)
|
80 |
+
y_rec_embeddings = self.wavlm(
|
81 |
+
input_values=y_rec_16.squeeze(), output_hidden_states=True
|
82 |
+
).hidden_states
|
83 |
+
|
84 |
+
floss = 0
|
85 |
+
for er, eg in zip(wav_embeddings, y_rec_embeddings):
|
86 |
+
floss += torch.mean(torch.abs(er - eg))
|
87 |
+
|
88 |
+
return floss.mean()
|
89 |
+
|
90 |
+
def generator(self, y_rec):
|
91 |
+
y_rec_16 = self.resample(y_rec)
|
92 |
+
y_rec_embeddings = self.wavlm(
|
93 |
+
input_values=y_rec_16, output_hidden_states=True
|
94 |
+
).hidden_states
|
95 |
+
y_rec_embeddings = (
|
96 |
+
torch.stack(y_rec_embeddings, dim=1)
|
97 |
+
.transpose(-1, -2)
|
98 |
+
.flatten(start_dim=1, end_dim=2)
|
99 |
+
)
|
100 |
+
y_df_hat_g = self.wd(y_rec_embeddings)
|
101 |
+
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
|
102 |
+
|
103 |
+
return loss_gen
|
104 |
+
|
105 |
+
def discriminator(self, wav, y_rec):
|
106 |
+
with torch.no_grad():
|
107 |
+
wav_16 = self.resample(wav)
|
108 |
+
wav_embeddings = self.wavlm(
|
109 |
+
input_values=wav_16, output_hidden_states=True
|
110 |
+
).hidden_states
|
111 |
+
y_rec_16 = self.resample(y_rec)
|
112 |
+
y_rec_embeddings = self.wavlm(
|
113 |
+
input_values=y_rec_16, output_hidden_states=True
|
114 |
+
).hidden_states
|
115 |
+
|
116 |
+
y_embeddings = (
|
117 |
+
torch.stack(wav_embeddings, dim=1)
|
118 |
+
.transpose(-1, -2)
|
119 |
+
.flatten(start_dim=1, end_dim=2)
|
120 |
+
)
|
121 |
+
y_rec_embeddings = (
|
122 |
+
torch.stack(y_rec_embeddings, dim=1)
|
123 |
+
.transpose(-1, -2)
|
124 |
+
.flatten(start_dim=1, end_dim=2)
|
125 |
+
)
|
126 |
+
|
127 |
+
y_d_rs = self.wd(y_embeddings)
|
128 |
+
y_d_gs = self.wd(y_rec_embeddings)
|
129 |
+
|
130 |
+
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
131 |
+
|
132 |
+
r_loss = torch.mean((1 - y_df_hat_r) ** 2)
|
133 |
+
g_loss = torch.mean((y_df_hat_g) ** 2)
|
134 |
+
|
135 |
+
loss_disc_f = r_loss + g_loss
|
136 |
+
|
137 |
+
return loss_disc_f.mean()
|
138 |
+
|
139 |
+
def discriminator_forward(self, wav):
|
140 |
+
with torch.no_grad():
|
141 |
+
wav_16 = self.resample(wav)
|
142 |
+
wav_embeddings = self.wavlm(
|
143 |
+
input_values=wav_16, output_hidden_states=True
|
144 |
+
).hidden_states
|
145 |
+
y_embeddings = (
|
146 |
+
torch.stack(wav_embeddings, dim=1)
|
147 |
+
.transpose(-1, -2)
|
148 |
+
.flatten(start_dim=1, end_dim=2)
|
149 |
+
)
|
150 |
+
|
151 |
+
y_d_rs = self.wd(y_embeddings)
|
152 |
+
|
153 |
+
return y_d_rs
|
models.py
CHANGED
@@ -40,33 +40,22 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
40 |
self.norm_2 = modules.LayerNorm(filter_channels)
|
41 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
42 |
|
43 |
-
self.
|
44 |
-
2 * filter_channels, filter_channels,
|
45 |
)
|
46 |
-
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
|
47 |
-
self.pre_out_conv_2 = nn.Conv1d(
|
48 |
-
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
49 |
-
)
|
50 |
-
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
51 |
|
52 |
if gin_channels != 0:
|
53 |
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
54 |
|
55 |
-
self.output_layer = nn.Sequential(
|
|
|
|
|
56 |
|
57 |
-
def forward_probability(self, x,
|
58 |
dur = self.dur_proj(dur)
|
59 |
x = torch.cat([x, dur], dim=1)
|
60 |
-
x = self.pre_out_conv_1(x * x_mask)
|
61 |
-
x = torch.relu(x)
|
62 |
-
x = self.pre_out_norm_1(x)
|
63 |
-
x = self.drop(x)
|
64 |
-
x = self.pre_out_conv_2(x * x_mask)
|
65 |
-
x = torch.relu(x)
|
66 |
-
x = self.pre_out_norm_2(x)
|
67 |
-
x = self.drop(x)
|
68 |
-
x = x * x_mask
|
69 |
x = x.transpose(1, 2)
|
|
|
70 |
output_prob = self.output_layer(x)
|
71 |
return output_prob
|
72 |
|
@@ -86,7 +75,7 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
86 |
|
87 |
output_probs = []
|
88 |
for dur in [dur_r, dur_hat]:
|
89 |
-
output_prob = self.forward_probability(x,
|
90 |
output_probs.append(output_prob)
|
91 |
|
92 |
return output_probs
|
@@ -354,7 +343,6 @@ class TextEncoder(nn.Module):
|
|
354 |
n_layers,
|
355 |
kernel_size,
|
356 |
p_dropout,
|
357 |
-
n_speakers,
|
358 |
gin_channels=0,
|
359 |
):
|
360 |
super().__init__()
|
@@ -376,31 +364,6 @@ class TextEncoder(nn.Module):
|
|
376 |
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
377 |
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
378 |
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
379 |
-
# self.emo_proj = nn.Linear(512, hidden_channels)
|
380 |
-
self.in_feature_net = nn.Sequential(
|
381 |
-
# input is assumed to an already normalized embedding
|
382 |
-
nn.Linear(512, 1028, bias=False),
|
383 |
-
nn.GELU(),
|
384 |
-
nn.LayerNorm(1028),
|
385 |
-
*[Block(1028, 512) for _ in range(1)],
|
386 |
-
nn.Linear(1028, 512, bias=False),
|
387 |
-
# normalize before passing to VQ?
|
388 |
-
# nn.GELU(),
|
389 |
-
# nn.LayerNorm(512),
|
390 |
-
)
|
391 |
-
self.emo_vq = VectorQuantize(
|
392 |
-
dim=512,
|
393 |
-
codebook_size=64,
|
394 |
-
codebook_dim=32,
|
395 |
-
commitment_weight=0.1,
|
396 |
-
decay=0.85,
|
397 |
-
heads=32,
|
398 |
-
kmeans_iters=20,
|
399 |
-
separate_codebook_per_head=True,
|
400 |
-
stochastic_sample_codes=True,
|
401 |
-
threshold_ema_dead_code=2,
|
402 |
-
)
|
403 |
-
self.out_feature_net = nn.Linear(512, hidden_channels)
|
404 |
|
405 |
self.encoder = attentions.Encoder(
|
406 |
hidden_channels,
|
@@ -413,18 +376,10 @@ class TextEncoder(nn.Module):
|
|
413 |
)
|
414 |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
415 |
|
416 |
-
def forward(
|
417 |
-
self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
|
418 |
-
):
|
419 |
-
sid = sid.cpu()
|
420 |
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
421 |
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
422 |
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
423 |
-
emo_emb = self.in_feature_net(emo)
|
424 |
-
emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
|
425 |
-
loss_commit = loss_commit.mean()
|
426 |
-
emo_emb = self.out_feature_net(emo_emb)
|
427 |
-
# emo_emb = self.emo_proj(emo.unsqueeze(1))
|
428 |
x = (
|
429 |
self.emb(x)
|
430 |
+ self.tone_emb(tone)
|
@@ -432,7 +387,6 @@ class TextEncoder(nn.Module):
|
|
432 |
+ bert_emb
|
433 |
+ ja_bert_emb
|
434 |
+ en_bert_emb
|
435 |
-
+ emo_emb
|
436 |
) * math.sqrt(
|
437 |
self.hidden_channels
|
438 |
) # [b, t, h]
|
@@ -445,7 +399,7 @@ class TextEncoder(nn.Module):
|
|
445 |
stats = self.proj(x) * x_mask
|
446 |
|
447 |
m, logs = torch.split(stats, self.out_channels, dim=1)
|
448 |
-
return x, m, logs, x_mask
|
449 |
|
450 |
|
451 |
class ResidualCouplingBlock(nn.Module):
|
@@ -748,6 +702,55 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|
748 |
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
749 |
|
750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
class ReferenceEncoder(nn.Module):
|
752 |
"""
|
753 |
inputs --- [N, Ty/r, n_mels*r] mels
|
@@ -878,7 +881,6 @@ class SynthesizerTrn(nn.Module):
|
|
878 |
n_layers,
|
879 |
kernel_size,
|
880 |
p_dropout,
|
881 |
-
self.n_speakers,
|
882 |
gin_channels=self.enc_gin_channels,
|
883 |
)
|
884 |
self.dec = Generator(
|
@@ -946,14 +948,13 @@ class SynthesizerTrn(nn.Module):
|
|
946 |
bert,
|
947 |
ja_bert,
|
948 |
en_bert,
|
949 |
-
emo=None,
|
950 |
):
|
951 |
if self.n_speakers > 0:
|
952 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
953 |
else:
|
954 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
955 |
-
x, m_p, logs_p, x_mask
|
956 |
-
x, x_lengths, tone, language, bert, ja_bert, en_bert,
|
957 |
)
|
958 |
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
959 |
z_p = self.flow(z, y_mask, g=g)
|
@@ -996,9 +997,11 @@ class SynthesizerTrn(nn.Module):
|
|
996 |
|
997 |
logw_ = torch.log(w + 1e-6) * x_mask
|
998 |
logw = self.dp(x, x_mask, g=g)
|
|
|
999 |
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
1000 |
x_mask
|
1001 |
) # for averaging
|
|
|
1002 |
|
1003 |
l_length = l_length_dp + l_length_sdp
|
1004 |
|
@@ -1018,9 +1021,8 @@ class SynthesizerTrn(nn.Module):
|
|
1018 |
x_mask,
|
1019 |
y_mask,
|
1020 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
1021 |
-
(x, logw, logw_),
|
1022 |
g,
|
1023 |
-
loss_commit,
|
1024 |
)
|
1025 |
|
1026 |
def infer(
|
@@ -1033,7 +1035,6 @@ class SynthesizerTrn(nn.Module):
|
|
1033 |
bert,
|
1034 |
ja_bert,
|
1035 |
en_bert,
|
1036 |
-
emo=None,
|
1037 |
noise_scale=0.667,
|
1038 |
length_scale=1,
|
1039 |
noise_scale_w=0.8,
|
@@ -1047,8 +1048,8 @@ class SynthesizerTrn(nn.Module):
|
|
1047 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1048 |
else:
|
1049 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1050 |
-
x, m_p, logs_p, x_mask
|
1051 |
-
x, x_lengths, tone, language, bert, ja_bert, en_bert,
|
1052 |
)
|
1053 |
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
1054 |
sdp_ratio
|
|
|
40 |
self.norm_2 = modules.LayerNorm(filter_channels)
|
41 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
42 |
|
43 |
+
self.LSTM = nn.LSTM(
|
44 |
+
2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
|
45 |
)
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
if gin_channels != 0:
|
48 |
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
49 |
|
50 |
+
self.output_layer = nn.Sequential(
|
51 |
+
nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
|
52 |
+
)
|
53 |
|
54 |
+
def forward_probability(self, x, dur):
|
55 |
dur = self.dur_proj(dur)
|
56 |
x = torch.cat([x, dur], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
x = x.transpose(1, 2)
|
58 |
+
x, _ = self.LSTM(x)
|
59 |
output_prob = self.output_layer(x)
|
60 |
return output_prob
|
61 |
|
|
|
75 |
|
76 |
output_probs = []
|
77 |
for dur in [dur_r, dur_hat]:
|
78 |
+
output_prob = self.forward_probability(x, dur)
|
79 |
output_probs.append(output_prob)
|
80 |
|
81 |
return output_probs
|
|
|
343 |
n_layers,
|
344 |
kernel_size,
|
345 |
p_dropout,
|
|
|
346 |
gin_channels=0,
|
347 |
):
|
348 |
super().__init__()
|
|
|
364 |
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
365 |
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
366 |
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
self.encoder = attentions.Encoder(
|
369 |
hidden_channels,
|
|
|
376 |
)
|
377 |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
378 |
|
379 |
+
def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
|
|
|
|
|
|
|
380 |
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
381 |
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
382 |
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
383 |
x = (
|
384 |
self.emb(x)
|
385 |
+ self.tone_emb(tone)
|
|
|
387 |
+ bert_emb
|
388 |
+ ja_bert_emb
|
389 |
+ en_bert_emb
|
|
|
390 |
) * math.sqrt(
|
391 |
self.hidden_channels
|
392 |
) # [b, t, h]
|
|
|
399 |
stats = self.proj(x) * x_mask
|
400 |
|
401 |
m, logs = torch.split(stats, self.out_channels, dim=1)
|
402 |
+
return x, m, logs, x_mask
|
403 |
|
404 |
|
405 |
class ResidualCouplingBlock(nn.Module):
|
|
|
702 |
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
703 |
|
704 |
|
705 |
+
class WavLMDiscriminator(nn.Module):
|
706 |
+
"""docstring for Discriminator."""
|
707 |
+
|
708 |
+
def __init__(
|
709 |
+
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
710 |
+
):
|
711 |
+
super(WavLMDiscriminator, self).__init__()
|
712 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
713 |
+
self.pre = norm_f(
|
714 |
+
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
715 |
+
)
|
716 |
+
|
717 |
+
self.convs = nn.ModuleList(
|
718 |
+
[
|
719 |
+
norm_f(
|
720 |
+
nn.Conv1d(
|
721 |
+
initial_channel, initial_channel * 2, kernel_size=5, padding=2
|
722 |
+
)
|
723 |
+
),
|
724 |
+
norm_f(
|
725 |
+
nn.Conv1d(
|
726 |
+
initial_channel * 2,
|
727 |
+
initial_channel * 4,
|
728 |
+
kernel_size=5,
|
729 |
+
padding=2,
|
730 |
+
)
|
731 |
+
),
|
732 |
+
norm_f(
|
733 |
+
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
|
734 |
+
),
|
735 |
+
]
|
736 |
+
)
|
737 |
+
|
738 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
739 |
+
|
740 |
+
def forward(self, x):
|
741 |
+
x = self.pre(x)
|
742 |
+
|
743 |
+
fmap = []
|
744 |
+
for l in self.convs:
|
745 |
+
x = l(x)
|
746 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
747 |
+
fmap.append(x)
|
748 |
+
x = self.conv_post(x)
|
749 |
+
x = torch.flatten(x, 1, -1)
|
750 |
+
|
751 |
+
return x
|
752 |
+
|
753 |
+
|
754 |
class ReferenceEncoder(nn.Module):
|
755 |
"""
|
756 |
inputs --- [N, Ty/r, n_mels*r] mels
|
|
|
881 |
n_layers,
|
882 |
kernel_size,
|
883 |
p_dropout,
|
|
|
884 |
gin_channels=self.enc_gin_channels,
|
885 |
)
|
886 |
self.dec = Generator(
|
|
|
948 |
bert,
|
949 |
ja_bert,
|
950 |
en_bert,
|
|
|
951 |
):
|
952 |
if self.n_speakers > 0:
|
953 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
954 |
else:
|
955 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
956 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
957 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
|
958 |
)
|
959 |
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
960 |
z_p = self.flow(z, y_mask, g=g)
|
|
|
997 |
|
998 |
logw_ = torch.log(w + 1e-6) * x_mask
|
999 |
logw = self.dp(x, x_mask, g=g)
|
1000 |
+
logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
|
1001 |
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
1002 |
x_mask
|
1003 |
) # for averaging
|
1004 |
+
l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
1005 |
|
1006 |
l_length = l_length_dp + l_length_sdp
|
1007 |
|
|
|
1021 |
x_mask,
|
1022 |
y_mask,
|
1023 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
1024 |
+
(x, logw, logw_, logw_sdp),
|
1025 |
g,
|
|
|
1026 |
)
|
1027 |
|
1028 |
def infer(
|
|
|
1035 |
bert,
|
1036 |
ja_bert,
|
1037 |
en_bert,
|
|
|
1038 |
noise_scale=0.667,
|
1039 |
length_scale=1,
|
1040 |
noise_scale_w=0.8,
|
|
|
1048 |
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1049 |
else:
|
1050 |
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1051 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
1052 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
|
1053 |
)
|
1054 |
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
1055 |
sdp_ratio
|
onnx_infer.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from onnx_modules.V220_OnnxInference import OnnxInferenceSession
|
2 |
+
import numpy as np
|
3 |
+
Session = OnnxInferenceSession(
|
4 |
+
{
|
5 |
+
"enc" : "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
|
6 |
+
"emb_g" : "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
|
7 |
+
"dp" : "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
|
8 |
+
"sdp" : "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
|
9 |
+
"flow" : "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
|
10 |
+
"dec" : "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx"
|
11 |
+
},
|
12 |
+
Providers = ["CPUExecutionProvider"]
|
13 |
+
)
|
14 |
+
|
15 |
+
#这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可
|
16 |
+
x = np.array(
|
17 |
+
[
|
18 |
+
0,
|
19 |
+
97,
|
20 |
+
0,
|
21 |
+
8,
|
22 |
+
0,
|
23 |
+
78,
|
24 |
+
0,
|
25 |
+
8,
|
26 |
+
0,
|
27 |
+
76,
|
28 |
+
0,
|
29 |
+
37,
|
30 |
+
0,
|
31 |
+
40,
|
32 |
+
0,
|
33 |
+
97,
|
34 |
+
0,
|
35 |
+
8,
|
36 |
+
0,
|
37 |
+
23,
|
38 |
+
0,
|
39 |
+
8,
|
40 |
+
0,
|
41 |
+
74,
|
42 |
+
0,
|
43 |
+
26,
|
44 |
+
0,
|
45 |
+
104,
|
46 |
+
0,
|
47 |
+
]
|
48 |
+
)
|
49 |
+
tone = np.zeros_like(x)
|
50 |
+
language = np.zeros_like(x)
|
51 |
+
sid = np.array([0])
|
52 |
+
bert = np.random.randn(x.shape[0], 1024)
|
53 |
+
ja_bert = np.random.randn(x.shape[0], 1024)
|
54 |
+
en_bert = np.random.randn(x.shape[0], 1024)
|
55 |
+
emo = np.random.randn(512, 1)
|
56 |
+
|
57 |
+
audio = Session(
|
58 |
+
x,
|
59 |
+
tone,
|
60 |
+
language,
|
61 |
+
bert,
|
62 |
+
ja_bert,
|
63 |
+
en_bert,
|
64 |
+
emo,
|
65 |
+
sid
|
66 |
+
)
|
67 |
+
|
68 |
+
print(audio)
|
re_matching.py
CHANGED
@@ -44,7 +44,6 @@ def text_matching(text: str) -> list:
|
|
44 |
result = []
|
45 |
for speaker, dialogue in matches:
|
46 |
result.append(extract_language_and_text_updated(speaker, dialogue))
|
47 |
-
print(result)
|
48 |
return result
|
49 |
|
50 |
|
|
|
44 |
result = []
|
45 |
for speaker, dialogue in matches:
|
46 |
result.append(extract_language_and_text_updated(speaker, dialogue))
|
|
|
47 |
return result
|
48 |
|
49 |
|
requirements.txt
CHANGED
@@ -11,7 +11,7 @@ jieba
|
|
11 |
transformers
|
12 |
pypinyin
|
13 |
cn2an
|
14 |
-
gradio==3.
|
15 |
av
|
16 |
mecab-python3
|
17 |
loguru
|
@@ -21,8 +21,7 @@ fugashi
|
|
21 |
num2words
|
22 |
PyYAML
|
23 |
requests
|
24 |
-
pyopenjtalk
|
25 |
-
openjtalk; sys_platform != 'linux'
|
26 |
jaconv
|
27 |
psutil
|
28 |
GPUtil
|
|
|
11 |
transformers
|
12 |
pypinyin
|
13 |
cn2an
|
14 |
+
gradio==3.50.2
|
15 |
av
|
16 |
mecab-python3
|
17 |
loguru
|
|
|
21 |
num2words
|
22 |
PyYAML
|
23 |
requests
|
24 |
+
pyopenjtalk-prebuilt
|
|
|
25 |
jaconv
|
26 |
psutil
|
27 |
GPUtil
|
resample.py
CHANGED
@@ -10,11 +10,11 @@ from config import config
|
|
10 |
|
11 |
|
12 |
def process(item):
|
13 |
-
wav_name, args = item
|
14 |
-
wav_path = os.path.join(args.in_dir, wav_name)
|
15 |
if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
|
16 |
wav, sr = librosa.load(wav_path, sr=args.sr)
|
17 |
-
soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
@@ -54,11 +54,15 @@ if __name__ == "__main__":
|
|
54 |
tasks = []
|
55 |
|
56 |
for dirpath, _, filenames in os.walk(args.in_dir):
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
for filename in filenames:
|
60 |
if filename.lower().endswith(".wav"):
|
61 |
-
|
|
|
62 |
|
63 |
for _ in tqdm(
|
64 |
pool.imap_unordered(process, tasks),
|
|
|
10 |
|
11 |
|
12 |
def process(item):
|
13 |
+
spkdir, wav_name, args = item
|
14 |
+
wav_path = os.path.join(args.in_dir, spkdir, wav_name)
|
15 |
if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
|
16 |
wav, sr = librosa.load(wav_path, sr=args.sr)
|
17 |
+
soundfile.write(os.path.join(args.out_dir, spkdir, wav_name), wav, sr)
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
|
|
54 |
tasks = []
|
55 |
|
56 |
for dirpath, _, filenames in os.walk(args.in_dir):
|
57 |
+
# 子级目录
|
58 |
+
spk_dir = os.path.relpath(dirpath, args.in_dir)
|
59 |
+
spk_dir_out = os.path.join(args.out_dir, spk_dir)
|
60 |
+
if not os.path.isdir(spk_dir_out):
|
61 |
+
os.makedirs(spk_dir_out, exist_ok=True)
|
62 |
for filename in filenames:
|
63 |
if filename.lower().endswith(".wav"):
|
64 |
+
twople = (spk_dir, filename, args)
|
65 |
+
tasks.append(twople)
|
66 |
|
67 |
for _ in tqdm(
|
68 |
pool.imap_unordered(process, tasks),
|
resample_legacy.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import librosa
|
4 |
+
from multiprocessing import Pool, cpu_count
|
5 |
+
|
6 |
+
import soundfile
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from config import config
|
10 |
+
|
11 |
+
|
12 |
+
def process(item):
|
13 |
+
wav_name, args = item
|
14 |
+
wav_path = os.path.join(args.in_dir, wav_name)
|
15 |
+
if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
|
16 |
+
wav, sr = librosa.load(wav_path, sr=args.sr)
|
17 |
+
soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"--sr",
|
24 |
+
type=int,
|
25 |
+
default=config.resample_config.sampling_rate,
|
26 |
+
help="sampling rate",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--in_dir",
|
30 |
+
type=str,
|
31 |
+
default=config.resample_config.in_dir,
|
32 |
+
help="path to source dir",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--out_dir",
|
36 |
+
type=str,
|
37 |
+
default=config.resample_config.out_dir,
|
38 |
+
help="path to target dir",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--processes",
|
42 |
+
type=int,
|
43 |
+
default=0,
|
44 |
+
help="cpu_processes",
|
45 |
+
)
|
46 |
+
args, _ = parser.parse_known_args()
|
47 |
+
# autodl 无卡模式会识别出46个cpu
|
48 |
+
if args.processes == 0:
|
49 |
+
processes = cpu_count() - 2 if cpu_count() > 4 else 1
|
50 |
+
else:
|
51 |
+
processes = args.processes
|
52 |
+
pool = Pool(processes=processes)
|
53 |
+
|
54 |
+
tasks = []
|
55 |
+
|
56 |
+
for dirpath, _, filenames in os.walk(args.in_dir):
|
57 |
+
if not os.path.isdir(args.out_dir):
|
58 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
59 |
+
for filename in filenames:
|
60 |
+
if filename.lower().endswith(".wav"):
|
61 |
+
tasks.append((filename, args))
|
62 |
+
|
63 |
+
for _ in tqdm(
|
64 |
+
pool.imap_unordered(process, tasks),
|
65 |
+
):
|
66 |
+
pass
|
67 |
+
|
68 |
+
pool.close()
|
69 |
+
pool.join()
|
70 |
+
|
71 |
+
print("音频重采样完毕!")
|
server.py
CHANGED
@@ -4,9 +4,6 @@ from pathlib import Path
|
|
4 |
|
5 |
import logging
|
6 |
import re_matching
|
7 |
-
import uuid
|
8 |
-
from flask import Flask, request, jsonify, render_template_string
|
9 |
-
from flask_cors import CORS
|
10 |
|
11 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
12 |
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
@@ -18,6 +15,7 @@ logging.basicConfig(
|
|
18 |
)
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
|
|
21 |
import librosa
|
22 |
import numpy as np
|
23 |
import torch
|
@@ -25,25 +23,31 @@ import torch.nn as nn
|
|
25 |
from torch.utils.data import Dataset
|
26 |
from torch.utils.data import DataLoader, Dataset
|
27 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
import utils
|
30 |
from config import config
|
31 |
-
|
32 |
import torch
|
33 |
import commons
|
34 |
from text import cleaned_text_to_sequence, get_bert
|
35 |
-
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
36 |
-
|
37 |
from text.cleaner import clean_text
|
38 |
import utils
|
39 |
|
40 |
from models import SynthesizerTrn
|
41 |
from text.symbols import symbols
|
42 |
import sys
|
43 |
-
|
44 |
from scipy.io.wavfile import write
|
|
|
45 |
|
46 |
net_g = None
|
|
|
47 |
device = (
|
48 |
"cuda:0"
|
49 |
if torch.cuda.is_available()
|
@@ -54,7 +58,22 @@ device = (
|
|
54 |
)
|
55 |
)
|
56 |
|
57 |
-
#device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def get_net_g(model_path: str, device: str, hps):
|
60 |
net_g = SynthesizerTrn(
|
@@ -68,11 +87,11 @@ def get_net_g(model_path: str, device: str, hps):
|
|
68 |
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
69 |
return net_g
|
70 |
|
71 |
-
|
72 |
-
|
73 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
74 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
75 |
-
|
76 |
if hps.data.add_blank:
|
77 |
phone = commons.intersperse(phone, 0)
|
78 |
tone = commons.intersperse(tone, 0)
|
@@ -80,18 +99,24 @@ def get_text(text, language_str, hps, device):
|
|
80 |
for i in range(len(word2ph)):
|
81 |
word2ph[i] = word2ph[i] * 2
|
82 |
word2ph[0] += 1
|
83 |
-
bert_ori = get_bert(
|
|
|
|
|
84 |
del word2ph
|
85 |
assert bert_ori.shape[-1] == len(phone), phone
|
86 |
|
87 |
if language_str == "ZH":
|
88 |
bert = bert_ori
|
89 |
-
ja_bert = torch.
|
90 |
-
en_bert = torch.
|
91 |
elif language_str == "JP":
|
92 |
-
bert = torch.
|
93 |
ja_bert = bert_ori
|
94 |
-
en_bert = torch.
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
raise ValueError("language_str should be ZH, JP or EN")
|
97 |
|
@@ -104,6 +129,7 @@ def get_text(text, language_str, hps, device):
|
|
104 |
language = torch.LongTensor(language)
|
105 |
return bert, ja_bert, en_bert, phone, tone, language
|
106 |
|
|
|
107 |
def infer(
|
108 |
text,
|
109 |
sdp_ratio,
|
@@ -111,18 +137,18 @@ def infer(
|
|
111 |
noise_scale_w,
|
112 |
length_scale,
|
113 |
sid,
|
114 |
-
|
115 |
-
|
116 |
):
|
117 |
|
118 |
language= 'JP' if is_japanese(text) else 'ZH'
|
119 |
-
if isinstance(reference_audio, np.ndarray):
|
120 |
-
emo = get_clap_audio_feature(reference_audio, device)
|
121 |
-
else:
|
122 |
-
emo = get_clap_text_feature(emotion, device)
|
123 |
-
emo = torch.squeeze(emo, dim=1)
|
124 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
125 |
-
text,
|
|
|
|
|
|
|
|
|
|
|
126 |
)
|
127 |
with torch.no_grad():
|
128 |
x_tst = phones.to(device).unsqueeze(0)
|
@@ -132,7 +158,7 @@ def infer(
|
|
132 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
133 |
en_bert = en_bert.to(device).unsqueeze(0)
|
134 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
135 |
-
emo = emo.to(device).unsqueeze(0)
|
136 |
del phones
|
137 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
138 |
audio = (
|
@@ -145,7 +171,6 @@ def infer(
|
|
145 |
bert,
|
146 |
ja_bert,
|
147 |
en_bert,
|
148 |
-
emo,
|
149 |
sdp_ratio=sdp_ratio,
|
150 |
noise_scale=noise_scale,
|
151 |
noise_scale_w=noise_scale_w,
|
@@ -155,7 +180,80 @@ def infer(
|
|
155 |
.float()
|
156 |
.numpy()
|
157 |
)
|
158 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
if torch.cuda.is_available():
|
160 |
torch.cuda.empty_cache()
|
161 |
unique_filename = f"temp{uuid.uuid4()}.wav"
|
@@ -176,19 +274,11 @@ def loadmodel(model):
|
|
176 |
except:
|
177 |
return "error"
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
data = {'text': text}
|
183 |
-
try:
|
184 |
-
response = requests.post(url, files=files,data=data)
|
185 |
-
return response.status_code, response.text
|
186 |
-
except Exception as e:
|
187 |
-
return 500, str(e)
|
188 |
|
189 |
-
|
190 |
-
CORS(app)
|
191 |
-
@app.route('/')
|
192 |
|
193 |
def tts():
|
194 |
global last_text, last_model
|
@@ -197,7 +287,8 @@ def tts():
|
|
197 |
noise_scale = float(request.args.get('noise_scale', 0.6))
|
198 |
noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
|
199 |
length_scale = float(request.args.get('length_scale', 1))
|
200 |
-
|
|
|
201 |
text = request.args.get('text')
|
202 |
is_chat = request.args.get('is_chat', 'false').lower() == 'true'
|
203 |
model = request.args.get('model',modelPaths[-1])
|
@@ -210,7 +301,7 @@ def tts():
|
|
210 |
<title>TTS API Documentation</title>
|
211 |
</head>
|
212 |
<body>
|
213 |
-
<iframe src="http://
|
214 |
</body>
|
215 |
</html>
|
216 |
""")
|
@@ -225,9 +316,7 @@ def tts():
|
|
225 |
write(unique_filename , 44100, silence)
|
226 |
else:
|
227 |
last_text = text
|
228 |
-
unique_filename =
|
229 |
-
status_code, response_text = send_audio_to_server(unique_filename,text)
|
230 |
-
print(f"Response from server: {response_text} (Status code: {status_code})")
|
231 |
with open(unique_filename ,'rb') as bit:
|
232 |
wav_bytes = bit.read()
|
233 |
os.remove(unique_filename)
|
@@ -236,14 +325,16 @@ def tts():
|
|
236 |
'Text': unique_filename .encode('utf-8')}
|
237 |
return wav_bytes, 200, headers
|
238 |
|
|
|
|
|
239 |
|
240 |
if __name__ == "__main__":
|
241 |
languages = [ "Auto", "ZH", "JP"]
|
242 |
modelPaths = []
|
243 |
-
for dirpath, dirnames, filenames in os.walk(
|
244 |
for filename in filenames:
|
245 |
modelPaths.append(os.path.join(dirpath, filename))
|
246 |
-
hps = utils.get_hparams_from_file('Data/
|
247 |
net_g = get_net_g(
|
248 |
model_path=modelPaths[-1], device=device, hps=hps
|
249 |
)
|
@@ -251,4 +342,80 @@ if __name__ == "__main__":
|
|
251 |
speakers = list(speaker_ids.keys())
|
252 |
last_text = ""
|
253 |
last_model = modelPaths[-1]
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
import logging
|
6 |
import re_matching
|
|
|
|
|
|
|
7 |
|
8 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
9 |
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
|
|
15 |
)
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
import librosa
|
20 |
import numpy as np
|
21 |
import torch
|
|
|
23 |
from torch.utils.data import Dataset
|
24 |
from torch.utils.data import DataLoader, Dataset
|
25 |
from tqdm import tqdm
|
26 |
+
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
27 |
+
|
28 |
+
import uuid
|
29 |
+
from flask import Flask, request, jsonify, render_template_string
|
30 |
+
from flask_cors import CORS
|
31 |
+
|
32 |
+
import gradio as gr
|
33 |
|
34 |
import utils
|
35 |
from config import config
|
36 |
+
|
37 |
import torch
|
38 |
import commons
|
39 |
from text import cleaned_text_to_sequence, get_bert
|
|
|
|
|
40 |
from text.cleaner import clean_text
|
41 |
import utils
|
42 |
|
43 |
from models import SynthesizerTrn
|
44 |
from text.symbols import symbols
|
45 |
import sys
|
|
|
46 |
from scipy.io.wavfile import write
|
47 |
+
from threading import Thread
|
48 |
|
49 |
net_g = None
|
50 |
+
|
51 |
device = (
|
52 |
"cuda:0"
|
53 |
if torch.cuda.is_available()
|
|
|
58 |
)
|
59 |
)
|
60 |
|
61 |
+
#device = "cpu"
|
62 |
+
BandList = {
|
63 |
+
"PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
|
64 |
+
"Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
|
65 |
+
"HelloHappyWorld":["こころ","美咲","薫","花音","はぐみ"],
|
66 |
+
"PastelPalettes":["彩","日菜","千聖","イヴ","麻弥"],
|
67 |
+
"Roselia":["友希那","紗夜","リサ","燐子","あこ"],
|
68 |
+
"RaiseASuilen":["レイヤ","ロック","ますき","チュチュ","パレオ"],
|
69 |
+
"Morfonica":["ましろ","瑠唯","つくし","七深","透子"],
|
70 |
+
"MyGo":["燈","愛音","そよ","立希","楽奈"],
|
71 |
+
"AveMujica":["祥子","睦","海鈴","にゃむ","初華"],
|
72 |
+
"圣翔音乐学园":["華戀","光","香子","雙葉","真晝","純那","克洛迪娜","真矢","奈奈"],
|
73 |
+
"凛明馆女子学校":["珠緒","壘","文","悠悠子","一愛"],
|
74 |
+
"弗隆提亚艺术学校":["艾露","艾露露","菈樂菲","司","靜羽"],
|
75 |
+
"西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
|
76 |
+
}
|
77 |
|
78 |
def get_net_g(model_path: str, device: str, hps):
|
79 |
net_g = SynthesizerTrn(
|
|
|
87 |
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
88 |
return net_g
|
89 |
|
90 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
91 |
+
style_text = None if style_text == "" else style_text
|
92 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
93 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
94 |
+
|
95 |
if hps.data.add_blank:
|
96 |
phone = commons.intersperse(phone, 0)
|
97 |
tone = commons.intersperse(tone, 0)
|
|
|
99 |
for i in range(len(word2ph)):
|
100 |
word2ph[i] = word2ph[i] * 2
|
101 |
word2ph[0] += 1
|
102 |
+
bert_ori = get_bert(
|
103 |
+
norm_text, word2ph, language_str, device, style_text, style_weight
|
104 |
+
)
|
105 |
del word2ph
|
106 |
assert bert_ori.shape[-1] == len(phone), phone
|
107 |
|
108 |
if language_str == "ZH":
|
109 |
bert = bert_ori
|
110 |
+
ja_bert = torch.randn(1024, len(phone))
|
111 |
+
en_bert = torch.randn(1024, len(phone))
|
112 |
elif language_str == "JP":
|
113 |
+
bert = torch.randn(1024, len(phone))
|
114 |
ja_bert = bert_ori
|
115 |
+
en_bert = torch.randn(1024, len(phone))
|
116 |
+
elif language_str == "EN":
|
117 |
+
bert = torch.randn(1024, len(phone))
|
118 |
+
ja_bert = torch.randn(1024, len(phone))
|
119 |
+
en_bert = bert_ori
|
120 |
else:
|
121 |
raise ValueError("language_str should be ZH, JP or EN")
|
122 |
|
|
|
129 |
language = torch.LongTensor(language)
|
130 |
return bert, ja_bert, en_bert, phone, tone, language
|
131 |
|
132 |
+
|
133 |
def infer(
|
134 |
text,
|
135 |
sdp_ratio,
|
|
|
137 |
noise_scale_w,
|
138 |
length_scale,
|
139 |
sid,
|
140 |
+
style_text=None,
|
141 |
+
style_weight=0.7,
|
142 |
):
|
143 |
|
144 |
language= 'JP' if is_japanese(text) else 'ZH'
|
|
|
|
|
|
|
|
|
|
|
145 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
146 |
+
text,
|
147 |
+
language,
|
148 |
+
hps,
|
149 |
+
device,
|
150 |
+
style_text=style_text,
|
151 |
+
style_weight=style_weight,
|
152 |
)
|
153 |
with torch.no_grad():
|
154 |
x_tst = phones.to(device).unsqueeze(0)
|
|
|
158 |
ja_bert = ja_bert.to(device).unsqueeze(0)
|
159 |
en_bert = en_bert.to(device).unsqueeze(0)
|
160 |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
161 |
+
# emo = emo.to(device).unsqueeze(0)
|
162 |
del phones
|
163 |
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
164 |
audio = (
|
|
|
171 |
bert,
|
172 |
ja_bert,
|
173 |
en_bert,
|
|
|
174 |
sdp_ratio=sdp_ratio,
|
175 |
noise_scale=noise_scale,
|
176 |
noise_scale_w=noise_scale_w,
|
|
|
180 |
.float()
|
181 |
.numpy()
|
182 |
)
|
183 |
+
del (
|
184 |
+
x_tst,
|
185 |
+
tones,
|
186 |
+
lang_ids,
|
187 |
+
bert,
|
188 |
+
x_tst_lengths,
|
189 |
+
speakers,
|
190 |
+
ja_bert,
|
191 |
+
en_bert,
|
192 |
+
) # , emo
|
193 |
+
if torch.cuda.is_available():
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
|
196 |
+
|
197 |
+
def inferAPI(
|
198 |
+
text,
|
199 |
+
sdp_ratio,
|
200 |
+
noise_scale,
|
201 |
+
noise_scale_w,
|
202 |
+
length_scale,
|
203 |
+
sid,
|
204 |
+
style_text=None,
|
205 |
+
style_weight=0.7,
|
206 |
+
):
|
207 |
+
|
208 |
+
language= 'JP' if is_japanese(text) else 'ZH'
|
209 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
210 |
+
text,
|
211 |
+
language,
|
212 |
+
hps,
|
213 |
+
device,
|
214 |
+
style_text=style_text,
|
215 |
+
style_weight=style_weight,
|
216 |
+
)
|
217 |
+
with torch.no_grad():
|
218 |
+
x_tst = phones.to(device).unsqueeze(0)
|
219 |
+
tones = tones.to(device).unsqueeze(0)
|
220 |
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
221 |
+
bert = bert.to(device).unsqueeze(0)
|
222 |
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
223 |
+
en_bert = en_bert.to(device).unsqueeze(0)
|
224 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
225 |
+
# emo = emo.to(device).unsqueeze(0)
|
226 |
+
del phones
|
227 |
+
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
228 |
+
audio = (
|
229 |
+
net_g.infer(
|
230 |
+
x_tst,
|
231 |
+
x_tst_lengths,
|
232 |
+
speakers,
|
233 |
+
tones,
|
234 |
+
lang_ids,
|
235 |
+
bert,
|
236 |
+
ja_bert,
|
237 |
+
en_bert,
|
238 |
+
sdp_ratio=sdp_ratio,
|
239 |
+
noise_scale=noise_scale,
|
240 |
+
noise_scale_w=noise_scale_w,
|
241 |
+
length_scale=length_scale,
|
242 |
+
)[0][0, 0]
|
243 |
+
.data.cpu()
|
244 |
+
.float()
|
245 |
+
.numpy()
|
246 |
+
)
|
247 |
+
del (
|
248 |
+
x_tst,
|
249 |
+
tones,
|
250 |
+
lang_ids,
|
251 |
+
bert,
|
252 |
+
x_tst_lengths,
|
253 |
+
speakers,
|
254 |
+
ja_bert,
|
255 |
+
en_bert,
|
256 |
+
) # , emo
|
257 |
if torch.cuda.is_available():
|
258 |
torch.cuda.empty_cache()
|
259 |
unique_filename = f"temp{uuid.uuid4()}.wav"
|
|
|
274 |
except:
|
275 |
return "error"
|
276 |
|
277 |
+
Flaskapp = Flask(__name__)
|
278 |
+
CORS(Flaskapp)
|
279 |
+
@Flaskapp.route('/')
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
+
@Flaskapp.route('/')
|
|
|
|
|
282 |
|
283 |
def tts():
|
284 |
global last_text, last_model
|
|
|
287 |
noise_scale = float(request.args.get('noise_scale', 0.6))
|
288 |
noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
|
289 |
length_scale = float(request.args.get('length_scale', 1))
|
290 |
+
style_weight = float(request.args.get('style_weight', 0.7))
|
291 |
+
style_text = request.args.get('style_text', 'happy')
|
292 |
text = request.args.get('text')
|
293 |
is_chat = request.args.get('is_chat', 'false').lower() == 'true'
|
294 |
model = request.args.get('model',modelPaths[-1])
|
|
|
301 |
<title>TTS API Documentation</title>
|
302 |
</head>
|
303 |
<body>
|
304 |
+
<iframe src="http://127.0.0.1:7860" style="width:100%; height:100vh; border:none;"></iframe>
|
305 |
</body>
|
306 |
</html>
|
307 |
""")
|
|
|
316 |
write(unique_filename , 44100, silence)
|
317 |
else:
|
318 |
last_text = text
|
319 |
+
unique_filename = inferAPI(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, style_text=style_text, style_weight=style_weight)
|
|
|
|
|
320 |
with open(unique_filename ,'rb') as bit:
|
321 |
wav_bytes = bit.read()
|
322 |
os.remove(unique_filename)
|
|
|
325 |
'Text': unique_filename .encode('utf-8')}
|
326 |
return wav_bytes, 200, headers
|
327 |
|
328 |
+
def gradio_interface():
|
329 |
+
return app.launch(share=True)
|
330 |
|
331 |
if __name__ == "__main__":
|
332 |
languages = [ "Auto", "ZH", "JP"]
|
333 |
modelPaths = []
|
334 |
+
for dirpath, dirnames, filenames in os.walk('Data/V23/models/'):
|
335 |
for filename in filenames:
|
336 |
modelPaths.append(os.path.join(dirpath, filename))
|
337 |
+
hps = utils.get_hparams_from_file('Data/V23/configs/config.json')
|
338 |
net_g = get_net_g(
|
339 |
model_path=modelPaths[-1], device=device, hps=hps
|
340 |
)
|
|
|
342 |
speakers = list(speaker_ids.keys())
|
343 |
last_text = ""
|
344 |
last_model = modelPaths[-1]
|
345 |
+
with gr.Blocks() as app:
|
346 |
+
for band in BandList:
|
347 |
+
with gr.TabItem(band):
|
348 |
+
for name in BandList[band]:
|
349 |
+
with gr.TabItem(name):
|
350 |
+
with gr.Row():
|
351 |
+
with gr.Column():
|
352 |
+
with gr.Row():
|
353 |
+
gr.Markdown(
|
354 |
+
'<div align="center">'
|
355 |
+
f'<img style="width:auto;height:400px;" src="https://mahiruoshi-bangdream-bert-vits2.hf.space/file/image/{name}.png">'
|
356 |
+
'</div>'
|
357 |
+
)
|
358 |
+
length_scale = gr.Slider(
|
359 |
+
minimum=0.1, maximum=2, value=1, step=0.01, label="语速调节"
|
360 |
+
)
|
361 |
+
with gr.Accordion(label="参数设定", open=False):
|
362 |
+
sdp_ratio = gr.Slider(
|
363 |
+
minimum=0, maximum=1, value=0.5, step=0.01, label="SDP/DP混合比"
|
364 |
+
)
|
365 |
+
noise_scale = gr.Slider(
|
366 |
+
minimum=0.1, maximum=2, value=0.6, step=0.01, label="感情调节"
|
367 |
+
)
|
368 |
+
noise_scale_w = gr.Slider(
|
369 |
+
minimum=0.1, maximum=2, value=0.667, step=0.01, label="音素长度"
|
370 |
+
)
|
371 |
+
speaker = gr.Dropdown(
|
372 |
+
choices=speakers, value=name, label="说话人"
|
373 |
+
)
|
374 |
+
with gr.Accordion(label="切换模型", open=False):
|
375 |
+
modelstrs = gr.Dropdown(label = "模型", choices = modelPaths, value = modelPaths[0], type = "value")
|
376 |
+
btnMod = gr.Button("载入模型")
|
377 |
+
statusa = gr.TextArea()
|
378 |
+
btnMod.click(loadmodel, inputs=[modelstrs], outputs = [statusa])
|
379 |
+
with gr.Column():
|
380 |
+
text = gr.TextArea(
|
381 |
+
label="输入纯日语或者中文",
|
382 |
+
placeholder="输入纯日语或者中文",
|
383 |
+
value="为什么要演奏春日影!",
|
384 |
+
)
|
385 |
+
style_text = gr.Textbox(label="辅助文本")
|
386 |
+
style_weight = gr.Slider(
|
387 |
+
minimum=0,
|
388 |
+
maximum=1,
|
389 |
+
value=0.7,
|
390 |
+
step=0.1,
|
391 |
+
label="Weight",
|
392 |
+
info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
|
393 |
+
)
|
394 |
+
btn = gr.Button("点击生成", variant="primary")
|
395 |
+
audio_output = gr.Audio(label="Output Audio")
|
396 |
+
'''
|
397 |
+
btntran = gr.Button("快速中翻日")
|
398 |
+
translateResult = gr.TextArea("从这复制翻译后的文本")
|
399 |
+
btntran.click(translate, inputs=[text], outputs = [translateResult])
|
400 |
+
'''
|
401 |
+
btn.click(
|
402 |
+
infer,
|
403 |
+
inputs=[
|
404 |
+
text,
|
405 |
+
sdp_ratio,
|
406 |
+
noise_scale,
|
407 |
+
noise_scale_w,
|
408 |
+
length_scale,
|
409 |
+
speaker,
|
410 |
+
style_text,
|
411 |
+
style_weight,
|
412 |
+
],
|
413 |
+
outputs=[audio_output],
|
414 |
+
)
|
415 |
+
|
416 |
+
api_thread = Thread(target=Flaskapp.run, args=("0.0.0.0", 5000))
|
417 |
+
gradio_thread = Thread(target=gradio_interface)
|
418 |
+
gradio_thread.start()
|
419 |
+
print("推理页面已开启!")
|
420 |
+
api_thread.start()
|
421 |
+
print("api页面已开启!运行在5000端口")
|
server_fastapi.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5 |
import gc
|
6 |
import random
|
7 |
|
|
|
8 |
import gradio
|
9 |
import numpy as np
|
10 |
import utils
|
@@ -203,28 +204,48 @@ if __name__ == "__main__":
|
|
203 |
auto_split: bool,
|
204 |
emotion: Optional[Union[int, str]] = None,
|
205 |
reference_audio=None,
|
|
|
|
|
206 |
) -> Union[Response, Dict[str, any]]:
|
207 |
"""TTS实现函数"""
|
208 |
# 检查模型是否存在
|
209 |
if model_id not in loaded_models.models.keys():
|
|
|
210 |
return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
|
211 |
# 检查是否提供speaker
|
212 |
if speaker_name is None and speaker_id is None:
|
|
|
213 |
return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
|
214 |
elif speaker_name is None:
|
215 |
# 检查speaker_id是否存在
|
216 |
if speaker_id not in loaded_models.models[model_id].id2spk.keys():
|
|
|
217 |
return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
|
218 |
speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
|
219 |
# 检查speaker_name是否存在
|
220 |
if speaker_name not in loaded_models.models[model_id].spk2id.keys():
|
|
|
221 |
return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
|
|
|
222 |
if language is None:
|
223 |
language = loaded_models.models[model_id].language
|
|
|
224 |
if auto_translate:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
text = trans.translate(Sentence=text, to_Language=language.lower())
|
226 |
if reference_audio is not None:
|
227 |
ref_audio = BytesIO(await reference_audio.read())
|
|
|
|
|
|
|
|
|
228 |
else:
|
229 |
ref_audio = reference_audio
|
230 |
if not auto_split:
|
@@ -242,6 +263,8 @@ if __name__ == "__main__":
|
|
242 |
device=loaded_models.models[model_id].device,
|
243 |
emotion=emotion,
|
244 |
reference_audio=ref_audio,
|
|
|
|
|
245 |
)
|
246 |
audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
|
247 |
else:
|
@@ -263,6 +286,8 @@ if __name__ == "__main__":
|
|
263 |
device=loaded_models.models[model_id].device,
|
264 |
emotion=emotion,
|
265 |
reference_audio=ref_audio,
|
|
|
|
|
266 |
)
|
267 |
)
|
268 |
audios.append(np.zeros(int(44100 * 0.2)))
|
@@ -293,6 +318,8 @@ if __name__ == "__main__":
|
|
293 |
auto_split: bool = Query(False, description="自动切分"),
|
294 |
emotion: Optional[Union[int, str]] = Query(None, description="emo"),
|
295 |
reference_audio: UploadFile = File(None),
|
|
|
|
|
296 |
):
|
297 |
"""语音接口,若需要上传参考音频请仅使用post请求"""
|
298 |
logger.info(
|
@@ -312,6 +339,8 @@ if __name__ == "__main__":
|
|
312 |
auto_split=auto_split,
|
313 |
emotion=emotion,
|
314 |
reference_audio=reference_audio,
|
|
|
|
|
315 |
)
|
316 |
|
317 |
@app.get("/voice")
|
@@ -331,6 +360,8 @@ if __name__ == "__main__":
|
|
331 |
auto_translate: bool = Query(False, description="自动翻译"),
|
332 |
auto_split: bool = Query(False, description="自动切分"),
|
333 |
emotion: Optional[Union[int, str]] = Query(None, description="emo"),
|
|
|
|
|
334 |
):
|
335 |
"""语音接口"""
|
336 |
logger.info(
|
@@ -349,6 +380,8 @@ if __name__ == "__main__":
|
|
349 |
auto_translate=auto_translate,
|
350 |
auto_split=auto_split,
|
351 |
emotion=emotion,
|
|
|
|
|
352 |
)
|
353 |
|
354 |
@app.get("/models/info")
|
@@ -370,7 +403,9 @@ if __name__ == "__main__":
|
|
370 |
)
|
371 |
result = loaded_models.del_model(model_id)
|
372 |
if result is None:
|
|
|
373 |
return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
|
|
|
374 |
return {"status": 0, "detail": "删除成功"}
|
375 |
|
376 |
@app.get("/models/add")
|
@@ -394,6 +429,7 @@ if __name__ == "__main__":
|
|
394 |
elif os.path.isfile(os.path.join(model_dir, "../config.json")):
|
395 |
config_path = os.path.join(model_dir, "../config.json")
|
396 |
else:
|
|
|
397 |
return {
|
398 |
"status": 15,
|
399 |
"detail": "查询未传入配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
|
@@ -628,8 +664,10 @@ if __name__ == "__main__":
|
|
628 |
f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
|
629 |
)
|
630 |
if not os.path.isfile(path):
|
|
|
631 |
return {"status": 18, "detail": "指定音频不存在"}
|
632 |
-
if not path.endswith(".wav"):
|
|
|
633 |
return {"status": 19, "detail": "非wav格式文件"}
|
634 |
return FileResponse(path=path)
|
635 |
|
|
|
5 |
import gc
|
6 |
import random
|
7 |
|
8 |
+
import librosa
|
9 |
import gradio
|
10 |
import numpy as np
|
11 |
import utils
|
|
|
204 |
auto_split: bool,
|
205 |
emotion: Optional[Union[int, str]] = None,
|
206 |
reference_audio=None,
|
207 |
+
style_text: Optional[str] = None,
|
208 |
+
style_weight: float = 0.7,
|
209 |
) -> Union[Response, Dict[str, any]]:
|
210 |
"""TTS实现函数"""
|
211 |
# 检查模型是否存在
|
212 |
if model_id not in loaded_models.models.keys():
|
213 |
+
logger.error(f"/voice 请求错误:模型model_id={model_id}未加载")
|
214 |
return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
|
215 |
# 检查是否提供speaker
|
216 |
if speaker_name is None and speaker_id is None:
|
217 |
+
logger.error("/voice 请求错误:推理请求未提供speaker_name或speaker_id")
|
218 |
return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
|
219 |
elif speaker_name is None:
|
220 |
# 检查speaker_id是否存在
|
221 |
if speaker_id not in loaded_models.models[model_id].id2spk.keys():
|
222 |
+
logger.error(f"/voice 请求错误:角色speaker_id={speaker_id}不存在")
|
223 |
return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
|
224 |
speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
|
225 |
# 检查speaker_name是否存在
|
226 |
if speaker_name not in loaded_models.models[model_id].spk2id.keys():
|
227 |
+
logger.error(f"/voice 请求错误:角色speaker_name={speaker_name}不存在")
|
228 |
return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
|
229 |
+
# 未传入则使用默认语言
|
230 |
if language is None:
|
231 |
language = loaded_models.models[model_id].language
|
232 |
+
# 翻译会破坏mix结构,auto也会变得无意义。不要在这两个模式下使用
|
233 |
if auto_translate:
|
234 |
+
if language == "auto" or language == "mix":
|
235 |
+
logger.error(
|
236 |
+
f"/voice 请求错误:请勿同时使用language = {language}与auto_translate模式"
|
237 |
+
)
|
238 |
+
return {
|
239 |
+
"status": 20,
|
240 |
+
"detail": f"请勿同时使用language = {language}与auto_translate模式",
|
241 |
+
}
|
242 |
text = trans.translate(Sentence=text, to_Language=language.lower())
|
243 |
if reference_audio is not None:
|
244 |
ref_audio = BytesIO(await reference_audio.read())
|
245 |
+
# 2.2 适配
|
246 |
+
if loaded_models.models[model_id].version == "2.2":
|
247 |
+
ref_audio, _ = librosa.load(ref_audio, 48000)
|
248 |
+
|
249 |
else:
|
250 |
ref_audio = reference_audio
|
251 |
if not auto_split:
|
|
|
263 |
device=loaded_models.models[model_id].device,
|
264 |
emotion=emotion,
|
265 |
reference_audio=ref_audio,
|
266 |
+
style_text=style_text,
|
267 |
+
style_weight=style_weight,
|
268 |
)
|
269 |
audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
|
270 |
else:
|
|
|
286 |
device=loaded_models.models[model_id].device,
|
287 |
emotion=emotion,
|
288 |
reference_audio=ref_audio,
|
289 |
+
style_text=style_text,
|
290 |
+
style_weight=style_weight,
|
291 |
)
|
292 |
)
|
293 |
audios.append(np.zeros(int(44100 * 0.2)))
|
|
|
318 |
auto_split: bool = Query(False, description="自动切分"),
|
319 |
emotion: Optional[Union[int, str]] = Query(None, description="emo"),
|
320 |
reference_audio: UploadFile = File(None),
|
321 |
+
style_text: Optional[str] = Form(None, description="风格文本"),
|
322 |
+
style_weight: float = Query(0.7, description="风格权重"),
|
323 |
):
|
324 |
"""语音接口,若需要上传参考音频请仅使用post请求"""
|
325 |
logger.info(
|
|
|
339 |
auto_split=auto_split,
|
340 |
emotion=emotion,
|
341 |
reference_audio=reference_audio,
|
342 |
+
style_text=style_text,
|
343 |
+
style_weight=style_weight,
|
344 |
)
|
345 |
|
346 |
@app.get("/voice")
|
|
|
360 |
auto_translate: bool = Query(False, description="自动翻译"),
|
361 |
auto_split: bool = Query(False, description="自动切分"),
|
362 |
emotion: Optional[Union[int, str]] = Query(None, description="emo"),
|
363 |
+
style_text: Optional[str] = Query(None, description="风格文本"),
|
364 |
+
style_weight: float = Query(0.7, description="风格权重"),
|
365 |
):
|
366 |
"""语音接口"""
|
367 |
logger.info(
|
|
|
380 |
auto_translate=auto_translate,
|
381 |
auto_split=auto_split,
|
382 |
emotion=emotion,
|
383 |
+
style_text=style_text,
|
384 |
+
style_weight=style_weight,
|
385 |
)
|
386 |
|
387 |
@app.get("/models/info")
|
|
|
403 |
)
|
404 |
result = loaded_models.del_model(model_id)
|
405 |
if result is None:
|
406 |
+
logger.error(f"/models/delete 模型删除错误:模型{model_id}不存在,删除失败")
|
407 |
return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
|
408 |
+
|
409 |
return {"status": 0, "detail": "删除成功"}
|
410 |
|
411 |
@app.get("/models/add")
|
|
|
429 |
elif os.path.isfile(os.path.join(model_dir, "../config.json")):
|
430 |
config_path = os.path.join(model_dir, "../config.json")
|
431 |
else:
|
432 |
+
logger.error("/models/add 模型添加失败:未在模型所在目录以及上级目录找到config.json文件")
|
433 |
return {
|
434 |
"status": 15,
|
435 |
"detail": "查询未传入配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
|
|
|
664 |
f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
|
665 |
)
|
666 |
if not os.path.isfile(path):
|
667 |
+
logger.error(f"/tools/get_audio 获取音频错误:指定音频{path}不存在")
|
668 |
return {"status": 18, "detail": "指定音频不存在"}
|
669 |
+
if not path.lower().endswith(".wav"):
|
670 |
+
logger.error(f"/tools/get_audio 获取音频错误:音频{path}非wav文件")
|
671 |
return {"status": 19, "detail": "非wav格式文件"}
|
672 |
return FileResponse(path=path)
|
673 |
|
slm/wavlm-base-plus/.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
slm/wavlm-base-plus/README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
datasets:
|
5 |
+
tags:
|
6 |
+
- speech
|
7 |
+
inference: false
|
8 |
+
---
|
9 |
+
|
10 |
+
# WavLM-Base-Plus
|
11 |
+
|
12 |
+
[Microsoft's WavLM](https://github.com/microsoft/unilm/tree/master/wavlm)
|
13 |
+
|
14 |
+
The base model pretrained on 16kHz sampled speech audio. When using the model, make sure that your speech input is also sampled at 16kHz.
|
15 |
+
|
16 |
+
**Note**: This model does not have a tokenizer as it was pretrained on audio alone. In order to use this model **speech recognition**, a tokenizer should be created and the model should be fine-tuned on labeled text data. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more in-detail explanation of how to fine-tune the model.
|
17 |
+
|
18 |
+
The model was pre-trained on:
|
19 |
+
|
20 |
+
- 60,000 hours of [Libri-Light](https://arxiv.org/abs/1912.07875)
|
21 |
+
- 10,000 hours of [GigaSpeech](https://arxiv.org/abs/2106.06909)
|
22 |
+
- 24,000 hours of [VoxPopuli](https://arxiv.org/abs/2101.00390)
|
23 |
+
|
24 |
+
[Paper: WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900)
|
25 |
+
|
26 |
+
Authors: Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei
|
27 |
+
|
28 |
+
**Abstract**
|
29 |
+
*Self-supervised learning (SSL) achieves great success in speech recognition, while limited exploration has been attempted for other speech processing tasks. As speech signal contains multi-faceted information including speaker identity, paralinguistics, spoken content, etc., learning universal representations for all speech tasks is challenging. In this paper, we propose a new pre-trained model, WavLM, to solve full-stack downstream speech tasks. WavLM is built based on the HuBERT framework, with an emphasis on both spoken content modeling and speaker identity preservation. We first equip the Transformer structure with gated relative position bias to improve its capability on recognition tasks. For better speaker discrimination, we propose an utterance mixing training strategy, where additional overlapped utterances are created unsupervisely and incorporated during model training. Lastly, we scale up the training dataset from 60k hours to 94k hours. WavLM Large achieves state-of-the-art performance on the SUPERB benchmark, and brings significant improvements for various speech processing tasks on their representative benchmarks.*
|
30 |
+
|
31 |
+
The original model can be found under https://github.com/microsoft/unilm/tree/master/wavlm.
|
32 |
+
|
33 |
+
# Usage
|
34 |
+
|
35 |
+
This is an English pre-trained speech model that has to be fine-tuned on a downstream task like speech recognition or audio classification before it can be
|
36 |
+
used in inference. The model was pre-trained in English and should therefore perform well only in English. The model has been shown to work well on the [SUPERB benchmark](https://superbbenchmark.org/).
|
37 |
+
|
38 |
+
**Note**: The model was pre-trained on phonemes rather than characters. This means that one should make sure that the input text is converted to a sequence
|
39 |
+
of phonemes before fine-tuning.
|
40 |
+
|
41 |
+
## Speech Recognition
|
42 |
+
|
43 |
+
To fine-tune the model for speech recognition, see [the official speech recognition example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/speech-recognition).
|
44 |
+
|
45 |
+
## Speech Classification
|
46 |
+
|
47 |
+
To fine-tune the model for speech classification, see [the official audio classification example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/audio-classification).
|
48 |
+
|
49 |
+
## Speaker Verification
|
50 |
+
|
51 |
+
TODO
|
52 |
+
|
53 |
+
## Speaker Diarization
|
54 |
+
|
55 |
+
TODO
|
56 |
+
|
57 |
+
# Contribution
|
58 |
+
|
59 |
+
The model was contributed by [cywang](https://huggingface.co/cywang) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
|
60 |
+
|
61 |
+
# License
|
62 |
+
|
63 |
+
The official license can be found [here](https://github.com/microsoft/UniSpeech/blob/main/LICENSE)
|
64 |
+
|
65 |
+
![design](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/wavlm.png)
|
slm/wavlm-base-plus/config.json
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "wavlm-base-plus",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"adapter_kernel_size": 3,
|
5 |
+
"adapter_stride": 2,
|
6 |
+
"add_adapter": false,
|
7 |
+
"apply_spec_augment": true,
|
8 |
+
"architectures": [
|
9 |
+
"WavLMModel"
|
10 |
+
],
|
11 |
+
"attention_dropout": 0.1,
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"classifier_proj_size": 256,
|
14 |
+
"codevector_dim": 256,
|
15 |
+
"contrastive_logits_temperature": 0.1,
|
16 |
+
"conv_bias": false,
|
17 |
+
"conv_dim": [
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512,
|
24 |
+
512
|
25 |
+
],
|
26 |
+
"conv_kernel": [
|
27 |
+
10,
|
28 |
+
3,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
3,
|
32 |
+
2,
|
33 |
+
2
|
34 |
+
],
|
35 |
+
"conv_stride": [
|
36 |
+
5,
|
37 |
+
2,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2
|
43 |
+
],
|
44 |
+
"ctc_loss_reduction": "sum",
|
45 |
+
"ctc_zero_infinity": false,
|
46 |
+
"diversity_loss_weight": 0.1,
|
47 |
+
"do_stable_layer_norm": false,
|
48 |
+
"eos_token_id": 2,
|
49 |
+
"feat_extract_activation": "gelu",
|
50 |
+
"feat_extract_norm": "group",
|
51 |
+
"feat_proj_dropout": 0.1,
|
52 |
+
"feat_quantizer_dropout": 0.0,
|
53 |
+
"final_dropout": 0.0,
|
54 |
+
"freeze_feat_extract_train": true,
|
55 |
+
"hidden_act": "gelu",
|
56 |
+
"hidden_dropout": 0.1,
|
57 |
+
"hidden_size": 768,
|
58 |
+
"initializer_range": 0.02,
|
59 |
+
"intermediate_size": 3072,
|
60 |
+
"layer_norm_eps": 1e-05,
|
61 |
+
"layerdrop": 0.05,
|
62 |
+
"mask_channel_length": 10,
|
63 |
+
"mask_channel_min_space": 1,
|
64 |
+
"mask_channel_other": 0.0,
|
65 |
+
"mask_channel_prob": 0.0,
|
66 |
+
"mask_channel_selection": "static",
|
67 |
+
"mask_feature_length": 10,
|
68 |
+
"mask_feature_min_masks": 0,
|
69 |
+
"mask_feature_prob": 0.0,
|
70 |
+
"mask_time_length": 10,
|
71 |
+
"mask_time_min_masks": 2,
|
72 |
+
"mask_time_min_space": 1,
|
73 |
+
"mask_time_other": 0.0,
|
74 |
+
"mask_time_prob": 0.05,
|
75 |
+
"mask_time_selection": "static",
|
76 |
+
"model_type": "wavlm",
|
77 |
+
"no_mask_channel_overlap": false,
|
78 |
+
"no_mask_time_overlap": false,
|
79 |
+
"num_adapter_layers": 3,
|
80 |
+
"num_attention_heads": 12,
|
81 |
+
"num_buckets": 320,
|
82 |
+
"num_codevector_groups": 2,
|
83 |
+
"num_codevectors_per_group": 320,
|
84 |
+
"num_conv_pos_embedding_groups": 16,
|
85 |
+
"num_conv_pos_embeddings": 128,
|
86 |
+
"num_ctc_classes": 80,
|
87 |
+
"num_feat_extract_layers": 7,
|
88 |
+
"num_hidden_layers": 12,
|
89 |
+
"num_negatives": 100,
|
90 |
+
"output_hidden_size": 768,
|
91 |
+
"pad_token_id": 0,
|
92 |
+
"proj_codevector_dim": 256,
|
93 |
+
"replace_prob": 0.5,
|
94 |
+
"torch_dtype": "float32",
|
95 |
+
"transformers_version": "4.13.0.dev0",
|
96 |
+
"use_weighted_layer_sum": false,
|
97 |
+
"vocab_size": 32,
|
98 |
+
"tokenizer_class": "Wav2Vec2CTCTokenizer"
|
99 |
+
}
|
slm/wavlm-base-plus/preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": false,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"return_attention_mask": true,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
slm/wavlm-base-plus/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3bb273a6ace99408b50cfc81afdbb7ef2de02da2eab0234e18db608ce692fe51
|
3 |
+
size 377617425
|
text/__init__.py
CHANGED
@@ -18,13 +18,15 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
|
|
18 |
return phones, tones, lang_ids
|
19 |
|
20 |
|
21 |
-
def get_bert(norm_text, word2ph, language, device):
|
22 |
from .chinese_bert import get_bert_feature as zh_bert
|
23 |
from .english_bert_mock import get_bert_feature as en_bert
|
24 |
from .japanese_bert import get_bert_feature as jp_bert
|
25 |
|
26 |
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
|
27 |
-
bert = lang_bert_func_map[language](
|
|
|
|
|
28 |
return bert
|
29 |
|
30 |
|
|
|
18 |
return phones, tones, lang_ids
|
19 |
|
20 |
|
21 |
+
def get_bert(norm_text, word2ph, language, device, style_text=None, style_weight=0.7):
|
22 |
from .chinese_bert import get_bert_feature as zh_bert
|
23 |
from .english_bert_mock import get_bert_feature as en_bert
|
24 |
from .japanese_bert import get_bert_feature as jp_bert
|
25 |
|
26 |
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
|
27 |
+
bert = lang_bert_func_map[language](
|
28 |
+
norm_text, word2ph, device, style_text, style_weight
|
29 |
+
)
|
30 |
return bert
|
31 |
|
32 |
|
text/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/__init__.cpython-311.pyc and b/text/__pycache__/__init__.cpython-311.pyc differ
|
|
text/__pycache__/bert_utils.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/bert_utils.cpython-311.pyc and b/text/__pycache__/bert_utils.cpython-311.pyc differ
|
|
text/__pycache__/chinese.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/chinese.cpython-311.pyc and b/text/__pycache__/chinese.cpython-311.pyc differ
|
|
text/__pycache__/chinese_bert.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/chinese_bert.cpython-311.pyc and b/text/__pycache__/chinese_bert.cpython-311.pyc differ
|
|
text/__pycache__/cleaner.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/cleaner.cpython-311.pyc and b/text/__pycache__/cleaner.cpython-311.pyc differ
|
|
text/__pycache__/english.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/english.cpython-311.pyc and b/text/__pycache__/english.cpython-311.pyc differ
|
|
text/__pycache__/english_bert_mock.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/english_bert_mock.cpython-311.pyc and b/text/__pycache__/english_bert_mock.cpython-311.pyc differ
|
|
text/__pycache__/japanese.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/japanese.cpython-311.pyc and b/text/__pycache__/japanese.cpython-311.pyc differ
|
|
text/__pycache__/japanese_bert.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/japanese_bert.cpython-311.pyc and b/text/__pycache__/japanese_bert.cpython-311.pyc differ
|
|
text/__pycache__/symbols.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/symbols.cpython-311.pyc and b/text/__pycache__/symbols.cpython-311.pyc differ
|
|
text/__pycache__/tone_sandhi.cpython-311.pyc
CHANGED
Binary files a/text/__pycache__/tone_sandhi.cpython-311.pyc and b/text/__pycache__/tone_sandhi.cpython-311.pyc differ
|
|
text/chinese_bert.py
CHANGED
@@ -12,7 +12,13 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
|
|
12 |
models = dict()
|
13 |
|
14 |
|
15 |
-
def get_bert_feature(
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
if (
|
17 |
sys.platform == "darwin"
|
18 |
and torch.backends.mps.is_available()
|
@@ -29,12 +35,24 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
|
|
29 |
inputs[i] = inputs[i].to(device)
|
30 |
res = models[device](**inputs, output_hidden_states=True)
|
31 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
assert len(word2ph) == len(text) + 2
|
34 |
word2phone = word2ph
|
35 |
phone_level_feature = []
|
36 |
for i in range(len(word2phone)):
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
phone_level_feature.append(repeat_feature)
|
39 |
|
40 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
12 |
models = dict()
|
13 |
|
14 |
|
15 |
+
def get_bert_feature(
|
16 |
+
text,
|
17 |
+
word2ph,
|
18 |
+
device=config.bert_gen_config.device,
|
19 |
+
style_text=None,
|
20 |
+
style_weight=0.7,
|
21 |
+
):
|
22 |
if (
|
23 |
sys.platform == "darwin"
|
24 |
and torch.backends.mps.is_available()
|
|
|
35 |
inputs[i] = inputs[i].to(device)
|
36 |
res = models[device](**inputs, output_hidden_states=True)
|
37 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
38 |
+
if style_text:
|
39 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
40 |
+
for i in style_inputs:
|
41 |
+
style_inputs[i] = style_inputs[i].to(device)
|
42 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
43 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
|
44 |
+
style_res_mean = style_res.mean(0)
|
45 |
assert len(word2ph) == len(text) + 2
|
46 |
word2phone = word2ph
|
47 |
phone_level_feature = []
|
48 |
for i in range(len(word2phone)):
|
49 |
+
if style_text:
|
50 |
+
repeat_feature = (
|
51 |
+
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
|
52 |
+
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
56 |
phone_level_feature.append(repeat_feature)
|
57 |
|
58 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
text/cleaner.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from text import chinese, japanese, cleaned_text_to_sequence
|
2 |
|
3 |
|
4 |
-
language_module_map = {"ZH": chinese, "JP": japanese}
|
5 |
|
6 |
|
7 |
def clean_text(text, language):
|
|
|
1 |
+
from text import chinese, japanese, english, cleaned_text_to_sequence
|
2 |
|
3 |
|
4 |
+
language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
|
5 |
|
6 |
|
7 |
def clean_text(text, language):
|
text/english.py
CHANGED
@@ -5,6 +5,7 @@ from g2p_en import G2p
|
|
5 |
from transformers import DebertaV2Tokenizer
|
6 |
|
7 |
from text import symbols
|
|
|
8 |
|
9 |
current_file_path = os.path.dirname(__file__)
|
10 |
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
|
@@ -217,6 +218,8 @@ def refine_ph(phn):
|
|
217 |
if re.search(r"\d$", phn):
|
218 |
tone = int(phn[-1]) + 1
|
219 |
phn = phn[:-1]
|
|
|
|
|
220 |
return phn.lower(), tone
|
221 |
|
222 |
|
@@ -389,45 +392,84 @@ def sep_text(text):
|
|
389 |
return words
|
390 |
|
391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
def g2p(text):
|
393 |
phones = []
|
394 |
tones = []
|
395 |
-
|
396 |
-
words = sep_text(text)
|
397 |
-
tokens = [tokenizer.tokenize(i) for i in words]
|
|
|
|
|
398 |
for word in words:
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
word2ph = []
|
422 |
-
for token,
|
423 |
-
phone_len = len(phoneme)
|
424 |
word_len = len(token)
|
425 |
|
426 |
-
aaa = distribute_phone(
|
427 |
word2ph += aaa
|
428 |
|
429 |
-
phones = ["_"] +
|
430 |
-
tones = [0] +
|
431 |
word2ph = [1] + word2ph + [1]
|
432 |
assert len(phones) == len(tones), text
|
433 |
assert len(phones) == sum(word2ph), text
|
|
|
5 |
from transformers import DebertaV2Tokenizer
|
6 |
|
7 |
from text import symbols
|
8 |
+
from text.symbols import punctuation
|
9 |
|
10 |
current_file_path = os.path.dirname(__file__)
|
11 |
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
|
|
|
218 |
if re.search(r"\d$", phn):
|
219 |
tone = int(phn[-1]) + 1
|
220 |
phn = phn[:-1]
|
221 |
+
else:
|
222 |
+
tone = 3
|
223 |
return phn.lower(), tone
|
224 |
|
225 |
|
|
|
392 |
return words
|
393 |
|
394 |
|
395 |
+
def text_to_words(text):
|
396 |
+
tokens = tokenizer.tokenize(text)
|
397 |
+
words = []
|
398 |
+
for idx, t in enumerate(tokens):
|
399 |
+
if t.startswith("▁"):
|
400 |
+
words.append([t[1:]])
|
401 |
+
else:
|
402 |
+
if t in punctuation:
|
403 |
+
if idx == len(tokens) - 1:
|
404 |
+
words.append([f"{t}"])
|
405 |
+
else:
|
406 |
+
if (
|
407 |
+
not tokens[idx + 1].startswith("▁")
|
408 |
+
and tokens[idx + 1] not in punctuation
|
409 |
+
):
|
410 |
+
if idx == 0:
|
411 |
+
words.append([])
|
412 |
+
words[-1].append(f"{t}")
|
413 |
+
else:
|
414 |
+
words.append([f"{t}"])
|
415 |
+
else:
|
416 |
+
if idx == 0:
|
417 |
+
words.append([])
|
418 |
+
words[-1].append(f"{t}")
|
419 |
+
return words
|
420 |
+
|
421 |
+
|
422 |
def g2p(text):
|
423 |
phones = []
|
424 |
tones = []
|
425 |
+
phone_len = []
|
426 |
+
# words = sep_text(text)
|
427 |
+
# tokens = [tokenizer.tokenize(i) for i in words]
|
428 |
+
words = text_to_words(text)
|
429 |
+
|
430 |
for word in words:
|
431 |
+
temp_phones, temp_tones = [], []
|
432 |
+
if len(word) > 1:
|
433 |
+
if "'" in word:
|
434 |
+
word = ["".join(word)]
|
435 |
+
for w in word:
|
436 |
+
if w in punctuation:
|
437 |
+
temp_phones.append(w)
|
438 |
+
temp_tones.append(0)
|
439 |
+
continue
|
440 |
+
if w.upper() in eng_dict:
|
441 |
+
phns, tns = refine_syllables(eng_dict[w.upper()])
|
442 |
+
temp_phones += [post_replace_ph(i) for i in phns]
|
443 |
+
temp_tones += tns
|
444 |
+
# w2ph.append(len(phns))
|
445 |
+
else:
|
446 |
+
phone_list = list(filter(lambda p: p != " ", _g2p(w)))
|
447 |
+
phns = []
|
448 |
+
tns = []
|
449 |
+
for ph in phone_list:
|
450 |
+
if ph in arpa:
|
451 |
+
ph, tn = refine_ph(ph)
|
452 |
+
phns.append(ph)
|
453 |
+
tns.append(tn)
|
454 |
+
else:
|
455 |
+
phns.append(ph)
|
456 |
+
tns.append(0)
|
457 |
+
temp_phones += [post_replace_ph(i) for i in phns]
|
458 |
+
temp_tones += tns
|
459 |
+
phones += temp_phones
|
460 |
+
tones += temp_tones
|
461 |
+
phone_len.append(len(temp_phones))
|
462 |
+
# phones = [post_replace_ph(i) for i in phones]
|
463 |
|
464 |
word2ph = []
|
465 |
+
for token, pl in zip(words, phone_len):
|
|
|
466 |
word_len = len(token)
|
467 |
|
468 |
+
aaa = distribute_phone(pl, word_len)
|
469 |
word2ph += aaa
|
470 |
|
471 |
+
phones = ["_"] + phones + ["_"]
|
472 |
+
tones = [0] + tones + [0]
|
473 |
word2ph = [1] + word2ph + [1]
|
474 |
assert len(phones) == len(tones), text
|
475 |
assert len(phones) == sum(word2ph), text
|
text/english_bert_mock.py
CHANGED
@@ -13,7 +13,13 @@ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
|
|
13 |
models = dict()
|
14 |
|
15 |
|
16 |
-
def get_bert_feature(
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
if (
|
18 |
sys.platform == "darwin"
|
19 |
and torch.backends.mps.is_available()
|
@@ -30,11 +36,24 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
|
|
30 |
inputs[i] = inputs[i].to(device)
|
31 |
res = models[device](**inputs, output_hidden_states=True)
|
32 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
|
34 |
word2phone = word2ph
|
35 |
phone_level_feature = []
|
36 |
for i in range(len(word2phone)):
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
phone_level_feature.append(repeat_feature)
|
39 |
|
40 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
13 |
models = dict()
|
14 |
|
15 |
|
16 |
+
def get_bert_feature(
|
17 |
+
text,
|
18 |
+
word2ph,
|
19 |
+
device=config.bert_gen_config.device,
|
20 |
+
style_text=None,
|
21 |
+
style_weight=0.7,
|
22 |
+
):
|
23 |
if (
|
24 |
sys.platform == "darwin"
|
25 |
and torch.backends.mps.is_available()
|
|
|
36 |
inputs[i] = inputs[i].to(device)
|
37 |
res = models[device](**inputs, output_hidden_states=True)
|
38 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
39 |
+
if style_text:
|
40 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
41 |
+
for i in style_inputs:
|
42 |
+
style_inputs[i] = style_inputs[i].to(device)
|
43 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
44 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
|
45 |
+
style_res_mean = style_res.mean(0)
|
46 |
assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
|
47 |
word2phone = word2ph
|
48 |
phone_level_feature = []
|
49 |
for i in range(len(word2phone)):
|
50 |
+
if style_text:
|
51 |
+
repeat_feature = (
|
52 |
+
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
|
53 |
+
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
57 |
phone_level_feature.append(repeat_feature)
|
58 |
|
59 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
text/japanese_bert.py
CHANGED
@@ -13,8 +13,16 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
|
|
13 |
models = dict()
|
14 |
|
15 |
|
16 |
-
def get_bert_feature(
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
text = "".join(text2sep_kata(text)[0])
|
|
|
|
|
18 |
if (
|
19 |
sys.platform == "darwin"
|
20 |
and torch.backends.mps.is_available()
|
@@ -31,12 +39,25 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
|
|
31 |
inputs[i] = inputs[i].to(device)
|
32 |
res = models[device](**inputs, output_hidden_states=True)
|
33 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
assert len(word2ph) == len(text) + 2
|
36 |
word2phone = word2ph
|
37 |
phone_level_feature = []
|
38 |
for i in range(len(word2phone)):
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
phone_level_feature.append(repeat_feature)
|
41 |
|
42 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
13 |
models = dict()
|
14 |
|
15 |
|
16 |
+
def get_bert_feature(
|
17 |
+
text,
|
18 |
+
word2ph,
|
19 |
+
device=config.bert_gen_config.device,
|
20 |
+
style_text=None,
|
21 |
+
style_weight=0.7,
|
22 |
+
):
|
23 |
text = "".join(text2sep_kata(text)[0])
|
24 |
+
if style_text:
|
25 |
+
style_text = "".join(text2sep_kata(style_text)[0])
|
26 |
if (
|
27 |
sys.platform == "darwin"
|
28 |
and torch.backends.mps.is_available()
|
|
|
39 |
inputs[i] = inputs[i].to(device)
|
40 |
res = models[device](**inputs, output_hidden_states=True)
|
41 |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
42 |
+
if style_text:
|
43 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
44 |
+
for i in style_inputs:
|
45 |
+
style_inputs[i] = style_inputs[i].to(device)
|
46 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
47 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
|
48 |
+
style_res_mean = style_res.mean(0)
|
49 |
|
50 |
assert len(word2ph) == len(text) + 2
|
51 |
word2phone = word2ph
|
52 |
phone_level_feature = []
|
53 |
for i in range(len(word2phone)):
|
54 |
+
if style_text:
|
55 |
+
repeat_feature = (
|
56 |
+
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
|
57 |
+
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
61 |
phone_level_feature.append(repeat_feature)
|
62 |
|
63 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
text/tone_sandhi.py
CHANGED
@@ -634,9 +634,11 @@ class ToneSandhi:
|
|
634 |
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
|
635 |
# output seg: [['听一听', 'v']]
|
636 |
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
637 |
-
new_seg = []
|
638 |
# function 1
|
639 |
-
|
|
|
|
|
640 |
if (
|
641 |
i - 1 >= 0
|
642 |
and word == "一"
|
@@ -645,6 +647,7 @@ class ToneSandhi:
|
|
645 |
and seg[i - 1][1] == "v"
|
646 |
):
|
647 |
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
|
|
|
648 |
else:
|
649 |
if (
|
650 |
i - 2 >= 0
|
@@ -655,7 +658,8 @@ class ToneSandhi:
|
|
655 |
continue
|
656 |
else:
|
657 |
new_seg.append([word, pos])
|
658 |
-
|
|
|
659 |
new_seg = []
|
660 |
# function 2
|
661 |
for i, (word, pos) in enumerate(seg):
|
|
|
634 |
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
|
635 |
# output seg: [['听一听', 'v']]
|
636 |
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
637 |
+
new_seg = [] * len(seg)
|
638 |
# function 1
|
639 |
+
i = 0
|
640 |
+
while i < len(seg):
|
641 |
+
word, pos = seg[i]
|
642 |
if (
|
643 |
i - 1 >= 0
|
644 |
and word == "一"
|
|
|
647 |
and seg[i - 1][1] == "v"
|
648 |
):
|
649 |
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
|
650 |
+
i += 2
|
651 |
else:
|
652 |
if (
|
653 |
i - 2 >= 0
|
|
|
658 |
continue
|
659 |
else:
|
660 |
new_seg.append([word, pos])
|
661 |
+
i += 1
|
662 |
+
seg = [i for i in new_seg if len(i) > 0]
|
663 |
new_seg = []
|
664 |
# function 2
|
665 |
for i, (word, pos) in enumerate(seg):
|
tools/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/tools/__pycache__/__init__.cpython-311.pyc and b/tools/__pycache__/__init__.cpython-311.pyc differ
|
|
tools/__pycache__/classify_language.cpython-311.pyc
CHANGED
Binary files a/tools/__pycache__/classify_language.cpython-311.pyc and b/tools/__pycache__/classify_language.cpython-311.pyc differ
|
|
tools/__pycache__/log.cpython-311.pyc
ADDED
Binary file (547 Bytes). View file
|
|
tools/__pycache__/sentence.cpython-311.pyc
CHANGED
Binary files a/tools/__pycache__/sentence.cpython-311.pyc and b/tools/__pycache__/sentence.cpython-311.pyc differ
|
|
tools/__pycache__/translate.cpython-311.pyc
CHANGED
Binary files a/tools/__pycache__/translate.cpython-311.pyc and b/tools/__pycache__/translate.cpython-311.pyc differ
|
|
train_ms.py
CHANGED
@@ -27,8 +27,15 @@ from models import (
|
|
27 |
SynthesizerTrn,
|
28 |
MultiPeriodDiscriminator,
|
29 |
DurationDiscriminator,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
31 |
-
from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
32 |
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
33 |
from text.symbols import symbols
|
34 |
|
@@ -42,7 +49,6 @@ torch.backends.cuda.enable_flash_sdp(True)
|
|
42 |
torch.backends.cuda.enable_mem_efficient_sdp(
|
43 |
True
|
44 |
) # Not available if torch version is lower than 2.0
|
45 |
-
torch.backends.cuda.enable_math_sdp(True)
|
46 |
global_step = 0
|
47 |
|
48 |
|
@@ -173,6 +179,8 @@ def run():
|
|
173 |
0.1,
|
174 |
gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
|
175 |
).cuda(local_rank)
|
|
|
|
|
176 |
if (
|
177 |
"use_spk_conditioned_encoder" in hps.model.keys()
|
178 |
and hps.model.use_spk_conditioned_encoder is True
|
@@ -210,6 +218,9 @@ def run():
|
|
210 |
param.requires_grad = False
|
211 |
|
212 |
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
|
|
|
|
|
|
|
213 |
optim_g = torch.optim.AdamW(
|
214 |
filter(lambda p: p.requires_grad, net_g.parameters()),
|
215 |
hps.train.learning_rate,
|
@@ -222,6 +233,12 @@ def run():
|
|
222 |
betas=hps.train.betas,
|
223 |
eps=hps.train.eps,
|
224 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
if net_dur_disc is not None:
|
226 |
optim_dur_disc = torch.optim.AdamW(
|
227 |
net_dur_disc.parameters(),
|
@@ -233,12 +250,11 @@ def run():
|
|
233 |
optim_dur_disc = None
|
234 |
net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
|
235 |
net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
|
236 |
-
|
237 |
if net_dur_disc is not None:
|
238 |
net_dur_disc = DDP(
|
239 |
net_dur_disc,
|
240 |
device_ids=[local_rank],
|
241 |
-
find_unused_parameters=True,
|
242 |
bucket_cap_mb=512,
|
243 |
)
|
244 |
|
@@ -250,9 +266,10 @@ def run():
|
|
250 |
token=config.openi_token,
|
251 |
mirror=config.mirror,
|
252 |
)
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
256 |
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
|
257 |
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
|
258 |
net_dur_disc,
|
@@ -261,28 +278,32 @@ def run():
|
|
261 |
if "skip_optimizer" in hps.train
|
262 |
else True,
|
263 |
)
|
264 |
-
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
|
265 |
-
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
|
266 |
-
net_g,
|
267 |
-
optim_g,
|
268 |
-
skip_optimizer=hps.train.skip_optimizer
|
269 |
-
if "skip_optimizer" in hps.train
|
270 |
-
else True,
|
271 |
-
)
|
272 |
-
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
|
273 |
-
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
|
274 |
-
net_d,
|
275 |
-
optim_d,
|
276 |
-
skip_optimizer=hps.train.skip_optimizer
|
277 |
-
if "skip_optimizer" in hps.train
|
278 |
-
else True,
|
279 |
-
)
|
280 |
-
if not optim_g.param_groups[0].get("initial_lr"):
|
281 |
-
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
|
282 |
-
if not optim_d.param_groups[0].get("initial_lr"):
|
283 |
-
optim_d.param_groups[0]["initial_lr"] = d_resume_lr
|
284 |
if not optim_dur_disc.param_groups[0].get("initial_lr"):
|
285 |
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
epoch_str = max(epoch_str, 1)
|
288 |
# global_step = (epoch_str - 1) * len(train_loader)
|
@@ -297,21 +318,43 @@ def run():
|
|
297 |
epoch_str = 1
|
298 |
global_step = 0
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
301 |
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
302 |
)
|
303 |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
304 |
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
305 |
)
|
|
|
|
|
|
|
306 |
if net_dur_disc is not None:
|
307 |
-
if not optim_dur_disc.param_groups[0].get("initial_lr"):
|
308 |
-
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
|
309 |
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
|
310 |
optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
311 |
)
|
312 |
else:
|
313 |
scheduler_dur_disc = None
|
314 |
-
scaler = GradScaler(enabled=hps.train.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
317 |
if rank == 0:
|
@@ -320,9 +363,9 @@ def run():
|
|
320 |
local_rank,
|
321 |
epoch,
|
322 |
hps,
|
323 |
-
[net_g, net_d, net_dur_disc],
|
324 |
-
[optim_g, optim_d, optim_dur_disc],
|
325 |
-
[scheduler_g, scheduler_d, scheduler_dur_disc],
|
326 |
scaler,
|
327 |
[train_loader, eval_loader],
|
328 |
logger,
|
@@ -334,9 +377,9 @@ def run():
|
|
334 |
local_rank,
|
335 |
epoch,
|
336 |
hps,
|
337 |
-
[net_g, net_d, net_dur_disc],
|
338 |
-
[optim_g, optim_d, optim_dur_disc],
|
339 |
-
[scheduler_g, scheduler_d, scheduler_dur_disc],
|
340 |
scaler,
|
341 |
[train_loader, None],
|
342 |
None,
|
@@ -344,6 +387,7 @@ def run():
|
|
344 |
)
|
345 |
scheduler_g.step()
|
346 |
scheduler_d.step()
|
|
|
347 |
if net_dur_disc is not None:
|
348 |
scheduler_dur_disc.step()
|
349 |
|
@@ -361,9 +405,9 @@ def train_and_evaluate(
|
|
361 |
logger,
|
362 |
writers,
|
363 |
):
|
364 |
-
net_g, net_d, net_dur_disc = nets
|
365 |
-
optim_g, optim_d, optim_dur_disc = optims
|
366 |
-
scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
|
367 |
train_loader, eval_loader = loaders
|
368 |
if writers is not None:
|
369 |
writer, writer_eval = writers
|
@@ -373,6 +417,7 @@ def train_and_evaluate(
|
|
373 |
|
374 |
net_g.train()
|
375 |
net_d.train()
|
|
|
376 |
if net_dur_disc is not None:
|
377 |
net_dur_disc.train()
|
378 |
for batch_idx, (
|
@@ -388,7 +433,6 @@ def train_and_evaluate(
|
|
388 |
bert,
|
389 |
ja_bert,
|
390 |
en_bert,
|
391 |
-
emo,
|
392 |
) in enumerate(tqdm(train_loader)):
|
393 |
if net_g.module.use_noise_scaled_mas:
|
394 |
current_mas_noise_scale = (
|
@@ -411,9 +455,8 @@ def train_and_evaluate(
|
|
411 |
bert = bert.cuda(local_rank, non_blocking=True)
|
412 |
ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
|
413 |
en_bert = en_bert.cuda(local_rank, non_blocking=True)
|
414 |
-
emo = emo.cuda(local_rank, non_blocking=True)
|
415 |
|
416 |
-
with autocast(enabled=hps.train.
|
417 |
(
|
418 |
y_hat,
|
419 |
l_length,
|
@@ -422,9 +465,8 @@ def train_and_evaluate(
|
|
422 |
x_mask,
|
423 |
z_mask,
|
424 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
425 |
-
(hidden_x, logw, logw_),
|
426 |
g,
|
427 |
-
loss_commit,
|
428 |
) = net_g(
|
429 |
x,
|
430 |
x_lengths,
|
@@ -436,7 +478,6 @@ def train_and_evaluate(
|
|
436 |
bert,
|
437 |
ja_bert,
|
438 |
en_bert,
|
439 |
-
emo,
|
440 |
)
|
441 |
mel = spec_to_mel_torch(
|
442 |
spec,
|
@@ -450,7 +491,7 @@ def train_and_evaluate(
|
|
450 |
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
451 |
)
|
452 |
y_hat_mel = mel_spectrogram_torch(
|
453 |
-
y_hat.squeeze(1),
|
454 |
hps.data.filter_length,
|
455 |
hps.data.n_mel_channels,
|
456 |
hps.data.sampling_rate,
|
@@ -466,7 +507,7 @@ def train_and_evaluate(
|
|
466 |
|
467 |
# Discriminator
|
468 |
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
469 |
-
with autocast(enabled=
|
470 |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
471 |
y_d_hat_r, y_d_hat_g
|
472 |
)
|
@@ -475,11 +516,20 @@ def train_and_evaluate(
|
|
475 |
y_dur_hat_r, y_dur_hat_g = net_dur_disc(
|
476 |
hidden_x.detach(),
|
477 |
x_mask.detach(),
|
|
|
478 |
logw.detach(),
|
|
|
|
|
|
|
|
|
|
|
479 |
logw_.detach(),
|
|
|
480 |
g.detach(),
|
481 |
)
|
482 |
-
|
|
|
|
|
483 |
# TODO: I think need to mean using the mask, but for now, just mean all
|
484 |
(
|
485 |
loss_dur_disc,
|
@@ -490,31 +540,60 @@ def train_and_evaluate(
|
|
490 |
optim_dur_disc.zero_grad()
|
491 |
scaler.scale(loss_dur_disc_all).backward()
|
492 |
scaler.unscale_(optim_dur_disc)
|
493 |
-
|
|
|
|
|
|
|
|
|
|
|
494 |
scaler.step(optim_dur_disc)
|
495 |
|
496 |
optim_d.zero_grad()
|
497 |
scaler.scale(loss_disc_all).backward()
|
498 |
scaler.unscale_(optim_d)
|
|
|
|
|
499 |
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
500 |
scaler.step(optim_d)
|
501 |
|
502 |
-
with autocast(enabled=hps.train.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
# Generator
|
504 |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
505 |
if net_dur_disc is not None:
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
with autocast(enabled=
|
510 |
loss_dur = torch.sum(l_length.float())
|
511 |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
512 |
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
513 |
|
514 |
loss_fm = feature_loss(fmap_r, fmap_g)
|
515 |
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
|
|
|
|
|
|
|
|
516 |
loss_gen_all = (
|
517 |
-
loss_gen
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
)
|
519 |
if net_dur_disc is not None:
|
520 |
loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
|
@@ -522,6 +601,8 @@ def train_and_evaluate(
|
|
522 |
optim_g.zero_grad()
|
523 |
scaler.scale(loss_gen_all).backward()
|
524 |
scaler.unscale_(optim_g)
|
|
|
|
|
525 |
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
526 |
scaler.step(optim_g)
|
527 |
scaler.update()
|
@@ -540,9 +621,12 @@ def train_and_evaluate(
|
|
540 |
scalar_dict = {
|
541 |
"loss/g/total": loss_gen_all,
|
542 |
"loss/d/total": loss_disc_all,
|
|
|
543 |
"learning_rate": lr,
|
544 |
"grad_norm_d": grad_norm_d,
|
545 |
"grad_norm_g": grad_norm_g,
|
|
|
|
|
546 |
}
|
547 |
scalar_dict.update(
|
548 |
{
|
@@ -550,6 +634,8 @@ def train_and_evaluate(
|
|
550 |
"loss/g/mel": loss_mel,
|
551 |
"loss/g/dur": loss_dur,
|
552 |
"loss/g/kl": loss_kl,
|
|
|
|
|
553 |
}
|
554 |
)
|
555 |
scalar_dict.update(
|
@@ -562,6 +648,30 @@ def train_and_evaluate(
|
|
562 |
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
563 |
)
|
564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
image_dict = {
|
566 |
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
567 |
y_mel[0].data.cpu().numpy()
|
@@ -599,6 +709,13 @@ def train_and_evaluate(
|
|
599 |
epoch,
|
600 |
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
601 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
if net_dur_disc is not None:
|
603 |
utils.save_checkpoint(
|
604 |
net_dur_disc,
|
@@ -642,7 +759,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|
642 |
bert,
|
643 |
ja_bert,
|
644 |
en_bert,
|
645 |
-
emo,
|
646 |
) in enumerate(eval_loader):
|
647 |
x, x_lengths = x.cuda(), x_lengths.cuda()
|
648 |
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
@@ -653,7 +769,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|
653 |
en_bert = en_bert.cuda()
|
654 |
tone = tone.cuda()
|
655 |
language = language.cuda()
|
656 |
-
emo = emo.cuda()
|
657 |
for use_sdp in [True, False]:
|
658 |
y_hat, attn, mask, *_ = generator.module.infer(
|
659 |
x,
|
@@ -664,7 +779,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|
664 |
bert,
|
665 |
ja_bert,
|
666 |
en_bert,
|
667 |
-
emo,
|
668 |
y=spec,
|
669 |
max_len=1000,
|
670 |
sdp_ratio=0.0 if not use_sdp else 1.0,
|
|
|
27 |
SynthesizerTrn,
|
28 |
MultiPeriodDiscriminator,
|
29 |
DurationDiscriminator,
|
30 |
+
WavLMDiscriminator,
|
31 |
+
)
|
32 |
+
from losses import (
|
33 |
+
generator_loss,
|
34 |
+
discriminator_loss,
|
35 |
+
feature_loss,
|
36 |
+
kl_loss,
|
37 |
+
WavLMLoss,
|
38 |
)
|
|
|
39 |
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
40 |
from text.symbols import symbols
|
41 |
|
|
|
49 |
torch.backends.cuda.enable_mem_efficient_sdp(
|
50 |
True
|
51 |
) # Not available if torch version is lower than 2.0
|
|
|
52 |
global_step = 0
|
53 |
|
54 |
|
|
|
179 |
0.1,
|
180 |
gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
|
181 |
).cuda(local_rank)
|
182 |
+
else:
|
183 |
+
net_dur_disc = None
|
184 |
if (
|
185 |
"use_spk_conditioned_encoder" in hps.model.keys()
|
186 |
and hps.model.use_spk_conditioned_encoder is True
|
|
|
218 |
param.requires_grad = False
|
219 |
|
220 |
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
|
221 |
+
net_wd = WavLMDiscriminator(
|
222 |
+
hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel
|
223 |
+
).cuda(local_rank)
|
224 |
optim_g = torch.optim.AdamW(
|
225 |
filter(lambda p: p.requires_grad, net_g.parameters()),
|
226 |
hps.train.learning_rate,
|
|
|
233 |
betas=hps.train.betas,
|
234 |
eps=hps.train.eps,
|
235 |
)
|
236 |
+
optim_wd = torch.optim.AdamW(
|
237 |
+
net_wd.parameters(),
|
238 |
+
hps.train.learning_rate,
|
239 |
+
betas=hps.train.betas,
|
240 |
+
eps=hps.train.eps,
|
241 |
+
)
|
242 |
if net_dur_disc is not None:
|
243 |
optim_dur_disc = torch.optim.AdamW(
|
244 |
net_dur_disc.parameters(),
|
|
|
250 |
optim_dur_disc = None
|
251 |
net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
|
252 |
net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
|
253 |
+
net_wd = DDP(net_wd, device_ids=[local_rank], bucket_cap_mb=512)
|
254 |
if net_dur_disc is not None:
|
255 |
net_dur_disc = DDP(
|
256 |
net_dur_disc,
|
257 |
device_ids=[local_rank],
|
|
|
258 |
bucket_cap_mb=512,
|
259 |
)
|
260 |
|
|
|
266 |
token=config.openi_token,
|
267 |
mirror=config.mirror,
|
268 |
)
|
269 |
+
dur_resume_lr = hps.train.learning_rate
|
270 |
+
wd_resume_lr = hps.train.learning_rate
|
271 |
+
if net_dur_disc is not None:
|
272 |
+
try:
|
273 |
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
|
274 |
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
|
275 |
net_dur_disc,
|
|
|
278 |
if "skip_optimizer" in hps.train
|
279 |
else True,
|
280 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
if not optim_dur_disc.param_groups[0].get("initial_lr"):
|
282 |
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
|
283 |
+
except:
|
284 |
+
print("Initialize dur_disc")
|
285 |
+
|
286 |
+
try:
|
287 |
+
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
|
288 |
+
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
|
289 |
+
net_g,
|
290 |
+
optim_g,
|
291 |
+
skip_optimizer=hps.train.skip_optimizer
|
292 |
+
if "skip_optimizer" in hps.train
|
293 |
+
else True,
|
294 |
+
)
|
295 |
+
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
|
296 |
+
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
|
297 |
+
net_d,
|
298 |
+
optim_d,
|
299 |
+
skip_optimizer=hps.train.skip_optimizer
|
300 |
+
if "skip_optimizer" in hps.train
|
301 |
+
else True,
|
302 |
+
)
|
303 |
+
if not optim_g.param_groups[0].get("initial_lr"):
|
304 |
+
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
|
305 |
+
if not optim_d.param_groups[0].get("initial_lr"):
|
306 |
+
optim_d.param_groups[0]["initial_lr"] = d_resume_lr
|
307 |
|
308 |
epoch_str = max(epoch_str, 1)
|
309 |
# global_step = (epoch_str - 1) * len(train_loader)
|
|
|
318 |
epoch_str = 1
|
319 |
global_step = 0
|
320 |
|
321 |
+
try:
|
322 |
+
_, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
|
323 |
+
utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
|
324 |
+
net_wd,
|
325 |
+
optim_wd,
|
326 |
+
skip_optimizer=hps.train.skip_optimizer
|
327 |
+
if "skip_optimizer" in hps.train
|
328 |
+
else True,
|
329 |
+
)
|
330 |
+
if not optim_wd.param_groups[0].get("initial_lr"):
|
331 |
+
optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
|
332 |
+
except Exception as e:
|
333 |
+
print(e)
|
334 |
+
|
335 |
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
336 |
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
337 |
)
|
338 |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
339 |
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
340 |
)
|
341 |
+
scheduler_wd = torch.optim.lr_scheduler.ExponentialLR(
|
342 |
+
optim_wd, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
343 |
+
)
|
344 |
if net_dur_disc is not None:
|
|
|
|
|
345 |
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
|
346 |
optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
|
347 |
)
|
348 |
else:
|
349 |
scheduler_dur_disc = None
|
350 |
+
scaler = GradScaler(enabled=hps.train.bf16_run)
|
351 |
+
|
352 |
+
wl = WavLMLoss(
|
353 |
+
hps.model.slm.model,
|
354 |
+
net_wd,
|
355 |
+
hps.data.sampling_rate,
|
356 |
+
hps.model.slm.sr,
|
357 |
+
).to(local_rank)
|
358 |
|
359 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
360 |
if rank == 0:
|
|
|
363 |
local_rank,
|
364 |
epoch,
|
365 |
hps,
|
366 |
+
[net_g, net_d, net_dur_disc, net_wd, wl],
|
367 |
+
[optim_g, optim_d, optim_dur_disc, optim_wd],
|
368 |
+
[scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
|
369 |
scaler,
|
370 |
[train_loader, eval_loader],
|
371 |
logger,
|
|
|
377 |
local_rank,
|
378 |
epoch,
|
379 |
hps,
|
380 |
+
[net_g, net_d, net_dur_disc, net_wd, wl],
|
381 |
+
[optim_g, optim_d, optim_dur_disc, optim_wd],
|
382 |
+
[scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
|
383 |
scaler,
|
384 |
[train_loader, None],
|
385 |
None,
|
|
|
387 |
)
|
388 |
scheduler_g.step()
|
389 |
scheduler_d.step()
|
390 |
+
scheduler_wd.step()
|
391 |
if net_dur_disc is not None:
|
392 |
scheduler_dur_disc.step()
|
393 |
|
|
|
405 |
logger,
|
406 |
writers,
|
407 |
):
|
408 |
+
net_g, net_d, net_dur_disc, net_wd, wl = nets
|
409 |
+
optim_g, optim_d, optim_dur_disc, optim_wd = optims
|
410 |
+
scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
|
411 |
train_loader, eval_loader = loaders
|
412 |
if writers is not None:
|
413 |
writer, writer_eval = writers
|
|
|
417 |
|
418 |
net_g.train()
|
419 |
net_d.train()
|
420 |
+
net_wd.train()
|
421 |
if net_dur_disc is not None:
|
422 |
net_dur_disc.train()
|
423 |
for batch_idx, (
|
|
|
433 |
bert,
|
434 |
ja_bert,
|
435 |
en_bert,
|
|
|
436 |
) in enumerate(tqdm(train_loader)):
|
437 |
if net_g.module.use_noise_scaled_mas:
|
438 |
current_mas_noise_scale = (
|
|
|
455 |
bert = bert.cuda(local_rank, non_blocking=True)
|
456 |
ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
|
457 |
en_bert = en_bert.cuda(local_rank, non_blocking=True)
|
|
|
458 |
|
459 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
460 |
(
|
461 |
y_hat,
|
462 |
l_length,
|
|
|
465 |
x_mask,
|
466 |
z_mask,
|
467 |
(z, z_p, m_p, logs_p, m_q, logs_q),
|
468 |
+
(hidden_x, logw, logw_, logw_sdp),
|
469 |
g,
|
|
|
470 |
) = net_g(
|
471 |
x,
|
472 |
x_lengths,
|
|
|
478 |
bert,
|
479 |
ja_bert,
|
480 |
en_bert,
|
|
|
481 |
)
|
482 |
mel = spec_to_mel_torch(
|
483 |
spec,
|
|
|
491 |
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
492 |
)
|
493 |
y_hat_mel = mel_spectrogram_torch(
|
494 |
+
y_hat.squeeze(1).float(),
|
495 |
hps.data.filter_length,
|
496 |
hps.data.n_mel_channels,
|
497 |
hps.data.sampling_rate,
|
|
|
507 |
|
508 |
# Discriminator
|
509 |
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
510 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
511 |
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
512 |
y_d_hat_r, y_d_hat_g
|
513 |
)
|
|
|
516 |
y_dur_hat_r, y_dur_hat_g = net_dur_disc(
|
517 |
hidden_x.detach(),
|
518 |
x_mask.detach(),
|
519 |
+
logw_.detach(),
|
520 |
logw.detach(),
|
521 |
+
g.detach(),
|
522 |
+
)
|
523 |
+
y_dur_hat_r_sdp, y_dur_hat_g_sdp = net_dur_disc(
|
524 |
+
hidden_x.detach(),
|
525 |
+
x_mask.detach(),
|
526 |
logw_.detach(),
|
527 |
+
logw_sdp.detach(),
|
528 |
g.detach(),
|
529 |
)
|
530 |
+
y_dur_hat_r = y_dur_hat_r + y_dur_hat_r_sdp
|
531 |
+
y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
|
532 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
533 |
# TODO: I think need to mean using the mask, but for now, just mean all
|
534 |
(
|
535 |
loss_dur_disc,
|
|
|
540 |
optim_dur_disc.zero_grad()
|
541 |
scaler.scale(loss_dur_disc_all).backward()
|
542 |
scaler.unscale_(optim_dur_disc)
|
543 |
+
# torch.nn.utils.clip_grad_norm_(
|
544 |
+
# parameters=net_dur_disc.parameters(), max_norm=100
|
545 |
+
# )
|
546 |
+
grad_norm_dur = commons.clip_grad_value_(
|
547 |
+
net_dur_disc.parameters(), None
|
548 |
+
)
|
549 |
scaler.step(optim_dur_disc)
|
550 |
|
551 |
optim_d.zero_grad()
|
552 |
scaler.scale(loss_disc_all).backward()
|
553 |
scaler.unscale_(optim_d)
|
554 |
+
if getattr(hps.train, "bf16_run", False):
|
555 |
+
torch.nn.utils.clip_grad_norm_(parameters=net_d.parameters(), max_norm=200)
|
556 |
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
557 |
scaler.step(optim_d)
|
558 |
|
559 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
560 |
+
loss_slm = wl.discriminator(
|
561 |
+
y.detach().squeeze(), y_hat.detach().squeeze()
|
562 |
+
).mean()
|
563 |
+
|
564 |
+
optim_wd.zero_grad()
|
565 |
+
scaler.scale(loss_slm).backward()
|
566 |
+
scaler.unscale_(optim_wd)
|
567 |
+
# torch.nn.utils.clip_grad_norm_(parameters=net_wd.parameters(), max_norm=200)
|
568 |
+
grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None)
|
569 |
+
scaler.step(optim_wd)
|
570 |
+
|
571 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
572 |
# Generator
|
573 |
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
574 |
if net_dur_disc is not None:
|
575 |
+
_, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw, g)
|
576 |
+
_, y_dur_hat_g_sdp = net_dur_disc(hidden_x, x_mask, logw_, logw_sdp, g)
|
577 |
+
y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
|
578 |
+
with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
|
579 |
loss_dur = torch.sum(l_length.float())
|
580 |
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
581 |
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
582 |
|
583 |
loss_fm = feature_loss(fmap_r, fmap_g)
|
584 |
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
585 |
+
|
586 |
+
loss_lm = wl(y.detach().squeeze(), y_hat.squeeze()).mean()
|
587 |
+
loss_lm_gen = wl.generator(y_hat.squeeze())
|
588 |
+
|
589 |
loss_gen_all = (
|
590 |
+
loss_gen
|
591 |
+
+ loss_fm
|
592 |
+
+ loss_mel
|
593 |
+
+ loss_dur
|
594 |
+
+ loss_kl
|
595 |
+
+ loss_lm
|
596 |
+
+ loss_lm_gen
|
597 |
)
|
598 |
if net_dur_disc is not None:
|
599 |
loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
|
|
|
601 |
optim_g.zero_grad()
|
602 |
scaler.scale(loss_gen_all).backward()
|
603 |
scaler.unscale_(optim_g)
|
604 |
+
if getattr(hps.train, "bf16_run", False):
|
605 |
+
torch.nn.utils.clip_grad_norm_(parameters=net_g.parameters(), max_norm=500)
|
606 |
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
607 |
scaler.step(optim_g)
|
608 |
scaler.update()
|
|
|
621 |
scalar_dict = {
|
622 |
"loss/g/total": loss_gen_all,
|
623 |
"loss/d/total": loss_disc_all,
|
624 |
+
"loss/wd/total": loss_slm,
|
625 |
"learning_rate": lr,
|
626 |
"grad_norm_d": grad_norm_d,
|
627 |
"grad_norm_g": grad_norm_g,
|
628 |
+
"grad_norm_dur": grad_norm_dur,
|
629 |
+
"grad_norm_wd": grad_norm_wd,
|
630 |
}
|
631 |
scalar_dict.update(
|
632 |
{
|
|
|
634 |
"loss/g/mel": loss_mel,
|
635 |
"loss/g/dur": loss_dur,
|
636 |
"loss/g/kl": loss_kl,
|
637 |
+
"loss/g/lm": loss_lm,
|
638 |
+
"loss/g/lm_gen": loss_lm_gen,
|
639 |
}
|
640 |
)
|
641 |
scalar_dict.update(
|
|
|
648 |
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
649 |
)
|
650 |
|
651 |
+
if net_dur_disc is not None:
|
652 |
+
scalar_dict.update({"loss/dur_disc/total": loss_dur_disc_all})
|
653 |
+
|
654 |
+
scalar_dict.update(
|
655 |
+
{
|
656 |
+
"loss/dur_disc_g/{}".format(i): v
|
657 |
+
for i, v in enumerate(losses_dur_disc_g)
|
658 |
+
}
|
659 |
+
)
|
660 |
+
scalar_dict.update(
|
661 |
+
{
|
662 |
+
"loss/dur_disc_r/{}".format(i): v
|
663 |
+
for i, v in enumerate(losses_dur_disc_r)
|
664 |
+
}
|
665 |
+
)
|
666 |
+
|
667 |
+
scalar_dict.update({"loss/g/dur_gen": loss_dur_gen})
|
668 |
+
scalar_dict.update(
|
669 |
+
{
|
670 |
+
"loss/g/dur_gen_{}".format(i): v
|
671 |
+
for i, v in enumerate(losses_dur_gen)
|
672 |
+
}
|
673 |
+
)
|
674 |
+
|
675 |
image_dict = {
|
676 |
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
677 |
y_mel[0].data.cpu().numpy()
|
|
|
709 |
epoch,
|
710 |
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
711 |
)
|
712 |
+
utils.save_checkpoint(
|
713 |
+
net_wd,
|
714 |
+
optim_wd,
|
715 |
+
hps.train.learning_rate,
|
716 |
+
epoch,
|
717 |
+
os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)),
|
718 |
+
)
|
719 |
if net_dur_disc is not None:
|
720 |
utils.save_checkpoint(
|
721 |
net_dur_disc,
|
|
|
759 |
bert,
|
760 |
ja_bert,
|
761 |
en_bert,
|
|
|
762 |
) in enumerate(eval_loader):
|
763 |
x, x_lengths = x.cuda(), x_lengths.cuda()
|
764 |
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
|
|
769 |
en_bert = en_bert.cuda()
|
770 |
tone = tone.cuda()
|
771 |
language = language.cuda()
|
|
|
772 |
for use_sdp in [True, False]:
|
773 |
y_hat, attn, mask, *_ = generator.module.infer(
|
774 |
x,
|
|
|
779 |
bert,
|
780 |
ja_bert,
|
781 |
en_bert,
|
|
|
782 |
y=spec,
|
783 |
max_len=1000,
|
784 |
sdp_ratio=0.0 if not use_sdp else 1.0,
|
utils.py
CHANGED
@@ -301,7 +301,11 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
|
|
301 |
|
302 |
to_del = [
|
303 |
os.path.join(path_to_models, fn)
|
304 |
-
for fn in (
|
|
|
|
|
|
|
|
|
305 |
]
|
306 |
|
307 |
def del_info(fn):
|
|
|
301 |
|
302 |
to_del = [
|
303 |
os.path.join(path_to_models, fn)
|
304 |
+
for fn in (
|
305 |
+
x_sorted("G")[:-n_ckpts_to_keep]
|
306 |
+
+ x_sorted("D")[:-n_ckpts_to_keep]
|
307 |
+
+ x_sorted("WD")[:-n_ckpts_to_keep]
|
308 |
+
)
|
309 |
]
|
310 |
|
311 |
def del_info(fn):
|
webui.py
CHANGED
@@ -42,6 +42,8 @@ def generate_audio(
|
|
42 |
language,
|
43 |
reference_audio,
|
44 |
emotion,
|
|
|
|
|
45 |
skip_start=False,
|
46 |
skip_end=False,
|
47 |
):
|
@@ -49,8 +51,8 @@ def generate_audio(
|
|
49 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
50 |
with torch.no_grad():
|
51 |
for idx, piece in enumerate(slices):
|
52 |
-
skip_start =
|
53 |
-
skip_end =
|
54 |
audio = infer(
|
55 |
piece,
|
56 |
reference_audio=reference_audio,
|
@@ -66,10 +68,11 @@ def generate_audio(
|
|
66 |
device=device,
|
67 |
skip_start=skip_start,
|
68 |
skip_end=skip_end,
|
|
|
|
|
69 |
)
|
70 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
71 |
audio_list.append(audio16bit)
|
72 |
-
# audio_list.append(silence) # 将静音添加到列表中
|
73 |
return audio_list
|
74 |
|
75 |
|
@@ -90,8 +93,8 @@ def generate_audio_multilang(
|
|
90 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
91 |
with torch.no_grad():
|
92 |
for idx, piece in enumerate(slices):
|
93 |
-
skip_start =
|
94 |
-
skip_end =
|
95 |
audio = infer_multilang(
|
96 |
piece,
|
97 |
reference_audio=reference_audio,
|
@@ -110,7 +113,6 @@ def generate_audio_multilang(
|
|
110 |
)
|
111 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
112 |
audio_list.append(audio16bit)
|
113 |
-
# audio_list.append(silence) # 将静音添加到列表中
|
114 |
return audio_list
|
115 |
|
116 |
|
@@ -127,63 +129,50 @@ def tts_split(
|
|
127 |
interval_between_sent,
|
128 |
reference_audio,
|
129 |
emotion,
|
|
|
|
|
130 |
):
|
131 |
-
if language == "mix":
|
132 |
-
return ("invalid", None)
|
133 |
while text.find("\n\n") != -1:
|
134 |
text = text.replace("\n\n", "\n")
|
|
|
135 |
para_list = re_matching.cut_para(text)
|
|
|
136 |
audio_list = []
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
skip_end = idx != len(para_list) - 1
|
141 |
-
audio = infer(
|
142 |
p,
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
device=device,
|
154 |
-
skip_start=skip_start,
|
155 |
-
skip_end=skip_end,
|
156 |
)
|
157 |
-
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
158 |
-
audio_list.append(audio16bit)
|
159 |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
160 |
audio_list.append(silence)
|
161 |
-
|
162 |
-
for idx, p in enumerate(para_list):
|
163 |
-
skip_start = idx != 0
|
164 |
-
skip_end = idx != len(para_list) - 1
|
165 |
audio_list_sent = []
|
166 |
sent_list = re_matching.cut_sent(p)
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
audio = infer(
|
171 |
s,
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
device=device,
|
183 |
-
skip_start=skip_start,
|
184 |
-
skip_end=skip_end,
|
185 |
)
|
186 |
-
audio_list_sent.append(audio)
|
187 |
silence = np.zeros((int)(44100 * interval_between_sent))
|
188 |
audio_list_sent.append(silence)
|
189 |
if (interval_between_para - interval_between_sent) > 0:
|
@@ -196,10 +185,47 @@ def tts_split(
|
|
196 |
) # 对完整句子做音量归一
|
197 |
audio_list.append(audio16bit)
|
198 |
audio_concat = np.concatenate(audio_list)
|
199 |
-
return ("Success", (
|
200 |
|
201 |
|
202 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
text: str,
|
204 |
speaker,
|
205 |
sdp_ratio,
|
@@ -209,15 +235,9 @@ def tts_fn(
|
|
209 |
language,
|
210 |
reference_audio,
|
211 |
emotion,
|
212 |
-
|
|
|
213 |
):
|
214 |
-
if prompt_mode == "Audio prompt":
|
215 |
-
if reference_audio == None:
|
216 |
-
return ("Invalid audio prompt", None)
|
217 |
-
else:
|
218 |
-
reference_audio = load_audio(reference_audio)[1]
|
219 |
-
else:
|
220 |
-
reference_audio = None
|
221 |
audio_list = []
|
222 |
if language == "mix":
|
223 |
bool_valid, str_valid = re_matching.validate_text(text)
|
@@ -226,120 +246,40 @@ def tts_fn(
|
|
226 |
hps.data.sampling_rate,
|
227 |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
228 |
)
|
229 |
-
result = []
|
230 |
for slice in re_matching.text_matching(text):
|
231 |
-
_speaker = slice
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
temp_lang += temp_
|
247 |
-
else:
|
248 |
-
if len(temp_contant) == 0:
|
249 |
-
temp_contant.append([])
|
250 |
-
temp_lang.append([])
|
251 |
-
temp_contant[-1].append(content)
|
252 |
-
temp_lang[-1].append(lang)
|
253 |
-
for i, j in zip(temp_lang, temp_contant):
|
254 |
-
result.append([*zip(i, j), _speaker])
|
255 |
-
for i, one in enumerate(result):
|
256 |
-
skip_start = i != 0
|
257 |
-
skip_end = i != len(result) - 1
|
258 |
-
_speaker = one.pop()
|
259 |
-
idx = 0
|
260 |
-
while idx < len(one):
|
261 |
-
text_to_generate = []
|
262 |
-
lang_to_generate = []
|
263 |
-
while True:
|
264 |
-
lang, content = one[idx]
|
265 |
-
temp_text = [content]
|
266 |
-
if len(text_to_generate) > 0:
|
267 |
-
text_to_generate[-1] += [temp_text.pop(0)]
|
268 |
-
lang_to_generate[-1] += [lang]
|
269 |
-
if len(temp_text) > 0:
|
270 |
-
text_to_generate += [[i] for i in temp_text]
|
271 |
-
lang_to_generate += [[lang]] * len(temp_text)
|
272 |
-
if idx + 1 < len(one):
|
273 |
-
idx += 1
|
274 |
-
else:
|
275 |
-
break
|
276 |
-
skip_start = (idx != 0) and skip_start
|
277 |
-
skip_end = (idx != len(one) - 1) and skip_end
|
278 |
-
print(text_to_generate, lang_to_generate)
|
279 |
-
audio_list.extend(
|
280 |
-
generate_audio_multilang(
|
281 |
-
text_to_generate,
|
282 |
-
sdp_ratio,
|
283 |
-
noise_scale,
|
284 |
-
noise_scale_w,
|
285 |
-
length_scale,
|
286 |
-
_speaker,
|
287 |
-
lang_to_generate,
|
288 |
-
reference_audio,
|
289 |
-
emotion,
|
290 |
-
skip_start,
|
291 |
-
skip_end,
|
292 |
-
)
|
293 |
)
|
294 |
-
|
295 |
elif language.lower() == "auto":
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
)
|
304 |
-
|
305 |
-
while idx < len(sentences_list):
|
306 |
-
text_to_generate = []
|
307 |
-
lang_to_generate = []
|
308 |
-
while True:
|
309 |
-
content, lang = sentences_list[idx]
|
310 |
-
temp_text = [content]
|
311 |
-
lang = lang.upper()
|
312 |
-
if lang == "JA":
|
313 |
-
lang = "JP"
|
314 |
-
if len(text_to_generate) > 0:
|
315 |
-
text_to_generate[-1] += [temp_text.pop(0)]
|
316 |
-
lang_to_generate[-1] += [lang]
|
317 |
-
if len(temp_text) > 0:
|
318 |
-
text_to_generate += [[i] for i in temp_text]
|
319 |
-
lang_to_generate += [[lang]] * len(temp_text)
|
320 |
-
if idx + 1 < len(sentences_list):
|
321 |
-
idx += 1
|
322 |
-
else:
|
323 |
-
break
|
324 |
-
skip_start = (idx != 0) and skip_start
|
325 |
-
skip_end = (idx != len(sentences_list) - 1) and skip_end
|
326 |
-
print(text_to_generate, lang_to_generate)
|
327 |
-
audio_list.extend(
|
328 |
-
generate_audio_multilang(
|
329 |
-
text_to_generate,
|
330 |
-
sdp_ratio,
|
331 |
-
noise_scale,
|
332 |
-
noise_scale_w,
|
333 |
-
length_scale,
|
334 |
-
speaker,
|
335 |
-
lang_to_generate,
|
336 |
-
reference_audio,
|
337 |
-
emotion,
|
338 |
-
skip_start,
|
339 |
-
skip_end,
|
340 |
-
)
|
341 |
-
)
|
342 |
-
idx += 1
|
343 |
else:
|
344 |
audio_list.extend(
|
345 |
generate_audio(
|
@@ -352,13 +292,65 @@ def tts_fn(
|
|
352 |
language,
|
353 |
reference_audio,
|
354 |
emotion,
|
|
|
|
|
355 |
)
|
356 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
audio_concat = np.concatenate(audio_list)
|
359 |
return "Success", (hps.data.sampling_rate, audio_concat)
|
360 |
|
361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
def load_audio(path):
|
363 |
audio, sr = librosa.load(path, 48000)
|
364 |
# audio = librosa.resample(audio, 44100, 48000)
|
@@ -408,34 +400,37 @@ if __name__ == "__main__":
|
|
408 |
)
|
409 |
trans = gr.Button("中翻日", variant="primary")
|
410 |
slicer = gr.Button("快速切分", variant="primary")
|
|
|
411 |
speaker = gr.Dropdown(
|
412 |
choices=speakers, value=speakers[0], label="Speaker"
|
413 |
)
|
414 |
_ = gr.Markdown(
|
415 |
-
value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n"
|
|
|
416 |
)
|
417 |
prompt_mode = gr.Radio(
|
418 |
["Text prompt", "Audio prompt"],
|
419 |
label="Prompt Mode",
|
420 |
value="Text prompt",
|
|
|
421 |
)
|
422 |
text_prompt = gr.Textbox(
|
423 |
label="Text prompt",
|
424 |
placeholder="用文字描述生成风格。如:Happy",
|
425 |
value="Happy",
|
426 |
-
visible=
|
427 |
)
|
428 |
audio_prompt = gr.Audio(
|
429 |
label="Audio prompt", type="filepath", visible=False
|
430 |
)
|
431 |
sdp_ratio = gr.Slider(
|
432 |
-
minimum=0, maximum=1, value=0.
|
433 |
)
|
434 |
noise_scale = gr.Slider(
|
435 |
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
|
436 |
)
|
437 |
noise_scale_w = gr.Slider(
|
438 |
-
minimum=0.1, maximum=2, value=0.
|
439 |
)
|
440 |
length_scale = gr.Slider(
|
441 |
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
|
@@ -445,6 +440,21 @@ if __name__ == "__main__":
|
|
445 |
)
|
446 |
btn = gr.Button("生成音频!", variant="primary")
|
447 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
with gr.Row():
|
449 |
with gr.Column():
|
450 |
interval_between_sent = gr.Slider(
|
@@ -487,6 +497,8 @@ if __name__ == "__main__":
|
|
487 |
audio_prompt,
|
488 |
text_prompt,
|
489 |
prompt_mode,
|
|
|
|
|
490 |
],
|
491 |
outputs=[text_output, audio_output],
|
492 |
)
|
@@ -511,6 +523,8 @@ if __name__ == "__main__":
|
|
511 |
interval_between_sent,
|
512 |
audio_prompt,
|
513 |
text_prompt,
|
|
|
|
|
514 |
],
|
515 |
outputs=[text_output, audio_output],
|
516 |
)
|
@@ -527,6 +541,12 @@ if __name__ == "__main__":
|
|
527 |
outputs=[audio_prompt],
|
528 |
)
|
529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
print("推理页面已开启!")
|
531 |
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
|
532 |
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
|
|
|
42 |
language,
|
43 |
reference_audio,
|
44 |
emotion,
|
45 |
+
style_text,
|
46 |
+
style_weight,
|
47 |
skip_start=False,
|
48 |
skip_end=False,
|
49 |
):
|
|
|
51 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
52 |
with torch.no_grad():
|
53 |
for idx, piece in enumerate(slices):
|
54 |
+
skip_start = idx != 0
|
55 |
+
skip_end = idx != len(slices) - 1
|
56 |
audio = infer(
|
57 |
piece,
|
58 |
reference_audio=reference_audio,
|
|
|
68 |
device=device,
|
69 |
skip_start=skip_start,
|
70 |
skip_end=skip_end,
|
71 |
+
style_text=style_text,
|
72 |
+
style_weight=style_weight,
|
73 |
)
|
74 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
75 |
audio_list.append(audio16bit)
|
|
|
76 |
return audio_list
|
77 |
|
78 |
|
|
|
93 |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
94 |
with torch.no_grad():
|
95 |
for idx, piece in enumerate(slices):
|
96 |
+
skip_start = idx != 0
|
97 |
+
skip_end = idx != len(slices) - 1
|
98 |
audio = infer_multilang(
|
99 |
piece,
|
100 |
reference_audio=reference_audio,
|
|
|
113 |
)
|
114 |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
115 |
audio_list.append(audio16bit)
|
|
|
116 |
return audio_list
|
117 |
|
118 |
|
|
|
129 |
interval_between_sent,
|
130 |
reference_audio,
|
131 |
emotion,
|
132 |
+
style_text,
|
133 |
+
style_weight,
|
134 |
):
|
|
|
|
|
135 |
while text.find("\n\n") != -1:
|
136 |
text = text.replace("\n\n", "\n")
|
137 |
+
text = text.replace("|", "")
|
138 |
para_list = re_matching.cut_para(text)
|
139 |
+
para_list = [p for p in para_list if p != ""]
|
140 |
audio_list = []
|
141 |
+
for p in para_list:
|
142 |
+
if not cut_by_sent:
|
143 |
+
audio_list += process_text(
|
|
|
|
|
144 |
p,
|
145 |
+
speaker,
|
146 |
+
sdp_ratio,
|
147 |
+
noise_scale,
|
148 |
+
noise_scale_w,
|
149 |
+
length_scale,
|
150 |
+
language,
|
151 |
+
reference_audio,
|
152 |
+
emotion,
|
153 |
+
style_text,
|
154 |
+
style_weight,
|
|
|
|
|
|
|
155 |
)
|
|
|
|
|
156 |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
|
157 |
audio_list.append(silence)
|
158 |
+
else:
|
|
|
|
|
|
|
159 |
audio_list_sent = []
|
160 |
sent_list = re_matching.cut_sent(p)
|
161 |
+
sent_list = [s for s in sent_list if s != ""]
|
162 |
+
for s in sent_list:
|
163 |
+
audio_list_sent += process_text(
|
|
|
164 |
s,
|
165 |
+
speaker,
|
166 |
+
sdp_ratio,
|
167 |
+
noise_scale,
|
168 |
+
noise_scale_w,
|
169 |
+
length_scale,
|
170 |
+
language,
|
171 |
+
reference_audio,
|
172 |
+
emotion,
|
173 |
+
style_text,
|
174 |
+
style_weight,
|
|
|
|
|
|
|
175 |
)
|
|
|
176 |
silence = np.zeros((int)(44100 * interval_between_sent))
|
177 |
audio_list_sent.append(silence)
|
178 |
if (interval_between_para - interval_between_sent) > 0:
|
|
|
185 |
) # 对完整句子做音量归一
|
186 |
audio_list.append(audio16bit)
|
187 |
audio_concat = np.concatenate(audio_list)
|
188 |
+
return ("Success", (hps.data.sampling_rate, audio_concat))
|
189 |
|
190 |
|
191 |
+
def process_mix(slice):
|
192 |
+
_speaker = slice.pop()
|
193 |
+
_text, _lang = [], []
|
194 |
+
for lang, content in slice:
|
195 |
+
content = content.split("|")
|
196 |
+
content = [part for part in content if part != ""]
|
197 |
+
if len(content) == 0:
|
198 |
+
continue
|
199 |
+
if len(_text) == 0:
|
200 |
+
_text = [[part] for part in content]
|
201 |
+
_lang = [[lang] for part in content]
|
202 |
+
else:
|
203 |
+
_text[-1].append(content[0])
|
204 |
+
_lang[-1].append(lang)
|
205 |
+
if len(content) > 1:
|
206 |
+
_text += [[part] for part in content[1:]]
|
207 |
+
_lang += [[lang] for part in content[1:]]
|
208 |
+
return _text, _lang, _speaker
|
209 |
+
|
210 |
+
|
211 |
+
def process_auto(text):
|
212 |
+
_text, _lang = [], []
|
213 |
+
for slice in text.split("|"):
|
214 |
+
if slice == "":
|
215 |
+
continue
|
216 |
+
temp_text, temp_lang = [], []
|
217 |
+
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
|
218 |
+
for sentence, lang in sentences_list:
|
219 |
+
if sentence == "":
|
220 |
+
continue
|
221 |
+
temp_text.append(sentence)
|
222 |
+
temp_lang.append(lang.upper())
|
223 |
+
_text.append(temp_text)
|
224 |
+
_lang.append(temp_lang)
|
225 |
+
return _text, _lang
|
226 |
+
|
227 |
+
|
228 |
+
def process_text(
|
229 |
text: str,
|
230 |
speaker,
|
231 |
sdp_ratio,
|
|
|
235 |
language,
|
236 |
reference_audio,
|
237 |
emotion,
|
238 |
+
style_text=None,
|
239 |
+
style_weight=0,
|
240 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
audio_list = []
|
242 |
if language == "mix":
|
243 |
bool_valid, str_valid = re_matching.validate_text(text)
|
|
|
246 |
hps.data.sampling_rate,
|
247 |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
|
248 |
)
|
|
|
249 |
for slice in re_matching.text_matching(text):
|
250 |
+
_text, _lang, _speaker = process_mix(slice)
|
251 |
+
if _speaker is None:
|
252 |
+
continue
|
253 |
+
print(f"Text: {_text}\nLang: {_lang}")
|
254 |
+
audio_list.extend(
|
255 |
+
generate_audio_multilang(
|
256 |
+
_text,
|
257 |
+
sdp_ratio,
|
258 |
+
noise_scale,
|
259 |
+
noise_scale_w,
|
260 |
+
length_scale,
|
261 |
+
_speaker,
|
262 |
+
_lang,
|
263 |
+
reference_audio,
|
264 |
+
emotion,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
)
|
266 |
+
)
|
267 |
elif language.lower() == "auto":
|
268 |
+
_text, _lang = process_auto(text)
|
269 |
+
print(f"Text: {_text}\nLang: {_lang}")
|
270 |
+
audio_list.extend(
|
271 |
+
generate_audio_multilang(
|
272 |
+
_text,
|
273 |
+
sdp_ratio,
|
274 |
+
noise_scale,
|
275 |
+
noise_scale_w,
|
276 |
+
length_scale,
|
277 |
+
speaker,
|
278 |
+
_lang,
|
279 |
+
reference_audio,
|
280 |
+
emotion,
|
281 |
)
|
282 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
else:
|
284 |
audio_list.extend(
|
285 |
generate_audio(
|
|
|
292 |
language,
|
293 |
reference_audio,
|
294 |
emotion,
|
295 |
+
style_text,
|
296 |
+
style_weight,
|
297 |
)
|
298 |
)
|
299 |
+
return audio_list
|
300 |
+
|
301 |
+
|
302 |
+
def tts_fn(
|
303 |
+
text: str,
|
304 |
+
speaker,
|
305 |
+
sdp_ratio,
|
306 |
+
noise_scale,
|
307 |
+
noise_scale_w,
|
308 |
+
length_scale,
|
309 |
+
language,
|
310 |
+
reference_audio,
|
311 |
+
emotion,
|
312 |
+
prompt_mode,
|
313 |
+
style_text=None,
|
314 |
+
style_weight=0,
|
315 |
+
):
|
316 |
+
if style_text == "":
|
317 |
+
style_text = None
|
318 |
+
if prompt_mode == "Audio prompt":
|
319 |
+
if reference_audio == None:
|
320 |
+
return ("Invalid audio prompt", None)
|
321 |
+
else:
|
322 |
+
reference_audio = load_audio(reference_audio)[1]
|
323 |
+
else:
|
324 |
+
reference_audio = None
|
325 |
+
|
326 |
+
audio_list = process_text(
|
327 |
+
text,
|
328 |
+
speaker,
|
329 |
+
sdp_ratio,
|
330 |
+
noise_scale,
|
331 |
+
noise_scale_w,
|
332 |
+
length_scale,
|
333 |
+
language,
|
334 |
+
reference_audio,
|
335 |
+
emotion,
|
336 |
+
style_text,
|
337 |
+
style_weight,
|
338 |
+
)
|
339 |
|
340 |
audio_concat = np.concatenate(audio_list)
|
341 |
return "Success", (hps.data.sampling_rate, audio_concat)
|
342 |
|
343 |
|
344 |
+
def format_utils(text, speaker):
|
345 |
+
_text, _lang = process_auto(text)
|
346 |
+
res = f"[{speaker}]"
|
347 |
+
for lang_s, content_s in zip(_lang, _text):
|
348 |
+
for lang, content in zip(lang_s, content_s):
|
349 |
+
res += f"<{lang.lower()}>{content}"
|
350 |
+
res += "|"
|
351 |
+
return "mix", res[:-1]
|
352 |
+
|
353 |
+
|
354 |
def load_audio(path):
|
355 |
audio, sr = librosa.load(path, 48000)
|
356 |
# audio = librosa.resample(audio, 44100, 48000)
|
|
|
400 |
)
|
401 |
trans = gr.Button("中翻日", variant="primary")
|
402 |
slicer = gr.Button("快速切分", variant="primary")
|
403 |
+
formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
|
404 |
speaker = gr.Dropdown(
|
405 |
choices=speakers, value=speakers[0], label="Speaker"
|
406 |
)
|
407 |
_ = gr.Markdown(
|
408 |
+
value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
|
409 |
+
visible=False,
|
410 |
)
|
411 |
prompt_mode = gr.Radio(
|
412 |
["Text prompt", "Audio prompt"],
|
413 |
label="Prompt Mode",
|
414 |
value="Text prompt",
|
415 |
+
visible=False,
|
416 |
)
|
417 |
text_prompt = gr.Textbox(
|
418 |
label="Text prompt",
|
419 |
placeholder="用文字描述生成风格。如:Happy",
|
420 |
value="Happy",
|
421 |
+
visible=False,
|
422 |
)
|
423 |
audio_prompt = gr.Audio(
|
424 |
label="Audio prompt", type="filepath", visible=False
|
425 |
)
|
426 |
sdp_ratio = gr.Slider(
|
427 |
+
minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio"
|
428 |
)
|
429 |
noise_scale = gr.Slider(
|
430 |
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
|
431 |
)
|
432 |
noise_scale_w = gr.Slider(
|
433 |
+
minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W"
|
434 |
)
|
435 |
length_scale = gr.Slider(
|
436 |
minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
|
|
|
440 |
)
|
441 |
btn = gr.Button("生成音频!", variant="primary")
|
442 |
with gr.Column():
|
443 |
+
with gr.Accordion("融合文本语义", open=False):
|
444 |
+
gr.Markdown(
|
445 |
+
value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
|
446 |
+
"**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
|
447 |
+
"效果较不明确,留空即为不使用该功能"
|
448 |
+
)
|
449 |
+
style_text = gr.Textbox(label="辅助文本")
|
450 |
+
style_weight = gr.Slider(
|
451 |
+
minimum=0,
|
452 |
+
maximum=1,
|
453 |
+
value=0.7,
|
454 |
+
step=0.1,
|
455 |
+
label="Weight",
|
456 |
+
info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
|
457 |
+
)
|
458 |
with gr.Row():
|
459 |
with gr.Column():
|
460 |
interval_between_sent = gr.Slider(
|
|
|
497 |
audio_prompt,
|
498 |
text_prompt,
|
499 |
prompt_mode,
|
500 |
+
style_text,
|
501 |
+
style_weight,
|
502 |
],
|
503 |
outputs=[text_output, audio_output],
|
504 |
)
|
|
|
523 |
interval_between_sent,
|
524 |
audio_prompt,
|
525 |
text_prompt,
|
526 |
+
style_text,
|
527 |
+
style_weight,
|
528 |
],
|
529 |
outputs=[text_output, audio_output],
|
530 |
)
|
|
|
541 |
outputs=[audio_prompt],
|
542 |
)
|
543 |
|
544 |
+
formatter.click(
|
545 |
+
format_utils,
|
546 |
+
inputs=[text, speaker],
|
547 |
+
outputs=[language, text],
|
548 |
+
)
|
549 |
+
|
550 |
print("推理页面已开启!")
|
551 |
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
|
552 |
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
|
webui_preprocess.py
CHANGED
@@ -19,9 +19,9 @@ def generate_config(data_dir, batch_size):
|
|
19 |
assert data_dir != "", "数据集名称不能为空"
|
20 |
start_path, _, train_path, val_path, config_path = get_path(data_dir)
|
21 |
if os.path.isfile(config_path):
|
22 |
-
config = json.load(open(config_path))
|
23 |
else:
|
24 |
-
config = json.load(open("configs/config.json"))
|
25 |
config["data"]["training_files"] = train_path
|
26 |
config["data"]["validation_files"] = val_path
|
27 |
config["train"]["batch_size"] = batch_size
|
@@ -44,7 +44,7 @@ def resample(data_dir):
|
|
44 |
in_dir = os.path.join(start_path, "raw")
|
45 |
out_dir = os.path.join(start_path, "wavs")
|
46 |
subprocess.run(
|
47 |
-
f"python
|
48 |
f"--sr 44100 "
|
49 |
f"--in_dir {in_dir} "
|
50 |
f"--out_dir {out_dir} ",
|
@@ -60,7 +60,9 @@ def preprocess_text(data_dir):
|
|
60 |
with open(lbl_path, "w", encoding="utf-8") as f:
|
61 |
for line in lines:
|
62 |
path, spk, language, text = line.strip().split("|")
|
63 |
-
path = os.path.join(start_path, "wavs", os.path.basename(path))
|
|
|
|
|
64 |
f.writelines(f"{path}|{spk}|{language}|{text}\n")
|
65 |
subprocess.run(
|
66 |
f"python preprocess_text.py "
|
@@ -83,16 +85,6 @@ def bert_gen(data_dir):
|
|
83 |
return "BERT 特征文件生成完成"
|
84 |
|
85 |
|
86 |
-
def clap_gen(data_dir):
|
87 |
-
assert data_dir != "", "数据集名称不能为空"
|
88 |
-
_, _, _, _, config_path = get_path(data_dir)
|
89 |
-
subprocess.run(
|
90 |
-
f"python clap_gen.py " f"--config {config_path}",
|
91 |
-
shell=True,
|
92 |
-
)
|
93 |
-
return "CLAP 特征文件生成完成"
|
94 |
-
|
95 |
-
|
96 |
if __name__ == "__main__":
|
97 |
with gr.Blocks() as app:
|
98 |
with gr.Row():
|
@@ -100,13 +92,13 @@ if __name__ == "__main__":
|
|
100 |
_ = gr.Markdown(
|
101 |
value="# Bert-VITS2 数据预处理\n"
|
102 |
"## 预先准备:\n"
|
103 |
-
"下载 BERT 和
|
104 |
"- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
|
105 |
"- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
|
106 |
"- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
|
107 |
-
"- [
|
108 |
"\n"
|
109 |
-
"将 BERT 模型放置到 `bert` 文件夹下,
|
110 |
"\n"
|
111 |
"数据准备:\n"
|
112 |
"将数据放置在 data 文件夹下,按照如下结构组织:\n"
|
@@ -156,12 +148,10 @@ if __name__ == "__main__":
|
|
156 |
preprocess_text_btn = gr.Button(value="执行", variant="primary")
|
157 |
_ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
|
158 |
bert_gen_btn = gr.Button(value="执行", variant="primary")
|
159 |
-
_ = gr.Markdown(value="## 第五步:生成 CLAP 特征文件")
|
160 |
-
clap_gen_btn = gr.Button(value="执行", variant="primary")
|
161 |
_ = gr.Markdown(
|
162 |
value="## 训练模型及部署:\n"
|
163 |
"修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
|
164 |
-
"- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
|
165 |
"- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
|
166 |
)
|
167 |
|
@@ -171,7 +161,6 @@ if __name__ == "__main__":
|
|
171 |
resample_btn.click(resample, inputs=[data_dir], outputs=[info])
|
172 |
preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
|
173 |
bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
|
174 |
-
clap_gen_btn.click(clap_gen, inputs=[data_dir], outputs=[info])
|
175 |
|
176 |
webbrowser.open("http://127.0.0.1:7860")
|
177 |
app.launch(share=False, server_port=7860)
|
|
|
19 |
assert data_dir != "", "数据集名称不能为空"
|
20 |
start_path, _, train_path, val_path, config_path = get_path(data_dir)
|
21 |
if os.path.isfile(config_path):
|
22 |
+
config = json.load(open(config_path, "r", encoding="utf-8"))
|
23 |
else:
|
24 |
+
config = json.load(open("configs/config.json", "r", encoding="utf-8"))
|
25 |
config["data"]["training_files"] = train_path
|
26 |
config["data"]["validation_files"] = val_path
|
27 |
config["train"]["batch_size"] = batch_size
|
|
|
44 |
in_dir = os.path.join(start_path, "raw")
|
45 |
out_dir = os.path.join(start_path, "wavs")
|
46 |
subprocess.run(
|
47 |
+
f"python resample_legacy.py "
|
48 |
f"--sr 44100 "
|
49 |
f"--in_dir {in_dir} "
|
50 |
f"--out_dir {out_dir} ",
|
|
|
60 |
with open(lbl_path, "w", encoding="utf-8") as f:
|
61 |
for line in lines:
|
62 |
path, spk, language, text = line.strip().split("|")
|
63 |
+
path = os.path.join(start_path, "wavs", os.path.basename(path)).replace(
|
64 |
+
"\\", "/"
|
65 |
+
)
|
66 |
f.writelines(f"{path}|{spk}|{language}|{text}\n")
|
67 |
subprocess.run(
|
68 |
f"python preprocess_text.py "
|
|
|
85 |
return "BERT 特征文件生成完成"
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if __name__ == "__main__":
|
89 |
with gr.Blocks() as app:
|
90 |
with gr.Row():
|
|
|
92 |
_ = gr.Markdown(
|
93 |
value="# Bert-VITS2 数据预处理\n"
|
94 |
"## 预先准备:\n"
|
95 |
+
"下载 BERT 和 WavLM 模型:\n"
|
96 |
"- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
|
97 |
"- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
|
98 |
"- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
|
99 |
+
"- [WavLM](https://huggingface.co/microsoft/wavlm-base-plus)\n"
|
100 |
"\n"
|
101 |
+
"将 BERT 模型放置到 `bert` 文件夹下,WavLM 模型放置到 `slm` 文件夹下,覆盖同名文件夹。\n"
|
102 |
"\n"
|
103 |
"数据准备:\n"
|
104 |
"将数据放置在 data 文件夹下,按照如下结构组织:\n"
|
|
|
148 |
preprocess_text_btn = gr.Button(value="执行", variant="primary")
|
149 |
_ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
|
150 |
bert_gen_btn = gr.Button(value="执行", variant="primary")
|
|
|
|
|
151 |
_ = gr.Markdown(
|
152 |
value="## 训练模型及部署:\n"
|
153 |
"修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
|
154 |
+
"- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth`、`WD_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
|
155 |
"- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
|
156 |
)
|
157 |
|
|
|
161 |
resample_btn.click(resample, inputs=[data_dir], outputs=[info])
|
162 |
preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
|
163 |
bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
|
|
|
164 |
|
165 |
webbrowser.open("http://127.0.0.1:7860")
|
166 |
app.launch(share=False, server_port=7860)
|