jhj0517 commited on
Commit
58c7e65
2 Parent(s): c8ae5e5 ae32f22

Merge pull request #39 from jhj0517/set-default-beam-size

Browse files
modules/faster_whisper_inference.py CHANGED
@@ -24,7 +24,7 @@ class FasterWhisperInference(BaseInterface):
24
  self.available_models = whisper.available_models()
25
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
26
  self.translatable_models = ["large", "large-v1", "large-v2"]
27
- self.default_beam_size = 5
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  def transcribe_file(self,
 
24
  self.available_models = whisper.available_models()
25
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
26
  self.translatable_models = ["large", "large-v1", "large-v2"]
27
+ self.default_beam_size = 1
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  def transcribe_file(self,
modules/whisper_Inference.py CHANGED
@@ -21,6 +21,7 @@ class WhisperInference(BaseInterface):
21
  self.model = None
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
 
24
 
25
  def transcribe_file(self,
26
  fileobjs: list,
@@ -250,6 +251,7 @@ class WhisperInference(BaseInterface):
250
  segments_result = self.model.transcribe(audio=audio,
251
  language=lang,
252
  verbose=False,
 
253
  task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
254
  progress_callback=progress_callback)["segments"]
255
  elapsed_time = time.time() - start_time
 
21
  self.model = None
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
24
+ self.default_beam_size = 1
25
 
26
  def transcribe_file(self,
27
  fileobjs: list,
 
251
  segments_result = self.model.transcribe(audio=audio,
252
  language=lang,
253
  verbose=False,
254
+ beam_size=self.default_beam_size,
255
  task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
256
  progress_callback=progress_callback)["segments"]
257
  elapsed_time = time.time() - start_time