xyfcc commited on
Commit
e83cc8c
1 Parent(s): 1f0ed9b

Update web_demo.py

Browse files
Files changed (1) hide show
  1. web_demo.py +60 -29
web_demo.py CHANGED
@@ -8,11 +8,9 @@ import requests
8
  from argparse import ArgumentParser
9
 
10
  import torchaudio
11
- from transformers import WhisperFeatureExtractor, AutoTokenizer, AutoModel
12
  from speech_tokenizer.modeling_whisper import WhisperVQEncoder
13
- #import os
14
 
15
- #os.environ["no_proxy"]="localhost,127.0.0.1,::1"
16
 
17
  sys.path.insert(0, "./cosyvoice")
18
  sys.path.insert(0, "./third_party/Matcha-TTS")
@@ -22,19 +20,28 @@ from speech_tokenizer.utils import extract_speech_token
22
  import gradio as gr
23
  import torch
24
 
 
25
  audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
26
 
27
  from flow_inference import AudioDecoder
28
 
 
 
 
 
29
  if __name__ == "__main__":
30
  parser = ArgumentParser()
31
  parser.add_argument("--host", type=str, default="0.0.0.0")
32
  parser.add_argument("--port", type=int, default="8888")
33
  parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
34
  parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
35
- parser.add_argument("--tokenizer-path", type=str, default="THUDM/glm-4-voice-tokenizer")
36
  args = parser.parse_args()
37
-
 
 
 
 
38
  flow_config = os.path.join(args.flow_path, "config.yaml")
39
  flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
40
  hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
@@ -42,7 +49,7 @@ if __name__ == "__main__":
42
  device = "cuda"
43
  audio_decoder: AudioDecoder = None
44
  whisper_model, feature_extractor = None, None
45
-
46
 
47
  def initialize_fn():
48
  global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
@@ -61,12 +68,18 @@ if __name__ == "__main__":
61
  whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
62
  feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
63
 
 
 
 
 
 
 
64
 
65
  def clear_fn():
66
  return [], [], '', '', '', None, None
67
 
68
 
