jhj0517 commited on
Commit
a63d493
1 Parent(s): 6726c6a

refactoring

Browse files
Files changed (1) hide show
  1. modules/whisper_Inference.py +148 -97
modules/whisper_Inference.py CHANGED
@@ -1,7 +1,11 @@
1
  import whisper
2
  import gradio as gr
 
3
  import os
 
 
4
  from datetime import datetime
 
5
 
6
  from .base_interface import BaseInterface
7
  from modules.subtitle_manager import get_srt, get_vtt, write_file, safe_filename
@@ -48,61 +52,45 @@ class WhisperInference(BaseInterface):
48
  Indicator to show progress directly in gradio.
49
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
50
  """
51
- def progress_callback(progress_value):
52
- progress(progress_value, desc="Transcribing..")
53
 
54
  try:
55
  if model_size != self.current_model_size or self.model is None:
56
- progress(0, desc="Initializing Model..")
57
- self.current_model_size = model_size
58
- self.model = whisper.load_model(name=model_size, download_root=os.path.join("models", "Whisper"))
59
-
60
- if lang == "Automatic Detection":
61
- lang = None
62
-
63
- progress(0, desc="Loading Audio..")
64
 
65
  files_info = {}
66
  for fileobj in fileobjs:
67
-
68
  audio = whisper.load_audio(fileobj.name)
69
 
70
- translatable_model = ["large", "large-v1", "large-v2"]
71
- if istranslate and self.current_model_size in translatable_model:
72
- result = self.model.transcribe(audio=audio, language=lang, verbose=False, task="translate",
73
- progress_callback=progress_callback)
74
- else:
75
- result = self.model.transcribe(audio=audio, language=lang, verbose=False,
76
- progress_callback=progress_callback)
77
-
78
  progress(1, desc="Completed!")
79
 
80
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
81
  file_name = safe_filename(file_name)
82
- timestamp = datetime.now().strftime("%m%d%H%M%S")
83
- if add_timestamp:
84
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
85
- else:
86
- output_path = os.path.join("outputs", f"{file_name}")
 
87
 
88
- if subformat == "SRT":
89
- subtitle = get_srt(result["segments"])
90
- write_file(subtitle, f"{output_path}.srt")
91
- elif subformat == "WebVTT":
92
- subtitle = get_vtt(result["segments"])
93
- write_file(subtitle, f"{output_path}.vtt")
94
-
95
- files_info[file_name] = subtitle
96
 
97
  total_result = ''
98
- for file_name, subtitle in files_info.items():
 
99
  total_result += '------------------------------------\n'
100
  total_result += f'{file_name}\n\n'
101
- total_result += f'{subtitle}'
 
102
 
103
- return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
104
  except Exception as e:
105
- return f"Error: {str(e)}"
 
106
  finally:
107
  self.release_cuda_memory()
108
  self.remove_input_files([fileobj.name for fileobj in fileobjs])
@@ -137,49 +125,32 @@ class WhisperInference(BaseInterface):
137
  Indicator to show progress directly in gradio.
138
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
139
  """
140
- def progress_callback(progress_value):
141
- progress(progress_value, desc="Transcribing..")
142
-
143
  try:
144
  if model_size != self.current_model_size or self.model is None:
145
- progress(0, desc="Initializing Model..")
146
- self.current_model_size = model_size
147
- self.model = whisper.load_model(name=model_size, download_root=os.path.join("models", "Whisper"))
148
-
149
- if lang == "Automatic Detection":
150
- lang = None
151
 
152
  progress(0, desc="Loading Audio from Youtube..")
153
  yt = get_ytdata(youtubelink)
154
  audio = whisper.load_audio(get_ytaudio(yt))
155
 
156
- translatable_model = ["large", "large-v1", "large-v2"]
157
- if istranslate and self.current_model_size in translatable_model:
158
- result = self.model.transcribe(audio=audio, language=lang, verbose=False, task="translate",
159
- progress_callback=progress_callback)
160
- else:
161
- result = self.model.transcribe(audio=audio, language=lang, verbose=False,
162
- progress_callback=progress_callback)
163
-
164
  progress(1, desc="Completed!")
165
 
166
  file_name = safe_filename(yt.title)
167
- timestamp = datetime.now().strftime("%m%d%H%M%S")
168
- if add_timestamp:
169
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
170
- else:
171
- output_path = os.path.join("outputs", f"{file_name}")
172
-
173
- if subformat == "SRT":
174
- subtitle = get_srt(result["segments"])
175
- write_file(subtitle, f"{output_path}.srt")
176
- elif subformat == "WebVTT":
177
- subtitle = get_vtt(result["segments"])
178
- write_file(subtitle, f"{output_path}.vtt")
179
-
180
- return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
181
  except Exception as e:
182
- return f"Error: {str(e)}"
 
183
  finally:
