jhj0517 commited on
Commit
c7bfcf2
1 Parent(s): 5fee9a3

Calculate WER between gen result & answer

Browse files
Files changed (1) hide show
  1. tests/test_transcription.py +13 -9
tests/test_transcription.py CHANGED
@@ -1,5 +1,6 @@
1
  from modules.whisper.whisper_factory import WhisperFactory
2
  from modules.whisper.data_classes import *
 
3
  from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
@@ -28,6 +29,10 @@ def test_transcribe(
28
  if not os.path.exists(audio_path):
29
  download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
30
 
 
 
 
 
31
  whisper_inferencer = WhisperFactory.create_whisper_inference(
32
  whisper_type=whisper_type,
33
  )
@@ -54,7 +59,7 @@ def test_transcribe(
54
  ),
55
  ).to_list()
56
 
57
- subtitle_str, file_path = whisper_inferencer.transcribe_file(
58
  [audio_path],
59
  None,
60
  "SRT",
@@ -62,12 +67,11 @@ def test_transcribe(
62
  gr.Progress(),
63
  *hparams,
64
  )
65
-
66
- assert isinstance(subtitle_str, str) and subtitle_str
67
- assert isinstance(file_path[0], str) and file_path
68
 
69
  if not is_pytube_detected_bot():
70
- whisper_inferencer.transcribe_youtube(
71
  TEST_YOUTUBE_URL,
72
  "SRT",
73
  False,
@@ -75,17 +79,17 @@ def test_transcribe(
75
  *hparams,
76
  )
77
  assert isinstance(subtitle_str, str) and subtitle_str
78
- assert isinstance(file_path[0], str) and file_path
79
 
80
- whisper_inferencer.transcribe_mic(
81
  audio_path,
82
  "SRT",
83
  False,
84
  gr.Progress(),
85
  *hparams,
86
  )
87
- assert isinstance(subtitle_str, str) and subtitle_str
88
- assert isinstance(file_path[0], str) and file_path
89
 
90
 
91
  def download_file(url, save_dir):
 
1
  from modules.whisper.whisper_factory import WhisperFactory
2
  from modules.whisper.data_classes import *
3
+ from modules.utils.subtitle_manager import read_file
4
  from modules.utils.paths import WEBUI_DIR
5
  from test_config import *
6
 
 
29
  if not os.path.exists(audio_path):
30
  download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
31
 
32
+ answer = TEST_ANSWER
33
+ if diarization:
34
+ answer = "SPEAKER_00|"+TEST_ANSWER
35
+
36
  whisper_inferencer = WhisperFactory.create_whisper_inference(
37
  whisper_type=whisper_type,
38
  )
 
59
  ),
60
  ).to_list()
61
 
62
+ subtitle_str, file_paths = whisper_inferencer.transcribe_file(
63
  [audio_path],
64
  None,
65
  "SRT",
 
67
  gr.Progress(),
68
  *hparams,
69
  )
70
+ subtitle = read_file(file_paths[0]).split("\n")
71
+ assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
 
72
 
73
  if not is_pytube_detected_bot():
74
+ subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
75
  TEST_YOUTUBE_URL,
76
  "SRT",
77
  False,
 
79
  *hparams,
80
  )
81
  assert isinstance(subtitle_str, str) and subtitle_str
82
+ assert os.path.exists(file_path)
83
 
84
+ subtitle_str, file_path = whisper_inferencer.transcribe_mic(
85
  audio_path,
86
  "SRT",
87
  False,
88
  gr.Progress(),
89
  *hparams,
90
  )
91
+ subtitle = read_file(file_path).split("\n")
92
+ assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
93
 
94
 
95
  def download_file(url, save_dir):