Mark Duppenthaler commited on
Commit
366edf8
1 Parent(s): 8e82d74
Dockerfile CHANGED
@@ -1,3 +1,5 @@
 
 
1
  FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
  ENV DEBIAN_FRONTEND=noninteractive
3
  RUN apt-get update && \
 
1
+ # TODO: This doesn't work, copied over from M4T but needs an update
2
+
3
  FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
4
  ENV DEBIAN_FRONTEND=noninteractive
5
  RUN apt-get update && \
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
__pycache__/app.cpython-38.pyc ADDED
Binary file (2.47 kB). View file
 
__pycache__/simuleval_transcoder.cpython-310.pyc CHANGED
Binary files a/__pycache__/simuleval_transcoder.cpython-310.pyc and b/__pycache__/simuleval_transcoder.cpython-310.pyc differ
 
__pycache__/simuleval_transcoder.cpython-38.pyc ADDED
Binary file (13.6 kB). View file
 
app.py CHANGED
@@ -6,101 +6,150 @@ import gradio as gr
6
  import numpy as np
7
  import torch
8
  import torchaudio
9
- from seamless_communication.models.inference.translator import Translator
10
 
11
-
12
- from m4t_app import *
13
  from simuleval_transcoder import *
14
- # from simuleval_transcoder import *
15
 
16
  from pydub import AudioSegment
17
  import time
18
  from time import sleep
19
 
20
- # m4t_demo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- USE_M4T = True
 
 
 
23
 
24
- Transcoder = SimulevalTranscoder()
25
 
26
- def translate_audio_file_segment(audio_file):
27
- print("translate_m4t state")
28
 
29
- return predict(
30
- task_name="S2ST",
31
- audio_source="microphone",
32
- input_audio_mic=audio_file,
33
- input_audio_file=None,
34
- input_text="",
35
- source_language="English",
36
- target_language="Portuguese",
37
- )
38
 
 
 
 
 
39
 
40
- def translate_m4t_callback(
 
 
41
  audio_file, translated_audio_bytes_state, translated_text_state
42
  ):
43
- translated_wav_segment, translated_text = translate_audio_file_segment(audio_file)
44
- print('translated_audio_bytes_state', translated_audio_bytes_state)
45
- print('translated_wav_segment', translated_wav_segment)
 
 
46
 
47
- # combine translated wav into larger..
48
- if type(translated_audio_bytes_state) is not tuple:
49
- translated_audio_bytes_state = translated_wav_segment
50
- else:
51
 
52
- translated_audio_bytes_state = (translated_audio_bytes_state[0], np.append(translated_audio_bytes_state[1], translated_wav_segment[1]))
53
 
54
- # translated_wav_segment[1]
 
 
 
 
 
 
55
 
 
 
56
 
57
- translated_text_state += " | " + str(translated_text)
 
 
 
 
 
 
58
  return [
59
- audio_file,
60
  translated_wav_segment,
61
- translated_audio_bytes_state,
62
- translated_text_state,
63
  translated_audio_bytes_state,
64
  translated_text_state,
65
  ]
66
 
67
 
68
  def clear():
69
- print("Clearing State")
70
  return [bytes(), ""]
71
 
72
 
73
  def blocks():
74
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
75
  translated_audio_bytes_state = gr.State(None)
76
  translated_text_state = gr.State("")
77
 
78
- # input_audio = gr.Audio(label="Input Audio", type="filepath", format="mp3")
79
- if USE_M4T:
80
- input_audio = gr.Audio(
81
- label="Input Audio",
82
- type="filepath",
83
- source="microphone",
84
- streaming=True,
85
- )
86
- else:
87
- input_audio = gr.Audio(
88
- label="Input Audio",
89
- type="filepath",
90
- format="mp3",
91
- source="microphone",
92
- streaming=True,
93
- )
94
 
95
  most_recent_input_audio_segment = gr.Audio(
96
  label="Recent Input Audio Segment segments",
97
- format="bytes",
98
  streaming=True
99
  )
100
- # TODO: Should add combined input audio segments...
101
-
102
- stream_as_bytes_btn = gr.Button("Translate most recent recording segment")
103
 
 
 
