jhj0517 commited on
Commit
25c9e51
1 Parent(s): 2415a05

add args for local model path

Browse files
app.py CHANGED
@@ -17,8 +17,10 @@ class App:
17
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
  self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
19
  if isinstance(self.whisper_inf, FasterWhisperInference):
 
20
  print("Use Faster Whisper implementation")
21
  else:
 
22
  print("Use Open AI Whisper implementation")
23
  print(f"Device \"{self.whisper_inf.device}\" is detected")
24
  self.nllb_inf = NLLBInference()
@@ -296,6 +298,8 @@ parser.add_argument('--password', type=str, default=None, help='Gradio authentic
296
  parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
297
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
298
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
 
 
299
  _args = parser.parse_args()
300
 
301
  if __name__ == "__main__":
 
17
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
  self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
19
  if isinstance(self.whisper_inf, FasterWhisperInference):
20
+ self.whisper_inf.model_dir = args.faster_whisper_model_dir
21
  print("Use Faster Whisper implementation")
22
  else:
23
+ self.whisper_inf.model_dir = args.whisper_model_dir
24
  print("Use Open AI Whisper implementation")
25
  print(f"Device \"{self.whisper_inf.device}\" is detected")
26
  self.nllb_inf = NLLBInference()
 
298
  parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
299
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
300
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
301
+ parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
302
+ parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
303
  _args = parser.parse_args()
304
 
305
  if __name__ == "__main__":
modules/faster_whisper_inference.py CHANGED
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
32
  self.available_compute_types = ctranslate2.get_supported_compute_types(
33
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
- self.default_beam_size = 1
36
 
37
  def transcribe_file(self,
38
  files: list,
@@ -311,7 +311,7 @@ class FasterWhisperInference(BaseInterface):
311
  self.model = faster_whisper.WhisperModel(
312
  device=self.device,
313
  model_size_or_path=model_size,
314
- download_root=os.path.join("models", "Whisper", "faster-whisper"),
315
  compute_type=self.current_compute_type
316
  )
317
 
 
32
  self.available_compute_types = ctranslate2.get_supported_compute_types(
33
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
+ self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
36
 
37
  def transcribe_file(self,
38
  files: list,
 
311
  self.model = faster_whisper.WhisperModel(
312
  device=self.device,
313
  model_size_or_path=model_size,
314
+ download_root=self.model_dir,
315
  compute_type=self.current_compute_type
316
  )
317
 
modules/whisper_Inference.py CHANGED
@@ -26,7 +26,7 @@ class WhisperInference(BaseInterface):
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
- self.default_beam_size = 1
30
 
31
  def transcribe_file(self,
32
  files: list,
@@ -288,7 +288,7 @@ class WhisperInference(BaseInterface):
288
  self.model = whisper.load_model(
289
  name=model_size,
290
  device=self.device,
291
- download_root=os.path.join("models", "Whisper")
292
  )
293
 
294
  @staticmethod
 
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
+ self.model_dir = os.path.join("models", "Whisper")
30
 
31
  def transcribe_file(self,
32
  files: list,
 
288
  self.model = whisper.load_model(
289
  name=model_size,
290
  device=self.device,
291
+ download_root=self.model_dir
292
  )
293
 
294
  @staticmethod