aadnk commited on
Commit
31f7bdb
1 Parent(s): 4aa12d0

Cache parallel processes

Browse files
Files changed (4) hide show
  1. app.py +54 -27
  2. cli.py +1 -1
  3. src/vadParallel.py +83 -4
  4. src/whisperContainer.py +31 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Iterator
2
  import argparse
3
 
@@ -5,12 +6,11 @@ from io import StringIO
5
  import os
6
  import pathlib
7
  import tempfile
8
- from src.vadParallel import ParallelTranscription
9
 
10
- from src.whisperContainer import WhisperContainer
11
 
12
  # External programs
13
- import whisper
14
  import ffmpeg
15
 
16
  # UI
@@ -50,13 +50,15 @@ LANGUAGES = [
50
  ]
51
 
52
  class WhisperTranscriber:
53
- def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
54
- self.model_cache = dict()
55
  self.parallel_device_list = None
 
 
56
 
57
  self.vad_model = None
58
- self.inputAudioMaxDuration = inputAudioMaxDuration
59
- self.deleteUploadedFiles = deleteUploadedFiles
60
 
61
  def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
62
  try:
@@ -66,11 +68,7 @@ class WhisperTranscriber:
66
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
67
  selectedModel = modelName if modelName is not None else "base"
68
 
69
- model = self.model_cache.get(selectedModel, None)
70
-
71
- if not model:
72
- model = WhisperContainer(selectedModel)
73
- self.model_cache[selectedModel] = model
74
 
75
  # Execute whisper
76
  result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
@@ -124,18 +122,34 @@ class WhisperTranscriber:
124
  result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
125
 
126
  else:
127
- # Default VAD
128
- result = whisperCallable(audio_path, 0, None, None)
 
 
 
 
 
 
 
129
 
130
  return result
131
 
132
  def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
133
- if (self.parallel_device_list is None or len(self.parallel_device_list) == 0):
134
  # No parallel devices, so just run the VAD and Whisper in sequence
135
  return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
136
 
137
- parallell_vad = ParallelTranscription()
138
- return parallell_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable, config=vadConfig, devices=self.parallel_device_list)
 
 
 
 
 
 
 
 
 
139
 
140
  def _concat_prompt(self, prompt1, prompt2):
141
  if (prompt1 is None):
@@ -177,7 +191,7 @@ class WhisperTranscriber:
177
  return output_files, text, vtt
178
 
179
  def clear_cache(self):
180
- self.model_cache = dict()
181
  self.vad_model = None
182
 
183
  def __get_source(self, urlData, uploadFile, microphoneData):
@@ -229,9 +243,16 @@ class WhisperTranscriber:
229
 
230
  return file.name
231
 
 
 
 
 
 
 
232
 
233
- def create_ui(inputAudioMaxDuration, share=False, server_name: str = None, server_port: int = 7860, vad_parallel_devices: str = None):
234
- ui = WhisperTranscriber(inputAudioMaxDuration)
 
235
 
236
  # Specify a list of devices to use for parallel processing
237
  ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
@@ -242,19 +263,19 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None, serve
242
 
243
  ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
244
 
245
- if inputAudioMaxDuration > 0:
246
- ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
247
 
248
  ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
249
 
250
  demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
251
- gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
252
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
253
  gr.Text(label="URL (YouTube, etc.)"),
254
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
255
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
256
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
257
- gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], label="VAD"),
258
  gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
259
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
260
  gr.Number(label="VAD - Padding (s)", precision=None, value=1),
@@ -265,15 +286,21 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None, serve
265
  gr.Text(label="Segments")
266
  ])
267
 
268
- demo.launch(share=share, server_name=server_name, server_port=server_port)
 
 
 
269
 
270
  if __name__ == '__main__':
271
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
272
- parser.add_argument("--inputAudioMaxDuration", type=int, default=600, help="Maximum audio file length in seconds, or -1 for no limit.")
273
  parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
274
  parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
275
  parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
