jhj0517 commited on
Commit
4589c43
2 Parent(s): 0685ed9 61ac4a7

Merge pull request #37 from jhj0517/implement_faster-whisper

Browse files
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import argparse
4
 
5
  from modules.whisper_Inference import WhisperInference
 
6
  from modules.nllb_inference import NLLBInference
7
  from ui.htmls import *
8
  from modules.youtube_manager import get_ytmetas
@@ -12,7 +13,11 @@ class App:
12
  def __init__(self, args):
13
  self.args = args
14
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
15
- self.whisper_inf = WhisperInference()
 
 
 
 
16
  self.nllb_inf = NLLBInference()
17
 
18
  @staticmethod
@@ -164,6 +169,7 @@ class App:
164
 
165
  # Create the parser for command-line arguments
166
  parser = argparse.ArgumentParser()
 
167
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
168
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
169
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
 
3
  import argparse
4
 
5
  from modules.whisper_Inference import WhisperInference
6
+ from modules.faster_whisper_inference import FasterWhisperInference
7
  from modules.nllb_inference import NLLBInference
8
  from ui.htmls import *
9
  from modules.youtube_manager import get_ytmetas
 
13
  def __init__(self, args):
14
  self.args = args
15
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
16
+ self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
17
+ if isinstance(self.whisper_inf, FasterWhisperInference):
18
+ print("Use Faster Whisper implementation")
19
+ else:
20
+ print("Use Open AI Whisper implementation")
21
  self.nllb_inf = NLLBInference()
22
 
23
  @staticmethod
 
169
 
170
  # Create the parser for command-line arguments
171
  parser = argparse.ArgumentParser()
