vbanonyme commited on
Commit
b690efd
1 Parent(s): 9fbceda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -65
app.py CHANGED
@@ -7,6 +7,7 @@ import uuid
7
  from io import StringIO
8
 
9
  import gradio as gr
 
10
  import torch
11
  import torchaudio
12
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
@@ -14,25 +15,25 @@ from TTS.tts.configs.xtts_config import XttsConfig
14
  from TTS.tts.models.xtts import Xtts
15
  from vinorm import TTSnorm
16
 
17
- # Download necessary components
18
  os.system("python -m unidic download")
19
 
20
- # Hugging Face token and API setup
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  api = HfApi(token=HF_TOKEN)
23
 
24
- # Setup checkpoint directory
 
25
  checkpoint_dir = "model/"
 
 
 
26
  os.makedirs(checkpoint_dir, exist_ok=True)
27
 
28
- # Required files for the model
29
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
30
-
31
- # Download model and configurations if not present
32
  files_in_dir = os.listdir(checkpoint_dir)
33
  if not all(file in files_in_dir for file in required_files):
34
  snapshot_download(
35
- repo_id="capleaf/viXTTS",
36
  repo_type="model",
37
  local_dir=checkpoint_dir,
38
  )
@@ -42,23 +43,21 @@ if not all(file in files_in_dir for file in required_files):
42
  local_dir=checkpoint_dir,
43
  )
44
 
45
- # Initialize XTTS model from configuration
46
  xtts_config = os.path.join(checkpoint_dir, "config.json")
47
  config = XttsConfig()
48
  config.load_json(xtts_config)
49
  MODEL = Xtts.init_from_config(config)
50
  MODEL.load_checkpoint(
51
- config, checkpoint_dir=checkpoint_dir, use_deepspeed=False
52
  )
53
  if torch.cuda.is_available():
54
  MODEL.cuda()
55
 
56
- # Supported languages for TTS
57
  supported_languages = config.languages
58
- if "vi" not in supported_languages:
59
  supported_languages.append("vi")
60
 
61
- # Function to normalize Vietnamese text
62
  def normalize_vietnamese_text(text):
63
  text = (
64
  TTSnorm(text, unknown=False, lower=False, rule=True)
@@ -74,8 +73,9 @@ def normalize_vietnamese_text(text):
74
  )
75
  return text
76
 
77
- # Function to calculate length to keep based on text properties
78
  def calculate_keep_len(text, lang):
 
79
  if lang in ["ja", "zh-cn"]:
80
  return -1
81
 
@@ -88,30 +88,65 @@ def calculate_keep_len(text, lang):
88
  return 13000 * word_count + 2000 * num_punct
89
  return -1
90
 
91
- # Function for TTS prediction
92
- def predict(prompt, language, audio_file_pth, normalize_text=True):
 
 
 
 
 
 
93
  if language not in supported_languages:
94
- return (None, gr.Warning(f"Language '{language}' is not supported."))
 
 
 
 
95
 