184
  yt = get_ytdata(youtubelink)
185
  file_path = get_ytaudio(yt)
@@ -213,43 +184,123 @@ class WhisperInference(BaseInterface):
213
  Indicator to show progress directly in gradio.
214
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
215
  """
216
- def progress_callback(progress_value):
217
- progress(progress_value, desc="Transcribing..")
218
 
219
  try:
220
  if model_size != self.current_model_size or self.model is None:
221
- progress(0, desc="Initializing Model..")
222
- self.current_model_size = model_size
223
- self.model = whisper.load_model(name=model_size, download_root=os.path.join("models", "Whisper"))
224
-
225
- if lang == "Automatic Detection":
226
- lang = None
227
-
228
- progress(0, desc="Loading Audio..")
229
-
230
- translatable_model = ["large", "large-v1", "large-v2"]
231
- if istranslate and self.current_model_size in translatable_model:
232
- result = self.model.transcribe(audio=micaudio, language=lang, verbose=False, task="translate",
233
- progress_callback=progress_callback)
234
- else:
235
- result = self.model.transcribe(audio=micaudio, language=lang, verbose=False,
236
- progress_callback=progress_callback)
237
 
 
 
 
 
238
  progress(1, desc="Completed!")
239
 
240
- timestamp = datetime.now().strftime("%m%d%H%M%S")
241
- output_path = os.path.join("outputs", f"Mic-{timestamp}")
 
 
 
 
242
 
243
- if subformat == "SRT":
244
- subtitle = get_srt(result["segments"])
245
- write_file(subtitle, f"{output_path}.srt")
246
- elif subformat == "WebVTT":
247
- subtitle = get_vtt(result["segments"])
248
- write_file(subtitle, f"{output_path}.vtt")
249
-
250
- return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
251
  except Exception as e:
252
- return f"Error: {str(e)}"
 
253
  finally:
254
  self.release_cuda_memory()
255
  self.remove_input_files([micaudio])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import whisper
2
  import gradio as gr
3
+ import time
4
  import os
5
+ from typing import BinaryIO, Union, Tuple
6
+ import numpy as np
7
  from datetime import datetime
8
+ import torch
9
 
10
  from .base_interface import BaseInterface
11
  from modules.subtitle_manager import get_srt, get_vtt, write_file, safe_filename
 
52
  Indicator to show progress directly in gradio.
53
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
54
  """
 
 
55
 
56
  try:
57
  if model_size != self.current_model_size or self.model is None:
58
+ self.initialize_model(model_size=model_size, progress=progress)
 
 
 
 
 
 
 
59
 
60
  files_info = {}
61
  for fileobj in fileobjs:
62
+ progress(0, desc="Loading Audio..")
63
  audio = whisper.load_audio(fileobj.name)
64
 
65
+ result, elapsed_time = self.transcribe(audio=audio,
66
+ lang=lang,
67
+ istranslate=istranslate,
68
+ progress=progress)
 
 
 
 
69
  progress(1, desc="Completed!")
70
 
71
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
72
  file_name = safe_filename(file_name)
73
+ subtitle = self.generate_and_write_subtitle(
74
+ file_name=file_name,
75
+ transcribed_segments=result,
76
+ add_timestamp=add_timestamp,
77
+ subformat=subformat
78
+ )
79
 
80
+ files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time}
 
 
 
 
 
 
 
81
 
82
  total_result = ''
83
+ total_time = 0
84
+ for file_name, info in files_info.items():
85
  total_result += '------------------------------------\n'
86
  total_result += f'{file_name}\n\n'
87
+ total_result += f"{info['subtitle']}"
88
+ total_time += info["elapsed_time"]
89
 
90
+ return f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
91
  except Exception as e:
92
+ print(f"Error transcribing file: {str(e)}")
93
+ return f"Error transcribing file: {str(e)}"
94
  finally:
95
  self.release_cuda_memory()
96
  self.remove_input_files([fileobj.name for fileobj in fileobjs])
 
125
  Indicator to show progress directly in gradio.
126
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
127
  """
 
 
 
128
  try:
129
  if model_size != self.current_model_size or self.model is None:
130
+ self.initialize_model(model_size=model_size, progress=progress)
 
 
 
 
 
131
 
132
  progress(0, desc="Loading Audio from Youtube..")
133
  yt = get_ytdata(youtubelink)
134
  audio = whisper.load_audio(get_ytaudio(yt))
135
 
136
+ result, elapsed_time = self.transcribe(audio=audio,
137
+ lang=lang,
138
+ istranslate=istranslate,
139
+ progress=progress)
 
 
 
 
140
  progress(1, desc="Completed!")
141
 
142
  file_name = safe_filename(yt.title)
143
+ subtitle = self.generate_and_write_subtitle(
144
+ file_name=file_name,
145
+ transcribed_segments=result,
146
+ add_timestamp=add_timestamp,
147
+ subformat=subformat
148
+ )
149
+
150
+ return f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
 
 
 
 
 
 
151
  except Exception as e:
152
+ print(f"Error transcribing youtube video: {str(e)}")
153
+ return f"Error transcribing youtube video: {str(e)}"
154
  finally:
155
  yt = get_ytdata(youtubelink)
156
  file_path = get_ytaudio(yt)
 
184
  Indicator to show progress directly in gradio.
185
  I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
186
  """
 
 