69
- def inference_fn(
70
  temperature: float,
71
  top_p: float,
72
  max_new_token: int,
@@ -105,17 +118,26 @@ def inference_fn(
105
  inputs += f"<|system|>\n{system_prompt}"
106
  inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
107
 
 
108
  with torch.no_grad():
109
- response = requests.post(
110
- "http://localhost:10000/generate_stream",
111
- data=json.dumps({
112
- "prompt": inputs,
113
- "temperature": temperature,
114
- "top_p": top_p,
115
- "max_new_tokens": max_new_token,
116
- }),
117
- stream=True
118
- )
 
 
 
 
 
 
 
 
119
  text_tokens, audio_tokens = [], []
120
  audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
121
  end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
@@ -128,7 +150,8 @@ def inference_fn(
128
  prev_mel = None
129
  is_finalize = False
130
  block_size = 10
131
- for chunk in response.iter_lines():
 
132
  token_id = json.loads(chunk)["token_id"]
133
  if token_id == end_token_id:
134
  is_finalize = True
@@ -165,15 +188,15 @@ def inference_fn(
165
  yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
166
 
167
 
168
- def update_input_interface(input_mode):
169
  if input_mode == "audio":
170
  return [gr.update(visible=True), gr.update(visible=False)]
171
  else:
172
  return [gr.update(visible=False), gr.update(visible=True)]
173
 
174
 
175
- # Create the Gradio interface
176
- with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
177
  with gr.Row():
178
  temperature = gr.Number(
179
  label="Temperature",
@@ -200,7 +223,9 @@ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
200
  with gr.Row():
201
  with gr.Column():
202
  input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
203
- audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
 
 
204
  text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
205
 
206
  with gr.Column():
@@ -252,10 +277,16 @@ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
252
  reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio])
253
  input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
254
 
255
- initialize_fn()
256
- # Launch the interface
257
- demo.launch(
258
- server_port=args.port,
259
- server_name=args.host,
260
- share=True
261
- )
 
 
 
 
 
 
 
8
  from argparse import ArgumentParser
9
 
10
  import torchaudio
11
+ from transformers import WhisperFeatureExtractor, AutoTokenizer
12
  from speech_tokenizer.modeling_whisper import WhisperVQEncoder
 
13
 
 
14
 
15
  sys.path.insert(0, "./cosyvoice")
16
  sys.path.insert(0, "./third_party/Matcha-TTS")
 
20
  import gradio as gr
21
  import torch
22
 
23
+
24
  audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
25
 
26
  from flow_inference import AudioDecoder
27
 
28
+ use_local_interface = True
29
+ if use_local_interface :
30
+ from model_server import ModelWorker
31
+
32
  if __name__ == "__main__":
33
  parser = ArgumentParser()
34
  parser.add_argument("--host", type=str, default="0.0.0.0")
35
  parser.add_argument("--port", type=int, default="8888")
36
  parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
37
  parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
38
+ parser.add_argument("--tokenizer-path", type= str, default="THUDM/glm-4-voice-tokenizer")
39
  args = parser.parse_args()
40
+ # --tokenizer-path /home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-tokenizer --model-path /home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-9b --flow-path /home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-decoder
41
+ # args.tokenizer_path = '/home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-tokenizer'
42
+ # args.model_path = '/home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-9b'
43
+ # args.flow_path = '/home/hanrf/llm/voice/model/ZhipuAI/glm-4-voice-decoder'
44
+
45
  flow_config = os.path.join(args.flow_path, "config.yaml")
46
  flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
47
  hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
 
49
  device = "cuda"
50
  audio_decoder: AudioDecoder = None
51
  whisper_model, feature_extractor = None, None
52
+ worker = None
53
 
54
  def initialize_fn():
55
  global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
 
68
  whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
69
  feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
70
 
71
+ global use_local_interface, worker
72
+ if use_local_interface :
73
+ model_path0 = 'THUDM/glm-4-voice-9b '
74
+ # dtype = 'bfloat16'
75
+ device0 = 'cuda:0'
76
+ worker = ModelWorker(model_path0,device0)
77
 
78
  def clear_fn():
79
  return [], [], '', '', '', None, None
80
 
81
 
82
+ def inference_fn(
83
  temperature: float,
84
  top_p: float,
85
  max_new_token: int,
 
118
  inputs += f"<|system|>\n{system_prompt}"
119
  inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
120
 
121
+ global use_local_interface , worker
122
  with torch.no_grad():
123
+ if use_local_interface :
124
+ params = { "prompt": inputs,
125
+ "temperature": temperature,
126
+ "top_p": top_p,
127
+ "max_new_tokens": max_new_token, }
128
+ response = worker.generate_stream( params )
129
+
130
+ else :
131
+ response = requests.post(
132
+ "http://localhost:10000/generate_stream",
133
+ data=json.dumps({
134
+ "prompt": inputs,
135
+ "temperature": temperature,
136
+ "top_p": top_p,
137
+ "max_new_tokens": max_new_token,
138
+ }),
139
+ stream=True
140
+ )
141
  text_tokens, audio_tokens = [], []
142
  audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
143
  end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
 
150
  prev_mel = None
151
  is_finalize = False
152
  block_size = 10
153
+ # for chunk in response.iter_lines():
154
+ for chunk in response :
155
  token_id = json.loads(chunk)["token_id"]
156
  if token_id == end_token_id:
157
  is_finalize = True
 
188
  yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
189
 
190
 
191
+ def update_input_interface(input_mode):
192
  if input_mode == "audio":
193
  return [gr.update(visible=True), gr.update(visible=False)]
194
  else:
195
  return [gr.update(visible=False), gr.update(visible=True)]
196
 
197
 
198
+ # Create the Gradio interface
199
+ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
200
  with gr.Row():
201
  temperature = gr.Number(
202
  label="Temperature",
 
223
  with gr.Row():
224
  with gr.Column():
225
  input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
226
+ # audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
227
+ audio = gr.Audio(sources=["upload","microphone"], label="Input audio", type='filepath', show_download_button=True, visible=True)
228
+ # audio = gr.Audio(source="microphone", label="Input audio", type='filepath', show_download_button=True, visible=True)
229
  text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
230
 
231
  with gr.Column():
 
277
  reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio])
278
  input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
279
 
280
+ initialize_fn()
281
+ # Launch the interface
282
+ demo.launch(
283
+ server_port=args.port,
284
+ server_name=args.host,
285
+ ssl_verify=False,
286
+ share=True
287
+ )
288
+
289
+ '''
290
+ server.launch(share=True)
291
+ https://1a9b77cb89ac33f546.gradio.live
292
+ '''