Yurii Paniv commited on
Commit
8c4d22a
1 Parent(s): 7ea68d8

#8 Move synthesis to package

Browse files
Files changed (2) hide show
  1. app.py +5 -47
  2. ukrainian_tts/tts.py +82 -0
app.py CHANGED
@@ -1,15 +1,8 @@
1
  import tempfile
2
-
3
  import gradio as gr
4
-
5
- from TTS.utils.synthesizer import Synthesizer
6
- import requests
7
- from os.path import exists
8
- from ukrainian_tts.formatter import preprocess_text
9
  from datetime import datetime
10
  from enum import Enum
11
- import torch
12
-
13
 
14
  class StressOption(Enum):
15
  AutomaticStress = "Автоматичні наголоси (за словником) 📖"
@@ -24,44 +17,11 @@ class VoiceOption(Enum):
24
  Olga = "Ольга (жіночий) 👩"
25
 
26
 
27
- def download(url, file_name):
28
- if not exists(file_name):
29
- print(f"Downloading {file_name}")
30
- r = requests.get(url, allow_redirects=True)
31
- with open(file_name, "wb") as file:
32
- file.write(r.content)
33
- else:
34
- print(f"Found {file_name}. Skipping download...")
35
-
36
-
37
- print("downloading uk/mykyta/vits-tts")
38
- release_number = "v3.0.0"
39
- model_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/model-inference.pth"
40
- config_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/config.json"
41
- speakers_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/speakers.pth"
42
-
43
- model_path = "model.pth"
44
- config_path = "config.json"
45
- speakers_path = "speakers.pth"
46
-
47
- download(model_link, model_path)
48
- download(config_link, config_path)
49
- download(speakers_link, speakers_path)
50
-
51
  badge = (
52
  "https://visitor-badge-reloaded.herokuapp.com/badge?page_id=robinhad.ukrainian-tts"
53
  )
54
 
55
- synthesizer = Synthesizer(
56
- model_path,
57
- config_path,
58
- speakers_path,
59
- None,
60
- None,
61
- )
62
-
63
- if synthesizer is None:
64
- raise NameError("model not found")
65
 
66
 
67
  def tts(text: str, voice: str, stress: str):
@@ -81,17 +41,15 @@ def tts(text: str, voice: str, stress: str):
81
  VoiceOption.Olga.value: "olga",
82
  }
83
  speaker_name = voice_mapping[voice]
84
- text = preprocess_text(text, autostress_with_model)
85
  text_limit = 7200
86
  text = (
87
  text if len(text) < text_limit else text[0:text_limit]
88
  ) # mitigate crashes on hf space
89
- print("Converted:", text)
90
 
91
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
92
- with torch.no_grad():
93
- wavs = synthesizer.tts(text, speaker_name=speaker_name)
94
- synthesizer.save_wav(wavs, fp)
95
  return fp.name, text
96
 
97
 
 
1
  import tempfile
 
2
  import gradio as gr
 
 
 
 
 
3
  from datetime import datetime
4
  from enum import Enum
5
+ from ukrainian_tts.tts import TTS
 
6
 
7
  class StressOption(Enum):
8
  AutomaticStress = "Автоматичні наголоси (за словником) 📖"
 
17
  Olga = "Ольга (жіночий) 👩"
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  badge = (
21
  "https://visitor-badge-reloaded.herokuapp.com/badge?page_id=robinhad.ukrainian-tts"
22
  )
23
 
24
+ ukr_tts = TTS()
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def tts(text: str, voice: str, stress: str):
 
41
  VoiceOption.Olga.value: "olga",
42
  }
43
  speaker_name = voice_mapping[voice]
44
+
45
  text_limit = 7200
46
  text = (
47
  text if len(text) < text_limit else text[0:text_limit]
48
  ) # mitigate crashes on hf space
49
+
50
 
51
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
52
+ ukr_tts.tts(text, speaker_name, autostress_with_model, fp)
 
 
53
  return fp.name, text
54
 
55
 
ukrainian_tts/tts.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import requests
3
+ from os.path import exists
4
+ from TTS.utils.synthesizer import Synthesizer
5
+ from enum import Enum
6
+ from .formatter import preprocess_text
7
+ from torch import no_grad
8
+
9
+ class Voices(Enum):
10
+ Olena = "olena"
11
+ Mykyta = "mykyta"
12
+ Lada = "lada"
13
+ Dmytro = "dmytro"
14
+ Olga = "olga"
15
+
16
+
17
+ class StressOption(Enum):
18
+ Dictionary = "dictionary"
19
+ Model = "model"
20
+
21
+
22
+ class TTS:
23
+ def __init__(self, cache_folder=None) -> None:
24
+ self.__setup_cache(cache_folder)
25
+
26
+
27
+ def tts(self, text: str, voice: str, stress: str, output_fp=BytesIO()):
28
+ autostress_with_model = (
29
+ True if stress == StressOption.Model.value else False
30
+ )
31
+
32
+ if voice not in [option.value for option in Voices]:
33
+ raise ValueError("Invalid value for voice selected! Please use one of the following values: {', '.join([option.value for option in Voices])}.")
34
+
35
+ text = preprocess_text(text, autostress_with_model)
36
+
37
+ with no_grad():
38
+ wavs = self.synthesizer.tts(text, speaker_name=voice)
39
+ self.synthesizer.save_wav(wavs, output_fp)
40
+
41
+ output_fp.seek(0)
42
+
43
+ return output_fp
44
+
45
+
46
+ def __setup_cache(self, cache_folder=None):
47
+ print("downloading uk/mykyta/vits-tts")
48
+ release_number = "v3.0.0"
49
+ model_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/model-inference.pth"
50
+ config_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/config.json"
51
+ speakers_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/speakers.pth"
52
+
53
+ model_path = "model.pth"
54
+ config_path = "config.json"
55
+ speakers_path = "speakers.pth"
56
+
57
+ self.__download(model_link, model_path)
58
+ self.__download(config_link, config_path)
59
+ self.__download(speakers_link, speakers_path)
60
+
61
+ self.synthesizer = Synthesizer(
62
+ model_path,
63
+ config_path,
64
+ speakers_path,
65
+ None,
66
+ None,
67
+ )
68
+
69
+ if self.synthesizer is None:
70
+ raise NameError("model not found")
71
+
72
+
73
+ def __download(self, url, file_name):
74
+ if not exists(file_name):
75
+ print(f"Downloading {file_name}")
76
+ r = requests.get(url, allow_redirects=True)
77
+ with open(file_name, "wb") as file:
78
+ file.write(r.content)
79
+ else:
80
+ print(f"Found {file_name}. Skipping download...")
81
+
82
+