jhj0517 commited on
Commit
b065a65
β€’
1 Parent(s): 1f8abba

refactoring to use data class

Browse files
Files changed (2) hide show
  1. app.py +29 -9
  2. modules/faster_whisper_inference.py +72 -170
app.py CHANGED
@@ -8,6 +8,8 @@ from modules.nllb_inference import NLLBInference
8
  from ui.htmls import *
9
  from modules.youtube_manager import get_ytmetas
10
  from modules.deepl_api import DeepLAPI
 
 
11
 
12
  class App:
13
  def __init__(self, args):
@@ -68,10 +70,16 @@ class App:
68
  files_subtitles = gr.Files(label="Downloadable output file", scale=4, interactive=False)
69
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
70
 
71
- params = [input_file, dd_model, dd_lang, dd_file_format, cb_translate, cb_timestamp]
72
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
 
 
 
 
 
 
73
  btn_run.click(fn=self.whisper_inf.transcribe_file,
74
- inputs=params + advanced_params,
75
  outputs=[tb_indicator, files_subtitles])
76
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
77
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
@@ -108,10 +116,16 @@ class App:
108
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
109
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
110
 
111
- params = [tb_youtubelink, dd_model, dd_lang, dd_file_format, cb_translate, cb_timestamp]
112
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
 
 
 
 
 
 
113
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
114
- inputs=params + advanced_params,
115
  outputs=[tb_indicator, files_subtitles])
116
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
117
  outputs=[img_thumbnail, tb_title, tb_description])
@@ -141,10 +155,16 @@ class App:
141
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
142
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
143
 
144
- params = [mic_input, dd_model, dd_lang, dd_file_format, cb_translate]
145
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
 
 
 
 
 
 
146
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
147
- inputs=params + advanced_params,
148
  outputs=[tb_indicator, files_subtitles])
149
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
150
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
 
8
  from ui.htmls import *
9
  from modules.youtube_manager import get_ytmetas
10
  from modules.deepl_api import DeepLAPI
11
+ from modules.whisper_data_class import *
12
+
13
 
14
  class App:
15
  def __init__(self, args):
 
70
  files_subtitles = gr.Files(label="Downloadable output file", scale=4, interactive=False)
71
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
72
 
73
+ params = [input_file, dd_file_format, cb_timestamp]
74
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
75
+ lang=dd_lang,
76
+ is_translate=cb_translate,
77
+ beam_size=nb_beam_size,
78
+ log_prob_threshold=nb_log_prob_threshold,
79
+ no_speech_threshold=nb_no_speech_threshold,
80
+ compute_type=dd_compute_type)
81
  btn_run.click(fn=self.whisper_inf.transcribe_file,
82
+ inputs=params + whisper_params.to_list(),
83
  outputs=[tb_indicator, files_subtitles])
84
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
85
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
 
116
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
117
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
118
 
119
+ params = [tb_youtubelink, dd_file_format, cb_timestamp]
120
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
121
+ lang=dd_lang,
122
+ is_translate=cb_translate,
123
+ beam_size=nb_beam_size,
124
+ log_prob_threshold=nb_log_prob_threshold,
125
+ no_speech_threshold=nb_no_speech_threshold,
126
+ compute_type=dd_compute_type)
127
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
128
+ inputs=params + whisper_params.to_list(),
129
  outputs=[tb_indicator, files_subtitles])
130
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
131
  outputs=[img_thumbnail, tb_title, tb_description])
 
155
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
156
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
157
 
158
+ params = [mic_input, dd_file_format]
159
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
160
+ lang=dd_lang,
161
+ is_translate=cb_translate,
162
+ beam_size=nb_beam_size,
163
+ log_prob_threshold=nb_log_prob_threshold,
164
+ no_speech_threshold=nb_no_speech_threshold,
165
+ compute_type=dd_compute_type)
166
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
167
+ inputs=params + whisper_params.to_list(),
168
  outputs=[tb_indicator, files_subtitles])
