smgc commited on
Commit
6422859
1 Parent(s): 552ed6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -84
app.py CHANGED
@@ -11,18 +11,11 @@ from threading import Event
11
  app = Flask(__name__)
12
  logging.basicConfig(level=logging.INFO)
13
 
14
- # 从环境变量中获取API密钥
15
  API_KEY = os.environ.get('PPLX_KEY')
16
-
17
- # 代理设置
18
  proxy_url = os.environ.get('PROXY_URL')
19
 
20
- # 设置代理
21
  if proxy_url:
22
- proxies = {
23
- 'http': proxy_url,
24
- 'https': proxy_url
25
- }
26
  transport = requests.Session()
27
  transport.proxies.update(proxies)
28
  else:
@@ -30,12 +23,7 @@ else:
30
 
31
  sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True)
32
 
33
- # 连接选项
34
- connect_opts = {
35
- 'transports': ['websocket', 'polling'], # 允许回退到轮询
36
- }
37
-
38
- # 其他选项
39
  sio_opts = {
40
  'extraHeaders': {
41
  'Cookie': os.environ.get('PPLX_COOKIE'),
@@ -58,31 +46,24 @@ def validate_api_key():
58
  return None
59
 
60
  def normalize_content(content):
61
- """
62
- 递归处理 msg['content'],确保其为字符串。
63
- 如果 content 是字典或列表,将其转换为字符串。
64
- """
65
  if isinstance(content, str):
66
  return content
67
  elif isinstance(content, dict):
68
- # 将字典转化为 JSON 字符串
69
  return json.dumps(content, ensure_ascii=False)
70
  elif isinstance(content, list):
71
- # 对于列表,递归处理每个元素
72
  return " ".join([normalize_content(item) for item in content])
73
  else:
74
- # 如果是其他类型,返回空字符串
75
  return ""
76
 
77
  def calculate_tokens(text):
78
- """
79
- 计算文本的 token 数量。
80
- 这里我们简单地通过空格分词来模拟 token 计数。
81
- 如果使用 GPT 模型,可以使用 tiktoken 库进行 tokenization。
82
- """
83
- # 使用简单的空格分词计数
84
- tokens = text.split()
85
- return len(tokens)
86
 
87
  @app.route('/')
88
  def root():
@@ -114,13 +95,10 @@ def messages():
114
 
115
  try:
116
  json_body = request.json
117
- model = json_body.get('model', 'claude-3-opus-20240229') # 动态获取模型,默认 claude-3-opus-20240229
118
- stream = json_body.get('stream', True) # 默认为True
119
 
120
- # 使用 normalize_content 递归处理 msg['content']
121
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
122
-
123
- # 动态计算输入的 token 数量
124
  input_tokens = calculate_tokens(previous_messages)
125
 
126
  msg_id = str(uuid.uuid4())
@@ -128,10 +106,8 @@ def messages():
128
  response_text = []
129
 
130
  if not stream:
131
- # 处理 stream 为 false 的情况
132
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
133
 
134
- # 记录日志:此时请求上下文仍然有效
135
  log_request(request.remote_addr, request.path, 200)
136
 
137
  def generate():
@@ -142,10 +118,10 @@ def messages():
142
  "type": "message",
143
  "role": "assistant",
144
  "content": [],
145
- "model": model, # 动态模型
146
  "stop_reason": None,
147
  "stop_sequence": None,
148
- "usage": {"input_tokens": input_tokens, "output_tokens": 1}, # 动态 input_tokens
149
  },
150
  })
151
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
@@ -171,23 +147,16 @@ def messages():
171
 
172
  def on_query_progress(data):
173
  nonlocal response_text
174
- if 'text' in data:
175
- try:
176
- # 解析嵌套的 JSON 字符串
177
- text_content = json.loads(data['text']) # data['text'] 是一个 JSON 字符串
178
- chunk = text_content['chunks'][-1] if 'chunks' in text_content and text_content['chunks'] else None
179
  if chunk:
180
  response_text.append(chunk)
181
- except json.JSONDecodeError as e:
182
- logging.error(f"Failed to decode JSON from 'text' field: {e}")
183
- response_text.append(f"Error decoding response: {e}")
184
-
185
- # 检查是否是最终响应
186
- if data.get('final', False):
187
- response_event.set()
188
-
189
- def on_query_complete(data):
190
- response_event.set()
191
 
192
  def on_disconnect():
193
  logging.info("Disconnected from Perplexity AI")
@@ -200,7 +169,6 @@ def messages():
200
 
201
  sio.on('connect', on_connect)
202
  sio.on('query_progress', on_query_progress)
203
- sio.on('query_complete', on_query_complete)
204
  sio.on('disconnect', on_disconnect)
205
  sio.on('connect_error', on_connect_error)
206
 
@@ -211,11 +179,15 @@ def messages():
211
  sio.sleep(0.1)