187
 
188
  try:
189
  if model_size != self.current_model_size or self.model is None:
190
+ self.initialize_model(model_size=model_size, progress=progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ result, elapsed_time = self.transcribe(audio=micaudio,
193
+ lang=lang,
194
+ istranslate=istranslate,
195
+ progress=progress)
196
  progress(1, desc="Completed!")
197
 
198
+ subtitle = self.generate_and_write_subtitle(
199
+ file_name="Mic",
200
+ transcribed_segments=result,
201
+ add_timestamp=True,
202
+ subformat=subformat
203
+ )
204
 
205
+ return f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
 
 
 
 
 
 
 
206
  except Exception as e:
207
+ print(f"Error transcribing mic: {str(e)}")
208
+ return f"Error transcribing mic: {str(e)}"
209
  finally:
210
  self.release_cuda_memory()
211
  self.remove_input_files([micaudio])
212
+
213
+ def transcribe(self,
214
+ audio: Union[str, np.ndarray, torch.Tensor],
215
+ lang: str,
216
+ istranslate: bool,
217
+ progress: gr.Progress
218
+ ) -> Tuple[list[dict], float]:
219
+ """
220
+ transcribe method for OpenAI's Whisper implementation.
221
+
222
+ Parameters
223
+ ----------
224
+ audio: Union[str, BinaryIO, torch.Tensor]
225
+ Audio path or file binary or Audio numpy array
226
+ lang: str
227
+ Source language of the file to transcribe from gr.Dropdown()
228
+ istranslate: bool
229
+ Boolean value from gr.Checkbox() that determines whether to translate to English.
230
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
231
+ progress: gr.Progress
232
+ Indicator to show progress directly in gradio.
233
+
234
+ Returns
235
+ ----------
236
+ segments_result: list[dict]
237
+ list of dicts that includes start, end timestamps and transcribed text
238
+ elapsed_time: float
239
+ elapsed time for transcription
240
+ """
241
+ start_time = time.time()
242
+
243
+ def progress_callback(progress_value):
244
+ progress(progress_value, desc="Transcribing..")
245
+
246
+ if lang == "Automatic Detection":
247
+ lang = None
248
+
249
+ translatable_model = ["large", "large-v1", "large-v2"]
250
+ segments_result = self.model.transcribe(audio=audio,
251
+ language=lang,
252
+ verbose=False,
253
+ task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
254
+ progress_callback=progress_callback)["segments"]
255
+ elapsed_time = time.time() - start_time
256
+
257
+ return segments_result, elapsed_time
258
+
259
+ def initialize_model(self,
260
+ model_size: str,
261
+ progress: gr.Progress
262
+ ):
263
+ """
264
+ Initialize model if it doesn't match with current model size
265
+ """
266
+ progress(0, desc="Initializing Model..")
267
+ self.current_model_size = model_size
268
+ self.model = whisper.load_model(name=model_size, download_root=os.path.join("models", "Whisper"))
269
+
270
+ @staticmethod
271
+ def generate_and_write_subtitle(file_name: str,
272
+ transcribed_segments: list,
273
+ add_timestamp: bool,
274
+ subformat: str,
275
+ ) -> str:
276
+ """
277
+ This method writes subtitle file and returns str to gr.Textbox
278
+ """
279
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
280
+ if add_timestamp:
281
+ output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
282
+ else:
283
+ output_path = os.path.join("outputs", f"{file_name}")
284
+
285
+ if subformat == "SRT":
286
+ subtitle = get_srt(transcribed_segments)
287
+ write_file(subtitle, f"{output_path}.srt")
288
+ elif subformat == "WebVTT":
289
+ subtitle = get_vtt(transcribed_segments)
290
+ write_file(subtitle, f"{output_path}.vtt")
291
+ return subtitle
292
+
293
+ @staticmethod
294
+ def format_time(elapsed_time: float) -> str:
295
+ hours, rem = divmod(elapsed_time, 3600)
296
+ minutes, seconds = divmod(rem, 60)
297
+
298
+ time_str = ""
299
+ if hours:
300
+ time_str += f"{hours} hours "
301
+ if minutes:
302
+ time_str += f"{minutes} minutes "
303
+ seconds = round(seconds)
304
+ time_str += f"{seconds} seconds"
305
+
306
+ return time_str.strip()