276
- parser.add_argument("--vad_parallel_devices", type=str, default="0,1", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
 
 
 
277
 
278
  args = parser.parse_args().__dict__
279
  create_ui(**args)
 
1
+ import math
2
  from typing import Iterator
3
  import argparse
4
 
 
6
  import os
7
  import pathlib
8
  import tempfile
9
+ from src.vadParallel import ParallelContext, ParallelTranscription
10
 
11
+ from src.whisperContainer import WhisperContainer, WhisperModelCache
12
 
13
  # External programs
 
14
  import ffmpeg
15
 
16
  # UI
 
50
  ]
51
 
52
  class WhisperTranscriber:
53
+ def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
54
+ self.model_cache = WhisperModelCache()
55
  self.parallel_device_list = None
56
+ self.parallel_context = None
57
+ self.vad_process_timeout = vad_process_timeout
58
 
59
  self.vad_model = None
60
+ self.inputAudioMaxDuration = input_audio_max_duration
61
+ self.deleteUploadedFiles = delete_uploaded_files
62
 
63
  def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
64
  try:
 
68
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
69
  selectedModel = modelName if modelName is not None else "base"
70
 
71
+ model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
 
 
 
 
72
 
73
  # Execute whisper
74
  result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
 
122
  result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
123
 
124
  else:
125
+ if (self._has_parallel_devices()):
126
+ # Use a simple period transcription instead, as we need to use the parallel context
127
+ periodic_vad = VadPeriodicTranscription()
128
+ period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
129
+
130
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
131
+ else:
132
+ # Default VAD
133
+ result = whisperCallable(audio_path, 0, None, None)
134
 
135
  return result
136
 
137
  def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
138
+ if (not self._has_parallel_devices()):
139
  # No parallel devices, so just run the VAD and Whisper in sequence
140
  return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
141
 
142
+ # Create parallel context if needed
143
+ if (self.parallel_context is None):
144
+ # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
145
+ self.parallel_context = ParallelContext(num_processes=len(self.parallel_device_list), auto_cleanup_timeout_seconds=self.vad_process_timeout)
146
+
147
+ parallel_vad = ParallelTranscription()
148
+ return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
149
+ config=vadConfig, devices=self.parallel_device_list, parallel_context=self.parallel_context)
150
+
151
+ def _has_parallel_devices(self):
152
+ return self.parallel_device_list is not None and len(self.parallel_device_list) > 0
153
 
154
  def _concat_prompt(self, prompt1, prompt2):
155
  if (prompt1 is None):
 
191
  return output_files, text, vtt
192
 
193
  def clear_cache(self):
194
+ self.model_cache.clear()
195
  self.vad_model = None
196
 
197
  def __get_source(self, urlData, uploadFile, microphoneData):
 
243
 
244
  return file.name
245
 
246
+ def close(self):
247
+ self.clear_cache()
248
+
249
+ if (self.parallel_context is not None):
250
+ self.parallel_context.close()
251
+
252
 
253
+ def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
254
+ default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None):
255
+ ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
256
 
257
  # Specify a list of devices to use for parallel processing
258
  ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
 
263
 
264
  ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
265
 
266
+ if input_audio_max_duration > 0:
267
+ ui_description += "\n\n" + "Max audio file length: " + str(input_audio_max_duration) + " s"
268
 
269
  ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
270
 
271
  demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
272
+ gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value=default_model_name, label="Model"),
273
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
274
  gr.Text(label="URL (YouTube, etc.)"),
275
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
276
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
277
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
278
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
279
  gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
280
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
281
  gr.Number(label="VAD - Padding (s)", precision=None, value=1),
 
286
  gr.Text(label="Segments")
287
  ])
288
 
289
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
290
+
291
+ # Clean up
292
+ ui.close()
293
 
294
  if __name__ == '__main__':
295
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
296
+ parser.add_argument("--input_audio_max_duration", type=int, default=600, help="Maximum audio file length in seconds, or -1 for no limit.")
297
  parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
298
  parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
299
  parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
300
+ parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
301
+ parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
302
+ parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
303
+ parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
304
 
305
  args = parser.parse_args().__dict__
306
  create_ui(**args)
cli.py CHANGED
@@ -74,7 +74,7 @@ def cli():
74
  vad_prompt_window = args.pop("vad_prompt_window")