212
  while response_text:
213
  chunk = response_text.pop(0)
214
- yield create_event("content_block_delta", {
215
  "type": "content_block_delta",
216
  "index": 0,
217
  "delta": {"type": "text_delta", "text": chunk},
218
- })
 
 
 
 
219
 
220
  except Exception as e:
221
  logging.error(f"Error during socket connection: {str(e)}")
@@ -228,16 +200,15 @@ def messages():
228
  if sio.connected:
229
  sio.disconnect()
230
 
231
- # 动态计算输出的 token 数量
232
  output_tokens = calculate_tokens(''.join(response_text))
233
 
234
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
235
  yield create_event("message_delta", {
236
  "type": "message_delta",
237
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
238
- "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, # 动态 output_tokens
239
  })
240
- yield create_event("message_stop", {"type": "message_stop"}) # 确保发送 message_stop 事件
241
 
242
  return Response(generate(), content_type='text/event-stream')
243
 
@@ -247,9 +218,6 @@ def messages():
247
  return jsonify({"error": str(e)}), 400
248
 
249
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
250
- """
251
- 处理 stream 为 false 的情况,返回完整的响应。
252
- """
253
  try:
254
  response_event = Event()
255
  response_text = []
@@ -274,20 +242,16 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
274
 
275
  def on_query_progress(data):
276
  nonlocal response_text
277
- if 'text' in data:
278
- try:
279
- # 解析嵌套的 JSON 字符串
280
- text_content = json.loads(data['text']) # data['text'] 是一个 JSON 字符串
281
- chunk = text_content['chunks'][-1] if 'chunks' in text_content and text_content['chunks'] else None
282
  if chunk:
283
  response_text.append(chunk)
284
- except json.JSONDecodeError as e:
285
- logging.error(f"Failed to decode JSON from 'text' field: {e}")
286
- response_text.append(f"Error decoding response: {e}")
287
-
288
- # 检查是否是最终响应
289
- if data.get('final', False):
290
- response_event.set()
291
 
292
  def on_disconnect():
293
  logging.info("Disconnected from Perplexity AI")
@@ -305,26 +269,28 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
305
 
306
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
307
 
308
- # 等待响应完成
309
  response_event.wait(timeout=30)
310
 
311
- # 动态计算输出的 token 数量
312
  output_tokens = calculate_tokens(''.join(response_text))
313
 
314
- # 生成完整的响应
315
  full_response = {
316
- "content": [{"text": ''.join(response_text), "type": "text"}], # 合并所有文本块
317
  "id": msg_id,
318
- "model": model, # 动态模型
319
  "role": "assistant",
320
  "stop_reason": "end_turn",
321
  "stop_sequence": None,
322
  "type": "message",
323
  "usage": {
324
- "input_tokens": input_tokens, # 动态 input_tokens
325
- "output_tokens": output_tokens, # 动态 output_tokens
326
  },
327
  }
 
 
 
 
 
328
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
329
 
330
  except Exception as e:
@@ -346,9 +312,13 @@ def server_error(error):
346
  return "Something broke!", 500
347
 
348
  def create_event(event, data):
349
- if isinstance(data, dict):
350
- data = json.dumps(data, ensure_ascii=False) # 确保中文不会被转义
351
- return f"event: {event}\ndata: {data}\n\n"
 
 
 
 
352
 
353
  if __name__ == '__main__':
354
  port = int(os.environ.get('PORT', 8081))
 
11
  app = Flask(__name__)
12
  logging.basicConfig(level=logging.INFO)
13
 
 
14
  API_KEY = os.environ.get('PPLX_KEY')
 
 
15
  proxy_url = os.environ.get('PROXY_URL')
16
 
 
17
  if proxy_url:
18
+ proxies = {'http': proxy_url, 'https': proxy_url}
 
 
 
19
  transport = requests.Session()
20
  transport.proxies.update(proxies)
21
  else:
 
23
 
24
  sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True)
25
 
26
+ connect_opts = {'transports': ['websocket', 'polling']}
 
 
 
 
 
