jhj0517 commited on
Commit
04fe334
β€’
2 Parent(s): 10a154c e3a6426

Merge pull request #134 from jhj0517/feature/more-parameters

Browse files
app.py CHANGED
@@ -8,6 +8,8 @@ from modules.nllb_inference import NLLBInference
8
  from ui.htmls import *
9
  from modules.youtube_manager import get_ytmetas
10
  from modules.deepl_api import DeepLAPI
 
 
11
 
12
  class App:
13
  def __init__(self, args):
@@ -61,6 +63,8 @@ class App:
61
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
62
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
63
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
 
 
64
  with gr.Row():
65
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
66
  with gr.Row():
@@ -68,10 +72,18 @@ class App:
68
  files_subtitles = gr.Files(label="Downloadable output file", scale=4, interactive=False)
69
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
70
 
71
- params = [input_file, dd_model, dd_lang, dd_file_format, cb_translate, cb_timestamp]
72
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
 
 
 
 
 
 
 
 
73
  btn_run.click(fn=self.whisper_inf.transcribe_file,
74
- inputs=params + advanced_params,
75
  outputs=[tb_indicator, files_subtitles])
76
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
77
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
@@ -101,6 +113,8 @@ class App:
101
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
102
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
103
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
 
 
104
  with gr.Row():
105
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
106
  with gr.Row():
@@ -108,10 +122,18 @@ class App:
108
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
109
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
110
 
111
- params = [tb_youtubelink, dd_model, dd_lang, dd_file_format, cb_translate, cb_timestamp]
112
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
 
 
 
 
 
 
 
 
113
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
114
- inputs=params + advanced_params,
115
  outputs=[tb_indicator, files_subtitles])
116
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
117
  outputs=[img_thumbnail, tb_title, tb_description])
@@ -134,6 +156,8 @@ class App:
134
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
135
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
136
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
 
 
137
  with gr.Row():
138
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
139
  with gr.Row():
@@ -141,10 +165,18 @@ class App:
141
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
142
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
143
 
144
- params = [mic_input, dd_model, dd_lang, dd_file_format, cb_translate]
145
- advanced_params = [nb_beam_size, nb_log_prob_threshold, nb_no_speech_threshold, dd_compute_type]
 
 
 
 
 
 
 
 
146
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
147
- inputs=params + advanced_params,
148
  outputs=[tb_indicator, files_subtitles])
149
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
150
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
 
8
  from ui.htmls import *
9
  from modules.youtube_manager import get_ytmetas
10
  from modules.deepl_api import DeepLAPI
11
+ from modules.whisper_data_class import *
12
+
13
 
14
  class App:
15
  def __init__(self, args):
 
63
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
64
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
65
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
66
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
67
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True)
68
  with gr.Row():
69
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
70
  with gr.Row():
 
72
  files_subtitles = gr.Files(label="Downloadable output file", scale=4, interactive=False)
73
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
74
 
75
+ params = [input_file, dd_file_format, cb_timestamp]
76
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
77
+ lang=dd_lang,
78
+ is_translate=cb_translate,
79
+ beam_size=nb_beam_size,
80
+ log_prob_threshold=nb_log_prob_threshold,
81
+ no_speech_threshold=nb_no_speech_threshold,
82
+ compute_type=dd_compute_type,
83
+ best_of=nb_best_of,
84
+ patience=nb_patience)
85
  btn_run.click(fn=self.whisper_inf.transcribe_file,
86
+ inputs=params + whisper_params.to_list(),
87
  outputs=[tb_indicator, files_subtitles])
88
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
89
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
 
113
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
114
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
115
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
116
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
117
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True)
118
  with gr.Row():
119
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
120
  with gr.Row():
 
122
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
123
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
124
 
125
+ params = [tb_youtubelink, dd_file_format, cb_timestamp]
126
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
127
+ lang=dd_lang,
128
+ is_translate=cb_translate,
129
+ beam_size=nb_beam_size,
130
+ log_prob_threshold=nb_log_prob_threshold,
131
+ no_speech_threshold=nb_no_speech_threshold,
132
+ compute_type=dd_compute_type,
133
+ best_of=nb_best_of,
134
+ patience=nb_patience)
135
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
136
+ inputs=params + whisper_params.to_list(),
137
  outputs=[tb_indicator, files_subtitles])
138
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
139
  outputs=[img_thumbnail, tb_title, tb_description])
 
156
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
157
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
158
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
159
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
160
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True)
161
  with gr.Row():
162
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
163
  with gr.Row():
 
165
  files_subtitles = gr.Files(label="Downloadable output file", scale=4)
166
  btn_openfolder = gr.Button('πŸ“‚', scale=1)
167
 
168
+ params = [mic_input, dd_file_format]
169
+ whisper_params = WhisperGradioComponents(model_size=dd_model,
170
+ lang=dd_lang,
171
+ is_translate=cb_translate,
172
+ beam_size=nb_beam_size,
173
+ log_prob_threshold=nb_log_prob_threshold,
174
+ no_speech_threshold=nb_no_speech_threshold,
175
+ compute_type=dd_compute_type,
176
+ best_of=nb_best_of,
177
+ patience=nb_patience)
178
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
179
+ inputs=params + whisper_params.to_list(),
180
  outputs=[tb_indicator, files_subtitles])
181
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
182
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
modules/faster_whisper_inference.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
 
3
- import tqdm
4
  import time
5
  import numpy as np
6
  from typing import BinaryIO, Union, Tuple, List
7
- from datetime import datetime, timedelta
8
 
9
  import faster_whisper
10
  import ctranslate2
@@ -15,6 +14,7 @@ import gradio as gr
15
  from .base_interface import BaseInterface
16
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
  from modules.youtube_manager import get_ytdata, get_ytaudio
 
18
 
19
 
20
  class FasterWhisperInference(BaseInterface):