169
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
170
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
modules/faster_whisper_inference.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
 
3
- import tqdm
4
  import time
5
  import numpy as np
6
  from typing import BinaryIO, Union, Tuple, List
7
- from datetime import datetime, timedelta
8
 
9
  import faster_whisper
10
  import ctranslate2
@@ -15,6 +14,7 @@ import gradio as gr
15
  from .base_interface import BaseInterface
16
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
  from modules.youtube_manager import get_ytdata, get_ytaudio
 
18
 
19
 
20
  class FasterWhisperInference(BaseInterface):
@@ -26,22 +26,17 @@ class FasterWhisperInference(BaseInterface):
26
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
27
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
- self.available_compute_types = ctranslate2.get_supported_compute_types("cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
 
30
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
31
  self.default_beam_size = 1
32
 
33
  def transcribe_file(self,
34
  fileobjs: list,
35
- model_size: str,
36
- lang: str,
37
  file_format: str,
38
- istranslate: bool,
39
  add_timestamp: bool,
40
- beam_size: int,
41
- log_prob_threshold: float,
42
- no_speech_threshold: float,
43
- compute_type: str,
44
- progress=gr.Progress()
45
  ) -> list:
46
  """
47
  Write subtitle file from Files
@@ -50,31 +45,14 @@ class FasterWhisperInference(BaseInterface):
50
  ----------
51
  fileobjs: list
52
  List of files to transcribe from gr.Files()
53
- model_size: str
54
- Whisper model size from gr.Dropdown()
55
- lang: str
56
- Source language of the file to transcribe from gr.Dropdown()
57
  file_format: str
58
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
59
- istranslate: bool
60
- Boolean value from gr.Checkbox() that determines whether to translate to English.
61
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
62
  add_timestamp: bool
63
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
64
- beam_size: int
65
- Int value from gr.Number() that is used for decoding option.
66
- log_prob_threshold: float
67
- float value from gr.Number(). If the average log probability over sampled tokens is
68
- below this value, treat as failed.
69
- no_speech_threshold: float
70
- float value from gr.Number(). If the no_speech probability is higher than this value AND
71
- the average log probability over sampled tokens is below `log_prob_threshold`,
72
- consider the segment as silent.
73
- compute_type: str
74
- compute type from gr.Dropdown().
75
- see more info : https://opennmt.net/CTranslate2/quantization.html
76
  progress: gr.Progress
77
  Indicator to show progress directly in gradio.
 
 
78
 
79
  Returns
80
  ----------
@@ -83,18 +61,12 @@ class FasterWhisperInference(BaseInterface):
83
  Files to return to gr.Files()
84
  """
85
  try:
86
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
87
-
88
  files_info = {}
89
  for fileobj in fileobjs:
90
  transcribed_segments, time_for_task = self.transcribe(
91
- audio=fileobj.name,
92
- lang=lang,
93
- istranslate=istranslate,
94
- beam_size=beam_size,
95
- log_prob_threshold=log_prob_threshold,
96
- no_speech_threshold=no_speech_threshold,
97
- progress=progress
98
  )
99
 
100
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
@@ -105,7 +77,7 @@ class FasterWhisperInference(BaseInterface):
105
  add_timestamp=add_timestamp,
106
  file_format=file_format
107
  )
108
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
109
 
110
  total_result = ''
111
  total_time = 0
@@ -115,10 +87,10 @@ class FasterWhisperInference(BaseInterface):
115
  total_result += f'{info["subtitle"]}'
116
  total_time += info["time_for_task"]
117
 
118
- gr_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
119
- gr_file_path = [info['path'] for info in files_info.values()]
120
 
121
- return [gr_str, gr_file_path]
122
 
123
  except Exception as e:
124
  print(f"Error transcribing file on line {e}")
@@ -128,50 +100,27 @@ class FasterWhisperInference(BaseInterface):
128
  self.remove_input_files([fileobj.name for fileobj in fileobjs])