27
  sio_opts = {
28
  'extraHeaders': {
29
  'Cookie': os.environ.get('PPLX_COOKIE'),
 
46
  return None
47
 
48
  def normalize_content(content):
 
 
 
 
49
  if isinstance(content, str):
50
  return content
51
  elif isinstance(content, dict):
 
52
  return json.dumps(content, ensure_ascii=False)
53
  elif isinstance(content, list):
 
54
  return " ".join([normalize_content(item) for item in content])
55
  else:
 
56
  return ""
57
 
58
  def calculate_tokens(text):
59
+ return len(text.split())
60
+
61
+ def validate_json(data):
62
+ try:
63
+ json.loads(json.dumps(data))
64
+ return True
65
+ except json.JSONDecodeError:
66
+ return False
67
 
68
  @app.route('/')
69
  def root():
 
95
 
96
  try:
97
  json_body = request.json
98
+ model = json_body.get('model', 'claude-3-opus-20240229')
99
+ stream = json_body.get('stream', True)
100
 
 
101
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
 
 
102
  input_tokens = calculate_tokens(previous_messages)
103
 
104
  msg_id = str(uuid.uuid4())
 
106
  response_text = []
107
 
108
  if not stream:
 
109
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
110
 
 
111
  log_request(request.remote_addr, request.path, 200)
112
 
113
  def generate():
 
118
  "type": "message",
119
  "role": "assistant",
120
  "content": [],
121
+ "model": model,
122
  "stop_reason": None,
123
  "stop_sequence": None,
124
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1},
125
  },
126
  })
127
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
 
147
 
148
  def on_query_progress(data):
149
  nonlocal response_text
150
+ try:
151
+ if 'text' in data:
152
+ text = json.loads(data['text'])
153
+ chunk = text['chunks'][-1] if text['chunks'] else None
 
154
  if chunk:
155
  response_text.append(chunk)
156
+ if data.get('final', False):
157
+ response_event.set()
158
+ except json.JSONDecodeError:
159
+ logging.error(f"Failed to parse query progress data: {data}")
 
 
 
 
 
 
160
 
161
  def on_disconnect():
162
  logging.info("Disconnected from Perplexity AI")
 
169
 
170
  sio.on('connect', on_connect)
171
  sio.on('query_progress', on_query_progress)
 
172
  sio.on('disconnect', on_disconnect)
173
  sio.on('connect_error', on_connect_error)
174
 
 
179
  sio.sleep(0.1)
180
  while response_text:
181
  chunk = response_text.pop(0)
182
+ event_data = {
183
  "type": "content_block_delta",
184
  "index": 0,
185
  "delta": {"type": "text_delta", "text": chunk},
186
+ }
187
+ if validate_json(event_data):
188
+ yield create_event("content_block_delta", event_data)
189
+ else:
190
+ logging.error(f"Invalid JSON for content_block_delta: {event_data}")
191
 
192
  except Exception as e:
193
  logging.error(f"Error during socket connection: {str(e)}")
 
200
  if sio.connected:
201
  sio.disconnect()
202
 
 
203
  output_tokens = calculate_tokens(''.join(response_text))
204
 
205
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
206
  yield create_event("message_delta", {
207
  "type": "message_delta",
208
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
209
+ "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
210
  })
211
+ yield create_event("message_stop", {"type": "message_stop"})
212
 
213
  return Response(generate(), content_type='text/event-stream')
214
 
 
218
  return jsonify({"error": str(e)}), 400
219
 
220
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
 
 
 
221
  try:
222
  response_event = Event()
223
  response_text = []
 
242
 
243
  def on_query_progress(data):
244
  nonlocal response_text
245
+ try:
246
+ if 'text' in data:
247
+ text = json.loads(data['text'])
248
+ chunk = text['chunks'][-1] if text['chunks'] else None
 
249
  if chunk:
250
  response_text.append(chunk)
251
+ if data.get('final', False):
252
+ response_event.set()
253
+ except json.JSONDecodeError:
254
+ logging.error(f"Failed to parse query progress data: {data}")
 
 
 
255
 
256
  def on_disconnect():
257
  logging.info("Disconnected from Perplexity AI")
 
269
 
270
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
271
 
 
272
  response_event.wait(timeout=30)
273
 
 
274
  output_tokens = calculate_tokens(''.join(response_text))
275
 
 
276
  full_response = {
277
+ "content": [{"text": ''.join(response_text), "type": "text"}],
278
  "id": msg_id,
279
+ "model": model,
280
  "role": "assistant",
281
  "stop_reason": "end_turn",
282
  "stop_sequence": None,
283
  "type": "message",
284
  "usage": {
285
+ "input_tokens": input_tokens,
286
+ "output_tokens": output_tokens,
287
  },
288
  }
289
+
290
+ if not validate_json(full_response):
291
+ logging.error(f"Invalid JSON response: {full_response}")
292
+ return jsonify({"error": "Invalid response format"}), 500
293
+
294
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
295
 
296
  except Exception as e:
 
312
  return "Something broke!", 500
313
 
314
  def create_event(event, data):
315
+ try:
316
+ if isinstance(data, dict):
317
+ data = json.dumps(data, ensure_ascii=False)
318
+ return f"event: {event}\ndata: {data}\n\n"
319
+ except json.JSONDecodeError:
320
+ logging.error(f"Failed to serialize event data: {data}")
321
+ return f"event: {event}\ndata: {json.dumps({'error': 'Data serialization failed'})}\n\n"
322
 
323
  if __name__ == '__main__':
324
  port = int(os.environ.get('PORT', 8081))