75
 
76
  model = whisper.load_model(model_name, device=device, download_root=model_dir)
77
- transcriber = WhisperTranscriber(deleteUploadedFiles=False)
78
  transcriber.parallel_device_list = args.pop("vad_parallel_devices")
79
 
80
  for audio_path in args.pop("audio"):
 
74
  vad_prompt_window = args.pop("vad_prompt_window")
75
 
76
  model = whisper.load_model(model_name, device=device, download_root=model_dir)
77
+ transcriber = WhisperTranscriber(delete_uploaded_files=False)
78
  transcriber.parallel_device_list = args.pop("vad_parallel_devices")
79
 
80
  for audio_path in args.pop("audio"):
src/vadParallel.py CHANGED
@@ -1,4 +1,6 @@
1
  import multiprocessing
 
 
2
  from src.vad import AbstractTranscription, TranscriptionConfig
3
  from src.whisperContainer import WhisperCallback
4
 
@@ -7,6 +9,68 @@ from multiprocessing import Pool
7
  from typing import List
8
  import os
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class ParallelTranscriptionConfig(TranscriptionConfig):
11
  def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
12
  super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
@@ -18,7 +82,7 @@ class ParallelTranscription(AbstractTranscription):
18
  super().__init__(sampling_rate=sampling_rate)
19
 
20
 
21
- def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str]):
22
  # First, get the timestamps for the original audio
23
  merged = transcription.get_merged_timestamps(audio, config)
24
 
@@ -45,12 +109,19 @@ class ParallelTranscription(AbstractTranscription):
45
  'language': None
46
  }
47
 
 
 
48
  # Spawn a separate process for each device
49
- context = multiprocessing.get_context('spawn')
 
 
 
 
 
 
50
 
51
- with context.Pool(len(devices)) as p:
52
  # Run the transcription in parallel
53
- results = p.starmap(self.transcribe, parameters)
54
 
55
  for result in results:
56
  # Merge the results
@@ -61,6 +132,14 @@ class ParallelTranscription(AbstractTranscription):
61
  if (result['language'] is not None):
62
  merged['language'] = result['language']
63
 
 
 
 
 
 
 
 
 
64
  return merged
65
 
66
  def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
 
1
  import multiprocessing
2
+ import threading
3
+ import time
4
  from src.vad import AbstractTranscription, TranscriptionConfig
5
  from src.whisperContainer import WhisperCallback
6
 
 
9
  from typing import List
10
  import os
11
 
12
+
13
+ class ParallelContext:
14
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
15
+ self.num_processes = num_processes
16
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
17
+ self.lock = threading.Lock()
18
+
19
+ self.ref_count = 0
20
+ self.pool = None
21
+ self.cleanup_timer = None
22
+
23
+ def get_pool(self):
24
+ # Initialize pool lazily
25
+ if (self.pool is None):
26
+ context = multiprocessing.get_context('spawn')
27
+ self.pool = context.Pool(self.num_processes)
28
+
29
+ self.ref_count = self.ref_count + 1
30
+
31
+ if (self.auto_cleanup_timeout_seconds is not None):
32
+ self._stop_auto_cleanup()
33
+
34
+ return self.pool
35
+
36
+ def return_pool(self, pool):
37
+ if (self.pool == pool and self.ref_count > 0):
38
+ self.ref_count = self.ref_count - 1
39
+
40
+ if (self.ref_count == 0):
41
+ if (self.auto_cleanup_timeout_seconds is not None):
42
+ self._start_auto_cleanup()
43
+
44
+ def _start_auto_cleanup(self):
45
+ if (self.cleanup_timer is not None):
46
+ self.cleanup_timer.cancel()
47
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
48
+ self.cleanup_timer.start()
49
+
50
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
51
+
52
+ def _stop_auto_cleanup(self):
53
+ if (self.cleanup_timer is not None):
54
+ self.cleanup_timer.cancel()
55
+ self.cleanup_timer = None
56
+
57
+ print("Stopped auto cleanup of pool")
58
+
59
+ def _execute_cleanup(self):
60
+ print("Executing cleanup of pool")
61
+
62
+ if (self.ref_count == 0):
63
+ self.close()
64
+
65
+ def close(self):
66
+ self._stop_auto_cleanup()
67
+
68
+ if (self.pool is not None):
69
+ print("Closing pool of " + str(self.num_processes) + " processes")
70
+ self.pool.close()
71
+ self.pool.join()
72
+ self.pool = None
73
+
74
  class ParallelTranscriptionConfig(TranscriptionConfig):