172
+ parser.add_argument('--disable_faster_whisper', type=bool, default=False, nargs='?', const=True, help='Disable the faster_whisper implementation. faster_whipser is implemented by https://github.com/guillaumekln/faster-whisper')
173
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
174
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
175
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
modules/faster_whisper_inference.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import tqdm
4
+ import time
5
+ import numpy as np
6
+ from typing import BinaryIO, Union, Tuple
7
+ from datetime import datetime, timedelta
8
+
9
+ import faster_whisper
10
+ import whisper
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from .base_interface import BaseInterface
15
+ from modules.subtitle_manager import get_srt, get_vtt, write_file, safe_filename
16
+ from modules.youtube_manager import get_ytdata, get_ytaudio
17
+
18
+
19
+ class FasterWhisperInference(BaseInterface):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.current_model_size = None
23
+ self.model = None
24
+ self.available_models = whisper.available_models()
25
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
26
+ self.translatable_models = ["large", "large-v1", "large-v2"]
27
+ self.default_beam_size = 5
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ def transcribe_file(self,
31
+ fileobjs: list,
32
+ model_size: str,
33
+ lang: str,
34
+ subformat: str,
35
+ istranslate: bool,
36
+ add_timestamp: bool,
37
+ progress=gr.Progress()
38
+ ) -> str:
39
+ """
40
+ Write subtitle file from Files
41
+
42
+ Parameters
43
+ ----------
44
+ fileobjs: list
45
+ List of files to transcribe from gr.Files()
46
+ model_size: str
47
+ Whisper model size from gr.Dropdown()
48
+ lang: str
49
+ Source language of the file to transcribe from gr.Dropdown()
50
+ subformat: str
51
+ Subtitle format to write from gr.Dropdown(). Supported format: [SRT, WebVTT]
52
+ istranslate: bool
53
+ Boolean value from gr.Checkbox() that determines whether to translate to English.
54
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
55
+ add_timestamp: bool
56
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
57
+ progress: gr.Progress
58
+ Indicator to show progress directly in gradio.
59
+
60
+ Returns
61
+ ----------
62
+ String to return to gr.Textbox()
63
+ """
64
+ try:
65
+ if model_size != self.current_model_size or self.model is None:
66
+ self.initialize_model(model_size=model_size, progress=progress)
67
+
68
+ if lang == "Automatic Detection":
69
+ lang = None
70
+
71
+ files_info = {}
72
+ for fileobj in fileobjs:
73
+ transcribed_segments, time_for_task = self.transcribe(
74
+ audio=fileobj.name,
75
+ lang=lang,
76
+ istranslate=istranslate,
77
+ progress=progress
78
+ )
79
+
80
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
81
+ file_name = safe_filename(file_name)
82
+ subtitle = self.generate_and_write_subtitle(
83
+ file_name=file_name,
84
+ transcribed_segments=transcribed_segments,
85
+ add_timestamp=add_timestamp,
86
+ subformat=subformat
87
+ )
88
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task}
89
+
90
+ total_result = ''
91
+ total_time = 0
92
+ for file_name, info in files_info.items():
93
+ total_result += '------------------------------------\n'
94
+ total_result += f'{file_name}\n\n'
95
+ total_result += f'{info["subtitle"]}'
96
+ total_time += info["time_for_task"]
97
+
98
+ return f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
99
+
100
+ except Exception as e:
101
+ print(f"Error transcribing file on line {e}")
102
+ finally:
103
+ self.release_cuda_memory()
104
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
105
+
106
+ def transcribe_youtube(self,
107
+ youtubelink: str,
108
+ model_size: str,
109
+ lang: str,
110
+ subformat: str,
111
+ istranslate: bool,
112
+ add_timestamp: bool,
113
+ progress=gr.Progress()
114
+ ) -> str:
115
+ """
116
+ Write subtitle file from Youtube
117
+
118
+ Parameters
119
+ ----------
120
+ youtubelink: str
121
+ Link of Youtube to transcribe from gr.Textbox()
122
+ model_size: str
123
+ Whisper model size from gr.Dropdown()
124
+ lang: str
125
+ Source language of the file to transcribe from gr.Dropdown()
126
+ subformat: str
127
+ Subtitle format to write from gr.Dropdown(). Supported format: [SRT, WebVTT]
128
+ istranslate: bool
129
+ Boolean value from gr.Checkbox() that determines whether to translate to English.
130
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
131
+ add_timestamp: bool
132
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
133
+ progress: gr.Progress
134
+ Indicator to show progress directly in gradio.
135
+
136
+ Returns
137
+ ----------
138
+ String to return to gr.Textbox()
139
+ """
140
+ try:
141
+ if model_size != self.current_model_size or self.model is None:
142
+ self.initialize_model(model_size=model_size, progress=progress)
143
+
144
+ if lang == "Automatic Detection":
145
+ lang = None
146
+
147
+ progress(0, desc="Loading Audio from Youtube..")
148
+ yt = get_ytdata(youtubelink)
149
+ audio = get_ytaudio(yt)
150
+
151
+ transcribed_segments, time_for_task = self.transcribe(
152
+ audio=audio,
153
+ lang=lang,
154
+ istranslate=istranslate,
155
+ progress=progress
156
+ )
157
+
158
+ progress(1, desc="Completed!")
159
+
160
+ file_name = safe_filename(yt.title)
161
+ subtitle = self.generate_and_write_subtitle(
162
+ file_name=file_name,
163
+ transcribed_segments=transcribed_segments,
164
+ add_timestamp=add_timestamp,
165
+ subformat=subformat
166
+ )
167
+ return f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
168
+ except Exception as e:
169
+ return f"Error: {str(e)}"
170
+ finally:
171
+ yt = get_ytdata(youtubelink)
172
+ file_path = get_ytaudio(yt)
173
+ self.release_cuda_memory()
174
+ self.remove_input_files([file_path])
175
+
176
+ def transcribe_mic(self,
177
+ micaudio: str,
178
+ model_size: str,
179
+ lang: str,
180
+ subformat: str,
181
+ istranslate: bool,
182
+ progress=gr.Progress()
183
+ ) -> str:
184
+ """
185
+ Write subtitle file from microphone
186
+
187
+ Parameters
188
+ ----------
189
+ micaudio: str
190
+ Audio file path from gr.Microphone()
191
+ model_size: str
192
+ Whisper model size from gr.Dropdown()
193
+ lang: str
194
+ Source language of the file to transcribe from gr.Dropdown()
195
+ subformat: str
196
+ Subtitle format to write from gr.Dropdown(). Supported format: [SRT, WebVTT]
197
+ istranslate: bool
198
+ Boolean value from gr.Checkbox() that determines whether to translate to English.
199
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
200
+ progress: gr.Progress
201
+ Indicator to show progress directly in gradio.
202
+
203
+ Returns
204
+ ----------
205
+ String to return to gr.Textbox()
206
+ """
207
+ try:
208
+ if model_size != self.current_model_size or self.model is None:
209
+ self.initialize_model(model_size=model_size, progress=progress)
210
+
211
+ if lang == "Automatic Detection":
212
+ lang = None
213
+
214
+ progress(0, desc="Loading Audio..")
215
+
216
+ transcribed_segments, time_for_task = self.transcribe(
217
+ audio=micaudio,
218
+ lang=lang,
219
+ istranslate=istranslate,
220
+ progress=progress
221
+ )
222
+ progress(1, desc="Completed!")
223
+
224
+ subtitle = self.generate_and_write_subtitle(
225
+ file_name="Mic",
226
+ transcribed_segments=transcribed_segments,
227
+ add_timestamp=True,
228
+ subformat=subformat
229
+ )
230
+ return f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
231
+ except Exception as e:
232
+ return f"Error: {str(e)}"
233
+ finally:
234
+ self.release_cuda_memory()
235
+ self.remove_input_files([micaudio])
236
+
237
+ def transcribe(self,
238
+ audio: Union[str, BinaryIO, np.ndarray],
239
+ lang: str,
240
+ istranslate: bool,
241
+ progress: gr.Progress
242
+ ) -> Tuple[list, float]:
243
+ """
244
+ transcribe method for faster-whisper.
245
+
246
+ Parameters
247
+ ----------
248
+ audio: Union[str, BinaryIO, np.ndarray]
249
+ Audio path or file binary or Audio numpy array
250
+ lang: str
251
+ Source language of the file to transcribe from gr.Dropdown()
252
+ istranslate: bool
253
+ Boolean value from gr.Checkbox() that determines whether to translate to English.
254
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
255
+ progress: gr.Progress
256
+ Indicator to show progress directly in gradio.
257
+
258
+ Returns
259
+ ----------
260
+ segments_result: list[dict]
261
+ list of dicts that includes start, end timestamps and transcribed text
262
+ elapsed_time: float
263
+ elapsed time for transcription
264
+ """
265
+ start_time = time.time()
266
+ segments, info = self.model.transcribe(
267
+ audio=audio,
268
+ language=lang,
269
+ beam_size=self.default_beam_size,
270
+ task="translate" if istranslate and self.current_model_size in self.translatable_models else "transcribe"
271
+ )
272
+ progress(0, desc="Loading audio..")
273
+ total_frames = self.get_total_frames(audio=audio, progress=progress)
274
+
275
+ segments_result = []
276
+ for segment in segments:
277
+ progress(segment.seek / total_frames, desc="Transcribing..")
278
+ segments_result.append({
279
+ "start": segment.start,
280
+ "end": segment.end,
281
+ "text": segment.text
282
+ })
283
+
284
+ elapsed_time = time.time() - start_time
285
+ return segments_result, elapsed_time
286
+
287
+ def initialize_model(self,
288
+ model_size: str,
289
+ progress: gr.Progress
290
+ ):
291
+ """
292
+ Initialize model if it doesn't match with current model size
293
+ """
294
+ progress(0, desc="Initializing Model..")
295
+ self.current_model_size = model_size
296
+ self.model = faster_whisper.WhisperModel(
297
+ device=self.device,
298
+ model_size_or_path=model_size,
299
+ download_root=os.path.join("models", "Whisper", "faster-whisper"),
300
+ compute_type="float16"
301
+ )
302
+
303
+ def get_total_frames(self,
304
+ audio: Union[str, BinaryIO],
305
+ progress: gr.Progress
306
+ ) -> float:
307
+ """
308
+ This method is only for tracking the progress.
309
+ Returns total frames to track progress.
310
+ """
311
+ progress(0, desc="Loading audio..")
312
+ decoded_audio = faster_whisper.decode_audio(audio)
313
+ features = self.model.feature_extractor(decoded_audio)
314
+ content_frames = features.shape[-1] - self.model.feature_extractor.nb_max_frames
315
+ return content_frames
316
+
317
+ @staticmethod
318
+ def generate_and_write_subtitle(file_name: str,
319
+ transcribed_segments: list,
320
+ add_timestamp: bool,
321
+ subformat: str,
322
+ ) -> str:
323
+ """
324
+ This method writes subtitle file and returns str to gr.Textbox
325
+ """
326
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
327
+ if add_timestamp:
328
+ output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
329
+ else:
330
+ output_path = os.path.join("outputs", f"{file_name}")
331
+
332
+ if subformat == "SRT":
333
+ subtitle = get_srt(transcribed_segments)
334
+ write_file(subtitle, f"{output_path}.srt")
335
+ elif subformat == "WebVTT":
336
+ subtitle = get_vtt(transcribed_segments)
337
+ write_file(subtitle, f"{output_path}.vtt")
338
+ return subtitle
339
+
340
+ @staticmethod
341
+ def format_time(elapsed_time: float) -> str:
342
+ hours, rem = divmod(elapsed_time, 3600)
343
+ minutes, seconds = divmod(rem, 60)
344
+
345
+ time_str = ""
346
+ if hours:
347
+ time_str += f"{hours} hours "
348
+ if minutes:
349
+ time_str += f"{minutes} minutes "
350
+ seconds = round(seconds)
351
+ time_str += f"{seconds} seconds"
352
+
353
+ return time_str.strip()
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu117
2
  torch