129
 
130
  def transcribe_youtube(self,
131
- youtubelink: str,
132
- model_size: str,
133
- lang: str,
134
  file_format: str,
135
- istranslate: bool,
136
  add_timestamp: bool,
137
- beam_size: int,
138
- log_prob_threshold: float,
139
- no_speech_threshold: float,
140
- compute_type: str,
141
- progress=gr.Progress()
142
  ) -> list:
143
  """
144
  Write subtitle file from Youtube
145
 
146
  Parameters
147
  ----------
148
- youtubelink: str
149
- Link of Youtube to transcribe from gr.Textbox()
150
- model_size: str
151
- Whisper model size from gr.Dropdown()
152
- lang: str
153
- Source language of the file to transcribe from gr.Dropdown()
154
  file_format: str
155
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
156
- istranslate: bool
157
- Boolean value from gr.Checkbox() that determines whether to translate to English.
158
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
159
  add_timestamp: bool
160
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
161
- beam_size: int
162
- Int value from gr.Number() that is used for decoding option.
163
- log_prob_threshold: float
164
- float value from gr.Number(). If the average log probability over sampled tokens is
165
- below this value, treat as failed.
166
- no_speech_threshold: float
167
- float value from gr.Number(). If the no_speech probability is higher than this value AND
168
- the average log probability over sampled tokens is below `log_prob_threshold`,
169
- consider the segment as silent.
170
- compute_type: str
171
- compute type from gr.Dropdown().
172
- see more info : https://opennmt.net/CTranslate2/quantization.html
173
  progress: gr.Progress
174
  Indicator to show progress directly in gradio.
 
 
175
 
176
  Returns
177
  ----------
@@ -180,20 +129,14 @@ class FasterWhisperInference(BaseInterface):
180
  Files to return to gr.Files()
181
  """
182
  try:
183
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
184
-
185
  progress(0, desc="Loading Audio from Youtube..")
186
- yt = get_ytdata(youtubelink)
187
  audio = get_ytaudio(yt)
188
 
189
  transcribed_segments, time_for_task = self.transcribe(
190
- audio=audio,
191
- lang=lang,
192
- istranslate=istranslate,
193
- beam_size=beam_size,
194
- log_prob_threshold=log_prob_threshold,
195
- no_speech_threshold=no_speech_threshold,
196
- progress=progress
197
  )
198
 
199
  progress(1, desc="Completed!")
@@ -214,7 +157,7 @@ class FasterWhisperInference(BaseInterface):
214
  finally:
215
  try:
216
  if 'yt' not in locals():
217
- yt = get_ytdata(youtubelink)
218
  file_path = get_ytaudio(yt)
219
  else:
220
  file_path = get_ytaudio(yt)
@@ -225,47 +168,24 @@ class FasterWhisperInference(BaseInterface):
225
  pass
226
 
227
  def transcribe_mic(self,
228
- micaudio: str,
229
- model_size: str,
230
- lang: str,
231
  file_format: str,
232
- istranslate: bool,
233
- beam_size: int,
234
- log_prob_threshold: float,
235
- no_speech_threshold: float,
236
- compute_type: str,
237
- progress=gr.Progress()
238
  ) -> list:
239
  """
240
  Write subtitle file from microphone
241
 
242
  Parameters
243
  ----------
244
- micaudio: str
245
  Audio file path from gr.Microphone()
246
- model_size: str
247
- Whisper model size from gr.Dropdown()
248
- lang: str
249
- Source language of the file to transcribe from gr.Dropdown()
250
  file_format: str
251
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
252
- istranslate: bool
253
- Boolean value from gr.Checkbox() that determines whether to translate to English.
254
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
255
- beam_size: int
256
- Int value from gr.Number() that is used for decoding option.
257
- log_prob_threshold: float
258
- float value from gr.Number(). If the average log probability over sampled tokens is
259
- below this value, treat as failed.
260
- no_speech_threshold: float
261
- float value from gr.Number(). If the no_speech probability is higher than this value AND
262
- the average log probability over sampled tokens is below `log_prob_threshold`,
263
- compute_type: str
264
- compute type from gr.Dropdown().
265
- see more info : https://opennmt.net/CTranslate2/quantization.html
266
- consider the segment as silent.
267
  progress: gr.Progress
268
  Indicator to show progress directly in gradio.
 
 
269
 
270
  Returns
271
  ----------
@@ -274,18 +194,11 @@ class FasterWhisperInference(BaseInterface):
274
  Files to return to gr.Files()
275
  """