75
  def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
76
  super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
 
82
  super().__init__(sampling_rate=sampling_rate)
83
 
84
 
85
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str], parallel_context: ParallelContext = None):
86
  # First, get the timestamps for the original audio
87
  merged = transcription.get_merged_timestamps(audio, config)
88
 
 
109
  'language': None
110
  }
111
 
112
+ created_context = False
113
+
114
  # Spawn a separate process for each device
115
+ try:
116
+ if (parallel_context is None):
117
+ parallel_context = ParallelContext(len(devices))
118
+ created_context = True
119
+
120
+ # Get a pool of processes
121
+ pool = parallel_context.get_pool()
122
 
 
123
  # Run the transcription in parallel
124
+ results = pool.starmap(self.transcribe, parameters)
125
 
126
  for result in results:
127
  # Merge the results
 
132
  if (result['language'] is not None):
133
  merged['language'] = result['language']
134
 
135
+ finally:
136
+ # Return the pool to the context
137
+ if (parallel_context is not None):
138
+ parallel_context.return_pool(pool)
139
+ # Always close the context if we created it
140
+ if (created_context):
141
+ parallel_context.close()
142
+
143
  return merged
144
 
145
  def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
src/whisperContainer.py CHANGED
@@ -1,18 +1,44 @@
1
  # External programs
2
  import whisper
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class WhisperContainer:
5
- def __init__(self, model_name: str, device: str = None):
6
  self.model_name = model_name
7
  self.device = device
 
8
 
9
  # Will be created on demand
10
  self.model = None
11
 
12
  def get_model(self):
13
  if self.model is None:
14
- print("Loading model " + self.model_name)
15
- self.model = whisper.load_model(self.model_name, device=self.device)
 
 
 
 
16
  return self.model
17
 
18
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
@@ -44,6 +70,8 @@ class WhisperContainer:
44
  self.model_name = state["model_name"]
45
  self.device = state["device"]
46
  self.model = None
 
 
47
 
48
 
49
  class WhisperCallback:
 
1
  # External programs
2
  import whisper
3
 
4
+ class WhisperModelCache:
5
+ def __init__(self):
6
+ self._cache = dict()
7
+
8
+ def get(self, model_name, device: str = None):
9
+ key = model_name + ":" + (device if device else '')
10
+
11
+ result = self._cache.get(key)
12
+
13
+ if result is None:
14
+ print("Loading whisper model " + model_name)
15
+ result = whisper.load_model(name=model_name, device=device)
16
+ self._cache[key] = result
17
+ return result
18
+
19
+ def clear(self):
20
+ self._cache.clear()
21
+
22
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
23
+ GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
24
+
25
  class WhisperContainer:
26
+ def __init__(self, model_name: str, device: str = None, cache: WhisperModelCache = None):
27
  self.model_name = model_name
28
  self.device = device
29
+ self.cache = cache
30
 
31
  # Will be created on demand
32
  self.model = None
33
 
34
  def get_model(self):
35
  if self.model is None:
36
+
37
+ if (self.cache is None):
38
+ print("Loading whisper model " + self.model_name)
39
+ self.model = whisper.load_model(self.model_name, device=self.device)
40
+ else:
41
+ self.model = self.cache.get(self.model_name, device=self.device)
42
  return self.model
43
 
44
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
 
70
  self.model_name = state["model_name"]
71
  self.device = state["device"]
72
  self.model = None
73
+ # Depickled objects must use the global cache
74
+ self.cache = GLOBAL_WHISPER_MODEL_CACHE
75
 
76
 
77
  class WhisperCallback: