smgc commited on
Commit
9ad452e
1 Parent(s): cd5aca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -27
app.py CHANGED
@@ -59,19 +59,34 @@ def validate_api_key():
59
  return None
60
 
61
  def normalize_content(content):
 
 
 
 
62
  if isinstance(content, str):
63
  return content
64
  elif isinstance(content, dict):
 
65
  return json.dumps(content, ensure_ascii=False)
66
  elif isinstance(content, list):
 
67
  return " ".join([normalize_content(item) for item in content])
68
  else:
 
69
  return ""
70
 
71
  def calculate_tokens(text):
 
 
 
 
 
 
72
  if re.search(r'[^\x00-\x7F]', text):
 
73
  return len(text)
74
  else:
 
75
  tokens = text.split()
76
  return len(tokens)
77
 
@@ -105,10 +120,13 @@ def messages():
105
 
106
  try:
107
  json_body = request.json
108
- model = json_body.get('model', 'claude-3-opus-20240229')
109
- stream = json_body.get('stream', True)
110
 
 
111
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
 
 
112
  input_tokens = calculate_tokens(previous_messages)
113
 
114
  msg_id = str(uuid.uuid4())
@@ -116,13 +134,13 @@ def messages():
116
  response_text = []
117
 
118
  if not stream:
 
119
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
120
 
 
121
  log_request(request.remote_addr, request.path, 200)
122
 
123
  def generate():
124
- output_tokens = 0 # 初始化 output_tokens
125
-
126
  yield create_event("message_start", {
127
  "type": "message_start",
128
  "message": {
@@ -130,10 +148,10 @@ def messages():
130
  "type": "message",
131
  "role": "assistant",
132
  "content": [],
133
- "model": model,
134
  "stop_reason": None,
135
  "stop_sequence": None,
136
- "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
137
  },
138
  })
139
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
@@ -158,24 +176,14 @@ def messages():
158
  sio.emit('perplexity_ask', (previous_messages, emit_data))
159
 
160
  def on_query_progress(data):
161
- nonlocal response_text, output_tokens
162
  if 'text' in data:
163
  text = json.loads(data['text'])
164
  chunk = text['chunks'][-1] if text['chunks'] else None
165
  if chunk:
166
  response_text.append(chunk)
167
- chunk_tokens = calculate_tokens(chunk)
168
- output_tokens += chunk_tokens
169
- yield create_event("content_block_delta", {
170
- "type": "content_block_delta",
171
- "index": 0,
172
- "delta": {"type": "text_delta", "text": chunk},
173
- })
174
- yield create_event("message_delta", {
175
- "type": "message_delta",
176
- "delta": {"usage": {"output_tokens": output_tokens}},
177
- })
178
 
 
179
  if data.get('final', False):
180
  response_event.set()
181
 
@@ -204,7 +212,11 @@ def messages():
204
  sio.sleep(0.1)
205
  while response_text:
206
  chunk = response_text.pop(0)
207
- yield from on_query_progress({"text": json.dumps({"chunks": [chunk]})})
 
 
 
 
208
 
209
  except Exception as e:
210
  logging.error(f"Error during socket connection: {str(e)}")
@@ -217,13 +229,16 @@ def messages():
217
  if sio.connected:
218
  sio.disconnect()
219
 
 
 
 
220
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
221
  yield create_event("message_delta", {
222
  "type": "message_delta",
223
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
224
- "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
225
  })
226
- yield create_event("message_stop", {"type": "message_stop"})
227
 
228
  return Response(generate(), content_type='text/event-stream')
229
 
@@ -233,6 +248,9 @@ def messages():
233
  return jsonify({"error": str(e)}), 400
234
 
235
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
 
 
 
236
  try:
237
  response_event = Event()
238
  response_text = []
@@ -263,6 +281,7 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
263
  if chunk:
264
  response_text.append(chunk)
265
 
 
266
  if data.get('final', False):
267
  response_event.set()
268
 
@@ -282,21 +301,24 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
282
 
283
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
284
 
 
285
  response_event.wait(timeout=30)
286
 
 
287
  output_tokens = calculate_tokens(''.join(response_text))
288
 
 
289
  full_response = {
290
- "content": [{"text": ''.join(response_text), "type": "text"}],
291
  "id": msg_id,
292
- "model": model,
293
  "role": "assistant",
294
  "stop_reason": "end_turn",
295
  "stop_sequence": None,
296
  "type": "message",
297
  "usage": {
298
- "input_tokens": input_tokens,
299
- "output_tokens": output_tokens,
300
  },
301
  }
302
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
@@ -321,10 +343,9 @@ def server_error(error):
321
 
322
  def create_event(event, data):
323
  if isinstance(data, dict):
324
- data = json.dumps(data, ensure_ascii=False)
325
  return f"event: {event}\ndata: {data}\n\n"
326
 
327
-
328
  if __name__ == '__main__':
329
  port = int(os.environ.get('PORT', 8081))
330
  logging.info(f"Perplexity proxy listening on port {port}")
 
59
  return None
60
 
61
  def normalize_content(content):
62
+ """
63
+ 递归处理 msg['content'],确保其为字符串。
64
+ 如果 content 是字典或列表,将其转换为字符串。
65
+ """
66
  if isinstance(content, str):
67
  return content
68
  elif isinstance(content, dict):
69
+ # 将字典转化为 JSON 字符串
70
  return json.dumps(content, ensure_ascii=False)
71
  elif isinstance(content, list):
72
+ # 对于列表,递归处理每个元素
73
  return " ".join([normalize_content(item) for item in content])
74
  else:
75
+ # 如果是其他类型,返回空字符串
76
  return ""
77
 
78
  def calculate_tokens(text):
79
+ """
80
+ 改进的 token 计算方法。
81
+ - 对于英文和有空格的文本,使用空格分词。
82
+ - 对于中文等没有空格的文本,使用字符级分词。
83
+ """
84
+ # 首先判断文本是否包含大量非 ASCII 字符(如中文)
85
  if re.search(r'[^\x00-\x7F]', text):
86
+ # 如果包含非 ASCII 字符,使用字符级分词
87
  return len(text)
88
  else:
89
+ # 否则使用空格分词
90
  tokens = text.split()
91
  return len(tokens)
92
 
 
120
 
121
  try:
122
  json_body = request.json
123
+ model = json_body.get('model', 'claude-3-opus-20240229') # 动态获取模型,默认 claude-3-opus-20240229
124
+ stream = json_body.get('stream', True) # 默认为True
125
 
126
+ # 使用 normalize_content 递归处理 msg['content']
127
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
128
+
129
+ # 动态计算输入的 token 数量
130
  input_tokens = calculate_tokens(previous_messages)
131
 
132
  msg_id = str(uuid.uuid4())
 
134
  response_text = []
135
 
136
  if not stream:
137
+ # 处理 stream 为 false 的情况
138
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
139
 
140
+ # 记录日志:此时请求上下文仍然有效
141
  log_request(request.remote_addr, request.path, 200)
142
 
143
  def generate():
 
 
144
  yield create_event("message_start", {
145
  "type": "message_start",
146
  "message": {
 
148
  "type": "message",
149
  "role": "assistant",
150
  "content": [],
151
+ "model": model, # 动态模型
152
  "stop_reason": None,
153
  "stop_sequence": None,
154
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1}, # 动态 input_tokens
155
  },
156
  })
157
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
 
176
  sio.emit('perplexity_ask', (previous_messages, emit_data))
177
 
178
  def on_query_progress(data):
179
+ nonlocal response_text
180
  if 'text' in data:
181
  text = json.loads(data['text'])
182
  chunk = text['chunks'][-1] if text['chunks'] else None
183
  if chunk:
184
  response_text.append(chunk)
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # 检查是否是最终响应
187
  if data.get('final', False):
188
  response_event.set()
189
 
 
212
  sio.sleep(0.1)
213
  while response_text:
214
  chunk = response_text.pop(0)
215
+ yield create_event("content_block_delta", {
216
+ "type": "content_block_delta",
217
+ "index": 0,
218
+ "delta": {"type": "text_delta", "text": chunk},
219
+ })
220
 
221
  except Exception as e:
222
  logging.error(f"Error during socket connection: {str(e)}")
 
229
  if sio.connected:
230
  sio.disconnect()
231
 
232
+ # 动态计算输出的 token 数量
233
+ output_tokens = calculate_tokens(''.join(response_text))
234
+
235
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
236
  yield create_event("message_delta", {
237
  "type": "message_delta",
238
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
239
+ "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, # 动态 output_tokens
240
  })
241
+ yield create_event("message_stop", {"type": "message_stop"}) # 确保发送 message_stop 事件
242
 
243
  return Response(generate(), content_type='text/event-stream')
244
 
 
248
  return jsonify({"error": str(e)}), 400
249
 
250
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
251
+ """
252
+ 处理 stream 为 false 的情况,返回完整的响应。
253
+ """
254
  try:
255
  response_event = Event()
256
  response_text = []
 
281
  if chunk:
282
  response_text.append(chunk)
283
 
284
+ # 检查是否是最终响应
285
  if data.get('final', False):
286
  response_event.set()
287
 
 
301
 
302
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
303
 
304
+ # 等待响应完成
305
  response_event.wait(timeout=30)
306
 
307
+ # 动态计算输出的 token 数量
308
  output_tokens = calculate_tokens(''.join(response_text))
309
 
310
+ # 生成完整的响应
311
  full_response = {
312
+ "content": [{"text": ''.join(response_text), "type": "text"}], # 合并所有文本块
313
  "id": msg_id,
314
+ "model": model, # 动态模型
315
  "role": "assistant",
316
  "stop_reason": "end_turn",
317
  "stop_sequence": None,
318
  "type": "message",
319
  "usage": {
320
+ "input_tokens": input_tokens, # 动态 input_tokens
321
+ "output_tokens": output_tokens, # 动态 output_tokens
322
  },
323
  }
324
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
 
343
 
344
  def create_event(event, data):
345
  if isinstance(data, dict):
346
+ data = json.dumps(data, ensure_ascii=False) # 确保中文不会被转义
347
  return f"event: {event}\ndata: {data}\n\n"
348
 
 
349
  if __name__ == '__main__':
350
  port = int(os.environ.get('PORT', 8081))
351
  logging.info(f"Perplexity proxy listening on port {port}")