andito HF staff commited on
Commit
2d00549
1 Parent(s): 9aea727

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. handler.py +38 -21
  2. s2s_pipeline.py +14 -4
  3. test.py +7 -0
handler.py CHANGED
@@ -2,9 +2,9 @@ from typing import Dict, Any, List, Generator
2
  import torch
3
  import os
4
  import logging
5
- from s2s_pipeline import main, rename_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
6
  import numpy as np
7
- from queue import Queue
8
  import threading
9
 
10
  class EndpointHandler:
@@ -21,16 +21,19 @@ class EndpointHandler:
21
  self.parler_tts_handler_kwargs,
22
  self.melo_tts_handler_kwargs,
23
  self.chat_tts_handler_kwargs,
24
- ) = get_default_arguments()
25
  setup_logger(self.module_kwargs.log_level)
26
 
27
- rename_args(self.whisper_stt_handler_kwargs, "stt")
28
- rename_args(self.paraformer_stt_handler_kwargs, "paraformer_stt")
29
- rename_args(self.language_model_handler_kwargs, "lm")
30
- rename_args(self.mlx_language_model_handler_kwargs, "mlx_lm")
31
- rename_args(self.parler_tts_handler_kwargs, "tts")
32
- rename_args(self.melo_tts_handler_kwargs, "melo")
33
- rename_args(self.chat_tts_handler_kwargs, "chat_tts")
 
 
 
34
 
35
  self.queues_and_events = initialize_queues_and_events()
36
 
@@ -54,17 +57,21 @@ class EndpointHandler:
54
  # Add a new queue for collecting the final output
55
  self.final_output_queue = Queue()
56
 
57
- # Start a thread to collect the final output
58
- self.output_collector_thread = threading.Thread(target=self._collect_output)
59
- self.output_collector_thread.start()
60
-
61
  def _collect_output(self):
62
  while True:
63
- output = self.queues_and_events['send_audio_chunks_queue'].get()
64
- if output == b"END":
65
- self.final_output_queue.put(b"END")
 
 
 
 
 
 
 
 
 
66
  break
67
- self.final_output_queue.put(output)
68
 
69
  def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
70
  """
@@ -74,6 +81,10 @@ class EndpointHandler:
74
  Returns:
75
  Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline.
76
  """
 
 
 
 
77
  input_type = data.get("input_type", "text")
78
  input_data = data.get("input", "")
79
 
@@ -89,12 +100,18 @@ class EndpointHandler:
89
  else:
90
  raise ValueError(f"Unsupported input type: {input_type}")
91
 
92
- # Stream the output chunks
 
93
  while True:
94
  chunk = self.final_output_queue.get()
95
- if chunk == b"END":
96
  break
97
- yield {"output": chunk}
 
 
 
 
 
98
 
99
  def cleanup(self):
100
  # Stop the pipeline
 
2
  import torch
3
  import os
4
  import logging
5
+ from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
6
  import numpy as np
7
+ from queue import Queue, Empty
8
  import threading
9
 
10
  class EndpointHandler:
 
21
  self.parler_tts_handler_kwargs,
22
  self.melo_tts_handler_kwargs,
23
  self.chat_tts_handler_kwargs,
24
+ ) = get_default_arguments(device='cpu', mode='none', tts='melo', stt='whisper-mlx')
25
  setup_logger(self.module_kwargs.log_level)
26
 
27
+ prepare_all_args(
28
+ self.module_kwargs,
29
+ self.whisper_stt_handler_kwargs,
30
+ self.paraformer_stt_handler_kwargs,
31
+ self.language_model_handler_kwargs,
32
+ self.mlx_language_model_handler_kwargs,
33
+ self.parler_tts_handler_kwargs,
34
+ self.melo_tts_handler_kwargs,
35
+ self.chat_tts_handler_kwargs,
36
+ )
37
 
38
  self.queues_and_events = initialize_queues_and_events()
39
 
 
57
  # Add a new queue for collecting the final output
58
  self.final_output_queue = Queue()
59
 
 
 
 
 
60
  def _collect_output(self):
61
  while True:
62
+ try:
63
+ output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=5) # 2-second timeout
64
+ if isinstance(output, (str, bytes)) and output in (b"END", "END"):
65
+ self.final_output_queue.put("END")
66
+ break
67
+ elif isinstance(output, np.ndarray):
68
+ self.final_output_queue.put(output.tobytes())
69
+ else:
70
+ self.final_output_queue.put(output)
71
+ except Empty:
72
+ # If no output for 2 seconds, assume processing is complete
73
+ self.final_output_queue.put("END")
74
  break
 
75
 
76
  def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
77
  """
 
81
  Returns:
82
  Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline.
83
  """
84
+ # Start a thread to collect the final output
85
+ self.output_collector_thread = threading.Thread(target=self._collect_output)
86
+ self.output_collector_thread.start()
87
+
88
  input_type = data.get("input_type", "text")
89
  input_data = data.get("input", "")
90
 
 
100
  else:
101
  raise ValueError(f"Unsupported input type: {input_type}")
102
 
103
+ # Collect all output chunks
104
+ output_chunks = []
105
  while True:
106
  chunk = self.final_output_queue.get()
107
+ if chunk == "END":
108
  break
109
+ output_chunks.append(chunk)
110
+
111
+ # Combine all audio chunks into a single byte string
112
+ combined_audio = b''.join(output_chunks)
113
+
114
+ return {"output": combined_audio}
115
 
116
  def cleanup(self):
117
  # Stop the pipeline
s2s_pipeline.py CHANGED
@@ -65,8 +65,8 @@ def rename_args(args, prefix):
65
 
66
  args.__dict__["gen_kwargs"] = gen_kwargs
67
 
68
- def get_default_arguments():
69
- return (
70
  ModuleArguments(),
71
  SocketReceiverArguments(),
72
  SocketSenderArguments(),
@@ -78,7 +78,14 @@ def get_default_arguments():
78
  ParlerTTSHandlerArguments(),
79
  MeloTTSHandlerArguments(),
80
  ChatTTSHandlerArguments(),
81
- )
 
 
 
 
 
 
 
82
 
83
  def parse_arguments():
84
  parser = HfArgumentParser(
@@ -241,7 +248,7 @@ def build_pipeline(
241
  )
242
  comms_handlers = [local_audio_streamer]
243
  should_listen.set()
244
- else:
245
  from connections.socket_receiver import SocketReceiver
246
  from connections.socket_sender import SocketSender
247
 
@@ -261,6 +268,9 @@ def build_pipeline(
261
  port=socket_sender_kwargs.send_port,
262
  ),
263
  ]
 
 
 
264
 
265
  vad = VADHandler(
266
  stop_event,
 
65
 
66
  args.__dict__["gen_kwargs"] = gen_kwargs
67
 
68
+ def get_default_arguments(**kwargs):
69
+ default_args = [
70
  ModuleArguments(),
71
  SocketReceiverArguments(),
72
  SocketSenderArguments(),
 
78
  ParlerTTSHandlerArguments(),
79
  MeloTTSHandlerArguments(),
80
  ChatTTSHandlerArguments(),
81
+ ]
82
+ # Update arguments with provided kwargs
83
+ for arg_obj in default_args:
84
+ for key, value in kwargs.items():
85
+ if hasattr(arg_obj, key):
86
+ setattr(arg_obj, key, value)
87
+
88
+ return tuple(default_args)
89
 
90
  def parse_arguments():
91
  parser = HfArgumentParser(
 
248
  )
249
  comms_handlers = [local_audio_streamer]
250
  should_listen.set()
251
+ elif module_kwargs.mode == "socket":
252
  from connections.socket_receiver import SocketReceiver
253
  from connections.socket_sender import SocketSender
254
 
 
268
  port=socket_sender_kwargs.send_port,
269
  ),
270
  ]
271
+ else:
272
+ comms_handlers = []
273
+ should_listen.set()
274
 
275
  vad = VADHandler(
276
  stop_event,
test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ endpoint = EndpointHandler('')
4
+
5
+ for x in endpoint({'text': 'how are you?'}):
6
+ print('passed')
7
+ print(x)