3
  git+https://github.com/jhj0517/jhj0517-whisper.git
 
4
  gradio==3.37.0
5
  pytube
 
1
  --extra-index-url https://download.pytorch.org/whl/cu117
2
  torch
3
  git+https://github.com/jhj0517/jhj0517-whisper.git
4
+ faster-whisper
5
  gradio==3.37.0
6
  pytube
user-start-webui.bat CHANGED
@@ -1,4 +1,5 @@
1
  :: This batch file is for launching with command line args
 
2
  @echo off
3
 
4
  :: Set values
@@ -8,6 +9,7 @@ set USERNAME=
8
  set PASSWORD=
9
  set SHARE=
10
  set THEME=
 
11
 
12
  :: Set args accordingly
13
  if not "%SERVER_NAME%"=="" (
@@ -28,7 +30,10 @@ if /I "%SHARE%"=="true" (
28
  if not "%THEME%"=="" (
29
  set THEME_ARG=--theme %THEME%
30
  )
 
 
 
31
 
32
  :: Call the original .bat script with optional arguments
33
- start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG%
34
  pause
 
1
  :: This batch file is for launching with command line args
2
+ :: See the wiki for a guide to command line arguments: https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments
3
  @echo off
4
 
5
  :: Set values
 
9
  set PASSWORD=
10
  set SHARE=
11
  set THEME=
12
+ set DISABLE_FASTER_WHISPER=
13
 
14
  :: Set args accordingly
15
  if not "%SERVER_NAME%"=="" (
 
30
  if not "%THEME%"=="" (
31
  set THEME_ARG=--theme %THEME%
32
  )
33
+ if /I "%DISABLE_FASTER_WHISPER%"=="true" (
34
+ set DISABLE_FASTER_WHISPER_ARG=--disable_faster_whisper
35
+ )
36
 
37
  :: Call the original .bat script with optional arguments
38
+ start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %DISABLE_FASTER_WHISPER_ARG%
39
  pause