@@ -26,78 +26,51 @@ class FasterWhisperInference(BaseInterface):
26
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
27
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
- self.available_compute_types = ctranslate2.get_supported_compute_types("cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
 
30
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
31
  self.default_beam_size = 1
32
 
33
  def transcribe_file(self,
34
- fileobjs: list,
35
- model_size: str,
36
- lang: str,
37
  file_format: str,
38
- istranslate: bool,
39
  add_timestamp: bool,
40
- beam_size: int,
41
- log_prob_threshold: float,
42
- no_speech_threshold: float,
43
- compute_type: str,
44
- progress=gr.Progress()
45
  ) -> list:
46
  """
47
  Write subtitle file from Files
48
 
49
  Parameters
50
  ----------
51
- fileobjs: list
52
  List of files to transcribe from gr.Files()
53
- model_size: str
54
- Whisper model size from gr.Dropdown()
55
- lang: str
56
- Source language of the file to transcribe from gr.Dropdown()
57
  file_format: str
58
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
59
- istranslate: bool
60
- Boolean value from gr.Checkbox() that determines whether to translate to English.
61
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
62
  add_timestamp: bool
63
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
64
- beam_size: int
65
- Int value from gr.Number() that is used for decoding option.
66
- log_prob_threshold: float
67
- float value from gr.Number(). If the average log probability over sampled tokens is
68
- below this value, treat as failed.
69
- no_speech_threshold: float
70
- float value from gr.Number(). If the no_speech probability is higher than this value AND
71
- the average log probability over sampled tokens is below `log_prob_threshold`,
72
- consider the segment as silent.
73
- compute_type: str
74
- compute type from gr.Dropdown().
75
- see more info : https://opennmt.net/CTranslate2/quantization.html
76
  progress: gr.Progress
77
  Indicator to show progress directly in gradio.
 
 
78
 
79
  Returns
80
  ----------
81
- A List of
82
- String to return to gr.Textbox()
83
- Files to return to gr.Files()
 
84
  """
85
  try:
86
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
87
-
88
  files_info = {}
89
- for fileobj in fileobjs:
90
  transcribed_segments, time_for_task = self.transcribe(
91
- audio=fileobj.name,
92
- lang=lang,
93
- istranslate=istranslate,
94
- beam_size=beam_size,
95
- log_prob_threshold=log_prob_threshold,
96
- no_speech_threshold=no_speech_threshold,
97
- progress=progress
98
  )
99
 
100
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
101
  file_name = safe_filename(file_name)
102
  subtitle, file_path = self.generate_and_write_file(
103
  file_name=file_name,
@@ -105,7 +78,7 @@ class FasterWhisperInference(BaseInterface):
105
  add_timestamp=add_timestamp,
106
  file_format=file_format
107
  )
108
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
109
 
110
  total_result = ''
111
  total_time = 0
@@ -115,106 +88,78 @@ class FasterWhisperInference(BaseInterface):
115
  total_result += f'{info["subtitle"]}'
116
  total_time += info["time_for_task"]
117
 
118
- gr_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
119
- gr_file_path = [info['path'] for info in files_info.values()]
120
 
121
- return [gr_str, gr_file_path]
122
 
123
  except Exception as e:
124
- print(f"Error transcribing file on line {e}")
125
  finally:
126
  self.release_cuda_memory()
127
- if not fileobjs:
128
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
129
 
130
  def transcribe_youtube(self,
131
- youtubelink: str,
132
- model_size: str,
133
- lang: str,
134
  file_format: str,
135
- istranslate: bool,
136
  add_timestamp: bool,
137
- beam_size: int,
138
- log_prob_threshold: float,
139
- no_speech_threshold: float,
140
- compute_type: str,
141
- progress=gr.Progress()
142
  ) -> list:
143
  """
144
  Write subtitle file from Youtube
145
 
146
  Parameters
147
  ----------
148
- youtubelink: str
149
- Link of Youtube to transcribe from gr.Textbox()
150
- model_size: str
151
- Whisper model size from gr.Dropdown()
152
- lang: str
153
- Source language of the file to transcribe from gr.Dropdown()
154
  file_format: str
155
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
156
- istranslate: bool
157
- Boolean value from gr.Checkbox() that determines whether to translate to English.
158
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
159
  add_timestamp: bool
160
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
161
- beam_size: int
162
- Int value from gr.Number() that is used for decoding option.
163
- log_prob_threshold: float
164
- float value from gr.Number(). If the average log probability over sampled tokens is
165
- below this value, treat as failed.
166
- no_speech_threshold: float
167
- float value from gr.Number(). If the no_speech probability is higher than this value AND
168
- the average log probability over sampled tokens is below `log_prob_threshold`,
169
- consider the segment as silent.
170
- compute_type: str
171
- compute type from gr.Dropdown().
172
- see more info : https://opennmt.net/CTranslate2/quantization.html
173
  progress: gr.Progress
174
  Indicator to show progress directly in gradio.
 
 
175
 
176
  Returns
177
  ----------
178
- A List of
179
- String to return to gr.Textbox()
180
- Files to return to gr.Files()
 
181
  """
182
  try:
183
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
184
-
185
  progress(0, desc="Loading Audio from Youtube..")
186
- yt = get_ytdata(youtubelink)
187
  audio = get_ytaudio(yt)
188
 
189
  transcribed_segments, time_for_task = self.transcribe(
190
- audio=audio,
191
- lang=lang,
192
- istranslate=istranslate,
193
- beam_size=beam_size,
194
- log_prob_threshold=log_prob_threshold,
195
- no_speech_threshold=no_speech_threshold,
196
- progress=progress
197
  )
198
 
199
  progress(1, desc="Completed!")
200
 
201
  file_name = safe_filename(yt.title)
202
- subtitle, file_path = self.generate_and_write_file(
203
  file_name=file_name,
204
  transcribed_segments=transcribed_segments,
205
  add_timestamp=add_timestamp,
206
  file_format=file_format
207
  )
208
- gr_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
209
 
210
- return [gr_str, file_path]
211
 
212
  except Exception as e:
213
- print(f"Error transcribing file on line {e}")
214
  finally:
215
  try:
216
  if 'yt' not in locals():
217
- yt = get_ytdata(youtubelink)
218
  file_path = get_ytaudio(yt)
219
  else:
220
  file_path = get_ytaudio(yt)
@@ -225,93 +170,60 @@ class FasterWhisperInference(BaseInterface):
225
  pass
226
 
227
  def transcribe_mic(self,
228
- micaudio: str,
229
- model_size: str,
230
- lang: str,
231
  file_format: str,
232
- istranslate: bool,
233
- beam_size: int,
234
- log_prob_threshold: float,
235
- no_speech_threshold: float,
236
- compute_type: str,
237
- progress=gr.Progress()
238
  ) -> list:
239
  """
240
  Write subtitle file from microphone
241
 
242
  Parameters
243
  ----------
244
- micaudio: str
245
  Audio file path from gr.Microphone()
246
- model_size: str
247
- Whisper model size from gr.Dropdown()
248
- lang: str
249
- Source language of the file to transcribe from gr.Dropdown()
250
  file_format: str
251
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
252
- istranslate: bool
253
- Boolean value from gr.Checkbox() that determines whether to translate to English.
254
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
255
- beam_size: int
256
- Int value from gr.Number() that is used for decoding option.
257
- log_prob_threshold: float
258
- float value from gr.Number(). If the average log probability over sampled tokens is
259
- below this value, treat as failed.
260
- no_speech_threshold: float
261
- float value from gr.Number(). If the no_speech probability is higher than this value AND
262
- the average log probability over sampled tokens is below `log_prob_threshold`,
263
- compute_type: str
264
- compute type from gr.Dropdown().
265
- see more info : https://opennmt.net/CTranslate2/quantization.html
266
- consider the segment as silent.
267
  progress: gr.Progress
268
  Indicator to show progress directly in gradio.
 
 
269
 
270
  Returns
271
  ----------
272
- A List of
273
- String to return to gr.Textbox()
274
- Files to return to gr.Files()
 
275
  """
276
  try:
277
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
278
-
279
  progress(0, desc="Loading Audio..")
280
-
281
  transcribed_segments, time_for_task = self.transcribe(
282
- audio=micaudio,
283
- lang=lang,
284
- istranslate=istranslate,
285
- beam_size=beam_size,
286
- log_prob_threshold=log_prob_threshold,
287
- no_speech_threshold=no_speech_threshold,
288
- progress=progress
289
  )
290
  progress(1, desc="Completed!")
291
 
292
- subtitle, file_path = self.generate_and_write_file(
293
  file_name="Mic",
294
  transcribed_segments=transcribed_segments,
295
  add_timestamp=True,
296
  file_format=file_format
297
  )
298
 
299
- gr_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
300
- return [gr_str, file_path]
301
  except Exception as e:
302
- print(f"Error transcribing file on line {e}")
303
  finally:
304
  self.release_cuda_memory()
305
- self.remove_input_files([micaudio])
306
 
307
  def transcribe(self,
308
  audio: Union[str, BinaryIO, np.ndarray],
309
- lang: str,
310
- istranslate: bool,
311
- beam_size: int,
312
- log_prob_threshold: float,
313
- no_speech_threshold: float,
314
- progress: gr.Progress
315
  ) -> Tuple[List[dict], float]:
316
  """
317
  transcribe method for faster-whisper.
@@ -320,22 +232,10 @@ class FasterWhisperInference(BaseInterface):
320
  ----------
321
  audio: Union[str, BinaryIO, np.ndarray]
322
  Audio path or file binary or Audio numpy array
323
- lang: str
324
- Source language of the file to transcribe from gr.Dropdown()
325
- istranslate: bool
326
- Boolean value from gr.Checkbox() that determines whether to translate to English.
327
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
328
- beam_size: int
329
- Int value from gr.Number() that is used for decoding option.
330
- log_prob_threshold: float
331
- float value from gr.Number(). If the average log probability over sampled tokens is
332
- below this value, treat as failed.
333
- no_speech_threshold: float
334
- float value from gr.Number(). If the no_speech probability is higher than this value AND
335
- the average log probability over sampled tokens is below `log_prob_threshold`,
336
- consider the segment as silent.
337
  progress: gr.Progress
338
  Indicator to show progress directly in gradio.
 
 
339
 
340
  Returns
341
  ----------
@@ -346,18 +246,26 @@ class FasterWhisperInference(BaseInterface):
346
  """
347
  start_time = time.time()
348
 
349
- if lang == "Automatic Detection":
350
- lang = None
 
 
 
 
 
351
  else:
352
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
353
- lang = language_code_dict[lang]
 
354
  segments, info = self.model.transcribe(
355
  audio=audio,
356
- language=lang,
357
- task="translate" if istranslate and self.current_model_size in self.translatable_models else "transcribe",
358
- beam_size=beam_size,
359
- log_prob_threshold=log_prob_threshold,
360
- no_speech_threshold=no_speech_threshold,
 
 
361
  )
362
  progress(0, desc="Loading audio..")
363
 
@@ -373,24 +281,33 @@ class FasterWhisperInference(BaseInterface):
373
  elapsed_time = time.time() - start_time
374
  return segments_result, elapsed_time
375
 
376
- def update_model_if_needed(self,
377
- model_size: str,
378
- compute_type: str,
379
- progress: gr.Progress
380
- ):
381
  """
382
- Initialize model if it doesn't match with current model setting
 
 
 
 
 
 
 
 
 
 
383
  """
384
- if model_size != self.current_model_size or self.model is None or self.current_compute_type != compute_type:
385
- progress(0, desc="Initializing Model..")
386
- self.current_model_size = model_size
387
- self.current_compute_type = compute_type
388
- self.model = faster_whisper.WhisperModel(
389
- device=self.device,
390
- model_size_or_path=model_size,
391
- download_root=os.path.join("models", "Whisper", "faster-whisper"),
392
- compute_type=self.current_compute_type
393
- )
394
 
395
  @staticmethod
396
  def generate_and_write_file(file_name: str,
@@ -399,7 +316,25 @@ class FasterWhisperInference(BaseInterface):
399
  file_format: str,
400
  ) -> str:
401
  """
402
- This method writes subtitle file and returns str to gr.Textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  """
404
  timestamp = datetime.now().strftime("%m%d%H%M%S")
405
  if add_timestamp:
@@ -425,6 +360,18 @@ class FasterWhisperInference(BaseInterface):
425
 
426
  @staticmethod
427
  def format_time(elapsed_time: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
428
  hours, rem = divmod(elapsed_time, 3600)
429
  minutes, seconds = divmod(rem, 60)
430
 
 
1
  import os
2
 
 
3
  import time
4
  import numpy as np
5
  from typing import BinaryIO, Union, Tuple, List
6
+ from datetime import datetime
7
 
8
  import faster_whisper
9
  import ctranslate2
 
14
  from .base_interface import BaseInterface
15
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
16
  from modules.youtube_manager import get_ytdata, get_ytaudio
17
+ from modules.whisper_data_class import *
18
 
19
 
20
  class FasterWhisperInference(BaseInterface):
 
26
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
27
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ self.available_compute_types = ctranslate2.get_supported_compute_types(
30
+ "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
31
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
32
  self.default_beam_size = 1
33
 
34
  def transcribe_file(self,
35
+ files: list,
 
 
36
  file_format: str,
 
37
  add_timestamp: bool,
38
+ progress=gr.Progress(),
39
+ *whisper_params,
 
 
 
40
  ) -> list:
41
  """
42
  Write subtitle file from Files
43
 
44
  Parameters
45
  ----------
46
+ files: list
47
  List of files to transcribe from gr.Files()
 
 
 
 
48
  file_format: str
49
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
50
  add_timestamp: bool
51
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
 
 
 
 
 
 
 
 
 
 
 
 
52
  progress: gr.Progress
53
  Indicator to show progress directly in gradio.
54
+ *whisper_params: tuple
55
+ Gradio components related to Whisper. see whisper_data_class.py for details.
56
 
57
  Returns
58
  ----------
59
+ result_str:
60
+ Result of transcription to return to gr.Textbox()
61
+ result_file_path:
62
+ Output file path to return to gr.Files()
63
  """
64
  try:
 
 
65
  files_info = {}
66
+ for file in files:
67
  transcribed_segments, time_for_task = self.transcribe(
68
+ file.name,
69
+ progress,
70
+ *whisper_params,
 
 
 
 
71
  )
72
 
73
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
74
  file_name = safe_filename(file_name)
75
  subtitle, file_path = self.generate_and_write_file(
76
  file_name=file_name,
 
78
  add_timestamp=add_timestamp,
79
  file_format=file_format
80
  )
81
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
82
 
83
  total_result = ''
84
  total_time = 0
 
88
  total_result += f'{info["subtitle"]}'
89
  total_time += info["time_for_task"]
90
 
91
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
92
+ result_file_path = [info['path'] for info in files_info.values()]
93
 
94
+ return [result_str, result_file_path]
95
 
96
  except Exception as e:
97
+ print(f"Error transcribing file: {e}")
98
  finally:
99
  self.release_cuda_memory()
100
+ if not files:
101
+ self.remove_input_files([file.name for file in files])
102
 
103
  def transcribe_youtube(self,
104
+ youtube_link: str,
 
 
105
  file_format: str,
 
106
  add_timestamp: bool,
107
+ progress=gr.Progress(),
108
+ *whisper_params,
 
 
 
109
  ) -> list:
110
  """
111
  Write subtitle file from Youtube
112
 
113
  Parameters
114
  ----------
115
+ youtube_link: str
116
+ URL of the Youtube video to transcribe from gr.Textbox()
 
 
 
 
117
  file_format: str
118
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
119
  add_timestamp: bool
120
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
 
 
 
 
 
 
 
 
 
 
 
 
121
  progress: gr.Progress
122
  Indicator to show progress directly in gradio.
123
+ *whisper_params: tuple
124
+ Gradio components related to Whisper. see whisper_data_class.py for details.
125
 
126
  Returns
127
  ----------
128
+ result_str:
129
+ Result of transcription to return to gr.Textbox()
130
+ result_file_path:
131
+ Output file path to return to gr.Files()
132
  """
133
  try:
 
 
134
  progress(0, desc="Loading Audio from Youtube..")
135
+ yt = get_ytdata(youtube_link)
136
  audio = get_ytaudio(yt)
137
 
138
  transcribed_segments, time_for_task = self.transcribe(
139
+ audio,
140
+ progress,
141
+ *whisper_params,
 
 
 
 
142
  )
143
 
144
  progress(1, desc="Completed!")
145
 
146
  file_name = safe_filename(yt.title)
147
+ subtitle, result_file_path = self.generate_and_write_file(
148
  file_name=file_name,
149
  transcribed_segments=transcribed_segments,
150
  add_timestamp=add_timestamp,
151
  file_format=file_format
152
  )
153
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
154
 
155
+ return [result_str, result_file_path]
156
 
157
  except Exception as e:
158
+ print(f"Error transcribing file: {e}")
159
  finally:
160
  try:
161
  if 'yt' not in locals():
162
+ yt = get_ytdata(youtube_link)
163
  file_path = get_ytaudio(yt)
164
  else:
165
  file_path = get_ytaudio(yt)
 
170
  pass
171
 
172
  def transcribe_mic(self,
173
+ mic_audio: str,
 
 
174
  file_format: str,
175
+ progress=gr.Progress(),
176
+ *whisper_params,
 
 
 
 
177
  ) -> list:
178
  """
179
  Write subtitle file from microphone
180
 
181
  Parameters
182
  ----------
183
+ mic_audio: str
184
  Audio file path from gr.Microphone()
 
 
 
 
185
  file_format: str
186
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  progress: gr.Progress
188
  Indicator to show progress directly in gradio.
189
+ *whisper_params: tuple
190
+ Gradio components related to Whisper. see whisper_data_class.py for details.
191
 
192
  Returns
193
  ----------
194
+ result_str:
195
+ Result of transcription to return to gr.Textbox()
196
+ result_file_path:
197
+ Output file path to return to gr.Files()
198
  """
199
  try:
 
 
200
  progress(0, desc="Loading Audio..")
 
201
  transcribed_segments, time_for_task = self.transcribe(
202
+ mic_audio,
203
+ progress,
204
+ *whisper_params,
 
 
 
 
205
  )
206
  progress(1, desc="Completed!")
207
 
208
+ subtitle, result_file_path = self.generate_and_write_file(
209
  file_name="Mic",
210
  transcribed_segments=transcribed_segments,
211
  add_timestamp=True,
212
  file_format=file_format
213
  )
214
 
215
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
216
+ return [result_str, result_file_path]
217
  except Exception as e:
218
+ print(f"Error transcribing file: {e}")
219
  finally:
220
  self.release_cuda_memory()
221
+ self.remove_input_files([mic_audio])
222
 
223
  def transcribe(self,
224
  audio: Union[str, BinaryIO, np.ndarray],
225
+ progress: gr.Progress,
226
+ *whisper_params,
 
 
 
 
227
  ) -> Tuple[List[dict], float]:
228
  """
229
  transcribe method for faster-whisper.
 
232
  ----------
233
  audio: Union[str, BinaryIO, np.ndarray]
234
  Audio path or file binary or Audio numpy array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  progress: gr.Progress
236
  Indicator to show progress directly in gradio.
237
+ *whisper_params: tuple
238
+ Gradio components related to Whisper. see whisper_data_class.py for details.
239
 
240
  Returns
241
  ----------
 
246
  """
247
  start_time = time.time()
248
 
249
+ params = WhisperGradioComponents.to_values(*whisper_params)
250
+
251
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
252
+ self.update_model(params.model_size, params.compute_type, progress)
253
+
254
+ if params.lang == "Automatic Detection":
255
+ params.lang = None
256
  else:
257
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
258
+ params.lang = language_code_dict[params.lang]
259
+
260
  segments, info = self.model.transcribe(
261
  audio=audio,
262
+ language=params.lang,
263
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
264
+ beam_size=params.beam_size,
265
+ log_prob_threshold=params.log_prob_threshold,
266
+ no_speech_threshold=params.no_speech_threshold,
267
+ best_of=params.best_of,
268
+ patience=params.patience
269
  )
270
  progress(0, desc="Loading audio..")
271
 
 
281
  elapsed_time = time.time() - start_time
282
  return segments_result, elapsed_time
283
 
284
+ def update_model(self,
285
+ model_size: str,
286
+ compute_type: str,
287
+ progress: gr.Progress
288
+ ):
289
  """
290
+ Update current model setting
291
+
292
+ Parameters
293
+ ----------
294
+ model_size: str
295
+ Size of whisper model
296
+ compute_type: str
297
+ Compute type for transcription.
298
+ see more info : https://opennmt.net/CTranslate2/quantization.html
299
+ progress: gr.Progress
300
+ Indicator to show progress directly in gradio.
301
  """
302
+ progress(0, desc="Initializing Model..")
303
+ self.current_model_size = model_size
304
+ self.current_compute_type = compute_type
305
+ self.model = faster_whisper.WhisperModel(
306
+ device=self.device,
307
+ model_size_or_path=model_size,
308
+ download_root=os.path.join("models", "Whisper", "faster-whisper"),
309
+ compute_type=self.current_compute_type
310
+ )
 
311
 
312
  @staticmethod
313
  def generate_and_write_file(file_name: str,
 
316
  file_format: str,
317
  ) -> str:
318
  """
319
+ Writes subtitle file
320
+
321
+ Parameters
322
+ ----------
323
+ file_name: str
324
+ Output file name
325
+ transcribed_segments: list
326
+ Text segments transcribed from audio
327
+ add_timestamp: bool
328
+ Determines whether to add a timestamp to the end of the filename.
329
+ file_format: str
330
+ File format to write. Supported formats: [SRT, WebVTT, txt]
331
+
332
+ Returns
333
+ ----------
334
+ content: str
335
+ Result of the transcription
336
+ output_path: str
337
+ output file path
338
  """
339
  timestamp = datetime.now().strftime("%m%d%H%M%S")
340
  if add_timestamp:
 
360
 
361
  @staticmethod
362
  def format_time(elapsed_time: float) -> str:
363
+ """
364
+ Get {hours} {minutes} {seconds} time format string
365
+
366
+ Parameters
367
+ ----------
368
+ elapsed_time: str
369
+ Elapsed time for transcription
370
+
371
+ Returns
372
+ ----------
373
+ Time format string
374
+ """
375
  hours, rem = divmod(elapsed_time, 3600)
376
  minutes, seconds = divmod(rem, 60)
377
 
modules/whisper_Inference.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from .base_interface import BaseInterface
11
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
  from modules.youtube_manager import get_ytdata, get_ytaudio
 
13
 
14
  DEFAULT_MODEL_SIZE = "large-v3"
15
 
@@ -21,82 +22,54 @@ class WhisperInference(BaseInterface):
21
  self.model = None
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
 
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
  self.available_compute_types = ["float16", "float32"]
26
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
27
  self.default_beam_size = 1
28
 
29
  def transcribe_file(self,
30
- fileobjs: list,
31
- model_size: str,
32
- lang: str,
33
  file_format: str,
34
- istranslate: bool,
35
  add_timestamp: bool,
36
- beam_size: int,
37
- log_prob_threshold: float,
38
- no_speech_threshold: float,
39
- compute_type: str,
40
- progress=gr.Progress()) -> list:
41
  """
42
  Write subtitle file from Files
43
 
44
  Parameters
45
  ----------
46
- fileobjs: list
47
  List of files to transcribe from gr.Files()
48
- model_size: str
49
- Whisper model size from gr.Dropdown()
50
- lang: str
51
- Source language of the file to transcribe from gr.Dropdown()
52
  file_format: str
53
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
54
- istranslate: bool
55
- Boolean value from gr.Checkbox() that determines whether to translate to English.
56
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
57
  add_timestamp: bool
58
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
59
- beam_size: int
60
- Int value from gr.Number() that is used for decoding option.
61
- log_prob_threshold: float
62
- float value from gr.Number(). If the average log probability over sampled tokens is
63
- below this value, treat as failed.
64
- no_speech_threshold: float
65
- float value from gr.Number(). If the no_speech probability is higher than this value AND
66
- the average log probability over sampled tokens is below `log_prob_threshold`,
67
- consider the segment as silent.
68
- compute_type: str
69
- compute type from gr.Dropdown().
70
  progress: gr.Progress
71
  Indicator to show progress directly in gradio.
72
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
 
73
 
74
  Returns
75
  ----------
76
- A List of
77
- String to return to gr.Textbox()
78
- Files to return to gr.Files()
 
79
  """
80
  try:
81
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
82
-
83
  files_info = {}
84
- for fileobj in fileobjs:
85
  progress(0, desc="Loading Audio..")
86
- audio = whisper.load_audio(fileobj.name)
87
-
88
- result, elapsed_time = self.transcribe(audio=audio,
89
- lang=lang,
90
- istranslate=istranslate,
91
- beam_size=beam_size,
92
- log_prob_threshold=log_prob_threshold,
93
- no_speech_threshold=no_speech_threshold,
94
- compute_type=compute_type,
95
- progress=progress
96
- )
97
  progress(1, desc="Completed!")
98
 
99
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
100
  file_name = safe_filename(file_name)
101
  subtitle, file_path = self.generate_and_write_file(
102
  file_name=file_name,
@@ -104,7 +77,7 @@ class WhisperInference(BaseInterface):
104
  add_timestamp=add_timestamp,
105
  file_format=file_format
106
  )
107
- files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time, "path": file_path}
108
 
109
  total_result = ''
110
  total_time = 0
@@ -114,100 +87,71 @@ class WhisperInference(BaseInterface):
114
  total_result += f"{info['subtitle']}"
115
  total_time += info["elapsed_time"]
116
 
117
- gr_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
118
- gr_file_path = [info['path'] for info in files_info.values()]
119
 
120
- return [gr_str, gr_file_path]
121
  except Exception as e:
122
  print(f"Error transcribing file: {str(e)}")
123
  finally:
124
  self.release_cuda_memory()
125
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
126
 
127
  def transcribe_youtube(self,
128
- youtubelink: str,
129
- model_size: str,
130
- lang: str,
131
  file_format: str,
132
- istranslate: bool,
133
  add_timestamp: bool,
134
- beam_size: int,
135
- log_prob_threshold: float,
136
- no_speech_threshold: float,
137
- compute_type: str,
138
- progress=gr.Progress()) -> list:
139
  """
140
  Write subtitle file from Youtube
141
 
142
  Parameters
143
  ----------
144
- youtubelink: str
145
- Link of Youtube to transcribe from gr.Textbox()
146
- model_size: str
147
- Whisper model size from gr.Dropdown()
148
- lang: str
149
- Source language of the file to transcribe from gr.Dropdown()
150
  file_format: str
151
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
152
- istranslate: bool
153
- Boolean value from gr.Checkbox() that determines whether to translate to English.
154
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
155
  add_timestamp: bool
156
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
157
- beam_size: int
158
- Int value from gr.Number() that is used for decoding option.
159
- log_prob_threshold: float
160
- float value from gr.Number(). If the average log probability over sampled tokens is
161
- below this value, treat as failed.
162
- no_speech_threshold: float
163
- float value from gr.Number(). If the no_speech probability is higher than this value AND
164
- the average log probability over sampled tokens is below `log_prob_threshold`,
165
- consider the segment as silent.
166
- compute_type: str
167
- compute type from gr.Dropdown().
168
  progress: gr.Progress
169
  Indicator to show progress directly in gradio.
170
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
 
171
 
172
  Returns
173
  ----------
174
- A List of
175
- String to return to gr.Textbox()
176
- Files to return to gr.Files()
 
177
  """
178
  try:
179
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
180
-
181
  progress(0, desc="Loading Audio from Youtube..")
182
- yt = get_ytdata(youtubelink)
183
  audio = whisper.load_audio(get_ytaudio(yt))
184
 
185
- result, elapsed_time = self.transcribe(audio=audio,
186
- lang=lang,
187
- istranslate=istranslate,
188
- beam_size=beam_size,
189
- log_prob_threshold=log_prob_threshold,
190
- no_speech_threshold=no_speech_threshold,
191
- compute_type=compute_type,
192
- progress=progress)
193
  progress(1, desc="Completed!")
194
 
195
  file_name = safe_filename(yt.title)
196
- subtitle, file_path = self.generate_and_write_file(
197
  file_name=file_name,
198
  transcribed_segments=result,
199
  add_timestamp=add_timestamp,
200
  file_format=file_format
201
  )
202
 
203
- gr_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
204
- return [gr_str, file_path]
205
  except Exception as e:
206
  print(f"Error transcribing youtube video: {str(e)}")
207
  finally:
208
  try:
209
  if 'yt' not in locals():
210
- yt = get_ytdata(youtubelink)
211
  file_path = get_ytaudio(yt)
212
  else:
213
  file_path = get_ytaudio(yt)
@@ -218,116 +162,71 @@ class WhisperInference(BaseInterface):
218
  pass
219
 
220
  def transcribe_mic(self,
221
- micaudio: str,
222
- model_size: str,
223
- lang: str,
224
  file_format: str,
225
- istranslate: bool,
226
- beam_size: int,
227
- log_prob_threshold: float,
228
- no_speech_threshold: float,
229
- compute_type: str,
230
- progress=gr.Progress()) -> list:
231
  """
232
  Write subtitle file from microphone
233
 
234
  Parameters
235
  ----------
236
- micaudio: str
237
  Audio file path from gr.Microphone()
238
- model_size: str
239
- Whisper model size from gr.Dropdown()
240
- lang: str
241
- Source language of the file to transcribe from gr.Dropdown()
242
  file_format: str
243
- Subtitle format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
244
- istranslate: bool
245
- Boolean value from gr.Checkbox() that determines whether to translate to English.
246
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
247
- beam_size: int
248
- Int value from gr.Number() that is used for decoding option.
249
- log_prob_threshold: float
250
- float value from gr.Number(). If the average log probability over sampled tokens is
251
- below this value, treat as failed.
252
- no_speech_threshold: float
253
- float value from gr.Number(). If the no_speech probability is higher than this value AND
254
- the average log probability over sampled tokens is below `log_prob_threshold`,
255
- consider the segment as silent.
256
- compute_type: str
257
- compute type from gr.Dropdown().
258
  progress: gr.Progress
259
  Indicator to show progress directly in gradio.
260
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
 
261
 
262
  Returns
263
  ----------
264
- A List of
265
- String to return to gr.Textbox()
266
- Files to return to gr.Files()
 
267
  """
268
  try:
269
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
270
-
271
- result, elapsed_time = self.transcribe(audio=micaudio,
272
- lang=lang,
273
- istranslate=istranslate,
274
- beam_size=beam_size,
275
- log_prob_threshold=log_prob_threshold,
276
- no_speech_threshold=no_speech_threshold,
277
- compute_type=compute_type,
278
- progress=progress)
279
  progress(1, desc="Completed!")
280
 
281
- subtitle, file_path = self.generate_and_write_file(
282
  file_name="Mic",
283
  transcribed_segments=result,
284
  add_timestamp=True,
285
  file_format=file_format
286
  )
287
 
288
- gr_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
289
- return [gr_str, file_path]
290
  except Exception as e:
291
  print(f"Error transcribing mic: {str(e)}")
292
  finally:
293
  self.release_cuda_memory()
294
- self.remove_input_files([micaudio])
295
 
296
  def transcribe(self,
297
  audio: Union[str, np.ndarray, torch.Tensor],
298
- lang: str,
299
- istranslate: bool,
300
- beam_size: int,
301
- log_prob_threshold: float,
302
- no_speech_threshold: float,
303
- compute_type: str,
304
- progress: gr.Progress
305
  ) -> Tuple[List[dict], float]:
306
  """
307
- transcribe method for OpenAI's Whisper implementation.
308
 
309
  Parameters
310
  ----------
311
- audio: Union[str, BinaryIO, torch.Tensor]
312
  Audio path or file binary or Audio numpy array
313
- lang: str
314
- Source language of the file to transcribe from gr.Dropdown()
315
- istranslate: bool
316
- Boolean value from gr.Checkbox() that determines whether to translate to English.
317
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
318
- beam_size: int
319
- Int value from gr.Number() that is used for decoding option.
320
- log_prob_threshold: float
321
- float value from gr.Number(). If the average log probability over sampled tokens is
322
- below this value, treat as failed.
323
- no_speech_threshold: float
324
- float value from gr.Number(). If the no_speech probability is higher than this value AND
325
- the average log probability over sampled tokens is below `log_prob_threshold`,
326
- consider the segment as silent.
327
- compute_type: str
328
- compute type from gr.Dropdown().
329
  progress: gr.Progress
330
  Indicator to show progress directly in gradio.
 
 
331
 
332
  Returns
333
  ----------
@@ -337,45 +236,58 @@ class WhisperInference(BaseInterface):
337
  elapsed time for transcription
338
  """
339
  start_time = time.time()
 
 
 
 
 
 
 
340
 
341
  def progress_callback(progress_value):
342
  progress(progress_value, desc="Transcribing..")
343
 
344
- if lang == "Automatic Detection":
345
- lang = None
346
-
347
- translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
348
  segments_result = self.model.transcribe(audio=audio,
349
- language=lang,
350
  verbose=False,
351
- beam_size=beam_size,
352
- logprob_threshold=log_prob_threshold,
353
- no_speech_threshold=no_speech_threshold,
354
- task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
355
- fp16=True if compute_type == "float16" else False,
 
 
356
  progress_callback=progress_callback)["segments"]
357
  elapsed_time = time.time() - start_time
358
 
359
  return segments_result, elapsed_time
360
 
361
- def update_model_if_needed(self,
362
- model_size: str,
363
- compute_type: str,
364
- progress: gr.Progress,
365
- ):
366
  """
367
- Initialize model if it doesn't match with current model setting
 
 
 
 
 
 
 
 
 
 
368
  """
369
- if compute_type != self.current_compute_type:
370
- self.current_compute_type = compute_type
371
- if model_size != self.current_model_size or self.model is None:
372
- progress(0, desc="Initializing Model..")
373
- self.current_model_size = model_size
374
- self.model = whisper.load_model(
375
- name=model_size,
376
- device=self.device,
377
- download_root=os.path.join("models", "Whisper")
378
- )
379
 
380
  @staticmethod
381
  def generate_and_write_file(file_name: str,
@@ -384,7 +296,25 @@ class WhisperInference(BaseInterface):
384
  file_format: str,
385
  ) -> str:
386
  """
387
- This method writes subtitle file and returns str to gr.Textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  """
389
  timestamp = datetime.now().strftime("%m%d%H%M%S")
390
  if add_timestamp:
@@ -410,6 +340,18 @@ class WhisperInference(BaseInterface):
410
 
411
  @staticmethod
412
  def format_time(elapsed_time: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
413
  hours, rem = divmod(elapsed_time, 3600)
414
  minutes, seconds = divmod(rem, 60)
415
 
 
10
  from .base_interface import BaseInterface
11
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
  from modules.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper_data_class import *
14
 
15
  DEFAULT_MODEL_SIZE = "large-v3"
16
 
 
22
  self.model = None
23
  self.available_models = whisper.available_models()
24
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
+ self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
  self.default_beam_size = 1
30
 
31
  def transcribe_file(self,
32
+ files: list,
 
 
33
  file_format: str,
 
34
  add_timestamp: bool,
35
+ progress=gr.Progress(),
36
+ *whisper_params
37
+ ) -> list:
 
 
38
  """
39
  Write subtitle file from Files
40
 
41
  Parameters
42
  ----------
43
+ files: list
44
  List of files to transcribe from gr.Files()
 
 
 
 
45
  file_format: str
46
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
47
  add_timestamp: bool
48
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
 
 
 
 
 
 
 
 
 
 
 
49
  progress: gr.Progress
50
  Indicator to show progress directly in gradio.
51
+ *whisper_params: tuple
52
+ Gradio components related to Whisper. see whisper_data_class.py for details.
53
 
54
  Returns
55
  ----------
56
+ result_str:
57
+ Result of transcription to return to gr.Textbox()
58
+ result_file_path:
59
+ Output file path to return to gr.Files()
60
  """
61
  try:
 
 
62
  files_info = {}
63
+ for file in files:
64
  progress(0, desc="Loading Audio..")
65
+ audio = whisper.load_audio(file.name)
66
+
67
+ result, elapsed_time = self.transcribe(audio,
68
+ progress,
69
+ *whisper_params)
 
 
 
 
 
 
70
  progress(1, desc="Completed!")
71
 
72
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
73
  file_name = safe_filename(file_name)
74
  subtitle, file_path = self.generate_and_write_file(
75
  file_name=file_name,
 
77
  add_timestamp=add_timestamp,
78
  file_format=file_format
79
  )
80
+ files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time, "path": file_path}
81
 
82
  total_result = ''
83
  total_time = 0
 
87
  total_result += f"{info['subtitle']}"
88
  total_time += info["elapsed_time"]
89
 
90
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
91
+ result_file_path = [info['path'] for info in files_info.values()]
92
 
93
+ return [result_str, result_file_path]
94
  except Exception as e:
95
  print(f"Error transcribing file: {str(e)}")
96
  finally:
97
  self.release_cuda_memory()
98
+ self.remove_input_files([file.name for file in files])
99
 
100
  def transcribe_youtube(self,
101
+ youtube_link: str,
 
 
102
  file_format: str,
 
103
  add_timestamp: bool,
104
+ progress=gr.Progress(),
105
+ *whisper_params) -> list:
 
 
 
106
  """
107
  Write subtitle file from Youtube
108
 
109
  Parameters
110
  ----------
111
+ youtube_link: str
112
+ URL of the Youtube video to transcribe from gr.Textbox()
 
 
 
 
113
  file_format: str
114
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
115
  add_timestamp: bool
116
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
 
 
 
 
 
 
 
 
 
 
 
117
  progress: gr.Progress
118
  Indicator to show progress directly in gradio.
119
+ *whisper_params: tuple
120
+ Gradio components related to Whisper. see whisper_data_class.py for details.
121
 
122
  Returns
123
  ----------
124
+ result_str:
125
+ Result of transcription to return to gr.Textbox()
126
+ result_file_path:
127
+ Output file path to return to gr.Files()
128
  """
129
  try:
 
 
130
  progress(0, desc="Loading Audio from Youtube..")
131
+ yt = get_ytdata(youtube_link)
132
  audio = whisper.load_audio(get_ytaudio(yt))
133
 
134
+ result, elapsed_time = self.transcribe(audio,
135
+ progress,
136
+ *whisper_params)
 
 
 
 
 
137
  progress(1, desc="Completed!")
138
 
139
  file_name = safe_filename(yt.title)
140
+ subtitle, result_file_path = self.generate_and_write_file(
141
  file_name=file_name,
142
  transcribed_segments=result,
143
  add_timestamp=add_timestamp,
144
  file_format=file_format
145
  )
146
 
147
+ result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
148
+ return [result_str, result_file_path]
149
  except Exception as e:
150
  print(f"Error transcribing youtube video: {str(e)}")
151
  finally:
152
  try:
153
  if 'yt' not in locals():
154
+ yt = get_ytdata(youtube_link)
155
  file_path = get_ytaudio(yt)
156
  else:
157
  file_path = get_ytaudio(yt)
 
162
  pass
163
 
164
  def transcribe_mic(self,
165
+ mic_audio: str,
 
 
166
  file_format: str,
167
+ progress=gr.Progress(),
168
+ *whisper_params) -> list:
 
 
 
 
169
  """
170
  Write subtitle file from microphone
171
 
172
  Parameters
173
  ----------
174
+ mic_audio: str
175
  Audio file path from gr.Microphone()
 
 
 
 
176
  file_format: str
177
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  progress: gr.Progress
179
  Indicator to show progress directly in gradio.
180
+ *whisper_params: tuple
181
+ Gradio components related to Whisper. see whisper_data_class.py for details.
182
 
183
  Returns
184
  ----------
185
+ result_str:
186
+ Result of transcription to return to gr.Textbox()
187
+ result_file_path:
188
+ Output file path to return to gr.Files()
189
  """
190
  try:
191
+ progress(0, desc="Loading Audio..")
192
+ result, elapsed_time = self.transcribe(
193
+ mic_audio,
194
+ progress,
195
+ *whisper_params,
196
+ )
 
 
 
 
197
  progress(1, desc="Completed!")
198
 
199
+ subtitle, result_file_path = self.generate_and_write_file(
200
  file_name="Mic",
201
  transcribed_segments=result,
202
  add_timestamp=True,
203
  file_format=file_format
204
  )
205
 
206
+ result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
207
+ return [result_str, result_file_path]
208
  except Exception as e:
209
  print(f"Error transcribing mic: {str(e)}")
210
  finally:
211
  self.release_cuda_memory()
212
+ self.remove_input_files([mic_audio])
213
 
214
  def transcribe(self,
215
  audio: Union[str, np.ndarray, torch.Tensor],
216
+ progress: gr.Progress,
217
+ *whisper_params,
 
 
 
 
 
218
  ) -> Tuple[List[dict], float]:
219
  """
220
+ transcribe method for faster-whisper.
221
 
222
  Parameters
223
  ----------
224
+ audio: Union[str, BinaryIO, np.ndarray]
225
  Audio path or file binary or Audio numpy array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  progress: gr.Progress
227
  Indicator to show progress directly in gradio.
228
+ *whisper_params: tuple
229
+ Gradio components related to Whisper. see whisper_data_class.py for details.
230
 
231
  Returns
232
  ----------
 
236
  elapsed time for transcription
237
  """
238
  start_time = time.time()
239
+ params = WhisperGradioComponents.to_values(*whisper_params)
240
+
241
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
242
+ self.update_model(params.model_size, params.compute_type, progress)
243
+
244
+ if params.lang == "Automatic Detection":
245
+ params.lang = None
246
 
247
  def progress_callback(progress_value):
248
  progress(progress_value, desc="Transcribing..")
249
 
 
 
 
 
250
  segments_result = self.model.transcribe(audio=audio,
251
+ language=params.lang,
252
  verbose=False,
253
+ beam_size=params.beam_size,
254
+ logprob_threshold=params.log_prob_threshold,
255
+ no_speech_threshold=params.no_speech_threshold,
256
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_model else "transcribe",
257
+ fp16=True if params.compute_type == "float16" else False,
258
+ best_of=params.best_of,
259
+ patience=params.patience,
260
  progress_callback=progress_callback)["segments"]
261
  elapsed_time = time.time() - start_time
262
 
263
  return segments_result, elapsed_time
264
 
265
+ def update_model(self,
266
+ model_size: str,
267
+ compute_type: str,
268
+ progress: gr.Progress,
269
+ ):
270
  """
271
+ Update current model setting
272
+
273
+ Parameters
274
+ ----------
275
+ model_size: str
276
+ Size of whisper model
277
+ compute_type: str
278
+ Compute type for transcription.
279
+ see more info : https://opennmt.net/CTranslate2/quantization.html
280
+ progress: gr.Progress
281
+ Indicator to show progress directly in gradio.
282
  """
283
+ progress(0, desc="Initializing Model..")
284
+ self.current_compute_type = compute_type
285
+ self.current_model_size = model_size
286
+ self.model = whisper.load_model(
287
+ name=model_size,
288
+ device=self.device,
289
+ download_root=os.path.join("models", "Whisper")
290
+ )
 
 
291
 
292
  @staticmethod
293
  def generate_and_write_file(file_name: str,
 
296
  file_format: str,
297
  ) -> str:
298
  """
299
+ Writes subtitle file
300
+
301
+ Parameters
302
+ ----------
303
+ file_name: str
304
+ Output file name
305
+ transcribed_segments: list
306
+ Text segments transcribed from audio
307
+ add_timestamp: bool
308
+ Determines whether to add a timestamp to the end of the filename.
309
+ file_format: str
310
+ File format to write. Supported formats: [SRT, WebVTT, txt]
311
+
312
+ Returns
313
+ ----------
314
+ content: str
315
+ Result of the transcription
316
+ output_path: str
317
+ output file path
318
  """
319
  timestamp = datetime.now().strftime("%m%d%H%M%S")
320
  if add_timestamp:
 
340
 
341
  @staticmethod
342
  def format_time(elapsed_time: float) -> str:
343
+ """
344
+ Get {hours} {minutes} {seconds} time format string
345
+
346
+ Parameters
347
+ ----------
348
+ elapsed_time: str
349
+ Elapsed time for transcription
350
+
351
+ Returns
352
+ ----------
353
+ Time format string
354
+ """
355
  hours, rem = divmod(elapsed_time, 3600)
356
  minutes, seconds = divmod(rem, 60)
357
 
modules/whisper_data_class.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+ import gradio as gr
3
+
4
+
5
+ @dataclass
6
+ class WhisperGradioComponents:
7
+ model_size: gr.Dropdown
8
+ lang: gr.Dropdown
9
+ is_translate: gr.Checkbox
10
+ beam_size: gr.Number
11
+ log_prob_threshold: gr.Number
12
+ no_speech_threshold: gr.Number
13
+ compute_type: gr.Dropdown
14
+ best_of: gr.Number
15
+ patience: gr.Number
16
+ """
17
+ A data class to pass Gradio components to the function before Gradio pre-processing.
18
+ See this documentation for more information about Gradio pre-processing: https://www.gradio.app/docs/components
19
+
20
+ Attributes
21
+ ----------
22
+ model_size: gr.Dropdown
23
+ Whisper model size.
24
+ lang: gr.Dropdown
25
+ Source language of the file to transcribe.
26
+ is_translate: gr.Checkbox
27
+ Boolean value that determines whether to translate to English.
28
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
29
+ beam_size: gr.Number
30
+ Int value that is used for decoding option.
31
+ log_prob_threshold: gr.Number
32
+ If the average log probability over sampled tokens is below this value, treat as failed.
33
+ no_speech_threshold: gr.Number
34
+ If the no_speech probability is higher than this value AND
35
+ the average log probability over sampled tokens is below `log_prob_threshold`,
36
+ consider the segment as silent.
37
+ compute_type: gr.Dropdown
38
+ compute type for transcription.
39
+ see more info : https://opennmt.net/CTranslate2/quantization.html
40
+ best_of: gr.Number
41
+ Number of candidates when sampling with non-zero temperature.
42
+ patience: gr.Number
43
+ Beam search patience factor.
44
+ """
45
+
46
+ def to_list(self) -> list:
47
+ """
48
+ Converts the data class attributes into a list, to pass parameters to a function before Gradio pre-processing.
49
+
50
+ Returns
51
+ ----------
52
+ A list of Gradio components
53
+ """
54
+ return [getattr(self, f.name) for f in fields(self)]
55
+
56
+ @staticmethod
57
+ def to_values(*params):
58
+ """
59
+ Convert a tuple of parameters into a WhisperValues data class, to use parameters in a function after Gradio pre-processing.
60
+
61
+ Parameters
62
+ ----------
63
+ *params: tuple
64
+ This is provided in a tuple because Gradio does not support **kwargs arbitrary.
65
+ Reference : https://discuss.huggingface.co/t/passing-an-additional-argument-to-a-function/25140/2
66
+
67
+ Returns
68
+ ----------
69
+ A WhisperValues data class
70
+ """
71
+ return WhisperValues(*params)
72
+
73
+
74
+ @dataclass
75
+ class WhisperValues:
76
+ model_size: str
77
+ lang: str
78
+ is_translate: bool
79
+ beam_size: int
80
+ log_prob_threshold: float
81
+ no_speech_threshold: float
82
+ compute_type: str
83
+ best_of: int
84
+ patience: float
85
+ """
86
+ A data class to use Whisper parameters in the function after Gradio pre-processing.
87
+ See this documentation for more information about Gradio pre-processing: : https://www.gradio.app/docs/components
88
+ """