104
  output_translation_segment = gr.Audio(
105
  label="Translated audio segment",
106
  autoplay=False,
@@ -119,7 +168,7 @@ def blocks():
119
  stream_output_text = gr.Textbox(label="Translated text")
120
 
121
  stream_as_bytes_btn.click(
122
- translate_m4t_callback,
123
  [input_audio, translated_audio_bytes_state, translated_text_state],
124
  [
125
  most_recent_input_audio_segment,
@@ -131,8 +180,21 @@ def blocks():
131
  ],
132
  )
133
 
134
- input_audio.change(
135
- translate_m4t_callback,
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  [input_audio, translated_audio_bytes_state, translated_text_state],
137
  [
138
  most_recent_input_audio_segment,
@@ -143,8 +205,11 @@ def blocks():
143
  translated_text_state,
144
  ],
145
  )
146
- # input_audio.change(stream_bytes, [input_audio, translated_audio_bytes_state, translated_text_state], [most_recent_input_audio_segment, stream_output_text, translated_audio_bytes_state, translated_text_state])
147
- # input_audio.change(lambda input_audio: recorded_audio, [input_audio], [recorded_audio])
 
 
 
148
  input_audio.clear(
149
  clear, None, [translated_audio_bytes_state, translated_text_state]
150
  )
@@ -154,6 +219,4 @@ def blocks():
154
 
155
  demo.queue().launch()
156
 
157
-
158
- # if __name__ == "__main__":
159
  blocks()
 
6
  import numpy as np
7
  import torch
8
  import torchaudio
 
9
 
 
 
10
  from simuleval_transcoder import *
 
11
 
12
  from pydub import AudioSegment
13
  import time
14
  from time import sleep
15
 
16
+ from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import (
17
+ TestTimeWaitKUnityS2TM4T,
18
+ )
19
+
20
+ language_code_to_name = {
21
+ "cmn": "Mandarin Chinese",
22
+ "deu": "German",
23
+ "eng": "English",
24
+ "fra": "French",
25
+ "spa": "Spanish",
26
+ }
27
+ S2ST_TARGET_LANGUAGE_NAMES = language_code_to_name.values()
28
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
29
+
30
+ DEFAULT_TARGET_LANGUAGE = "English"
31
+
32
+ # TODO: Update this so it takes in target langs from input, refactor sample rate
33
+ transcoder = SimulevalTranscoder(
34
+ sample_rate=48_000,
35
+ debug=False,
36
+ buffer_limit=1,
37
+ )
38
+
39
+ def start_recording():
40
+ logger.debug(f"start_recording: starting transcoder")
41
+ transcoder.start()
42
+
43
+
44
+ def translate_audio_segment(audio):
45
+ logger.debug(f"translate_audio_segment: incoming audio")
46
+ sample_rate, data = audio
47
+
48
+ transcoder.process_incoming_bytes(data.tobytes(), 'eng', sample_rate)
49
 
50
+ speech_and_text_output = transcoder.get_buffered_output()
51
+ if speech_and_text_output is None:
52
+ logger.debug("No output from transcoder.get_buffered_output()")
53
+ return None, None
54
 
55
+ logger.debug(f"We DID get output from the transcoder! {speech_and_text_output}")
56
 
57
+ text = None
58
+ speech = None
59
 
60
+ if speech_and_text_output.speech_samples:
61
+ speech = (speech_and_text_output.speech_samples, speech_and_text_output.speech_sample_rate)
 
 
 
 
 
 
 
62
 
63
+ if speech_and_text_output.text:
64
+ text = speech_and_text_output.text
65
+ if speech_and_text_output.final:
66
+ text += "\n"
67
 
68
+ return speech, text
69
+
70
+ def streaming_input_callback(
71
  audio_file, translated_audio_bytes_state, translated_text_state
72
  ):
73
+ translated_wav_segment, translated_text = translate_audio_segment(audio_file)
74
+ logger.debug(f'translated_audio_bytes_state {translated_audio_bytes_state}')
75
+ logger.debug(f'translated_wav_segment {translated_wav_segment}')
76
+
77
+ # TODO: accumulate each segment to provide a continuous audio segment
78
 
79
+ if translated_wav_segment is not None:
80
+ sample_rate, audio_bytes = translated_wav_segment
81
+ audio_np_array = np.frombuffer(audio_bytes, dtype=np.float32, count=3)
 
82
 
 
83
 
84
+ # combine translated wav
85
+ if type(translated_audio_bytes_state) is not tuple:
86
+ translated_audio_bytes_state = (sample_rate, audio_np_array)
87
+ # translated_audio_bytes_state = np.array([])
88
+ else:
89
+
90
+ translated_audio_bytes_state = (translated_audio_bytes_state[0], np.append(translated_audio_bytes_state[1], translated_wav_segment[1]))
91
 
92
+ if translated_text is not None:
93
+ translated_text_state += " | " + str(translated_text)
94
 
95
+ # most_recent_input_audio_segment = (most_recent_input_audio_segment[0], np.append(most_recent_input_audio_segment[1], audio_file[1]))
96
+
97
+ # Not necessary but for readability.
98
+ most_recent_input_audio_segment = audio_file
99
+ translated_wav_segment = translated_wav_segment
100
+ output_translation_combined = translated_audio_bytes_state
101
+ stream_output_text = translated_text_state
102
  return [
103
+ most_recent_input_audio_segment,
104
  translated_wav_segment,
105
+ output_translation_combined,
106
+ stream_output_text,
107
  translated_audio_bytes_state,
108
  translated_text_state,
109
  ]
110
 
111
 
112
  def clear():
113
+ logger.debug(f"Clearing State")
114
  return [bytes(), ""]
115
 
116
 
117
  def blocks():
118
  with gr.Blocks() as demo:
119
+
120
+ with gr.Row():
121
+ # Hook this up once supported
122
+ target_language = gr.Dropdown(
123
+ label="Target language",
124
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
125
+ value=DEFAULT_TARGET_LANGUAGE,
126
+ )
127
+
128
  translated_audio_bytes_state = gr.State(None)
129
  translated_text_state = gr.State("")
130
 
131
+ input_audio = gr.Audio(
132
+ label="Input Audio",
133
+ # source="microphone", # gradio==3.41.0
134
+ sources=["microphone"], # new gradio seems to call this less often...
135
+ streaming=True,
136
+ )
137
+
138
+ # input_audio = gr.Audio(
139
+ # label="Input Audio",
140
+ # type="filepath",
141
+ # source="microphone",
142
+ # streaming=True,
143
+ # )
 
 
 
144
 
145
  most_recent_input_audio_segment = gr.Audio(
146
  label="Recent Input Audio Segment segments",
147
+ # format="bytes",
148
  streaming=True
149
  )
 
 
 
150
 
151
+ # Force translate
152
+ stream_as_bytes_btn = gr.Button("Force translate most recent recording segment (ask for model output)")
153
  output_translation_segment = gr.Audio(
154
  label="Translated audio segment",
155
  autoplay=False,
 
168
  stream_output_text = gr.Textbox(label="Translated text")
169
 
170
  stream_as_bytes_btn.click(
171
+ streaming_input_callback,
172
  [input_audio, translated_audio_bytes_state, translated_text_state],
173
  [
174
  most_recent_input_audio_segment,
 
180
  ],
181
  )
182
 
183
+ # input_audio.change(
184
+ # streaming_input_callback,
185
+ # [input_audio, translated_audio_bytes_state, translated_text_state],
186
+ # [
187
+ # most_recent_input_audio_segment,
188
+ # output_translation_segment,
189
+ # output_translation_combined,
190
+ # stream_output_text,
191
+ # translated_audio_bytes_state,
192
+ # translated_text_state,
193
+ # ],
194
+ # )
195
+
196
+ input_audio.stream(
197
+ streaming_input_callback,
198
  [input_audio, translated_audio_bytes_state, translated_text_state],
199
  [
200
  most_recent_input_audio_segment,
 
205
  translated_text_state,
206
  ],
207
  )
208
+
209
+ input_audio.start_recording(
210
+ start_recording,
211
+ )
212
+
213
  input_audio.clear(
214
  clear, None, [translated_audio_bytes_state, translated_text_state]
215
  )
 
219
 
220
  demo.queue().launch()
221
 
 
 
222
  blocks()
internal_demo_simuleval_transcoder.py DELETED
@@ -1,272 +0,0 @@
1
- from simuleval.utils.agent import build_system_from_dir
2
- from typing import Any, Tuple
3
- import numpy as np
4
- import soundfile
5
- from fairseq.data.audio.audio_utils import convert_waveform
6
- import io
7
- import asyncio
8
- from simuleval.data.segments import SpeechSegment, EmptySegment
9
- import threading
10
- import math
11
- import logging
12
- import sys
13
- from pathlib import Path
14
- import time
15
- from g2p_en import G2p
16
- import torch
17
- import traceback
18
- import time
19
- import random
20
-
21
- from .speech_and_text_output import SpeechAndTextOutput
22
-
23
- MODEL_SAMPLE_RATE = 16_000
24
-
25
- logger = logging.getLogger()
26
- logger.addHandler(logging.StreamHandler(sys.stdout))
27
-
28
-
29
- class SimulevalTranscoder:
30
- def __init__(self, agent, sample_rate, debug, buffer_limit):
31
- self.agent = agent
32
- self.input_queue = asyncio.Queue()
33
- self.output_queue = asyncio.Queue()
34
- self.states = self.agent.build_states()
35
- if debug:
36
- self.states[0].debug = True
37
- self.incoming_sample_rate = sample_rate
38
- self.close = False
39
- self.g2p = G2p()
40
-
41
- # buffer all outgoing translations within this amount of time
42
- self.output_buffer_idle_ms = 5000
43
- self.output_buffer_size_limit = (
44
- buffer_limit # phonemes for text, seconds for speech
45
- )
46
- self.output_buffer_cur_size = 0
47
- self.output_buffer = []
48
- self.speech_output_sample_rate = None
49
-
50
- self.last_output_ts = time.time() * 1000
51
- self.timeout_ms = (
52
- 30000 # close the transcoder thread after this amount of silence
53
- )
54
- self.first_input_ts = None
55
- self.first_output_ts = None
56
- self.output_data_type = None # speech or text
57
- self.debug = debug
58
- self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
59
- if self.debug:
60
- debug_folder = Path(__file__).resolve().parent.parent / "debug"
61
- self.test_incoming_wav = soundfile.SoundFile(
62
- debug_folder / f"{self.debug_ts}_test_incoming.wav",
63
- mode="w+",
64
- format="WAV",
65
- subtype="PCM_16",
66
- samplerate=self.incoming_sample_rate,
67
- channels=1,
68
- )
69
- self.states[0].test_input_segments_wav = soundfile.SoundFile(
70
- debug_folder / f"{self.debug_ts}_test_input_segments.wav",
71
- mode="w+",
72
- format="WAV",
73
- samplerate=MODEL_SAMPLE_RATE,
74
- channels=1,
75
- )
76
-
77
- def debug_log(self, *args):
78
- if self.debug:
79
- logger.info(*args)
80
-
81
- @classmethod
82
- def build_agent(cls, model_path):
83
- logger.info(f"Building simuleval agent: {model_path}")
84
- agent = build_system_from_dir(
85
- Path(__file__).resolve().parent.parent / f"models/{model_path}",
86
- config_name="vad_main.yaml",
87
- )
88
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
- agent.to(device, fp16=True)
90
- logger.info(
91
- f"Successfully built simuleval agent {model_path} on device {device}"
92
- )
93
-
94
- return agent
95
-
96
- def process_incoming_bytes(self, incoming_bytes):
97
- segment, _sr = self._preprocess_wav(incoming_bytes)
98
- # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
99
- self.input_queue.put_nowait(segment)
100
-
101
- def get_input_segment(self):
102
- if self.input_queue.empty():
103
- return None
104
- chunk = self.input_queue.get_nowait()
105
- self.input_queue.task_done()
106
- return chunk
107
-
108
- def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
109
- segment, sample_rate = soundfile.read(
110
- io.BytesIO(data),
111
- dtype="float32",
112
- always_2d=True,
113
- frames=-1,
114
- start=0,
115
- format="RAW",
116
- subtype="PCM_16",
117
- samplerate=self.incoming_sample_rate,
118
- channels=1,
119
- )
120
- if self.debug:
121
- self.test_incoming_wav.seek(0, soundfile.SEEK_END)
122
- self.test_incoming_wav.write(segment)
123
-
124
- segment = segment.T
125
- segment, new_sample_rate = convert_waveform(
126
- segment,
127
- sample_rate,
128
- normalize_volume=False,
129
- to_mono=True,
130
- to_sample_rate=MODEL_SAMPLE_RATE,
131
- )
132
-
133
- assert MODEL_SAMPLE_RATE == new_sample_rate
134
- segment = segment.squeeze(axis=0)
135
- return segment, new_sample_rate
136
-
137
- def process_pipeline_impl(self, input_segment):
138
- try:
139
- output_segment = self.agent.pushpop(input_segment, self.states)
140
- if (
141
- self.states[0].first_input_ts is not None
142
- and self.first_input_ts is None
143
- ):
144
- # TODO: this is hacky
145
- self.first_input_ts = self.states[0].first_input_ts
146
-
147
- if not output_segment.is_empty:
148
- self.output_queue.put_nowait(output_segment)
149
-
150
- if output_segment.finished:
151
- self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
152
-
153
- for state in self.states:
154
- state.reset()
155
-
156
- if self.debug:
157
- # when we rebuild states, this value is reset to whatever
158
- # is in the system dir config, which defaults debug=False.
159
- self.states[0].debug = True
160
- except Exception as e:
161
- logger.error(f"Got exception while processing pipeline: {e}")
162
- traceback.print_exc()
163
- return input_segment
164
-
165
- def process_pipeline_loop(self):
166
- if self.close:
167
- return # closes the thread
168
-
169
- self.debug_log("processing_pipeline")
170
- while not self.close:
171
- input_segment = self.get_input_segment()
172
- if input_segment is None:
173
- if self.states[0].is_fresh_state: # TODO: this is hacky
174
- time.sleep(0.3)
175
- else:
176
- time.sleep(0.03)
177
- continue
178
- self.process_pipeline_impl(input_segment)
179
- self.debug_log("finished processing_pipeline")
180
-
181
- def process_pipeline_once(self):
182
- if self.close:
183
- return
184
-
185
- self.debug_log("processing pipeline once")
186
- input_segment = self.get_input_segment()
187
- if input_segment is None:
188
- return
189
- self.process_pipeline_impl(input_segment)
190
- self.debug_log("finished processing_pipeline_once")
191
-
192
- def get_output_segment(self):
193
- if self.output_queue.empty():
194
- return None
195
-
196
- output_chunk = self.output_queue.get_nowait()
197
- self.output_queue.task_done()
198
- return output_chunk
199
-
200
- def start(self):
201
- self.debug_log("starting transcoder in a thread")
202
- threading.Thread(target=self.process_pipeline_loop).start()
203
-
204
- def first_translation_time(self):
205
- return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
206
-
207
- def get_buffered_output(self) -> SpeechAndTextOutput:
208
- now = time.time() * 1000
209
- self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
210
- while not self.output_queue.empty():
211
- tmp_out = self.get_output_segment()
212
- if tmp_out and len(tmp_out.content) > 0:
213
- if not self.output_data_type:
214
- self.output_data_type = tmp_out.data_type
215
- if len(self.output_buffer) == 0:
216
- self.last_output_ts = now
217
- self._populate_output_buffer(tmp_out)
218
- self._increment_output_buffer_size(tmp_out)
219
-
220
- if tmp_out.finished:
221
- res = self._gather_output_buffer_data(final=True)
222
- self.output_buffer = []
223
- self.increment_output_buffer_size = 0
224
- self.last_output_ts = now
225
- self.first_output_ts = now
226
- return res
227
-
228
- if len(self.output_buffer) > 0 and (
229
- now - self.last_output_ts >= self.output_buffer_idle_ms
230
- or self.output_buffer_cur_size >= self.output_buffer_size_limit
231
- ):
232
- self.last_output_ts = now
233
- res = self._gather_output_buffer_data(final=False)
234
- self.output_buffer = []
235
- self.output_buffer_phoneme_count = 0
236
- self.first_output_ts = now
237
- return res
238
- else:
239
- return None
240
-
241
- def _gather_output_buffer_data(self, final):
242
- if self.output_data_type == "text":
243
- return SpeechAndTextOutput(text=" ".join(self.output_buffer), final=final)
244
- elif self.output_data_type == "speech":
245
- return SpeechAndTextOutput(
246
- speech_samples=self.output_buffer,
247
- speech_sample_rate=MODEL_SAMPLE_RATE,
248
- final=final,
249
- )
250
- else:
251
- raise ValueError(
252
- f"Invalid output buffer data type: {self.output_data_type}"
253
- )
254
-
255
- def _increment_output_buffer_size(self, segment):
256
- if segment.data_type == "text":
257
- self.output_buffer_cur_size += self._compute_phoneme_count(segment.content)
258
- elif segment.data_type == "speech":
259
- self.output_buffer_cur_size += (
260
- len(segment.content) / MODEL_SAMPLE_RATE
261
- ) # seconds
262
-
263
- def _populate_output_buffer(self, segment):
264
- if segment.data_type == "text":
265
- self.output_buffer.append(segment.content)
266
- elif segment.data_type == "speech":
267
- self.output_buffer += segment.content
268
- else:
269
- raise ValueError(f"Invalid segment data type: {segment.data_type}")
270
-
271
- def _compute_phoneme_count(self, string: str) -> int:
272
- return len([x for x in self.g2p(string) if x != " "])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,18 +1,21 @@
 
 
1
  # fairseq2==0.1.0
2
 
3
- # Temp to skip
4
- # git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
5
  # git+https://github.com/facebookresearch/seamless_communication
6
- ./seamless_communication
 
7
  # comment this out to test fairseq1 first
8
  # git+https://github.com/facebookresearch/SimulEval.git
9
  gradio==3.41.0
10
  huggingface_hub==0.16.4
11
- torch==2.0.1
12
- torchaudio==2.0.2
13
- transformers==4.32.1
14
  pydub
15
-
 
 
16
 
17
  # Can't import fairseq1 together.. causes conflict:
18
  #The conflict is caused by:
 
1
+ # TODO: fairseq2 install is complicated so currently done outside
2
+
3
  # fairseq2==0.1.0
4
 
 
 
5
  # git+https://github.com/facebookresearch/seamless_communication
6
+ # ./fairseq2
7
+ # ./seamless_communication
8
  # comment this out to test fairseq1 first
9
  # git+https://github.com/facebookresearch/SimulEval.git
10
  gradio==3.41.0
11
  huggingface_hub==0.16.4
12
+ # torch==2.1.0
13
+ # torchaudio==2.0.2
14
+ # transformers==4.32.1
15
  pydub
16
+ g2p_en
17
+ colorlog
18
+ # git+ssh://[email protected]/facebookresearch/SimulEval.git
19
 
20
  # Can't import fairseq1 together.. causes conflict:
21
  #The conflict is caused by:
seamless_communication DELETED
@@ -1 +0,0 @@
1
- Subproject commit 02405dfd0c187d625aa66255ff8c39f98031a091
 
 
simuleval_transcoder.py CHANGED
@@ -1,225 +1,455 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from pathlib import Path
3
- from typing import Callable, Dict, List, Optional, Tuple, Union
4
-
5
  import torch
6
- import torch.nn as nn
7
- from fairseq2.assets.card import AssetCard
8
- from fairseq2.data import Collater
9
- from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
10
- from fairseq2.data.text.text_tokenizer import TextTokenizer
11
- from fairseq2.data.typing import StringLike
12
- from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
13
- from fairseq2.memory import MemoryBlock
14
- from fairseq2.typing import DataType, Device
15
- from torch import Tensor
16
- from enum import Enum, auto
17
- from seamless_communication.models.inference.ngram_repeat_block_processor import (
18
- NGramRepeatBlockProcessor,
19
  )
20
 
21
- from seamless_communication.models.unity import (
22
- UnitTokenizer,
23
- UnitYGenerator,
24
- UnitYModel,
25
- load_unity_model,
26
- load_unity_text_tokenizer,
27
- load_unity_unit_tokenizer,
 
 
 
 
 
 
 
 
 
 
28
  )
29
- from seamless_communication.models.unity.generator import SequenceToUnitOutput
30
- from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
 
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # from seamless_communication.models.streaming.agents import (
35
- # SileroVADAgent,
36
- # TestTimeWaitKS2TVAD,
37
- # TestTimeWaitKUnityV1M4T
38
- # )
 
39
 
40
- from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import (
41
- TestTimeWaitKUnityS2TM4T,
42
- )
43
 
44
- from seamless_communication.cli.streaming.dataloader import Fairseq2SpeechToTextDataloader
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- ### From test_pipeline
47
- import math
48
- import soundfile
49
- from argparse import Namespace, ArgumentParser
50
- from simuleval.data.segments import SpeechSegment, EmptySegment
51
- from simuleval.utils import build_system_from_dir
52
- from pathlib import Path
53
- import numpy as np
54
 
55
- class AudioFrontEnd:
56
- def __init__(self, wav_file, segment_size) -> None:
57
- self.samples, self.sample_rate = soundfile.read(wav_file)
58
- # print(len(self.samples), self.samples[:100])
59
- self.samples = self.samples.tolist()
60
- self.segment_size = segment_size
61
- self.step = 0
62
- def send_segment(self):
63
- """
64
- This is the front-end logic in simuleval instance.py
65
- """
66
- num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
67
- print("self.segment_size", self.segment_size)
68
- print('num_samples is', num_samples)
69
- print('self.sample_rate is', self.sample_rate)
70
- if self.step < len(self.samples):
71
- if self.step + num_samples >= len(self.samples):
72
- samples = self.samples[self.step :]
73
- is_finished = True
74
- else:
75
- samples = self.samples[self.step : self.step + num_samples]
76
- is_finished = False
77
- self.step = min(self.step + num_samples, len(self.samples))
78
- # print("len(samples) is", len(samples))
79
- # import pdb
80
- # pdb.set_trace()
81
- segment = SpeechSegment(
82
- index=self.step / self.sample_rate * 1000,
83
- content=samples,
84
- sample_rate=self.sample_rate,
85
- finished=is_finished,
86
- )
87
- else:
88
- # Finish reading this audio
89
- segment = EmptySegment(
90
- index=self.step / self.sample_rate * 1000,
91
- finished=True,
92
- )
93
- return segment
94
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- def load_model_for_inference(
98
- load_model_fn: Callable[..., nn.Module],
99
- model_name_or_card: Union[str, AssetCard],
100
- device: Device,
101
- dtype: DataType,
102
- ) -> nn.Module:
103
- model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
104
- model.eval()
105
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- def load_model_fairseq2():
108
- data_configs = dict(
109
- dataloader="fairseq2_s2t",
110
- data_file="/large_experiments/seamless/ust/abinesh/data/s2st50_manifests/50-10/simuleval/dev_mtedx_filt_50-10_debug.tsv",
111
- )
112
 
113
- model_configs = dict(
114
- model_name="seamlessM4T_v2_large",
115
- device="cuda:0",
116
- source_segment_size=320,
117
- waitk_lagging=7,
118
- fixed_pre_decision_ratio=2,
119
- init_target_tokens="</s> __eng__",
120
- max_len_a=0,
121
- max_len_b=200,
122
- agent_class="seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t.TestTimeWaitKUnityS2TM4T",
123
- task="s2st",
124
- tgt_lang="eng",
125
- )
126
 
127
- eval_configs = dict(
128
- latency_metrics="StartOffset EndOffset AL",
129
- output=f"{TestTimeWaitKUnityS2TM4T.__name__}-wait{model_configs['waitk_lagging']}-debug",
130
- )
131
 
132
- model = TestTimeWaitKUnityS2TM4T({**data_configs, **model_configs, **eval_configs})
133
- print("model", model)
 
 
 
 
 
 
 
 
 
134
 
135
- evaluate(
136
- TestTimeWaitKUnityS2TM4T, {**data_configs, **model_configs, **eval_configs}
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  class SimulevalTranscoder:
140
- # def __init__(self, agent, sample_rate, debug, buffer_limit):
141
- def __init__(self):
142
- # print("MDUPPES in here", SileroVADAgent, TestTimeWaitKS2TVAD)
143
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
146
 
147
- load_model_fairseq2()
 
 
148
 
149
- device = "cpu"
150
- print("DEVICE", device)
151
- model_name_or_card="seamlessM4T_medium"
152
- vocoder_name_or_card="vocoder_36langs"
153
- # dtype=torch.float16,
154
- # For CPU Mode need to use 32, float16 causes errors downstream
155
- dtype=dtype=torch.float32
156
 
157
- model: UnitYModel = load_model_for_inference(
158
- load_unity_model, model_name_or_card, device, dtype
159
  )
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- print(model, type(model))
163
- parser = ArgumentParser()
164
- source_segment_size = 320 # milliseconds
165
- audio_frontend = AudioFrontEnd(
166
- wav_file="/checkpoint/mduppes/samples/marta.wav",
167
- segment_size=source_segment_size,
 
168
  )
169
 
170
- # mostly taken from S2S first agent: OnlineFeatureExtractorAgent defaults
171
- SHIFT_SIZE = 10
172
- WINDOW_SIZE = 25
173
- SAMPLE_RATE = 16000
174
- FEATURE_DIM = 80
175
-
176
- # args and convert to namespace so it can be accesed via .
177
- args = {
178
- "shift_size": SHIFT_SIZE,
179
- "window_size": WINDOW_SIZE,
180
- "sample_rate": audio_frontend.sample_rate,
181
- "feature_dim": 160, # from Wav2Vec2Frontend
182
- "denormalize": False, # not sure..
183
- "global_stats": None, # default file path containing cmvn stats..
184
- }
185
- print(args)
186
- args = Namespace(**args)
187
-
188
- pipeline = TestTimeWaitKUnityV1M4T(model, args)
189
- system_states = pipeline.build_states()
190
- print('system states:')
191
- for state in system_states:
192
- print(state, vars(state))
193
-
194
- input_segment = np.empty(0, dtype=np.int16)
195
- segments = []
196
- while True:
197
- speech_segment = audio_frontend.send_segment()
198
- input_segment = np.concatenate((input_segment, np.array(speech_segment.content)))
199
- # Translation happens here
200
- output_segment = pipeline.pushpop(speech_segment, system_states)
201
- print('pushpop result')
202
- print(output_segment)
203
- print('system states after pushpop:')
204
- for state in system_states:
205
- print(state, vars(state))
206
  if output_segment.finished:
207
- segments.append(input_segment)
208
- input_segment = np.empty(0, dtype=np.int16)
209
- print("Resetting states")
210
- for state in system_states:
211
- state.reset()
212
- if speech_segment.finished:
213
- break
214
- # The VAD-segmented samples from the full input audio
215
- for i, seg in enumerate(segments):
216
- with soundfile.SoundFile(
217
- Path("/checkpoint/mduppes/samples") / f"marta_{i}.wav",
218
- mode="w+",
219
- format="WAV",
220
- samplerate=16000,
221
- channels=1,
222
- ) as f:
223
- f.seek(0, soundfile.SEEK_END)
224
- f.write(seg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
 
 
 
1
 
2
+ from typing import Any, List, Tuple, Union, Optional
3
+ import numpy as np
4
+ import soundfile
5
+ import io
6
+ import asyncio
7
+ from simuleval.agents.pipeline import TreeAgentPipeline
8
+ from simuleval.agents.states import AgentStates
9
+ from simuleval.data.segments import Segment, EmptySegment, SpeechSegment
10
+ import threading
11
+ import math
12
+ import logging
13
+ import sys
14
  from pathlib import Path
15
+ import time
16
+ from g2p_en import G2p
17
  import torch
18
+ import traceback
19
+ import time
20
+ import random
21
+ import colorlog
22
+
23
+ # Sanity check that pipeline is loadable
24
+ from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import (
25
+ # TestTimeWaitKUnityS2TM4T,
26
+ TestTimeWaitKUnityS2TM4TVAD
 
 
 
 
27
  )
28
 
29
+ from simuleval.utils.agent import build_system_args
30
+
31
+ MODEL_SAMPLE_RATE = 16_000
32
+
33
+ logger = logging.getLogger(__name__)
34
+ logger.propagate = False
35
+ handler = colorlog.StreamHandler(stream=sys.stdout)
36
+ formatter = colorlog.ColoredFormatter(
37
+ "%(log_color)s[%(asctime)s][%(levelname)s][%(module)s]:%(reset)s %(message)s",
38
+ reset=True,
39
+ log_colors={
40
+ "DEBUG": "cyan",
41
+ "INFO": "green",
42
+ "WARNING": "yellow",
43
+ "ERROR": "red",
44
+ "CRITICAL": "red,bg_white",
45
+ },
46
  )
47
+ handler.setFormatter(formatter)
48
+ logger.addHandler(handler)
49
+ logger.setLevel(logging.DEBUG)
50
 
51
 
52
+ # TODO: Integrate this better so target lang and others can be changed. Also currently dependent on devserver internals
53
+ def build_agent():
54
+ config = {
55
+ 'dataloader': 'fairseq2_s2t',
56
+ 'data_file': '/large_experiments/seamless/ust/abinesh/data/s2st50_manifests/50-10/simuleval/dev_mtedx_filt_50-10_debug.tsv',
57
+ 'model_name': 'seamlessM4T_v2_large',
58
+ 'device': 'cuda:0',
59
+ 'source_segment_size': 320,
60
+ 'waitk_lagging': 7,
61
+ 'fixed_pre_decision_ratio': 2,
62
+ 'init_target_tokens': '</s> __eng__',
63
+ 'max_len_a': 0,
64
+ 'max_len_b': 200,
65
+ 'agent_class': 'seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t.TestTimeWaitKUnityS2TM4TVAD',
66
+ 'task': 's2st',
67
+ 'tgt_lang': 'eng',
68
+ 'latency_metrics': 'StartOffset EndOffset AL',
69
+ 'output': 'TestTimeWaitKUnityS2TM4TVAD-wait7-debug'
70
+ }
71
 
72
+ agent , _ = build_system_args(config)
73
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ # agent.to(device, fp16=True)
75
+ logger.info(
76
+ f"Successfully built simuleval agent"
77
+ )
78
 
79
+ return agent
 
 
80
 
81
+ class SpeechAndTextOutput:
82
+ def __init__(
83
+ self,
84
+ text: str = None,
85
+ speech_samples: list = None,
86
+ speech_sample_rate: float = None,
87
+ final: bool = False,
88
+ ):
89
+ self.text = text
90
+ self.speech_samples = speech_samples
91
+ self.speech_sample_rate = speech_sample_rate
92
+ self.final = final
93
 
94
+ class OutputSegments:
95
+ def __init__(self, segments: Union[List[Segment], Segment]):
96
+ if isinstance(segments, Segment):
97
+ segments = [segments]
98
+ self.segments: List[Segment] = [s for s in segments]
 
 
 
99
 
100
+ @property
101
+ def is_empty(self):
102
+ return all(segment.is_empty for segment in self.segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ @property
105
+ def finished(self):
106
+ return all(segment.finished for segment in self.segments)
107
 
108
+ def compute_length(self, g2p):
109
+ lengths = []
110
+ for segment in self.segments:
111
+ if segment.data_type == "text":
112
+ lengths.append(len([x for x in g2p(segment.content) if x != " "]))
113
+ elif segment.data_type == "speech":
114
+ lengths.append(len(segment.content) / MODEL_SAMPLE_RATE)
115
+ elif isinstance(segment, EmptySegment):
116
+ continue
117
+ else:
118
+ logger.warning(
119
+ f"Unexpected data_type: {segment.data_type} not in 'speech', 'text'"
120
+ )
121
+ return max(lengths)
122
 
123
+ @classmethod
124
+ def join_output_buffer(
125
+ cls, buffer: List[List[Segment]], output: SpeechAndTextOutput
126
+ ):
127
+ num_segments = len(buffer[0])
128
+ for i in range(num_segments):
129
+ segment_list = [
130
+ buffer[j][i]
131
+ for j in range(len(buffer))
132
+ if buffer[j][i].data_type is not None
133
+ ]
134
+ if len(segment_list) == 0:
135
+ continue
136
+ if len(set(segment.data_type for segment in segment_list)) != 1:
137
+ logger.warning(
138
+ f"Data type mismatch at {i}: {set(segment.data_type for segment in segment_list)}"
139
+ )
140
+ continue
141
+ data_type = segment_list[0].data_type
142
+ if data_type == "text":
143
+ if output.text is not None:
144
+ logger.warning("Multiple text outputs, overwriting!")
145
+ output.text = " ".join([segment.content for segment in segment_list])
146
+ elif data_type == "speech":
147
+ if output.speech_samples is not None:
148
+ logger.warning("Multiple speech outputs, overwriting!")
149
+ speech_out = []
150
+ for segment in segment_list:
151
+ speech_out += segment.content
152
+ output.speech_samples = speech_out
153
+ output.speech_sample_rate = MODEL_SAMPLE_RATE
154
+ elif isinstance(segment_list[0], EmptySegment):
155
+ continue
156
+ else:
157
+ logger.warning(
158
+ f"Invalid output buffer data type: {data_type}, expected 'speech' or 'text"
159
+ )
160
 
161
+ return output
 
 
 
 
162
 
163
+ def __repr__(self) -> str:
164
+ repr_str = str(self.segments)
165
+ return f"{self.__class__.__name__}(\n\t{repr_str}\n)"
 
 
 
 
 
 
 
 
 
 
166
 
 
 
 
 
167
 
168
+ def convert_waveform(
169
+ waveform: Union[np.ndarray, torch.Tensor],
170
+ sample_rate: int,
171
+ normalize_volume: bool = False,
172
+ to_mono: bool = False,
173
+ to_sample_rate: Optional[int] = None,
174
+ ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
175
+ """convert a waveform:
176
+ - to a target sample rate
177
+ - from multi-channel to mono channel
178
+ - volume normalization
179
 
180
+ Args:
181
+ waveform (numpy.ndarray or torch.Tensor): 2D original waveform
182
+ (channels x length)
183
+ sample_rate (int): original sample rate
184
+ normalize_volume (bool): perform volume normalization
185
+ to_mono (bool): convert to mono channel if having multiple channels
186
+ to_sample_rate (Optional[int]): target sample rate
187
+ Returns:
188
+ waveform (numpy.ndarray): converted 2D waveform (channels x length)
189
+ sample_rate (float): target sample rate
190
+ """
191
+ try:
192
+ import torchaudio.sox_effects as ta_sox
193
+ except ImportError:
194
+ raise ImportError("Please install torchaudio: pip install torchaudio")
195
+
196
+ effects = []
197
+ if normalize_volume:
198
+ effects.append(["gain", "-n"])
199
+ if to_sample_rate is not None and to_sample_rate != sample_rate:
200
+ effects.append(["rate", f"{to_sample_rate}"])
201
+ if to_mono and waveform.shape[0] > 1:
202
+ effects.append(["channels", "1"])
203
+ if len(effects) > 0:
204
+ is_np_input = isinstance(waveform, np.ndarray)
205
+ _waveform = torch.from_numpy(waveform) if is_np_input else waveform
206
+ converted, converted_sample_rate = ta_sox.apply_effects_tensor(
207
+ _waveform, sample_rate, effects
208
+ )
209
+ if is_np_input:
210
+ converted = converted.numpy()
211
+ return converted, converted_sample_rate
212
+ return waveform, sample_rate
213
 
214
  class SimulevalTranscoder:
215
+ def __init__(self, sample_rate, debug, buffer_limit):
216
+ self.agent = build_agent()
217
+ self.input_queue = asyncio.Queue()
218
+ self.output_queue = asyncio.Queue()
219
+ self.states = self.agent.build_states()
220
+ if debug:
221
+ self.get_states_root().debug = True
222
+ self.incoming_sample_rate = sample_rate
223
+ self.close = False
224
+ self.g2p = G2p()
225
+
226
+ # buffer all outgoing translations within this amount of time
227
+ self.output_buffer_idle_ms = 5000
228
+ self.output_buffer_size_limit = (
229
+ buffer_limit # phonemes for text, seconds for speech
230
+ )
231
+ self.output_buffer_cur_size = 0
232
+ self.output_buffer: List[List[Segment]] = []
233
+ self.speech_output_sample_rate = None
234
+
235
+ self.last_output_ts = time.time() * 1000
236
+ self.timeout_ms = (
237
+ 30000 # close the transcoder thread after this amount of silence
238
+ )
239
+ self.first_input_ts = None
240
+ self.first_output_ts = None
241
+ self.debug = debug
242
+ self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
243
+ if self.debug:
244
+ debug_folder = Path(__file__).resolve().parent.parent / "debug"
245
+ self.test_incoming_wav = soundfile.SoundFile(
246
+ debug_folder / f"{self.debug_ts}_test_incoming.wav",
247
+ mode="w+",
248
+ format="WAV",
249
+ subtype="PCM_16",
250
+ samplerate=self.incoming_sample_rate,
251
+ channels=1,
252
+ )
253
+ self.get_states_root().test_input_segments_wav = soundfile.SoundFile(
254
+ debug_folder / f"{self.debug_ts}_test_input_segments.wav",
255
+ mode="w+",
256
+ format="WAV",
257
+ samplerate=MODEL_SAMPLE_RATE,
258
+ channels=1,
259
+ )
260
 
261
+ def get_states_root(self) -> AgentStates:
262
+ if isinstance(self.agent, TreeAgentPipeline):
263
+ # self.states is a dict
264
+ return self.states[self.agent.source_module]
265
+ else:
266
+ # self.states is a list
267
+ return self.states[0]
268
 
269
+ def reset_states(self):
270
+ if isinstance(self.agent, TreeAgentPipeline):
271
+ states_iter = self.states.values()
272
+ else:
273
+ states_iter = self.states
274
+ for state in states_iter:
275
+ state.reset()
276
 
277
+ def debug_log(self, *args):
278
+ if self.debug:
279
+ logger.info(*args)
280
 
281
+ def process_incoming_bytes(self, incoming_bytes, target_language, sample_rate):
282
+ # TODO: currently just taking sample rate here, refactor sample rate
283
+ # bytes is 16bit signed int
284
+ self.incoming_sample_rate = sample_rate
285
+ segment, sr = self._preprocess_wav(incoming_bytes)
 
 
286
 
287
+ segment = SpeechSegment(
288
+ content=segment, sample_rate=sr, tgt_lang=target_language
289
  )
290
+ # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
291
+ self.input_queue.put_nowait(segment)
292
 
293
+ def get_input_segment(self):
294
+ if self.input_queue.empty():
295
+ return None
296
+ chunk = self.input_queue.get_nowait()
297
+ self.input_queue.task_done()
298
+ return chunk
299
+
300
+ def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
301
+ segment, sample_rate = soundfile.read(
302
+ io.BytesIO(data),
303
+ dtype="float32",
304
+ always_2d=True,
305
+ frames=-1,
306
+ start=0,
307
+ format="RAW",
308
+ subtype="PCM_16",
309
+ samplerate=self.incoming_sample_rate,
310
+ channels=1,
311
+ )
312
+ if self.debug:
313
+ self.test_incoming_wav.seek(0, soundfile.SEEK_END)
314
+ self.test_incoming_wav.write(segment)
315
 
316
+ segment = segment.T
317
+ segment, new_sample_rate = convert_waveform(
318
+ segment,
319
+ sample_rate,
320
+ normalize_volume=False,
321
+ to_mono=True,
322
+ to_sample_rate=MODEL_SAMPLE_RATE,
323
  )
324
 
325
+ assert MODEL_SAMPLE_RATE == new_sample_rate
326
+ segment = segment.squeeze(axis=0)
327
+ return segment, new_sample_rate
328
+
329
+ def process_pipeline_impl(self, input_segment):
330
+ try:
331
+ with torch.no_grad():
332
+ output_segment = OutputSegments(
333
+ self.agent.pushpop(input_segment, self.states)
334
+ )
335
+ if (
336
+ self.get_states_root().first_input_ts is not None
337
+ and self.first_input_ts is None
338
+ ):
339
+ # TODO: this is hacky
340
+ self.first_input_ts = self.get_states_root().first_input_ts
341
+
342
+ if not output_segment.is_empty:
343
+ self.output_queue.put_nowait(output_segment)
344
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  if output_segment.finished:
346
+ self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
347
+
348
+ self.reset_states()
349
+
350
+ if self.debug:
351
+ # when we rebuild states, this value is reset to whatever
352
+ # is in the system dir config, which defaults debug=False.
353
+ self.get_states_root().debug = True
354
+ except Exception as e:
355
+ logger.error(f"Got exception while processing pipeline: {e}")
356
+ traceback.print_exc()
357
+ return input_segment
358
+
359
+ def process_pipeline_loop(self):
360
+ if self.close:
361
+ return # closes the thread
362
+
363
+ self.debug_log("processing_pipeline")
364
+ while not self.close:
365
+ input_segment = self.get_input_segment()
366
+ if input_segment is None:
367
+ # if self.get_states_root().is_fresh_state: # TODO: this is hacky
368
+ # time.sleep(0.3)
369
+ # else:
370
+ time.sleep(0.03)
371
+ continue
372
+ self.process_pipeline_impl(input_segment)
373
+ self.debug_log("finished processing_pipeline")
374
+
375
+ def process_pipeline_once(self):
376
+ if self.close:
377
+ return
378
+
379
+ self.debug_log("processing pipeline once")
380
+ input_segment = self.get_input_segment()
381
+ if input_segment is None:
382
+ return
383
+ self.process_pipeline_impl(input_segment)
384
+ self.debug_log("finished processing_pipeline_once")
385
+
386
+ def get_output_segment(self):
387
+ if self.output_queue.empty():
388
+ return None
389
+
390
+ output_chunk = self.output_queue.get_nowait()
391
+ self.output_queue.task_done()
392
+ return output_chunk
393
+
394
+ def start(self):
395
+ self.debug_log("starting transcoder in a thread")
396
+ threading.Thread(target=self.process_pipeline_loop).start()
397
+
398
+ def first_translation_time(self):
399
+ return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
400
+
401
+ def get_buffered_output(self) -> SpeechAndTextOutput:
402
+ now = time.time() * 1000
403
+ self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
404
+ while not self.output_queue.empty():
405
+ tmp_out = self.get_output_segment()
406
+ if tmp_out and tmp_out.compute_length(self.g2p) > 0:
407
+ if len(self.output_buffer) == 0:
408
+ self.last_output_ts = now
409
+ self._populate_output_buffer(tmp_out)
410
+ self._increment_output_buffer_size(tmp_out)
411
+
412
+ if tmp_out.finished:
413
+ self.debug_log("tmp_out.finished")
414
+ res = self._gather_output_buffer_data(final=True)
415
+ self.debug_log(f"gathered output data: {res}")
416
+ self.output_buffer = []
417
+ self.increment_output_buffer_size = 0
418
+ self.last_output_ts = now
419
+ self.first_output_ts = now
420
+ return res
421
+ else:
422
+ self.debug_log("tmp_out.compute_length is not > 0")
423
+
424
+ if len(self.output_buffer) > 0 and (
425
+ now - self.last_output_ts >= self.output_buffer_idle_ms
426
+ or self.output_buffer_cur_size >= self.output_buffer_size_limit
427
+ ):
428
+ self.debug_log(
429
+ "[get_buffered_output] output_buffer is not empty. getting res to return."
430
+ )
431
+ self.last_output_ts = now
432
+ res = self._gather_output_buffer_data(final=False)
433
+ self.debug_log(f"gathered output data: {res}")
434
+ self.output_buffer = []
435
+ self.output_buffer_phoneme_count = 0
436
+ self.first_output_ts = now
437
+ return res
438
+ else:
439
+ self.debug_log("[get_buffered_output] output_buffer is empty...")
440
+ return None
441
+
442
+ def _gather_output_buffer_data(self, final):
443
+ output = SpeechAndTextOutput()
444
+ output.final = final
445
+ output = OutputSegments.join_output_buffer(self.output_buffer, output)
446
+ return output
447
+
448
+ def _increment_output_buffer_size(self, segment: OutputSegments):
449
+ self.output_buffer_cur_size += segment.compute_length(self.g2p)
450
+
451
+ def _populate_output_buffer(self, segment: OutputSegments):
452
+ self.output_buffer.append(segment.segments)
453
 
454
+ def _compute_phoneme_count(self, string: str) -> int:
455
+ return len([x for x in self.g2p(string) if x != " "])