zhzluke96 commited on
Commit
da8d589
1 Parent(s): c4c6bff
modules/Denoiser/AudioDenoiser.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Union
4
+ import torch
5
+ import torchaudio
6
+ from torch import nn
7
+ from audio_denoiser.helpers.torch_helper import batched_apply
8
+ from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
9
+ from audio_denoiser.helpers.audio_helper import (
10
+ create_spectrogram,
11
+ reconstruct_from_spectrogram,
12
+ )
13
+
14
+ _expected_t_std = 0.23
15
+ _recommended_backend = "soundfile"
16
+
17
+
18
+ # ref: https://github.com/jose-solorzano/audio-denoiser
19
+ class AudioDenoiser:
20
+ def __init__(
21
+ self,
22
+ local_dir: str,
23
+ device: Union[str, torch.device] = None,
24
+ num_iterations: int = 100,
25
+ ):
26
+ super().__init__()
27
+ if device is None:
28
+ is_cuda = torch.cuda.is_available()
29
+ if not is_cuda:
30
+ logging.warning("CUDA not available. Will use CPU.")
31
+ device = torch.device("cuda:0") if is_cuda else torch.device("cpu")
32
+ self.device = device
33
+ self.model = load_audio_denosier_model(dir_path=local_dir, device=device)
34
+ self.model.eval()
35
+ self.model_sample_rate = self.model.sample_rate
36
+ self.scaler = self.model.scaler
37
+ self.n_fft = self.model.n_fft
38
+ self.segment_num_frames = self.model.num_frames
39
+ self.num_iterations = num_iterations
40
+
41
+ @staticmethod
42
+ def _sp_log(spectrogram: torch.Tensor, eps=0.01):
43
+ return torch.log(spectrogram + eps)
44
+
45
+ @staticmethod
46
+ def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01):
47
+ return torch.clamp(torch.exp(log_spectrogram) - eps, min=0)
48
+
49
+ @staticmethod
50
+ def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float:
51
+ # Expected for training data is ~0.23
52
+ abs_waveform = torch.abs(waveform)
53
+ quantile_value = torch.quantile(abs_waveform, q).item()
54
+ trimmed_values = waveform[abs_waveform >= quantile_value]
55
+ return torch.std(trimmed_values).item()
56
+
57
+ def process_waveform(
58
+ self,
59
+ waveform: torch.Tensor,
60
+ sample_rate: int,
61
+ return_cpu_tensor: bool = False,
62
+ auto_scale: bool = False,
63
+ ) -> torch.Tensor:
64
+ """
65
+ Denoises a waveform.
66
+ @param waveform: A waveform tensor. Use torchaudio structure.
67
+ @param sample_rate: The sample rate of the waveform in Hz.
68
+ @param return_cpu_tensor: Whether the returned tensor must be a CPU tensor.
69
+ @param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio.
70
+ @return: A denoised waveform.
71
+ """
72
+ waveform = waveform.cpu()
73
+ if auto_scale:
74
+ w_t_std = self._trimmed_dev(waveform)
75
+ waveform = waveform * _expected_t_std / w_t_std
76
+ if sample_rate != self.model_sample_rate:
77
+ transform = torchaudio.transforms.Resample(
78
+ orig_freq=sample_rate, new_freq=self.model_sample_rate
79
+ )
80
+ waveform = transform(waveform)
81
+ hop_len = self.n_fft // 2
82
+ spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len)
83
+ spectrogram = spectrogram.to(self.device)
84
+ num_a_channels = spectrogram.size(0)
85
+ with torch.no_grad():
86
+ results = []
87
+ for c in range(num_a_channels):
88
+ c_spectrogram = spectrogram[c]
89
+ # c_spectrogram: (257, num_frames)
90
+ fft_size, num_frames = c_spectrogram.shape
91
+ num_segments = math.ceil(num_frames / self.segment_num_frames)
92
+ adj_num_frames = num_segments * self.segment_num_frames
93
+ if adj_num_frames > num_frames:
94
+ c_spectrogram = nn.functional.pad(
95
+ c_spectrogram, (0, adj_num_frames - num_frames)
96
+ )
97
+ c_spectrogram = c_spectrogram.view(
98
+ fft_size, num_segments, self.segment_num_frames
99
+ )
100
+ # c_spectrogram: (257, num_segments, 32)
101
+ c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2))
102
+ # c_spectrogram: (num_segments, 257, 32)
103
+ log_c_spectrogram = self._sp_log(c_spectrogram)
104
+ scaled_log_c_sp = self.scaler(log_c_spectrogram)
105
+ pred_noise_log_sp = batched_apply(
106
+ self.model, scaled_log_c_sp, detached=True
107
+ )
108
+ log_denoised_sp = log_c_spectrogram - pred_noise_log_sp
109
+ denoised_sp = self._sp_exp(log_denoised_sp)
110
+ # denoised_sp: (num_segments, 257, 32)
111
+ denoised_sp = torch.permute(denoised_sp, (1, 0, 2))
112
+ # denoised_sp: (257, num_segments, 32)
113
+ denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames)
114
+ # denoised_sp: (1, 257, adj_num_frames)
115
+ denoised_sp = denoised_sp[:, :, :num_frames]
116
+ denoised_sp = denoised_sp.cpu()
117
+ denoised_waveform = reconstruct_from_spectrogram(
118
+ denoised_sp, num_iterations=self.num_iterations
119
+ )
120
+ # denoised_waveform: (1, num_samples)
121
+ results.append(denoised_waveform)
122
+ cpu_results = torch.cat(results)
123
+ return cpu_results if return_cpu_tensor else cpu_results.to(self.device)
124
+
125
+ def process_audio_file(
126
+ self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False
127
+ ):
128
+ """
129
+ Denoises an audio file.
130
+ @param in_audio_file: An input audio file with a format supported by torchaudio.
131
+ @param out_audio_file: Am output audio file with a format supported by torchaudio.
132
+ @param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio.
133
+ """
134
+ waveform, sample_rate = torchaudio.load(in_audio_file)
135
+ denoised_waveform = self.process_waveform(
136
+ waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale
137
+ )
138
+ torchaudio.save(
139
+ out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate
140
+ )
modules/Denoiser/AudioNosiseModel.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from audio_denoiser.modules.Permute import Permute
5
+ from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
6
+ from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
7
+
8
+ import json
9
+
10
+
11
+ class AudioNoiseModel(nn.Module):
12
+ def __init__(self, config: dict):
13
+ super(AudioNoiseModel, self).__init__()
14
+
15
+ # Encoder layers
16
+ self.config = config
17
+ scaler_dict = config["scaler"]
18
+ self.scaler = SpectrogramScaler.from_dict(scaler_dict)
19
+ self.in_channels = config.get("in_channels", 257)
20
+ self.roberta_hidden_size = config.get("roberta_hidden_size", 768)
21
+ self.model1 = nn.Sequential(
22
+ nn.Conv1d(self.in_channels, 1024, kernel_size=1),
23
+ nn.ELU(),
24
+ nn.Conv1d(1024, 1024, kernel_size=1),
25
+ nn.ELU(),
26
+ nn.Conv1d(1024, self.in_channels, kernel_size=1),
27
+ )
28
+ self.model2 = nn.Sequential(
29
+ Permute(0, 2, 1),
30
+ nn.Linear(self.in_channels, self.roberta_hidden_size),
31
+ SimpleRoberta(num_hidden_layers=5, hidden_size=self.roberta_hidden_size),
32
+ nn.Linear(self.roberta_hidden_size, self.in_channels),
33
+ Permute(0, 2, 1),
34
+ )
35
+
36
+ @property
37
+ def sample_rate(self) -> int:
38
+ return self.config.get("sample_rate", 16000)
39
+
40
+ @property
41
+ def n_fft(self) -> int:
42
+ return self.config.get("n_fft", 512)
43
+
44
+ @property
45
+ def num_frames(self) -> int:
46
+ return self.config.get("num_frames", 32)
47
+
48
+ def forward(self, x, use_scaler: bool = False, out_scale: float = 1.0):
49
+ if use_scaler:
50
+ x = self.scaler(x)
51
+ x1 = self.model1(x)
52
+ x2 = self.model2(x)
53
+ x = x1 + x2
54
+ return x * out_scale
55
+
56
+
57
+ def load_audio_denosier_model(dir_path: str, device) -> AudioNoiseModel:
58
+ config = json.load(open(f"{dir_path}/config.json", "r"))
59
+ model = AudioNoiseModel(config)
60
+ model.load_state_dict(torch.load(f"{dir_path}/pytorch_model.bin"))
61
+
62
+ model.to(device)
63
+ model.model1.to(device)
64
+ model.model2.to(device)
65
+
66
+ return model
modules/Denoiser/__init__.py ADDED
File without changes
modules/Enhancer/ResembleEnhance.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from resemble_enhance.enhancer.enhancer import Enhancer
4
+ from resemble_enhance.enhancer.hparams import HParams
5
+ from resemble_enhance.inference import inference
6
+
7
+ import torch
8
+
9
+ from modules.utils.constants import MODELS_DIR
10
+ from pathlib import Path
11
+
12
+ from threading import Lock
13
+
14
+ resemble_enhance = None
15
+ lock = Lock()
16
+
17
+
18
+ def load_enhancer(device: torch.device):
19
+ global resemble_enhance
20
+ with lock:
21
+ if resemble_enhance is None:
22
+ resemble_enhance = ResembleEnhance(device)
23
+ resemble_enhance.load_model()
24
+ return resemble_enhance
25
+
26
+
27
+ class ResembleEnhance:
28
+ hparams: HParams
29
+ enhancer: Enhancer
30
+
31
+ def __init__(self, device: torch.device):
32
+ self.device = device
33
+
34
+ self.enhancer = None
35
+ self.hparams = None
36
+
37
+ def load_model(self):
38
+ hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
39
+ enhancer = Enhancer(hparams)
40
+ state_dict = torch.load(
41
+ Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
42
+ map_location="cpu",
43
+ )["module"]
44
+ enhancer.load_state_dict(state_dict)
45
+ enhancer.eval()
46
+ enhancer.to(self.device)
47
+ enhancer.denoiser.to(self.device)
48
+
49
+ self.hparams = hparams
50
+ self.enhancer = enhancer
51
+
52
+ @torch.inference_mode()
53
+ def denoise(self, dwav, sr, device) -> tuple[torch.Tensor, int]:
54
+ assert self.enhancer is not None, "Model not loaded"
55
+ assert self.enhancer.denoiser is not None, "Denoiser not loaded"
56
+ enhancer = self.enhancer
57
+ return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
58
+
59
+ @torch.inference_mode()
60
+ def enhance(
61
+ self,
62
+ dwav,
63
+ sr,
64
+ device,
65
+ nfe=32,
66
+ solver="midpoint",
67
+ lambd=0.5,
68
+ tau=0.5,
69
+ ) -> tuple[torch.Tensor, int]:
70
+ assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
71
+ assert solver in (
72
+ "midpoint",
73
+ "rk4",
74
+ "euler",
75
+ ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
76
+ assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
77
+ assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
78
+ assert self.enhancer is not None, "Model not loaded"
79
+ enhancer = self.enhancer
80
+ enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
81
+ return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ import torchaudio
86
+ from modules.models import load_chat_tts
87
+
88
+ load_chat_tts()
89
+
90
+ device = torch.device("cuda")
91
+ ench = ResembleEnhance(device)
92
+ ench.load_model()
93
+
94
+ wav, sr = torchaudio.load("test.wav")
95
+
96
+ print(wav.shape, type(wav), sr, type(sr))
97
+ exit()
98
+
99
+ wav = wav.squeeze(0).cuda()
100
+
101
+ print(wav.device)
102
+
103
+ denoised, d_sr = ench.denoise(wav.cpu(), sr, device)
104
+ denoised = denoised.unsqueeze(0)
105
+ print(denoised.shape)
106
+ torchaudio.save("denoised.wav", denoised, d_sr)
107
+
108
+ for solver in ("midpoint", "rk4", "euler"):
109
+ for lambd in (0.1, 0.5, 0.9):
110
+ for tau in (0.1, 0.5, 0.9):
111
+ enhanced, e_sr = ench.enhance(
112
+ wav.cpu(), sr, device, solver=solver, lambd=lambd, tau=tau, nfe=128
113
+ )
114
+ enhanced = enhanced.unsqueeze(0)
115
+ print(enhanced.shape)
116
+ torchaudio.save(f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced, e_sr)
modules/Enhancer/__init__.py ADDED
File without changes
modules/SynthesizeSegments.py CHANGED
@@ -1,17 +1,18 @@
 
1
  from pydub import AudioSegment
2
- from typing import Any, List, Dict, Union
3
  from scipy.io.wavfile import write
4
  import io
 
 
5
  from modules.utils import rng
6
  from modules.utils.audio import time_stretch, pitch_shift
7
  from modules import generate_audio
8
  from modules.normalization import text_normalize
9
  import logging
10
  import json
11
- import copy
12
- import numpy as np
13
 
14
- from modules.speaker import Speaker
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -24,7 +25,7 @@ def audio_data_to_segment(audio_data, sr):
24
  return AudioSegment.from_file(byte_io, format="wav")
25
 
26
 
27
- def combine_audio_segments(audio_segments: list) -> AudioSegment:
28
  combined_audio = AudioSegment.empty()
29
  for segment in audio_segments:
30
  combined_audio += segment
@@ -54,230 +55,191 @@ def to_number(value, t, default=0):
54
  return default
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  class SynthesizeSegments:
58
  def __init__(self, batch_size: int = 8):
59
  self.batch_size = batch_size
60
  self.batch_default_spk_seed = rng.np_rng()
61
  self.batch_default_infer_seed = rng.np_rng()
62
 
63
- def segment_to_generate_params(self, segment: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
64
  if segment.get("params", None) is not None:
65
- return segment["params"]
66
 
67
  text = segment.get("text", "")
68
  is_end = segment.get("is_end", False)
69
 
70
  text = str(text).strip()
71
 
72
- attrs = segment.get("attrs", {})
73
- spk = attrs.get("spk", "")
74
- if isinstance(spk, str):
75
- spk = int(spk)
76
- seed = to_number(attrs.get("seed", ""), int, -1)
77
- top_k = to_number(attrs.get("top_k", ""), int, None)
78
- top_p = to_number(attrs.get("top_p", ""), float, None)
79
- temp = to_number(attrs.get("temp", ""), float, None)
80
-
81
- prompt1 = attrs.get("prompt1", "")
82
- prompt2 = attrs.get("prompt2", "")
83
- prefix = attrs.get("prefix", "")
 
 
 
 
 
84
  disable_normalize = attrs.get("normalize", "") == "False"
85
 
86
- params = {
87
- "text": text,
88
- "temperature": temp if temp is not None else 0.3,
89
- "top_P": top_p if top_p is not None else 0.5,
90
- "top_K": top_k if top_k is not None else 20,
91
- "spk": spk if spk else -1,
92
- "infer_seed": seed if seed else -1,
93
- "prompt1": prompt1 if prompt1 else "",
94
- "prompt2": prompt2 if prompt2 else "",
95
- "prefix": prefix if prefix else "",
96
- }
 
97
 
98
  if not disable_normalize:
99
- params["text"] = text_normalize(text, is_end=is_end)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Set default values for spk and infer_seed
102
- if params["spk"] == -1:
103
- params["spk"] = self.batch_default_spk_seed
104
- if params["infer_seed"] == -1:
105
- params["infer_seed"] = self.batch_default_infer_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- return params
 
 
 
108
 
109
  def bucket_segments(
110
- self, segments: List[Dict[str, Any]]
111
- ) -> List[List[Dict[str, Any]]]:
112
- # Create a dictionary to hold buckets
113
- buckets = {}
114
  for segment in segments:
 
 
 
 
115
  params = self.segment_to_generate_params(segment)
116
 
117
- key_params = copy.copy(params)
118
- if isinstance(key_params.get("spk"), Speaker):
119
- key_params["spk"] = str(key_params["spk"].id)
120
  key = json.dumps(
121
- {k: v for k, v in key_params.items() if k != "text"}, sort_keys=True
122
  )
123
  if key not in buckets:
124
  buckets[key] = []
125
  buckets[key].append(segment)
126
 
127
- # Convert dictionary to list of buckets
128
- bucket_list = list(buckets.values())
129
- return bucket_list
130
 
131
- def synthesize_segments(self, segments: List[Dict[str, Any]]) -> List[AudioSegment]:
132
- audio_segments = [None] * len(
133
- segments
134
- ) # Create a list with the same length as segments
135
  buckets = self.bucket_segments(segments)
136
- logger.debug(f"segments len: {len(segments)}")
137
- logger.debug(f"bucket pool size: {len(buckets)}")
138
- for bucket in buckets:
139
- for i in range(0, len(bucket), self.batch_size):
140
- batch = bucket[i : i + self.batch_size]
141
- param_arr = [
142
- self.segment_to_generate_params(segment) for segment in batch
143
- ]
144
- texts = [params["text"] for params in param_arr]
145
-
146
- params = param_arr[0] # Use the first segment to get the parameters
147
- audio_datas = generate_audio.generate_audio_batch(
148
- texts=texts,
149
- temperature=params["temperature"],
150
- top_P=params["top_P"],
151
- top_K=params["top_K"],
152
- spk=params["spk"],
153
- infer_seed=params["infer_seed"],
154
- prompt1=params["prompt1"],
155
- prompt2=params["prompt2"],
156
- prefix=params["prefix"],
157
- )
158
- for idx, segment in enumerate(batch):
159
- (sr, audio_data) = audio_datas[idx]
160
- rate = float(segment.get("rate", "1.0"))
161
- volume = float(segment.get("volume", "0"))
162
- pitch = float(segment.get("pitch", "0"))
163
-
164
- audio_segment = audio_data_to_segment(audio_data, sr)
165
- audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
166
- original_index = segments.index(
167
- segment
168
- ) # Get the original index of the segment
169
- audio_segments[original_index] = (
170
- audio_segment # Place the audio_segment in the correct position
171
- )
172
-
173
- return audio_segments
174
 
 
 
175
 
176
- def generate_audio_segment(
177
- text: str,
178
- spk: int = -1,
179
- seed: int = -1,
180
- top_p: float = 0.5,
181
- top_k: int = 20,
182
- temp: float = 0.3,
183
- prompt1: str = "",
184
- prompt2: str = "",
185
- prefix: str = "",
186
- enable_normalize=True,
187
- is_end: bool = False,
188
- ) -> AudioSegment:
189
- if enable_normalize:
190
- text = text_normalize(text, is_end=is_end)
191
-
192
- logger.debug(f"generate segment: {text}")
193
-
194
- sample_rate, audio_data = generate_audio.generate_audio(
195
- text=text,
196
- temperature=temp if temp is not None else 0.3,
197
- top_P=top_p if top_p is not None else 0.5,
198
- top_K=top_k if top_k is not None else 20,
199
- spk=spk if spk else -1,
200
- infer_seed=seed if seed else -1,
201
- prompt1=prompt1 if prompt1 else "",
202
- prompt2=prompt2 if prompt2 else "",
203
- prefix=prefix if prefix else "",
204
- )
205
-
206
- byte_io = io.BytesIO()
207
- write(byte_io, sample_rate, audio_data)
208
- byte_io.seek(0)
209
 
210
- return AudioSegment.from_file(byte_io, format="wav")
211
-
212
-
213
- def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
214
- if "break" in segment:
215
- pause_segment = AudioSegment.silent(duration=segment["break"])
216
- return pause_segment
217
-
218
- attrs = segment.get("attrs", {})
219
- text = segment.get("text", "")
220
- is_end = segment.get("is_end", False)
221
-
222
- text = str(text).strip()
223
-
224
- if text == "":
225
- return None
226
-
227
- spk = attrs.get("spk", "")
228
- if isinstance(spk, str):
229
- spk = int(spk)
230
- seed = to_number(attrs.get("seed", ""), int, -1)
231
- top_k = to_number(attrs.get("top_k", ""), int, None)
232
- top_p = to_number(attrs.get("top_p", ""), float, None)
233
- temp = to_number(attrs.get("temp", ""), float, None)
234
-
235
- prompt1 = attrs.get("prompt1", "")
236
- prompt2 = attrs.get("prompt2", "")
237
- prefix = attrs.get("prefix", "")
238
- disable_normalize = attrs.get("normalize", "") == "False"
239
-
240
- audio_segment = generate_audio_segment(
241
- text,
242
- enable_normalize=not disable_normalize,
243
- spk=spk,
244
- seed=seed,
245
- top_k=top_k,
246
- top_p=top_p,
247
- temp=temp,
248
- prompt1=prompt1,
249
- prompt2=prompt2,
250
- prefix=prefix,
251
- is_end=is_end,
252
- )
253
-
254
- rate = float(attrs.get("rate", "1.0"))
255
- volume = float(attrs.get("volume", "0"))
256
- pitch = float(attrs.get("pitch", "0"))
257
-
258
- audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
259
 
260
- return audio_segment
261
 
262
 
263
  # 示例使用
264
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
265
  ssml_segments = [
266
- {
267
- "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
268
- "attrs": {"spk": 2, "temp": 0.1, "seed": 42},
269
- },
270
- {
271
- "text": "大🍉,一个大🍉,嘿,你的感觉真的很奇妙 [lbreak]",
272
- "attrs": {"spk": 2, "temp": 0.1, "seed": 42},
273
- },
274
- {
275
- "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
276
- "attrs": {"spk": 2, "temp": 0.3, "seed": 42},
277
- },
278
  ]
279
 
280
  synthesizer = SynthesizeSegments(batch_size=2)
281
  audio_segments = synthesizer.synthesize_segments(ssml_segments)
 
282
  combined_audio = combine_audio_segments(audio_segments)
283
  combined_audio.export("output.wav", format="wav")
 
1
+ from box import Box
2
  from pydub import AudioSegment
3
+ from typing import List, Union
4
  from scipy.io.wavfile import write
5
  import io
6
+ from modules.api.utils import calc_spk_style
7
+ from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
8
  from modules.utils import rng
9
  from modules.utils.audio import time_stretch, pitch_shift
10
  from modules import generate_audio
11
  from modules.normalization import text_normalize
12
  import logging
13
  import json
 
 
14
 
15
+ from modules.speaker import Speaker, speaker_mgr
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
25
  return AudioSegment.from_file(byte_io, format="wav")
26
 
27
 
28
+ def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
29
  combined_audio = AudioSegment.empty()
30
  for segment in audio_segments:
31
  combined_audio += segment
 
55
  return default
56
 
57
 
58
+ class TTSAudioSegment(Box):
59
+ text: str
60
+ temperature: float
61
+ top_P: float
62
+ top_K: int
63
+ spk: int
64
+ infer_seed: int
65
+ prompt1: str
66
+ prompt2: str
67
+ prefix: str
68
+
69
+ _type: str
70
+
71
+ def __init__(self, *args, **kwargs):
72
+ super().__init__(*args, **kwargs)
73
+
74
+
75
  class SynthesizeSegments:
76
  def __init__(self, batch_size: int = 8):
77
  self.batch_size = batch_size
78
  self.batch_default_spk_seed = rng.np_rng()
79
  self.batch_default_infer_seed = rng.np_rng()
80
 
81
+ def segment_to_generate_params(
82
+ self, segment: Union[SSMLSegment, SSMLBreak]
83
+ ) -> TTSAudioSegment:
84
+ if isinstance(segment, SSMLBreak):
85
+ return TTSAudioSegment(_type="break")
86
+
87
  if segment.get("params", None) is not None:
88
+ return TTSAudioSegment(**segment.get("params"))
89
 
90
  text = segment.get("text", "")
91
  is_end = segment.get("is_end", False)
92
 
93
  text = str(text).strip()
94
 
95
+ attrs = segment.attrs
96
+ spk = attrs.spk
97
+ style = attrs.style
98
+
99
+ ss_params = calc_spk_style(spk, style)
100
+
101
+ if "spk" in ss_params:
102
+ spk = ss_params["spk"]
103
+
104
+ seed = to_number(attrs.seed, int, ss_params.get("seed") or -1)
105
+ top_k = to_number(attrs.top_k, int, None)
106
+ top_p = to_number(attrs.top_p, float, None)
107
+ temp = to_number(attrs.temp, float, None)
108
+
109
+ prompt1 = attrs.prompt1 or ss_params.get("prompt1")
110
+ prompt2 = attrs.prompt2 or ss_params.get("prompt2")
111
+ prefix = attrs.prefix or ss_params.get("prefix")
112
  disable_normalize = attrs.get("normalize", "") == "False"
113
 
114
+ seg = TTSAudioSegment(
115
+ _type="voice",
116
+ text=text,
117
+ temperature=temp if temp is not None else 0.3,
118
+ top_P=top_p if top_p is not None else 0.5,
119
+ top_K=top_k if top_k is not None else 20,
120
+ spk=spk if spk else -1,
121
+ infer_seed=seed if seed else -1,
122
+ prompt1=prompt1 if prompt1 else "",
123
+ prompt2=prompt2 if prompt2 else "",
124
+ prefix=prefix if prefix else "",
125
+ )
126
 
127
  if not disable_normalize:
128
+ seg.text = text_normalize(text, is_end=is_end)
129
+
130
+ # NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况
131
+ if seg.spk == -1:
132
+ seg.spk = self.batch_default_spk_seed
133
+ if seg.infer_seed == -1:
134
+ seg.infer_seed = self.batch_default_infer_seed
135
+
136
+ return seg
137
+
138
+ def process_break_segments(
139
+ self,
140
+ src_segments: List[SSMLBreak],
141
+ bucket_segments: List[SSMLBreak],
142
+ audio_segments: List[AudioSegment],
143
+ ):
144
+ for segment in bucket_segments:
145
+ index = src_segments.index(segment)
146
+ audio_segments[index] = AudioSegment.silent(
147
+ duration=int(segment.attrs.duration)
148
+ )
149
 
150
+ def process_voice_segments(
151
+ self,
152
+ src_segments: List[SSMLSegment],
153
+ bucket: List[SSMLSegment],
154
+ audio_segments: List[AudioSegment],
155
+ ):
156
+ for i in range(0, len(bucket), self.batch_size):
157
+ batch = bucket[i : i + self.batch_size]
158
+ param_arr = [self.segment_to_generate_params(segment) for segment in batch]
159
+ texts = [params.text for params in param_arr]
160
+
161
+ params = param_arr[0]
162
+ audio_datas = generate_audio.generate_audio_batch(
163
+ texts=texts,
164
+ temperature=params.temperature,
165
+ top_P=params.top_P,
166
+ top_K=params.top_K,
167
+ spk=params.spk,
168
+ infer_seed=params.infer_seed,
169
+ prompt1=params.prompt1,
170
+ prompt2=params.prompt2,
171
+ prefix=params.prefix,
172
+ )
173
+ for idx, segment in enumerate(batch):
174
+ sr, audio_data = audio_datas[idx]
175
+ rate = float(segment.get("rate", "1.0"))
176
+ volume = float(segment.get("volume", "0"))
177
+ pitch = float(segment.get("pitch", "0"))
178
 
179
+ audio_segment = audio_data_to_segment(audio_data, sr)
180
+ audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
181
+ original_index = src_segments.index(segment)
182
+ audio_segments[original_index] = audio_segment
183
 
184
  def bucket_segments(
185
+ self, segments: List[Union[SSMLSegment, SSMLBreak]]
186
+ ) -> List[List[Union[SSMLSegment, SSMLBreak]]]:
187
+ buckets = {"<break>": []}
 
188
  for segment in segments:
189
+ if isinstance(segment, SSMLBreak):
190
+ buckets["<break>"].append(segment)
191
+ continue
192
+
193
  params = self.segment_to_generate_params(segment)
194
 
195
+ if isinstance(params.spk, Speaker):
196
+ params.spk = str(params.spk.id)
197
+
198
  key = json.dumps(
199
+ {k: v for k, v in params.items() if k != "text"}, sort_keys=True
200
  )
201
  if key not in buckets:
202
  buckets[key] = []
203
  buckets[key].append(segment)
204
 
205
+ return buckets
 
 
206
 
207
+ def synthesize_segments(
208
+ self, segments: List[Union[SSMLSegment, SSMLBreak]]
209
+ ) -> List[AudioSegment]:
210
+ audio_segments = [None] * len(segments)
211
  buckets = self.bucket_segments(segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ break_segments = buckets.pop("<break>")
214
+ self.process_break_segments(segments, break_segments, audio_segments)
215
 
216
+ buckets = list(buckets.values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ for bucket in buckets:
219
+ self.process_voice_segments(segments, bucket, audio_segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ return audio_segments
222
 
223
 
224
  # 示例使用
225
  if __name__ == "__main__":
226
+ ctx1 = SSMLContext()
227
+ ctx1.spk = 1
228
+ ctx1.seed = 42
229
+ ctx1.temp = 0.1
230
+ ctx2 = SSMLContext()
231
+ ctx2.spk = 2
232
+ ctx2.seed = 42
233
+ ctx2.temp = 0.1
234
  ssml_segments = [
235
+ SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
236
+ SSMLBreak(duration_ms=1000),
237
+ SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
238
+ SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()),
 
 
 
 
 
 
 
 
239
  ]
240
 
241
  synthesizer = SynthesizeSegments(batch_size=2)
242
  audio_segments = synthesizer.synthesize_segments(ssml_segments)
243
+ print(audio_segments)
244
  combined_audio = combine_audio_segments(audio_segments)
245
  combined_audio.export("output.wav", format="wav")
modules/api/impl/google_api.py CHANGED
@@ -18,7 +18,6 @@ from modules.ssml import parse_ssml
18
  from modules.SynthesizeSegments import (
19
  SynthesizeSegments,
20
  combine_audio_segments,
21
- synthesize_segment,
22
  )
23
 
24
  from modules.api import utils as api_utils
 
18
  from modules.SynthesizeSegments import (
19
  SynthesizeSegments,
20
  combine_audio_segments,
 
21
  )
22
 
23
  from modules.api import utils as api_utils
modules/api/impl/speaker_api.py CHANGED
@@ -7,11 +7,11 @@ from modules.api.Api import APIManager
7
 
8
 
9
  class CreateSpeaker(BaseModel):
10
- seed: int
11
  name: str
12
  gender: str
13
  describe: str
14
- tensor: list
 
15
 
16
 
17
  class UpdateSpeaker(BaseModel):
@@ -76,7 +76,7 @@ def setup(app: APIManager):
76
  gender=request.gender,
77
  describe=request.describe,
78
  )
79
- else:
80
  # from seed
81
  speaker = speaker_mgr.create_speaker_from_seed(
82
  seed=request.seed,
@@ -84,6 +84,10 @@ def setup(app: APIManager):
84
  gender=request.gender,
85
  describe=request.describe,
86
  )
 
 
 
 
87
  return {"message": "ok", "data": speaker.to_json()}
88
 
89
  @app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
 
7
 
8
 
9
  class CreateSpeaker(BaseModel):
 
10
  name: str
11
  gender: str
12
  describe: str
13
+ tensor: list = None
14
+ seed: int = None
15
 
16
 
17
  class UpdateSpeaker(BaseModel):
 
76
  gender=request.gender,
77
  describe=request.describe,
78
  )
79
+ elif request.seed:
80
  # from seed
81
  speaker = speaker_mgr.create_speaker_from_seed(
82
  seed=request.seed,
 
84
  gender=request.gender,
85
  describe=request.describe,
86
  )
87
+ else:
88
+ raise HTTPException(
89
+ status_code=400, detail="Missing tensor or seed in request"
90
+ )
91
  return {"message": "ok", "data": speaker.to_json()}
92
 
93
  @app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
modules/api/impl/ssml_api.py CHANGED
@@ -10,7 +10,6 @@ from modules.normalization import text_normalize
10
  from modules.ssml import parse_ssml
11
  from modules.SynthesizeSegments import (
12
  SynthesizeSegments,
13
- synthesize_segment,
14
  combine_audio_segments,
15
  )
16
 
@@ -23,6 +22,8 @@ from modules.api.Api import APIManager
23
  class SSMLRequest(BaseModel):
24
  ssml: str
25
  format: str = "mp3"
 
 
26
  batch_size: int = 4
27
 
28
 
@@ -48,29 +49,15 @@ async def synthesize_ssml(
48
  for seg in segments:
49
  seg["text"] = text_normalize(seg["text"], is_end=True)
50
 
51
- if batch_size != 1:
52
- synthesize = SynthesizeSegments(batch_size)
53
- audio_segments = synthesize.synthesize_segments(segments)
54
- combined_audio = combine_audio_segments(audio_segments)
55
- buffer = io.BytesIO()
56
- combined_audio.export(buffer, format="wav")
57
- buffer.seek(0)
58
- if format == "mp3":
59
- buffer = api_utils.wav_to_mp3(buffer)
60
- return StreamingResponse(buffer, media_type=f"audio/{format}")
61
- else:
62
-
63
- def audio_streamer():
64
- for segment in segments:
65
- audio_segment = synthesize_segment(segment=segment)
66
- buffer = io.BytesIO()
67
- audio_segment.export(buffer, format="wav")
68
- buffer.seek(0)
69
- if format == "mp3":
70
- buffer = api_utils.wav_to_mp3(buffer)
71
- yield buffer.read()
72
-
73
- return StreamingResponse(audio_streamer(), media_type=f"audio/{format}")
74
 
75
  except Exception as e:
76
  import logging
 
10
  from modules.ssml import parse_ssml
11
  from modules.SynthesizeSegments import (
12
  SynthesizeSegments,
 
13
  combine_audio_segments,
14
  )
15
 
 
22
  class SSMLRequest(BaseModel):
23
  ssml: str
24
  format: str = "mp3"
25
+
26
+ # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
27
  batch_size: int = 4
28
 
29
 
 
49
  for seg in segments:
50
  seg["text"] = text_normalize(seg["text"], is_end=True)
51
 
52
+ synthesize = SynthesizeSegments(batch_size)
53
+ audio_segments = synthesize.synthesize_segments(segments)
54
+ combined_audio = combine_audio_segments(audio_segments)
55
+ buffer = io.BytesIO()
56
+ combined_audio.export(buffer, format="wav")
57
+ buffer.seek(0)
58
+ if format == "mp3":
59
+ buffer = api_utils.wav_to_mp3(buffer)
60
+ return StreamingResponse(buffer, media_type=f"audio/{format}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  except Exception as e:
63
  import logging
modules/api/utils.py CHANGED
@@ -52,7 +52,6 @@ def to_number(value, t, default=0):
52
  def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
53
  voice_attrs = {
54
  "spk": None,
55
- "seed": None,
56
  "prompt1": None,
57
  "prompt2": None,
58
  "prefix": None,
@@ -85,7 +84,6 @@ def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
85
  merge_prompt(voice_attrs, params)
86
 
87
  voice_attrs["spk"] = params.get("spk", voice_attrs.get("spk", None))
88
- voice_attrs["seed"] = params.get("seed", voice_attrs.get("seed", None))
89
  voice_attrs["temperature"] = params.get(
90
  "temp", voice_attrs.get("temperature", None)
91
  )
 
52
  def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
53
  voice_attrs = {
54
  "spk": None,
 
55
  "prompt1": None,
56
  "prompt2": None,
57
  "prefix": None,
 
84
  merge_prompt(voice_attrs, params)
85
 
86
  voice_attrs["spk"] = params.get("spk", voice_attrs.get("spk", None))
 
87
  voice_attrs["temperature"] = params.get(
88
  "temp", voice_attrs.get("temperature", None)
89
  )
modules/denoise.py CHANGED
@@ -1,7 +1,51 @@
1
- from audio_denoiser.AudioDenoiser import AudioDenoiser
 
 
2
  import torch
3
  import torchaudio
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  class TTSAudioDenoiser:
7
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
  import torch
5
  import torchaudio
6
+ from modules.Denoiser.AudioDenoiser import AudioDenoiser
7
+
8
+ from modules.utils.constants import MODELS_DIR
9
+
10
+ from modules.devices import devices
11
+
12
+ import soundfile as sf
13
+
14
+ ad: Union[AudioDenoiser, None] = None
15
 
16
 
17
  class TTSAudioDenoiser:
18
+
19
+ def load_ad(self):
20
+ global ad
21
+ if ad is None:
22
+ ad = AudioDenoiser(
23
+ os.path.join(
24
+ MODELS_DIR,
25
+ "Denoise",
26
+ "audio-denoiser-512-32-v1",
27
+ ),
28
+ device=devices.device,
29
+ )
30
+ ad.model.to(devices.device)
31
+ return ad
32
+
33
+ def denoise(self, audio_data, sample_rate, auto_scale=False):
34
+ ad = self.load_ad()
35
+ sr = ad.model_sample_rate
36
+ return sr, ad.process_waveform(audio_data, sample_rate, auto_scale)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ tts_deno = TTSAudioDenoiser()
41
+ data, sr = sf.read("test.wav")
42
+ audio_tensor = torch.from_numpy(data).unsqueeze(0).float()
43
+ print(audio_tensor)
44
+
45
+ # data, sr = torchaudio.load("test.wav")
46
+ # print(data)
47
+ # data = data.to(devices.device)
48
+
49
+ sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr)
50
+ denoised = denoised.cpu()
51
+ torchaudio.save("denoised.wav", denoised, sample_rate=sr)
modules/generate_audio.py CHANGED
@@ -79,7 +79,7 @@ def generate_audio_batch(
79
  params_infer_code["spk_emb"] = spk.emb
80
  logger.info(("spk", spk.name))
81
  else:
82
- raise ValueError("spk must be int or Speaker")
83
 
84
  logger.info(
85
  {
 
79
  params_infer_code["spk_emb"] = spk.emb
80
  logger.info(("spk", spk.name))
81
  else:
82
+ raise ValueError(f"spk must be int or Speaker, but: <{type(spk)}> {spk}")
83
 
84
  logger.info(
85
  {
modules/models.py CHANGED
@@ -37,17 +37,9 @@ def load_chat_tts_in_thread():
37
  logger.info("ChatTTS models loaded")
38
 
39
 
40
- def initialize_chat_tts():
41
  with lock:
42
  if chat_tts is None:
43
- model_thread = threading.Thread(target=load_chat_tts_in_thread)
44
- model_thread.start()
45
- model_thread.join()
46
-
47
-
48
- def load_chat_tts():
49
- if chat_tts is None:
50
- with lock:
51
  load_chat_tts_in_thread()
52
  if chat_tts is None:
53
  raise Exception("Failed to load ChatTTS models")
 
37
  logger.info("ChatTTS models loaded")
38
 
39
 
40
+ def load_chat_tts():
41
  with lock:
42
  if chat_tts is None:
 
 
 
 
 
 
 
 
43
  load_chat_tts_in_thread()
44
  if chat_tts is None:
45
  raise Exception("Failed to load ChatTTS models")
modules/speaker.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from typing import Union
 
3
  import torch
4
 
5
  from modules import models
@@ -16,6 +17,18 @@ def create_speaker_from_seed(seed):
16
 
17
 
18
  class Speaker:
 
 
 
 
 
 
 
 
 
 
 
 
19
  def __init__(self, seed, name="", gender="", describe=""):
20
  self.id = uuid.uuid4()
21
  self.seed = seed
@@ -24,15 +37,20 @@ class Speaker:
24
  self.describe = describe
25
  self.emb = None
26
 
 
 
 
27
  def to_json(self, with_emb=False):
28
- return {
29
- "id": str(self.id),
30
- "seed": self.seed,
31
- "name": self.name,
32
- "gender": self.gender,
33
- "describe": self.describe,
34
- "emb": self.emb.tolist() if with_emb else None,
35
- }
 
 
36
 
37
  def fix(self):
38
  is_update = False
@@ -78,14 +96,9 @@ class SpeakerManager:
78
  self.speakers = {}
79
  for speaker_file in os.listdir(self.speaker_dir):
80
  if speaker_file.endswith(".pt"):
81
- speaker = torch.load(
82
- self.speaker_dir + speaker_file, map_location=torch.device("cpu")
83
  )
84
- self.speakers[speaker_file] = speaker
85
-
86
- is_update = speaker.fix()
87
- if is_update:
88
- torch.save(speaker, self.speaker_dir + speaker_file)
89
 
90
  def list_speakers(self):
91
  return list(self.speakers.values())
@@ -103,8 +116,8 @@ class SpeakerManager:
103
  def create_speaker_from_tensor(
104
  self, tensor, filename="", name="", gender="", describe=""
105
  ):
106
- if name == "":
107
- name = filename
108
  speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
109
  if isinstance(tensor, torch.Tensor):
110
  speaker.emb = tensor
 
1
  import os
2
  from typing import Union
3
+ from box import Box
4
  import torch
5
 
6
  from modules import models
 
17
 
18
 
19
  class Speaker:
20
+ @staticmethod
21
+ def from_file(file_like):
22
+ speaker = torch.load(file_like, map_location=torch.device("cpu"))
23
+ speaker.fix()
24
+ return speaker
25
+
26
+ @staticmethod
27
+ def from_tensor(tensor):
28
+ speaker = Speaker(seed=-2)
29
+ speaker.emb = tensor
30
+ return speaker
31
+
32
  def __init__(self, seed, name="", gender="", describe=""):
33
  self.id = uuid.uuid4()
34
  self.seed = seed
 
37
  self.describe = describe
38
  self.emb = None
39
 
40
+ # TODO replace emb => tokens
41
+ self.tokens = []
42
+
43
  def to_json(self, with_emb=False):
44
+ return Box(
45
+ **{
46
+ "id": str(self.id),
47
+ "seed": self.seed,
48
+ "name": self.name,
49
+ "gender": self.gender,
50
+ "describe": self.describe,
51
+ "emb": self.emb.tolist() if with_emb else None,
52
+ }
53
+ )
54
 
55
  def fix(self):
56
  is_update = False
 
96
  self.speakers = {}
97
  for speaker_file in os.listdir(self.speaker_dir):
98
  if speaker_file.endswith(".pt"):
99
+ self.speakers[speaker_file] = Speaker.from_file(
100
+ self.speaker_dir + speaker_file
101
  )
 
 
 
 
 
102
 
103
  def list_speakers(self):
104
  return list(self.speakers.values())
 
116
  def create_speaker_from_tensor(
117
  self, tensor, filename="", name="", gender="", describe=""
118
  ):
119
+ if filename == "":
120
+ filename = name
121
  speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
122
  if isinstance(tensor, torch.Tensor):
123
  speaker.emb = tensor
modules/ssml_parser/SSMLParser.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml import etree
2
+
3
+
4
+ from typing import Any, List, Dict, Union
5
+ import logging
6
+
7
+ from modules.data import styles_mgr
8
+ from modules.speaker import speaker_mgr
9
+ from box import Box
10
+ import copy
11
+
12
+
13
+ class SSMLContext(Box):
14
+ def __init__(self, parent=None):
15
+ self.parent: Union[SSMLContext, None] = parent
16
+
17
+ self.style = None
18
+ self.spk = None
19
+ self.volume = None
20
+ self.rate = None
21
+ self.pitch = None
22
+ # tempurature
23
+ self.temp = None
24
+ self.top_p = None
25
+ self.top_k = None
26
+ self.seed = None
27
+ self.noramalize = None
28
+ self.prompt1 = None
29
+ self.prompt2 = None
30
+ self.prefix = None
31
+
32
+
33
+ class SSMLSegment(Box):
34
+ def __init__(self, text: str, attrs=SSMLContext()):
35
+ self.attrs = attrs
36
+ self.text = text
37
+ self.params = None
38
+
39
+
40
+ class SSMLBreak:
41
+ def __init__(self, duration_ms: Union[str, int, float]):
42
+ # TODO 支持其他单位
43
+ duration_ms = int(str(duration_ms).replace("ms", ""))
44
+ self.attrs = Box(**{"duration": duration_ms})
45
+
46
+
47
+ class SSMLParser:
48
+
49
+ def __init__(self):
50
+ self.logger = logging.getLogger(__name__)
51
+ self.logger.debug("SSMLParser.__init__()")
52
+ self.resolvers = []
53
+
54
+ def resolver(self, tag: str):
55
+ def decorator(func):
56
+ self.resolvers.append((tag, func))
57
+ return func
58
+
59
+ return decorator
60
+
61
+ def parse(self, ssml: str) -> List[Union[SSMLSegment, SSMLBreak]]:
62
+ root = etree.fromstring(ssml)
63
+
64
+ root_ctx = SSMLContext()
65
+ segments = []
66
+ self.resolve(root, root_ctx, segments)
67
+
68
+ return segments
69
+
70
+ def resolve(
71
+ self, element: etree.Element, context: SSMLContext, segments: List[SSMLSegment]
72
+ ):
73
+ resolver = [resolver for tag, resolver in self.resolvers if tag == element.tag]
74
+ if len(resolver) == 0:
75
+ raise NotImplementedError(f"Tag {element.tag} not supported.")
76
+ else:
77
+ resolver = resolver[0]
78
+
79
+ resolver(element, context, segments, self)
80
+
81
+
82
+ def create_ssml_parser():
83
+ parser = SSMLParser()
84
+
85
+ @parser.resolver("speak")
86
+ def tag_speak(element, context, segments, parser):
87
+ ctx = copy.deepcopy(context)
88
+
89
+ version = element.get("version")
90
+ if version != "0.1":
91
+ raise ValueError(f"Unsupported SSML version {version}")
92
+
93
+ for child in element:
94
+ parser.resolve(child, ctx, segments)
95
+
96
+ @parser.resolver("voice")
97
+ def tag_voice(element, context, segments, parser):
98
+ ctx = copy.deepcopy(context)
99
+
100
+ ctx.spk = element.get("spk", ctx.spk)
101
+ ctx.style = element.get("style", ctx.style)
102
+ ctx.spk = element.get("spk", ctx.spk)
103
+ ctx.volume = element.get("volume", ctx.volume)
104
+ ctx.rate = element.get("rate", ctx.rate)
105
+ ctx.pitch = element.get("pitch", ctx.pitch)
106
+ # tempurature
107
+ ctx.temp = element.get("temp", ctx.temp)
108
+ ctx.top_p = element.get("top_p", ctx.top_p)
109
+ ctx.top_k = element.get("top_k", ctx.top_k)
110
+ ctx.seed = element.get("seed", ctx.seed)
111
+ ctx.noramalize = element.get("noramalize", ctx.noramalize)
112
+ ctx.prompt1 = element.get("prompt1", ctx.prompt1)
113
+ ctx.prompt2 = element.get("prompt2", ctx.prompt2)
114
+ ctx.prefix = element.get("prefix", ctx.prefix)
115
+
116
+ # 处理 voice 开头的文本
117
+ if element.text and element.text.strip():
118
+ segments.append(SSMLSegment(element.text.strip(), ctx))
119
+
120
+ for child in element:
121
+ parser.resolve(child, ctx, segments)
122
+
123
+ # 处理 voice 结尾的文本
124
+ if child.tail and child.tail.strip():
125
+ segments.append(SSMLSegment(child.tail.strip(), ctx))
126
+
127
+ @parser.resolver("break")
128
+ def tag_break(element, context, segments, parser):
129
+ time_ms = int(element.get("time", "0").replace("ms", ""))
130
+ segments.append(SSMLBreak(time_ms))
131
+
132
+ @parser.resolver("prosody")
133
+ def tag_prosody(element, context, segments, parser):
134
+ ctx = copy.deepcopy(context)
135
+
136
+ ctx.spk = element.get("spk", ctx.spk)
137
+ ctx.style = element.get("style", ctx.style)
138
+ ctx.spk = element.get("spk", ctx.spk)
139
+ ctx.volume = element.get("volume", ctx.volume)
140
+ ctx.rate = element.get("rate", ctx.rate)
141
+ ctx.pitch = element.get("pitch", ctx.pitch)
142
+ # tempurature
143
+ ctx.temp = element.get("temp", ctx.temp)
144
+ ctx.top_p = element.get("top_p", ctx.top_p)
145
+ ctx.top_k = element.get("top_k", ctx.top_k)
146
+ ctx.seed = element.get("seed", ctx.seed)
147
+ ctx.noramalize = element.get("noramalize", ctx.noramalize)
148
+ ctx.prompt1 = element.get("prompt1", ctx.prompt1)
149
+ ctx.prompt2 = element.get("prompt2", ctx.prompt2)
150
+ ctx.prefix = element.get("prefix", ctx.prefix)
151
+
152
+ if element.text and element.text.strip():
153
+ segments.append(SSMLSegment(element.text.strip(), ctx))
154
+
155
+ return parser
156
+
157
+
158
+ if __name__ == "__main__":
159
+ parser = create_ssml_parser()
160
+
161
+ ssml = """
162
+ <speak version="0.1">
163
+ <voice spk="xiaoyan" style="news">
164
+ <prosody rate="fast">你好</prosody>
165
+ <break time="500ms"/>
166
+ <prosody rate="slow">你好</prosody>
167
+ </voice>
168
+ </speak>
169
+ """
170
+
171
+ segments = parser.parse(ssml)
172
+ for segment in segments:
173
+ if isinstance(segment, SSMLBreak):
174
+ print("<break>", segment.attrs)
175
+ elif isinstance(segment, SSMLSegment):
176
+ print(segment.text, segment.attrs)
177
+ else:
178
+ raise ValueError("Unknown segment type")
modules/ssml_parser/__init__.py ADDED
File without changes
modules/ssml_parser/test_ssml_parser.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from lxml import etree
3
+ from modules.ssml_parser.SSMLParser import (
4
+ create_ssml_parser,
5
+ SSMLSegment,
6
+ SSMLBreak,
7
+ SSMLContext,
8
+ )
9
+
10
+
11
+ @pytest.fixture
12
+ def parser():
13
+ return create_ssml_parser()
14
+
15
+
16
+ @pytest.mark.ssml_parser
17
+ def test_speak_tag(parser):
18
+ ssml = """
19
+ <speak version="0.1">
20
+ <voice spk="xiaoyan" style="news">
21
+ <prosody rate="fast">你好</prosody>
22
+ <break time="500ms"/>
23
+ <prosody rate="slow">你好</prosody>
24
+ </voice>
25
+ </speak>
26
+ """
27
+ segments = parser.parse(ssml)
28
+ assert len(segments) == 3
29
+ assert isinstance(segments[0], SSMLSegment)
30
+ assert segments[0].text == "你好"
31
+ assert segments[0].params.rate == "fast"
32
+ assert isinstance(segments[1], SSMLBreak)
33
+ assert segments[1].duration == 500
34
+ assert isinstance(segments[2], SSMLSegment)
35
+ assert segments[2].text == "你好"
36
+ assert segments[2].params.rate == "slow"
37
+
38
+
39
+ @pytest.mark.ssml_parser
40
+ def test_voice_tag(parser):
41
+ ssml = """
42
+ <speak version="0.1">
43
+ <voice spk="xiaoyan" style="news">你好</voice>
44
+ </speak>
45
+ """
46
+ segments = parser.parse(ssml)
47
+ assert len(segments) == 1
48
+ assert isinstance(segments[0], SSMLSegment)
49
+ assert segments[0].text == "你好"
50
+ assert segments[0].params.spk == "xiaoyan"
51
+ assert segments[0].params.style == "news"
52
+
53
+
54
+ @pytest.mark.ssml_parser
55
+ def test_break_tag(parser):
56
+ ssml = """
57
+ <speak version="0.1">
58
+ <break time="500ms"/>
59
+ </speak>
60
+ """
61
+ segments = parser.parse(ssml)
62
+ assert len(segments) == 1
63
+ assert isinstance(segments[0], SSMLBreak)
64
+ assert segments[0].duration == 500
65
+
66
+
67
+ @pytest.mark.ssml_parser
68
+ def test_prosody_tag(parser):
69
+ ssml = """
70
+ <speak version="0.1">
71
+ <prosody rate="fast">你好</prosody>
72
+ </speak>
73
+ """
74
+ segments = parser.parse(ssml)
75
+ assert len(segments) == 1
76
+ assert isinstance(segments[0], SSMLSegment)
77
+ assert segments[0].text == "你好"
78
+ assert segments[0].params.rate == "fast"
79
+
80
+
81
+ @pytest.mark.ssml_parser
82
+ def test_unsupported_version(parser):
83
+ ssml = """
84
+ <speak version="0.2">
85
+ <voice spk="xiaoyan" style="news">你好</voice>
86
+ </speak>
87
+ """
88
+ with pytest.raises(ValueError, match=r"Unsupported SSML version 0.2"):
89
+ parser.parse(ssml)
90
+
91
+
92
+ @pytest.mark.ssml_parser
93
+ def test_unsupported_tag(parser):
94
+ ssml = """
95
+ <speak version="0.1">
96
+ <unsupported>你好</unsupported>
97
+ </speak>
98
+ """
99
+ with pytest.raises(NotImplementedError, match=r"Tag unsupported not supported."):
100
+ parser.parse(ssml)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ pytest.main()
modules/utils/JsonObject.py CHANGED
@@ -8,6 +8,9 @@ class JsonObject:
8
  # If no initial dictionary is provided, use an empty dictionary
9
  self._dict_obj = initial_dict if initial_dict is not None else {}
10
 
 
 
 
11
  def __getattr__(self, name):
12
  """
13
  Get an attribute value. If the attribute does not exist,
@@ -111,3 +114,19 @@ class JsonObject:
111
  :return: A list of values.
112
  """
113
  return self._dict_obj.values()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # If no initial dictionary is provided, use an empty dictionary
9
  self._dict_obj = initial_dict if initial_dict is not None else {}
10
 
11
+ if self._dict_obj is self:
12
+ raise ValueError("JsonObject cannot be initialized with itself")
13
+
14
  def __getattr__(self, name):
15
  """
16
  Get an attribute value. If the attribute does not exist,
 
114
  :return: A list of values.
115
  """
116
  return self._dict_obj.values()
117
+
118
+ def clone(self):
119
+ """
120
+ Clone the JsonObject.
121
+
122
+ :return: A new JsonObject with the same internal dictionary.
123
+ """
124
+ return JsonObject(self._dict_obj.copy())
125
+
126
+ def merge(self, other):
127
+ """
128
+ Merge the internal dictionary with another dictionary.
129
+
130
+ :param other: The other dictionary to merge.
131
+ """
132
+ self._dict_obj.update(other)
modules/utils/constants.py CHANGED
@@ -10,4 +10,4 @@ DATA_DIR = os.path.join(ROOT_DIR, "data")
10
 
11
  MODELS_DIR = os.path.join(ROOT_DIR, "models")
12
 
13
- speakers_dir = os.path.join(DATA_DIR, "speakers")
 
10
 
11
  MODELS_DIR = os.path.join(ROOT_DIR, "models")
12
 
13
+ SPEAKERS_DIR = os.path.join(DATA_DIR, "speakers")
modules/webui/app.py CHANGED
@@ -5,7 +5,9 @@ import torch
5
  import gradio as gr
6
 
7
  from modules import config
 
8
 
 
9
  from modules.webui.tts_tab import create_tts_interface
10
  from modules.webui.ssml_tab import create_ssml_interface
11
  from modules.webui.spliter_tab import create_spliter_tab
@@ -93,15 +95,15 @@ def create_interface():
93
  with gr.TabItem("Spilter"):
94
  create_spliter_tab(ssml_input, tabs=tabs)
95
 
96
- if config.runtime_env_vars.webui_experimental:
97
- with gr.TabItem("Speaker"):
98
- create_speaker_panel()
99
- with gr.TabItem("Denoise"):
100
- gr.Markdown("🚧 Under construction")
101
- with gr.TabItem("Inpainting"):
102
- gr.Markdown("🚧 Under construction")
103
- with gr.TabItem("ASR"):
104
- gr.Markdown("🚧 Under construction")
105
 
106
  with gr.TabItem("README"):
107
  create_readme_tab()
 
5
  import gradio as gr
6
 
7
  from modules import config
8
+ from modules.webui import webui_config
9
 
10
+ from modules.webui.system_tab import create_system_tab
11
  from modules.webui.tts_tab import create_tts_interface
12
  from modules.webui.ssml_tab import create_ssml_interface
13
  from modules.webui.spliter_tab import create_spliter_tab
 
95
  with gr.TabItem("Spilter"):
96
  create_spliter_tab(ssml_input, tabs=tabs)
97
 
98
+ with gr.TabItem("Speaker"):
99
+ create_speaker_panel()
100
+ with gr.TabItem("Inpainting", visible=webui_config.experimental):
101
+ gr.Markdown("🚧 Under construction")
102
+ with gr.TabItem("ASR", visible=webui_config.experimental):
103
+ gr.Markdown("🚧 Under construction")
104
+
105
+ with gr.TabItem("System"):
106
+ create_system_tab()
107
 
108
  with gr.TabItem("README"):
109
  create_readme_tab()
modules/webui/speaker_tab.py CHANGED
@@ -1,13 +1,259 @@
 
1
  import gradio as gr
 
2
 
3
- from modules.webui.webui_utils import get_speakers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  # 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
7
  def create_speaker_panel():
8
  speakers = get_speakers()
9
 
10
- def get_speaker_show_name(spk):
11
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- gr.Markdown("🚧 Under construction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
  import gradio as gr
3
+ import torch
4
 
5
+ from modules.hf import spaces
6
+ from modules.webui.webui_utils import get_speakers, tts_generate
7
+ from modules.speaker import speaker_mgr, Speaker
8
+
9
+ import tempfile
10
+
11
+
12
+ def spk_to_tensor(spk):
13
+ spk = spk.split(" : ")[1].strip() if " : " in spk else spk
14
+ if spk == "None" or spk == "":
15
+ return None
16
+ return speaker_mgr.get_speaker(spk).emb
17
+
18
+
19
+ def get_speaker_show_name(spk):
20
+ if spk.gender == "*" or spk.gender == "":
21
+ return spk.name
22
+ return f"{spk.gender} : {spk.name}"
23
+
24
+
25
+ def merge_spk(
26
+ spk_a,
27
+ spk_a_w,
28
+ spk_b,
29
+ spk_b_w,
30
+ spk_c,
31
+ spk_c_w,
32
+ spk_d,
33
+ spk_d_w,
34
+ ):
35
+ tensor_a = spk_to_tensor(spk_a)
36
+ tensor_b = spk_to_tensor(spk_b)
37
+ tensor_c = spk_to_tensor(spk_c)
38
+ tensor_d = spk_to_tensor(spk_d)
39
+
40
+ assert (
41
+ tensor_a is not None
42
+ or tensor_b is not None
43
+ or tensor_c is not None
44
+ or tensor_d is not None
45
+ ), "At least one speaker should be selected"
46
+
47
+ merge_tensor = torch.zeros_like(
48
+ tensor_a
49
+ if tensor_a is not None
50
+ else (
51
+ tensor_b
52
+ if tensor_b is not None
53
+ else tensor_c if tensor_c is not None else tensor_d
54
+ )
55
+ )
56
+
57
+ total_weight = 0
58
+ if tensor_a is not None:
59
+ merge_tensor += spk_a_w * tensor_a
60
+ total_weight += spk_a_w
61
+ if tensor_b is not None:
62
+ merge_tensor += spk_b_w * tensor_b
63
+ total_weight += spk_b_w
64
+ if tensor_c is not None:
65
+ merge_tensor += spk_c_w * tensor_c
66
+ total_weight += spk_c_w
67
+ if tensor_d is not None:
68
+ merge_tensor += spk_d_w * tensor_d
69
+ total_weight += spk_d_w
70
+
71
+ if total_weight > 0:
72
+ merge_tensor /= total_weight
73
+
74
+ merged_spk = Speaker.from_tensor(merge_tensor)
75
+ merged_spk.name = "<MIX>"
76
+
77
+ return merged_spk
78
+
79
+
80
+ @torch.inference_mode()
81
+ @spaces.GPU
82
+ def merge_and_test_spk_voice(
83
+ spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
84
+ ):
85
+ merged_spk = merge_spk(
86
+ spk_a,
87
+ spk_a_w,
88
+ spk_b,
89
+ spk_b_w,
90
+ spk_c,
91
+ spk_c_w,
92
+ spk_d,
93
+ spk_d_w,
94
+ )
95
+ return tts_generate(
96
+ spk=merged_spk,
97
+ text=test_text,
98
+ )
99
+
100
+
101
+ @torch.inference_mode()
102
+ @spaces.GPU
103
+ def merge_spk_to_file(
104
+ spk_a,
105
+ spk_a_w,
106
+ spk_b,
107
+ spk_b_w,
108
+ spk_c,
109
+ spk_c_w,
110
+ spk_d,
111
+ spk_d_w,
112
+ speaker_name,
113
+ speaker_gender,
114
+ speaker_desc,
115
+ ):
116
+ merged_spk = merge_spk(
117
+ spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
118
+ )
119
+ merged_spk.name = speaker_name
120
+ merged_spk.gender = speaker_gender
121
+ merged_spk.desc = speaker_desc
122
+
123
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
124
+ torch.save(merged_spk, tmp_file)
125
+ tmp_file_path = tmp_file.name
126
+
127
+ return tmp_file_path
128
+
129
+
130
+ merge_desc = """
131
+ ## Speaker Merger
132
+
133
+ 在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
134
+
135
+ ### 1. 选择说话人
136
+ 您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
137
+
138
+ ### 2. 合成语音
139
+ 在选择好说话人和设置好权重后,您可以在“测试文本”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
140
+
141
+ ### 3. 保存说话人
142
+ 您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“保存说话人”按钮来保存合成的说话人。保存后的说话人文件将显示在“合成说话人”栏中,供下载使用。
143
+ """
144
 
145
 
146
  # 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
147
  def create_speaker_panel():
148
  speakers = get_speakers()
149
 
150
+ speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
151
+
152
+ with gr.Tabs():
153
+ with gr.TabItem("Merger"):
154
+ gr.Markdown(merge_desc)
155
+
156
+ with gr.Row():
157
+ with gr.Column(scale=5):
158
+ with gr.Row():
159
+ with gr.Group():
160
+ spk_a = gr.Dropdown(
161
+ choices=speaker_names, value="None", label="Speaker A"
162
+ )
163
+ spk_a_w = gr.Slider(
164
+ value=1, minimum=0, maximum=10, step=1, label="Weight A"
165
+ )
166
+
167
+ with gr.Group():
168
+ spk_b = gr.Dropdown(
169
+ choices=speaker_names, value="None", label="Speaker B"
170
+ )
171
+ spk_b_w = gr.Slider(
172
+ value=1, minimum=0, maximum=10, step=1, label="Weight B"
173
+ )
174
+
175
+ with gr.Group():
176
+ spk_c = gr.Dropdown(
177
+ choices=speaker_names, value="None", label="Speaker C"
178
+ )
179
+ spk_c_w = gr.Slider(
180
+ value=1, minimum=0, maximum=10, step=1, label="Weight C"
181
+ )
182
+
183
+ with gr.Group():
184
+ spk_d = gr.Dropdown(
185
+ choices=speaker_names, value="None", label="Speaker D"
186
+ )
187
+ spk_d_w = gr.Slider(
188
+ value=1, minimum=0, maximum=10, step=1, label="Weight D"
189
+ )
190
+
191
+ with gr.Row():
192
+ with gr.Column(scale=3):
193
+ with gr.Group():
194
+ gr.Markdown("🎤Test voice")
195
+ with gr.Row():
196
+ test_voice_btn = gr.Button(
197
+ "Test Voice", variant="secondary"
198
+ )
199
+
200
+ with gr.Column(scale=4):
201
+ test_text = gr.Textbox(
202
+ label="Test Text",
203
+ placeholder="Please input test text",
204
+ value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
205
+ )
206
+
207
+ output_audio = gr.Audio(label="Output Audio")
208
+
209
+ with gr.Column(scale=1):
210
+ with gr.Group():
211
+ gr.Markdown("🗃️Save to file")
212
+
213
+ speaker_name = gr.Textbox(
214
+ label="Name", value="forge_speaker_merged"
215
+ )
216
+ speaker_gender = gr.Textbox(label="Gender", value="*")
217
+ speaker_desc = gr.Textbox(
218
+ label="Description", value="merged speaker"
219
+ )
220
+
221
+ save_btn = gr.Button("Save Speaker", variant="primary")
222
+
223
+ merged_spker = gr.File(
224
+ label="Merged Speaker", interactive=False, type="binary"
225
+ )
226
+
227
+ test_voice_btn.click(
228
+ merge_and_test_spk_voice,
229
+ inputs=[
230
+ spk_a,
231
+ spk_a_w,
232
+ spk_b,
233
+ spk_b_w,
234
+ spk_c,
235
+ spk_c_w,
236
+ spk_d,
237
+ spk_d_w,
238
+ test_text,
239
+ ],
240
+ outputs=[output_audio],
241
+ )
242
 
243
+ save_btn.click(
244
+ merge_spk_to_file,
245
+ inputs=[
246
+ spk_a,
247
+ spk_a_w,
248
+ spk_b,
249
+ spk_b_w,
250
+ spk_c,
251
+ spk_c_w,
252
+ spk_d,
253
+ spk_d_w,
254
+ speaker_name,
255
+ speaker_gender,
256
+ speaker_desc,
257
+ ],
258
+ outputs=[merged_spker],
259
+ )
modules/webui/spliter_tab.py CHANGED
@@ -9,6 +9,7 @@ from modules.webui.webui_utils import (
9
  from modules.hf import spaces
10
 
11
 
 
12
  @torch.inference_mode()
13
  @spaces.GPU
14
  def merge_dataframe_to_ssml(dataframe, spk, style, seed):
@@ -31,7 +32,7 @@ def merge_dataframe_to_ssml(dataframe, spk, style, seed):
31
  if seed:
32
  ssml += f' seed="{seed}"'
33
  ssml += ">\n"
34
- ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
35
  ssml += f"{indent}</voice>\n"
36
  return f"<speak version='0.1'>\n{ssml}</speak>"
37
 
 
9
  from modules.hf import spaces
10
 
11
 
12
+ # NOTE: 因为 text_normalize 需要使用 tokenizer
13
  @torch.inference_mode()
14
  @spaces.GPU
15
  def merge_dataframe_to_ssml(dataframe, spk, style, seed):
 
32
  if seed:
33
  ssml += f' seed="{seed}"'
34
  ssml += ">\n"
35
+ ssml += f"{indent}{indent}{text_normalize(row.iloc[1])}\n"
36
  ssml += f"{indent}</voice>\n"
37
  return f"<speak version='0.1'>\n{ssml}</speak>"
38
 
modules/webui/system_tab.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules.webui import webui_config
3
+
4
+
5
+ def create_system_tab():
6
+ with gr.Row():
7
+ with gr.Column(scale=1):
8
+ gr.Markdown(f"info")
9
+
10
+ with gr.Column(scale=5):
11
+ toggle_experimental = gr.Checkbox(
12
+ label="Enable Experimental Features",
13
+ value=webui_config.experimental,
14
+ interactive=False,
15
+ )
modules/webui/tts_tab.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from modules.webui.webui_utils import (
4
  get_speakers,
5
  get_styles,
 
6
  refine_text,
7
  tts_generate,
8
  )
@@ -10,6 +11,13 @@ from modules.webui import webui_config
10
  from modules.webui.examples import example_texts
11
  from modules import config
12
 
 
 
 
 
 
 
 
13
 
14
  def create_tts_interface():
15
  speakers = get_speakers()
@@ -90,15 +98,18 @@ def create_tts_interface():
90
  outputs=[spk_input_text],
91
  )
92
 
93
- if config.runtime_env_vars.webui_experimental:
94
- with gr.Tab(label="Upload"):
95
- spk_input_upload = gr.File(label="Speaker (Upload)")
96
- # TODO 读取 speaker
97
- # spk_input_upload.change(
98
- # fn=lambda x: x.read().decode("utf-8"),
99
- # inputs=[spk_input_upload],
100
- # outputs=[spk_input_text],
101
- # )
 
 
 
102
  with gr.Group():
103
  gr.Markdown("💃Inference Seed")
104
  infer_seed_input = gr.Number(
@@ -122,85 +133,62 @@ def create_tts_interface():
122
  prompt2_input = gr.Textbox(label="Prompt 2")
123
  prefix_input = gr.Textbox(label="Prefix")
124
 
125
- if config.runtime_env_vars.webui_experimental:
126
- prompt_audio = gr.File(label="prompt_audio")
 
127
 
128
  infer_seed_rand_button.click(
129
  lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
130
  inputs=[infer_seed_input],
131
  outputs=[infer_seed_input],
132
  )
133
- with gr.Column(scale=3):
134
- with gr.Row():
135
- with gr.Column(scale=4):
136
- with gr.Group():
137
- input_title = gr.Markdown(
138
- "📝Text Input",
139
- elem_id="input-title",
140
- )
141
- gr.Markdown(
142
- f"- 字数限制{webui_config.tts_max:,}字,超过部分截断"
143
- )
144
- gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
145
- gr.Markdown(
146
- "- If the input text is all in English, it is recommended to check disable_normalize"
147
- )
148
- text_input = gr.Textbox(
149
- show_label=False,
150
- label="Text to Speech",
151
- lines=10,
152
- placeholder="输入文本或选择示例",
153
- elem_id="text-input",
154
- )
155
- # TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
156
- # text_input.change(
157
- # fn=lambda x: (
158
- # f"📝Text Input ({len(x)} char)"
159
- # if x
160
- # else (
161
- # "📝Text Input (0 char)"
162
- # if not x
163
- # else "📝Text Input (0 char)"
164
- # )
165
- # ),
166
- # inputs=[text_input],
167
- # outputs=[input_title],
168
- # )
169
- with gr.Row():
170
- contorl_tokens = [
171
- "[laugh]",
172
- "[uv_break]",
173
- "[v_break]",
174
- "[lbreak]",
175
- ]
176
-
177
- for tk in contorl_tokens:
178
- t_btn = gr.Button(tk)
179
- t_btn.click(
180
- lambda text, tk=tk: text + " " + tk,
181
- inputs=[text_input],
182
- outputs=[text_input],
183
- )
184
- with gr.Column(scale=1):
185
- with gr.Group():
186
- gr.Markdown("🎶Refiner")
187
- refine_prompt_input = gr.Textbox(
188
- label="Refine Prompt",
189
- value="[oral_2][laugh_0][break_6]",
190
- )
191
- refine_button = gr.Button("✍️Refine Text")
192
- # TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
193
- # send_button = gr.Button("📩Split and send to SSML")
194
-
195
- with gr.Group():
196
- gr.Markdown("🔊Generate")
197
- disable_normalize_input = gr.Checkbox(
198
- value=False, label="Disable Normalize"
199
- )
200
- tts_button = gr.Button(
201
- "🔊Generate Audio",
202
- variant="primary",
203
- elem_classes="big-button",
204
  )
205
 
206
  with gr.Group():
@@ -220,6 +208,31 @@ def create_tts_interface():
220
  with gr.Group():
221
  gr.Markdown("🎨Output")
222
  tts_output = gr.Audio(label="Generated Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  refine_button.click(
225
  refine_text,
@@ -243,6 +256,9 @@ def create_tts_interface():
243
  style_input_dropdown,
244
  disable_normalize_input,
245
  batch_size_input,
 
 
 
246
  ],
247
  outputs=tts_output,
248
  )
 
3
  from modules.webui.webui_utils import (
4
  get_speakers,
5
  get_styles,
6
+ load_spk_info,
7
  refine_text,
8
  tts_generate,
9
  )
 
11
  from modules.webui.examples import example_texts
12
  from modules import config
13
 
14
+ default_text_content = """
15
+ chat T T S 是一款强大的对话式文本转语音模型。它有中英混读和多说话人的能力。
16
+ chat T T S 不仅能够生成自然流畅的语音,还能控制[laugh]笑声啊[laugh],
17
+ 停顿啊[uv_break]语气词啊等副语言现象[uv_break]。这个韵律超越了许多开源模型[uv_break]。
18
+ 请注意,chat T T S 的使用应遵守法律和伦理准则,避免滥用的安全风险。[uv_break]
19
+ """
20
+
21
 
22
  def create_tts_interface():
23
  speakers = get_speakers()
 
98
  outputs=[spk_input_text],
99
  )
100
 
101
+ with gr.Tab(label="Upload"):
102
+ spk_file_upload = gr.File(label="Speaker (Upload)")
103
+
104
+ gr.Markdown("📝Speaker info")
105
+ infos = gr.Markdown("empty")
106
+
107
+ spk_file_upload.change(
108
+ fn=load_spk_info,
109
+ inputs=[spk_file_upload],
110
+ outputs=[infos],
111
+ ),
112
+
113
  with gr.Group():
114
  gr.Markdown("💃Inference Seed")
115
  infer_seed_input = gr.Number(
 
133
  prompt2_input = gr.Textbox(label="Prompt 2")
134
  prefix_input = gr.Textbox(label="Prefix")
135
 
136
+ prompt_audio = gr.File(
137
+ label="prompt_audio", visible=webui_config.experimental
138
+ )
139
 
140
  infer_seed_rand_button.click(
141
  lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
142
  inputs=[infer_seed_input],
143
  outputs=[infer_seed_input],
144
  )
145
+ with gr.Column(scale=4):
146
+ with gr.Group():
147
+ input_title = gr.Markdown(
148
+ "📝Text Input",
149
+ elem_id="input-title",
150
+ )
151
+ gr.Markdown(f"- 字数限制{webui_config.tts_max:,}字,超过部分截断")
152
+ gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
153
+ gr.Markdown(
154
+ "- If the input text is all in English, it is recommended to check disable_normalize"
155
+ )
156
+ text_input = gr.Textbox(
157
+ show_label=False,
158
+ label="Text to Speech",
159
+ lines=10,
160
+ placeholder="输入文本或选择示例",
161
+ elem_id="text-input",
162
+ value=default_text_content,
163
+ )
164
+ # TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
165
+ # text_input.change(
166
+ # fn=lambda x: (
167
+ # f"📝Text Input ({len(x)} char)"
168
+ # if x
169
+ # else (
170
+ # "📝Text Input (0 char)"
171
+ # if not x
172
+ # else "📝Text Input (0 char)"
173
+ # )
174
+ # ),
175
+ # inputs=[text_input],
176
+ # outputs=[input_title],
177
+ # )
178
+ with gr.Row():
179
+ contorl_tokens = [
180
+ "[laugh]",
181
+ "[uv_break]",
182
+ "[v_break]",
183
+ "[lbreak]",
184
+ ]
185
+
186
+ for tk in contorl_tokens:
187
+ t_btn = gr.Button(tk)
188
+ t_btn.click(
189
+ lambda text, tk=tk: text + " " + tk,
190
+ inputs=[text_input],
191
+ outputs=[text_input],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  )
193
 
194
  with gr.Group():
 
208
  with gr.Group():
209
  gr.Markdown("🎨Output")
210
  tts_output = gr.Audio(label="Generated Audio")
211
+ with gr.Column(scale=1):
212
+ with gr.Group():
213
+ gr.Markdown("🎶Refiner")
214
+ refine_prompt_input = gr.Textbox(
215
+ label="Refine Prompt",
216
+ value="[oral_2][laugh_0][break_6]",
217
+ )
218
+ refine_button = gr.Button("✍️Refine Text")
219
+
220
+ with gr.Group():
221
+ gr.Markdown("🔊Generate")
222
+ disable_normalize_input = gr.Checkbox(
223
+ value=False, label="Disable Normalize"
224
+ )
225
+
226
+ # FIXME: 不知道为啥,就是非常慢,单独调脚本是很快的
227
+ with gr.Group(visible=webui_config.experimental):
228
+ gr.Markdown("💪🏼Enhance")
229
+ enable_enhance = gr.Checkbox(value=False, label="Enable Enhance")
230
+ enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
231
+ tts_button = gr.Button(
232
+ "🔊Generate Audio",
233
+ variant="primary",
234
+ elem_classes="big-button",
235
+ )
236
 
237
  refine_button.click(
238
  refine_text,
 
256
  style_input_dropdown,
257
  disable_normalize_input,
258
  batch_size_input,
259
+ enable_enhance,
260
+ enable_de_noise,
261
+ spk_file_upload,
262
  ],
263
  outputs=tts_output,
264
  )
modules/webui/webui_config.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  tts_max = 1000
2
  ssml_max = 1000
3
  spliter_threshold = 100
4
  max_batch_size = 8
 
 
1
+ from typing import Literal
2
+
3
+
4
  tts_max = 1000
5
  ssml_max = 1000
6
  spliter_threshold = 100
7
  max_batch_size = 8
8
+ experimental = False
modules/webui/webui_utils.py CHANGED
@@ -1,37 +1,26 @@
1
- import os
2
- import logging
3
- import sys
4
-
5
  import numpy as np
6
 
 
7
  from modules.devices import devices
8
  from modules.synthesize_audio import synthesize_audio
9
  from modules.hf import spaces
10
  from modules.webui import webui_config
11
 
12
- logging.basicConfig(
13
- level=os.getenv("LOG_LEVEL", "INFO"),
14
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
15
- )
16
-
17
-
18
- import gradio as gr
19
-
20
  import torch
21
 
22
- from modules.ssml import parse_ssml
23
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
24
 
25
- from modules.speaker import speaker_mgr
26
  from modules.data import styles_mgr
27
 
28
  from modules.api.utils import calc_spk_style
29
- import modules.generate_audio as generate
30
 
31
  from modules.normalization import text_normalize
32
- from modules import refiner, config
33
 
34
- from modules.utils import env, audio
35
  from modules.SentenceSplitter import SentenceSplitter
36
 
37
 
@@ -43,11 +32,30 @@ def get_styles():
43
  return styles_mgr.list_items()
44
 
45
 
46
- def segments_length_limit(segments, total_max: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ret_segments = []
48
  total_len = 0
49
  for seg in segments:
50
- if "text" not in seg:
 
51
  continue
52
  total_len += len(seg["text"])
53
  if total_len > total_max:
@@ -56,6 +64,28 @@ def segments_length_limit(segments, total_max: int):
56
  return ret_segments
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  @torch.inference_mode()
60
  @spaces.GPU
61
  def synthesize_ssml(ssml: str, batch_size=4):
@@ -69,7 +99,8 @@ def synthesize_ssml(ssml: str, batch_size=4):
69
  if ssml == "":
70
  return None
71
 
72
- segments = parse_ssml(ssml)
 
73
  max_len = webui_config.ssml_max
74
  segments = segments_length_limit(segments, max_len)
75
 
@@ -87,18 +118,21 @@ def synthesize_ssml(ssml: str, batch_size=4):
87
  @spaces.GPU
88
  def tts_generate(
89
  text,
90
- temperature,
91
- top_p,
92
- top_k,
93
- spk,
94
- infer_seed,
95
- use_decoder,
96
- prompt1,
97
- prompt2,
98
- prefix,
99
- style,
100
  disable_normalize=False,
101
  batch_size=4,
 
 
 
102
  ):
103
  try:
104
  batch_size = int(batch_size)
@@ -126,12 +160,15 @@ def tts_generate(
126
  prompt1 = prompt1 or params.get("prompt1", "")
127
  prompt2 = prompt2 or params.get("prompt2", "")
128
 
129
- infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.int64)
130
  infer_seed = int(infer_seed)
131
 
132
  if not disable_normalize:
133
  text = text_normalize(text)
134
 
 
 
 
135
  sample_rate, audio_data = synthesize_audio(
136
  text=text,
137
  temperature=temperature,
@@ -146,6 +183,10 @@ def tts_generate(
146
  batch_size=batch_size,
147
  )
148
 
 
 
 
 
149
  audio_data = audio.audio_to_int16(audio_data)
150
  return sample_rate, audio_data
151
 
 
1
+ from typing import Union
 
 
 
2
  import numpy as np
3
 
4
+ from modules.Enhancer.ResembleEnhance import load_enhancer
5
  from modules.devices import devices
6
  from modules.synthesize_audio import synthesize_audio
7
  from modules.hf import spaces
8
  from modules.webui import webui_config
9
 
 
 
 
 
 
 
 
 
10
  import torch
11
 
12
+ from modules.ssml_parser.SSMLParser import create_ssml_parser, SSMLBreak, SSMLSegment
13
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
14
 
15
+ from modules.speaker import speaker_mgr, Speaker
16
  from modules.data import styles_mgr
17
 
18
  from modules.api.utils import calc_spk_style
 
19
 
20
  from modules.normalization import text_normalize
21
+ from modules import refiner
22
 
23
+ from modules.utils import audio
24
  from modules.SentenceSplitter import SentenceSplitter
25
 
26
 
 
32
  return styles_mgr.list_items()
33
 
34
 
35
+ def load_spk_info(file):
36
+ if file is None:
37
+ return "empty"
38
+ try:
39
+
40
+ spk: Speaker = Speaker.from_file(file)
41
+ infos = spk.to_json()
42
+ return f"""
43
+ - name: {infos.name}
44
+ - gender: {infos.gender}
45
+ - describe: {infos.describe}
46
+ """.strip()
47
+ except:
48
+ return "load failed"
49
+
50
+
51
+ def segments_length_limit(
52
+ segments: list[Union[SSMLBreak, SSMLSegment]], total_max: int
53
+ ) -> list[Union[SSMLBreak, SSMLSegment]]:
54
  ret_segments = []
55
  total_len = 0
56
  for seg in segments:
57
+ if isinstance(seg, SSMLBreak):
58
+ ret_segments.append(seg)
59
  continue
60
  total_len += len(seg["text"])
61
  if total_len > total_max:
 
64
  return ret_segments
65
 
66
 
67
+ @torch.inference_mode()
68
+ @spaces.GPU
69
+ def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
70
+ audio_data = torch.from_numpy(audio_data).float().squeeze().cpu()
71
+ if enable_denoise or enable_enhance:
72
+ enhancer = load_enhancer(devices.device)
73
+ if enable_denoise:
74
+ audio_data, sr = enhancer.denoise(audio_data, sr, devices.device)
75
+ if enable_enhance:
76
+ audio_data, sr = enhancer.enhance(
77
+ audio_data,
78
+ sr,
79
+ devices.device,
80
+ tau=0.9,
81
+ nfe=64,
82
+ solver="euler",
83
+ lambd=0.5,
84
+ )
85
+ audio_data = audio_data.cpu().numpy()
86
+ return audio_data, int(sr)
87
+
88
+
89
  @torch.inference_mode()
90
  @spaces.GPU
91
  def synthesize_ssml(ssml: str, batch_size=4):
 
99
  if ssml == "":
100
  return None
101
 
102
+ parser = create_ssml_parser()
103
+ segments = parser.parse(ssml)
104
  max_len = webui_config.ssml_max
105
  segments = segments_length_limit(segments, max_len)
106
 
 
118
  @spaces.GPU
119
  def tts_generate(
120
  text,
121
+ temperature=0.3,
122
+ top_p=0.7,
123
+ top_k=20,
124
+ spk=-1,
125
+ infer_seed=-1,
126
+ use_decoder=True,
127
+ prompt1="",
128
+ prompt2="",
129
+ prefix="",
130
+ style="",
131
  disable_normalize=False,
132
  batch_size=4,
133
+ enable_enhance=False,
134
+ enable_denoise=False,
135
+ spk_file=None,
136
  ):
137
  try:
138
  batch_size = int(batch_size)
 
160
  prompt1 = prompt1 or params.get("prompt1", "")
161
  prompt2 = prompt2 or params.get("prompt2", "")
162
 
163
+ infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
164
  infer_seed = int(infer_seed)
165
 
166
  if not disable_normalize:
167
  text = text_normalize(text)
168
 
169
+ if spk_file:
170
+ spk = Speaker.from_file(spk_file)
171
+
172
  sample_rate, audio_data = synthesize_audio(
173
  text=text,
174
  temperature=temperature,
 
183
  batch_size=batch_size,
184
  )
185
 
186
+ audio_data, sample_rate = apply_audio_enhance(
187
+ audio_data, sample_rate, enable_denoise, enable_enhance
188
+ )
189
+
190
  audio_data = audio.audio_to_int16(audio_data)
191
  return sample_rate, audio_data
192
 
webui.py CHANGED
@@ -93,8 +93,10 @@ if __name__ == "__main__":
93
  device_id = get_and_update_env(args, "device_id", None, str)
94
  use_cpu = get_and_update_env(args, "use_cpu", [], list)
95
  compile = get_and_update_env(args, "compile", False, bool)
96
- webui_experimental = get_and_update_env(args, "webui_experimental", False, bool)
97
 
 
 
 
98
  webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
99
  webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
100
  webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
 
93
  device_id = get_and_update_env(args, "device_id", None, str)
94
  use_cpu = get_and_update_env(args, "use_cpu", [], list)
95
  compile = get_and_update_env(args, "compile", False, bool)
 
96
 
97
+ webui_config.experimental = get_and_update_env(
98
+ args, "webui_experimental", False, bool
99
+ )
100
  webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
101
  webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
102
  webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)