276
  try:
277
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
278
-
279
  progress(0, desc="Loading Audio..")
280
-
281
  transcribed_segments, time_for_task = self.transcribe(
282
- audio=micaudio,
283
- lang=lang,
284
- istranslate=istranslate,
285
- beam_size=beam_size,
286
- log_prob_threshold=log_prob_threshold,
287
- no_speech_threshold=no_speech_threshold,
288
- progress=progress
289
  )
290
  progress(1, desc="Completed!")
291
 
@@ -302,16 +215,12 @@ class FasterWhisperInference(BaseInterface):
302
  print(f"Error transcribing file on line {e}")
303
  finally:
304
  self.release_cuda_memory()
305
- self.remove_input_files([micaudio])
306
 
307
  def transcribe(self,
308
  audio: Union[str, BinaryIO, np.ndarray],
309
- lang: str,
310
- istranslate: bool,
311
- beam_size: int,
312
- log_prob_threshold: float,
313
- no_speech_threshold: float,
314
- progress: gr.Progress
315
  ) -> Tuple[List[dict], float]:
316
  """
317
  transcribe method for faster-whisper.
@@ -320,22 +229,10 @@ class FasterWhisperInference(BaseInterface):
320
  ----------
321
  audio: Union[str, BinaryIO, np.ndarray]
322
  Audio path or file binary or Audio numpy array
323
- lang: str
324
- Source language of the file to transcribe from gr.Dropdown()
325
- istranslate: bool
326
- Boolean value from gr.Checkbox() that determines whether to translate to English.
327
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
328
- beam_size: int
329
- Int value from gr.Number() that is used for decoding option.
330
- log_prob_threshold: float
331
- float value from gr.Number(). If the average log probability over sampled tokens is
332
- below this value, treat as failed.
333
- no_speech_threshold: float
334
- float value from gr.Number(). If the no_speech probability is higher than this value AND
335
- the average log probability over sampled tokens is below `log_prob_threshold`,
336
- consider the segment as silent.
337
  progress: gr.Progress
338
  Indicator to show progress directly in gradio.
 
 
339
 
340
  Returns
341
  ----------
@@ -346,18 +243,24 @@ class FasterWhisperInference(BaseInterface):
346
  """
347
  start_time = time.time()
348
 
349
- if lang == "Automatic Detection":
 
 
 
 
 
350
  lang = None
351
  else:
352
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
353
- lang = language_code_dict[lang]
 
354
  segments, info = self.model.transcribe(
355
  audio=audio,
356
  language=lang,
357
- task="translate" if istranslate and self.current_model_size in self.translatable_models else "transcribe",
358
- beam_size=beam_size,
359
- log_prob_threshold=log_prob_threshold,
360
- no_speech_threshold=no_speech_threshold,
361
  )
362
  progress(0, desc="Loading audio..")
363
 
@@ -373,24 +276,23 @@ class FasterWhisperInference(BaseInterface):
373
  elapsed_time = time.time() - start_time
374
  return segments_result, elapsed_time
375
 
376
- def update_model_if_needed(self,
377
- model_size: str,
378
- compute_type: str,
379
- progress: gr.Progress
380
- ):
381
  """
382
- Initialize model if it doesn't match with current model setting
383
  """