96
  speaker_wav = audio_file_pth
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  try:
99
- if len(prompt) < 2:
100
- return (None, gr.Warning("Please provide a longer prompt text."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Normalize Vietnamese text if specified
103
  if normalize_text and language == "vi":
104
  prompt = normalize_vietnamese_text(prompt)
105
 
106
- # Get conditioning latents for the model
107
- gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
108
- audio_path=speaker_wav,
109
- gpt_cond_len=30,
110
- gpt_cond_chunk_len=4,
111
- max_ref_length=60,
112
- )
113
-
114
- # Perform inference to generate audio
115
  out = MODEL.inference(
116
  prompt,
117
  language,
@@ -121,50 +156,166 @@ def predict(prompt, language, audio_file_pth, normalize_text=True):
121
  temperature=0.75,
122
  enable_text_splitting=True,
123
  )
124
-
125
- # Calculate inference time and real-time factor
126
  inference_time = time.time() - t0
 
 
 
 
127
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
 
 
128
 
129
- # Limit output length based on text properties
130
  keep_len = calculate_keep_len(prompt, language)
131
  out["wav"] = out["wav"][:keep_len]
132
 
133
- # Save generated audio
134
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
135
 
136
  except RuntimeError as e:
137
- # Handle runtime errors
138
- error_message = "Unexpected error occurred during inference."
139
- return (None, gr.Warning(error_message))
140
-
141
- return ("output.wav", "")
142
-
143
- # Gradio interface setup
144
- demo = gr.Interface(
145
- fn=predict,
146
- inputs=[
147
- gr.Textbox(
148
- label="Text Prompt (Văn bản cần đọc)",
149
- placeholder="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
150
- ),
151
- gr.Dropdown(
152
- label="Language (Ngôn ngữ)",
153
- choices=[
154
- "vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
155
- "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi"
156
- ],
157
- value="vi"
158
- ),
159
- gr.File(label="Reference Audio (Giọng mẫu)", type="filepath"),
160
- gr.Checkbox(label="Normalize Vietnamese text"),
161
- ],
162
- outputs=[
163
- gr.Audio(label="Synthesized Audio"),
164
- gr.Textbox(label="Metrics"),
165
- ],
166
- title="viXTTS Demo ✨",
167
- description="Generate speech from text input using viXTTS.",
168
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- demo.launch()
 
 
7
  from io import StringIO
8
 
9
  import gradio as gr
10
+ import spaces
11
  import torch
12
  import torchaudio
13
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
 
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
17
 
18
+ # download for mecab
19
  os.system("python -m unidic download")
20
 
 
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  api = HfApi(token=HF_TOKEN)
23
 
24
+ # This will trigger downloading model
25
+ print("Downloading if not downloaded viXTTS")
26
  checkpoint_dir = "model/"
27
+ repo_id = "capleaf/viXTTS"
28
+ use_deepspeed = False
29
+
30
  os.makedirs(checkpoint_dir, exist_ok=True)
31
 
 
32
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
 
 
33
  files_in_dir = os.listdir(checkpoint_dir)
34
  if not all(file in files_in_dir for file in required_files):
35
  snapshot_download(
36
+ repo_id=repo_id,
37
  repo_type="model",
38
  local_dir=checkpoint_dir,
39
  )
 
43
  local_dir=checkpoint_dir,
44
  )
45
 
 
46
  xtts_config = os.path.join(checkpoint_dir, "config.json")
47
  config = XttsConfig()
48
  config.load_json(xtts_config)
49
  MODEL = Xtts.init_from_config(config)
50
  MODEL.load_checkpoint(
51
+ config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
52
  )
53
  if torch.cuda.is_available():
54
  MODEL.cuda()
55
 
 
56
  supported_languages = config.languages
57
+ if not "vi" in supported_languages:
58
  supported_languages.append("vi")
59
 
60
+
61
  def normalize_vietnamese_text(text):
62
  text = (
63
  TTSnorm(text, unknown=False, lower=False, rule=True)
 
73
  )
74
  return text
75
 
76
+
77
  def calculate_keep_len(text, lang):
78
+ """Simple hack for short sentences"""
79
  if lang in ["ja", "zh-cn"]:
80
  return -1
81
 
 
88
  return 13000 * word_count + 2000 * num_punct
89
  return -1
90
 
91
+
92
+ @spaces.GPU
93
+ def predict(
94
+ prompt,
95
+ language,
96
+ audio_file_pth,
97
+ normalize_text=True,
98
+ ):
99
  if language not in supported_languages:
100
+ metrics_text = gr.Warning(
101
+ f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
102
+ )
103
+
104
+ return (None, metrics_text)
105
 
106
  speaker_wav = audio_file_pth
107
 
108
+ if len(prompt) < 2:
109
+ metrics_text = gr.Warning("Please give a longer prompt text")
110
+ return (None, metrics_text)
111
+
112
+ if len(prompt) > 250000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000:
113
+ metrics_text = gr.Warning(
114
+ str(len(prompt))
115
+ + " characters.\n"
116
+ + "Your prompt is too long, please keep it under 250 characters\n"
117
+ + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
118
+ )
119
+ return (None, metrics_text)
120
+
121
  try:
122
+ metrics_text = ""
123
+ t_latent = time.time()
124
+
125
+ try:
126
+ (
127
+ gpt_cond_latent,
128
+ speaker_embedding,
129
+ ) = MODEL.get_conditioning_latents(
130
+ audio_path=speaker_wav,
131
+ gpt_cond_len=30,
132
+ gpt_cond_chunk_len=4,
133
+ max_ref_length=60,
134
+ )
135
+
136
+ except Exception as e:
137
+ print("Speaker encoding error", str(e))
138
+ metrics_text = gr.Warning(
139
+ "It appears something wrong with reference, did you unmute your microphone?"
140
+ )
141
+ return (None, metrics_text)
142
+
143
+ prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
144
 
 
145
  if normalize_text and language == "vi":
146
  prompt = normalize_vietnamese_text(prompt)
147
 
148
+ print("I: Generating new audio...")
149
+ t0 = time.time()
 
 
 
 
 
 
 
150
  out = MODEL.inference(
151
  prompt,
152
  language,
 
156
  temperature=0.75,
157
  enable_text_splitting=True,
158
  )
 
 
159
  inference_time = time.time() - t0
160
+ print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
161
+ metrics_text += (
162
+ f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
163
+ )
164
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
165
+ print(f"Real-time factor (RTF): {real_time_factor}")
166
+ metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
167
 
168
+ # Temporary hack for short sentences
169
  keep_len = calculate_keep_len(prompt, language)
170
  out["wav"] = out["wav"][:keep_len]
171
 
 
172
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
173
 
174
  except RuntimeError as e:
175
+ if "device-side assert" in str(e):
176
+ # cannot do anything on cuda device side error, need tor estart
177
+ print(
178
+ f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
179
+ flush=True,
180
+ )
181
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
182
+ print("Cuda device-assert Runtime encountered need restart")
183
+
184
+ error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
185
+ error_data = [
186
+ error_time,
187
+ prompt,
188
+ language,
189
+ audio_file_pth,
190
+ ]
191
+ error_data = [str(e) if type(e) != str else e for e in error_data]
192
+ print(error_data)
193
+ print(speaker_wav)
194
+ write_io = StringIO()
195
+ csv.writer(write_io).writerows([error_data])
196
+ csv_upload = write_io.getvalue().encode()
197
+
198
+ filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
199
+ print("Writing error csv")
200
+ error_api = HfApi()
201
+ error_api.upload_file(
202
+ path_or_fileobj=csv_upload,
203
+ path_in_repo=filename,
204
+ repo_id="coqui/xtts-flagged-dataset",
205
+ repo_type="dataset",
206
+ )
207
+
208
+ # speaker_wav
209
+ print("Writing error reference audio")
210
+ speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
211
+ error_api = HfApi()
212
+ error_api.upload_file(
213
+ path_or_fileobj=speaker_wav,
214
+ path_in_repo=speaker_filename,
215
+ repo_id="coqui/xtts-flagged-dataset",
216
+ repo_type="dataset",
217
+ )
218
+
219
+ # HF Space specific.. This error is unrecoverable need to restart space
220
+ space = api.get_space_runtime(repo_id=repo_id)
221
+ if space.stage != "BUILDING":
222
+ api.restart_space(repo_id=repo_id)
223
+ else:
224
+ print("TRIED TO RESTART but space is building")
225
+
226
+ else:
227
+ if "Failed to decode" in str(e):
228
+ print("Speaker encoding error", str(e))
229
+ metrics_text = gr.Warning(
230
+ metrics_text="It appears something wrong with reference, did you unmute your microphone?"
231
+ )
232
+ else:
233
+ print("RuntimeError: non device-side assert error:", str(e))
234
+ metrics_text = gr.Warning(
235
+ "Something unexpected happened please retry again."
236
+ )
237
+ return (None, metrics_text)
238
+ return ("output.wav", metrics_text)
239
+
240
+
241
+ with gr.Blocks(analytics_enabled=False) as demo:
242
+ with gr.Row():
243
+ with gr.Column():
244
+ gr.Markdown(
245
+ """
246
+ # viXTTS Demo ✨
247
+ - Github: https://github.com/thinhlpg/vixtts-demo/
248
+ - viVoice: https://github.com/thinhlpg/viVoice
249
+ """
250
+ )
251
+ with gr.Column():
252
+ # placeholder to align the image
253
+ pass
254
+
255
+ with gr.Row():
256
+ with gr.Column():
257
+ input_text_gr = gr.Textbox(
258
+ label="Text Prompt (Văn bản cần đọc)",
259
+ info="Mỗi câu nên từ 10 từ trở lên. Tối đa 250 ký tự (khoảng 2 - 3 câu).",
260
+ value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
261
+ )
262
+ language_gr = gr.Dropdown(
263
+ label="Language (Ngôn ngữ)",
264
+ choices=[
265
+ "vi",
266
+ "en",
267
+ "es",
268
+ "fr",
269
+ "de",
270
+ "it",
271
+ "pt",
272
+ "pl",
273
+ "tr",
274
+ "ru",
275
+ "nl",
276
+ "cs",
277
+ "ar",
278
+ "zh-cn",
279
+ "ja",
280
+ "ko",
281
+ "hu",
282
+ "hi",
283
+ ],
284
+ max_choices=1,
285
+ value="vi",
286
+ )
287
+ normalize_text = gr.Checkbox(
288
+ label="Chuẩn hóa văn bản tiếng Việt",
289
+ info="Normalize Vietnamese text",
290
+ value=True,
291
+ )
292
+ ref_gr = gr.Audio(
293
+ label="Reference Audio (Giọng mẫu)",
294
+ type="filepath",
295
+ value="model/samples/nu-luu-loat.wav",
296
+ )
297
+ tts_button = gr.Button(
298
+ "Đọc 🗣️🔥",
299
+ elem_id="send-btn",
300
+ visible=True,
301
+ variant="primary",
302
+ )
303
+
304
+ with gr.Column():
305
+ audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
306
+ out_text_gr = gr.Text(label="Metrics")
307
+
308
+ tts_button.click(
309
+ predict,
310
+ [
311
+ input_text_gr,
312
+ language_gr,
313
+ ref_gr,
314
+ normalize_text,
315
+ ],
316
+ outputs=[audio_gr, out_text_gr],
317
+ api_name="predict",
318
+ )
319
 
320
+ demo.queue()
321
+ demo.launch(debug=True, show_api=True, share=True)