384
- if model_size != self.current_model_size or self.model is None or self.current_compute_type != compute_type:
385
- progress(0, desc="Initializing Model..")
386
- self.current_model_size = model_size
387
- self.current_compute_type = compute_type
388
- self.model = faster_whisper.WhisperModel(
389
- device=self.device,
390
- model_size_or_path=model_size,
391
- download_root=os.path.join("models", "Whisper", "faster-whisper"),
392
- compute_type=self.current_compute_type
393
- )
394
 
395
  @staticmethod
396
  def generate_and_write_file(file_name: str,
 
1
  import os
2
 
 
3
  import time
4
  import numpy as np
5
  from typing import BinaryIO, Union, Tuple, List
6
+ from datetime import datetime
7
 
8
  import faster_whisper
9
  import ctranslate2
 
14
  from .base_interface import BaseInterface
15
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
16
  from modules.youtube_manager import get_ytdata, get_ytaudio
17
+ from modules.whisper_data_class import *
18
 
19
 
20
  class FasterWhisperInference(BaseInterface):
 
26
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
27
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ self.available_compute_types = ctranslate2.get_supported_compute_types(
30
+ "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
31
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
32
  self.default_beam_size = 1
33
 
34
  def transcribe_file(self,
35
  fileobjs: list,
 
 
36
  file_format: str,
 
37
  add_timestamp: bool,
38
+ progress=gr.Progress(),
39
+ *whisper_params,
 
 
 
40
  ) -> list:
41
  """
42
  Write subtitle file from Files
 
45
  ----------
46
  fileobjs: list
47
  List of files to transcribe from gr.Files()
 
 
 
 
48
  file_format: str
49
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
50
  add_timestamp: bool
51
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
 
 
 
 
 
 
 
 
 
 
 
 
52
  progress: gr.Progress
53
  Indicator to show progress directly in gradio.
54
+ *whisper_params: tuple
55
+ Gradio components related to Whisper. see whisper_data_class.py for details.
56
 
57
  Returns
58
  ----------
 
61
  Files to return to gr.Files()
62
  """
63
  try:
 
 
64
  files_info = {}
65
  for fileobj in fileobjs:
66
  transcribed_segments, time_for_task = self.transcribe(
67
+ fileobj.name,
68
+ progress,
69
+ *whisper_params,
 
 
 
 
70
  )
71
 
72
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
 
77
  add_timestamp=add_timestamp,
78
  file_format=file_format
79
  )
80
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
81
 
82
  total_result = ''
83
  total_time = 0
 
87
  total_result += f'{info["subtitle"]}'
88
  total_time += info["time_for_task"]
89
 
90
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
91
+ result_file_path = [info['path'] for info in files_info.values()]
92
 
93
+ return [result_str, result_file_path]
94
 
95
  except Exception as e:
96
  print(f"Error transcribing file on line {e}")
 
100
  self.remove_input_files([fileobj.name for fileobj in fileobjs])
101
 
102
  def transcribe_youtube(self,
103
+ youtube_link: str,
 
 
104
  file_format: str,
 
105
  add_timestamp: bool,
106
+ progress=gr.Progress(),
107
+ *whisper_params,
 
 
 
108
  ) -> list:
109
  """
110
  Write subtitle file from Youtube
111
 
112
  Parameters
113
  ----------
114
+ youtube_link: str
115
+ URL of the Youtube video to transcribe from gr.Textbox()
 
 
 
 
116
  file_format: str
117
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
118
  add_timestamp: bool
119
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
 
 
 
 
 
 
 
 
 
 
 
 
120
  progress: gr.Progress
121
  Indicator to show progress directly in gradio.
122
+ *whisper_params: tuple
123
+ Gradio components related to Whisper. see whisper_data_class.py for details.
124
 
125
  Returns
126
  ----------
 
129
  Files to return to gr.Files()
130
  """
131
  try:
 
 
132
  progress(0, desc="Loading Audio from Youtube..")
133
+ yt = get_ytdata(youtube_link)
134
  audio = get_ytaudio(yt)
135
 
136
  transcribed_segments, time_for_task = self.transcribe(
137
+ audio,
138
+ progress,
139
+ *whisper_params,
 
 
 
 
140
  )
141
 
142
  progress(1, desc="Completed!")
 
157
  finally:
158
  try:
159
  if 'yt' not in locals():
160
+ yt = get_ytdata(youtube_link)
161
  file_path = get_ytaudio(yt)
162
  else:
163
  file_path = get_ytaudio(yt)
 
168
  pass
169
 
170
  def transcribe_mic(self,
171
+ mic_audio: str,
 
 
172
  file_format: str,
173
+ progress=gr.Progress(),
174
+ *whisper_params,
 
 
 
 
175
  ) -> list:
176
  """
177
  Write subtitle file from microphone
178
 
179
  Parameters
180
  ----------
181
+ mic_audio: str
182
  Audio file path from gr.Microphone()
 
 
 
 
183
  file_format: str
184
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  progress: gr.Progress
186
  Indicator to show progress directly in gradio.
187
+ *whisper_params: tuple
188
+ Gradio components related to Whisper. see whisper_data_class.py for details.
189
 
190
  Returns
191
  ----------
 
194
  Files to return to gr.Files()
195
  """
196
  try:
 
 
197
  progress(0, desc="Loading Audio..")
 
198
  transcribed_segments, time_for_task = self.transcribe(
199
+ mic_audio,
200
+ progress,
201
+ *whisper_params,
 
 
 
 
202
  )
203
  progress(1, desc="Completed!")
204
 
 
215
  print(f"Error transcribing file on line {e}")
216
  finally:
217
  self.release_cuda_memory()
218
+ self.remove_input_files([mic_audio])
219
 
220
  def transcribe(self,
221
  audio: Union[str, BinaryIO, np.ndarray],
222
+ progress: gr.Progress,
223
+ *whisper_params,
 
 
 
 
224
  ) -> Tuple[List[dict], float]:
225
  """
226
  transcribe method for faster-whisper.
 
229
  ----------
230
  audio: Union[str, BinaryIO, np.ndarray]
231
  Audio path or file binary or Audio numpy array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  progress: gr.Progress
233
  Indicator to show progress directly in gradio.
234
+ *whisper_params: tuple
235
+ Gradio components related to Whisper. see whisper_data_class.py for details.
236
 
237
  Returns
238
  ----------
 
243
  """
244
  start_time = time.time()
245
 
246
+ params = WhisperGradioComponents.to_values(*whisper_params)
247
+
248
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
249
+ self.update_model(params.model_size, params.compute_type, progress)
250
+
251
+ if params.lang == "Automatic Detection":
252
  lang = None
253
  else:
254
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
255
+ lang = language_code_dict[params.lang]
256
+
257
  segments, info = self.model.transcribe(
258
  audio=audio,
259
  language=lang,
260
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
261
+ beam_size=params.beam_size,
262
+ log_prob_threshold=params.log_prob_threshold,
263
+ no_speech_threshold=params.no_speech_threshold,
264
  )
265
  progress(0, desc="Loading audio..")
266
 
 
276
  elapsed_time = time.time() - start_time
277
  return segments_result, elapsed_time
278
 
279
+ def update_model(self,
280
+ model_size: str,
281
+ compute_type: str,
282
+ progress: gr.Progress
283
+ ):
284
  """
285
+ update current model setting
286
  """
287
+ progress(0, desc="Initializing Model..")
288
+ self.current_model_size = model_size
289
+ self.current_compute_type = compute_type
290
+ self.model = faster_whisper.WhisperModel(
291
+ device=self.device,
292
+ model_size_or_path=model_size,
293
+ download_root=os.path.join("models", "Whisper", "faster-whisper"),
294
+ compute_type=self.current_compute_type
295
+ )
 
296
 
297
  @staticmethod
298
  def generate_and_write_file(file_name: str,