smgc commited on
Commit
a639253
1 Parent(s): 548cbf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -54
app.py CHANGED
@@ -9,16 +9,27 @@ import logging
9
  from threading import Event
10
  import re
11
 
 
12
  app = Flask(__name__)
13
- logging.basicConfig(level=logging.INFO)
14
 
15
- # 从环境变量中获取API密钥
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  API_KEY = os.environ.get('PPLX_KEY')
17
 
18
  # 代理设置
19
  proxy_url = os.environ.get('PROXY_URL')
20
-
21
- # 设置代理
22
  if proxy_url:
23
  proxies = {
24
  'http': proxy_url,
@@ -29,7 +40,7 @@ if proxy_url:
29
  else:
30
  transport = None
31
 
32
- sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True)
33
 
34
  # 连接选项
35
  connect_opts = {
@@ -49,7 +60,7 @@ sio_opts = {
49
 
50
  def log_request(ip, route, status):
51
  timestamp = datetime.now().isoformat()
52
- logging.info(f"{timestamp} - {ip} - {route} - {status}")
53
 
54
  def validate_api_key():
55
  api_key = request.headers.get('x-api-key')
@@ -66,13 +77,10 @@ def normalize_content(content):
66
  if isinstance(content, str):
67
  return content
68
  elif isinstance(content, dict):
69
- # 将字典转化为 JSON 字符串
70
  return json.dumps(content, ensure_ascii=False)
71
  elif isinstance(content, list):
72
- # 对于列表,递归处理每个元素
73
  return " ".join([normalize_content(item) for item in content])
74
  else:
75
- # 如果是其他类型,返回空字符串
76
  return ""
77
 
78
  def calculate_tokens(text):
@@ -81,12 +89,9 @@ def calculate_tokens(text):
81
  - 对于英文和有空格的文本,使用空格分词。
82
  - 对于中文等没有空格的文本,使用字符级分词。
83
  """
84
- # 首先判断文本是否包含大量非 ASCII 字符(如中文)
85
  if re.search(r'[^\x00-\x7F]', text):
86
- # 如果包含非 ASCII 字符,使用字符级分词
87
  return len(text)
88
  else:
89
- # 否则使用空格分词
90
  tokens = text.split()
91
  return len(tokens)
92
 
@@ -120,13 +125,10 @@ def messages():
120
 
121
  try:
122
  json_body = request.json
123
- model = json_body.get('model', 'claude-3-opus-20240229') # 动态获取模型,默认 claude-3-opus-20240229
124
- stream = json_body.get('stream', True) # 默认为True
125
 
126
- # 使用 normalize_content 递归处理 msg['content']
127
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
128
-
129
- # 动态计算输入的 token 数量
130
  input_tokens = calculate_tokens(previous_messages)
131
 
132
  msg_id = str(uuid.uuid4())
@@ -134,10 +136,8 @@ def messages():
134
  response_text = []
135
 
136
  if not stream:
137
- # 处理 stream 为 false 的情况
138
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
139
 
140
- # 记录日志:此时请求上下文仍然有效
141
  log_request(request.remote_addr, request.path, 200)
142
 
143
  def generate():
@@ -148,17 +148,17 @@ def messages():
148
  "type": "message",
149
  "role": "assistant",
150
  "content": [],
151
- "model": model, # 动态模型
152
  "stop_reason": None,
153
  "stop_sequence": None,
154
- "usage": {"input_tokens": input_tokens, "output_tokens": 1}, # 动态 input_tokens
155
  },
156
  })
157
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
158
  yield create_event("ping", {"type": "ping"})
159
 
160
  def on_connect():
161
- logging.info("Connected to Perplexity AI")
162
  emit_data = {
163
  "version": "2.9",
164
  "source": "default",
@@ -183,25 +183,20 @@ def messages():
183
  if chunk:
184
  response_text.append(chunk)
185
 
186
- # 检查是否是最终响应
187
  if data.get('final', False):
188
  response_event.set()
189
 
190
- def on_query_complete(data):
191
- response_event.set()
192
-
193
  def on_disconnect():
194
- logging.info("Disconnected from Perplexity AI")
195
  response_event.set()
196
 
197
  def on_connect_error(data):
198
- logging.error(f"Connection error: {data}")
199
  response_text.append(f"Error connecting to Perplexity AI: {data}")
200
  response_event.set()
201
 
202
  sio.on('connect', on_connect)
203
  sio.on('query_progress', on_query_progress)
204
- sio.on('query_complete', on_query_complete)
205
  sio.on('disconnect', on_disconnect)
206
  sio.on('connect_error', on_connect_error)
207
 
@@ -219,7 +214,7 @@ def messages():
219
  })
220
 
221
  except Exception as e:
222
- logging.error(f"Error during socket connection: {str(e)}")
223
  yield create_event("content_block_delta", {
224
  "type": "content_block_delta",
225
  "index": 0,
@@ -229,34 +224,30 @@ def messages():
229
  if sio.connected:
230
  sio.disconnect()
231
 
232
- # 动态计算输出的 token 数量
233
  output_tokens = calculate_tokens(''.join(response_text))
234
 
235
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
236
  yield create_event("message_delta", {
237
  "type": "message_delta",
238
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
239
- "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, # 动态 output_tokens
240
  })
241
- yield create_event("message_stop", {"type": "message_stop"}) # 确保发送 message_stop 事件
242
 
243
  return Response(generate(), content_type='text/event-stream')
244
 
245
  except Exception as e:
246
- logging.error(f"Request error: {str(e)}")
247
  log_request(request.remote_addr, request.path, 400)
248
  return jsonify({"error": str(e)}), 400
249
 
250
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
251
- """
252
- 处理 stream 为 false 的情况,返回完整的响应。
253
- """
254
  try:
255
  response_event = Event()
256
  response_text = []
257
 
258
  def on_connect():
259
- logging.info("Connected to Perplexity AI")
260
  emit_data = {
261
  "version": "2.9",
262
  "source": "default",
@@ -281,16 +272,15 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
281
  if chunk:
282
  response_text.append(chunk)
283
 
284
- # 检查是否是最终响应
285
  if data.get('final', False):
286
  response_event.set()
287
 
288
  def on_disconnect():
289
- logging.info("Disconnected from Perplexity AI")
290
  response_event.set()
291
 
292
  def on_connect_error(data):
293
- logging.error(f"Connection error: {data}")
294
  response_text.append(f"Error connecting to Perplexity AI: {data}")
295
  response_event.set()
296
 
@@ -301,30 +291,26 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
301
 
302
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
303
 
304
- # 等待响应完成
305
  response_event.wait(timeout=30)
306
-
307
- # 动态计算输出的 token 数量
308
  output_tokens = calculate_tokens(''.join(response_text))
309
 
310
- # 生成完整的响应
311
  full_response = {
312
- "content": [{"text": ''.join(response_text), "type": "text"}], # 合并所有文本块
313
  "id": msg_id,
314
- "model": model, # 动态模型
315
  "role": "assistant",
316
  "stop_reason": "end_turn",
317
  "stop_sequence": None,
318
  "type": "message",
319
  "usage": {
320
- "input_tokens": input_tokens, # 动态 input_tokens
321
- "output_tokens": output_tokens, # 动态 output_tokens
322
  },
323
  }
324
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
325
 
326
  except Exception as e:
327
- logging.error(f"Error during socket connection: {str(e)}")
328
  return jsonify({"error": str(e)}), 500
329
  finally:
330
  if sio.connected:
@@ -337,18 +323,18 @@ def not_found(error):
337
 
338
  @app.errorhandler(500)
339
  def server_error(error):
340
- logging.error(f"Server error: {str(error)}")
341
  log_request(request.remote_addr, request.path, 500)
342
  return "Something broke!", 500
343
 
344
  def create_event(event, data):
345
  if isinstance(data, dict):
346
- data = json.dumps(data, ensure_ascii=False) # 确保中文不会被转义
347
  return f"event: {event}\ndata: {data}\n\n"
348
 
349
  if __name__ == '__main__':
350
  port = int(os.environ.get('PORT', 8081))
351
- logging.info(f"Perplexity proxy listening on port {port}")
352
  if not API_KEY:
353
- logging.warning("Warning: PPLX_KEY environment variable is not set. API key validation will fail.")
354
  app.run(host='0.0.0.0', port=port)
 
9
  from threading import Event
10
  import re
11
 
12
+ # 创建 Flask 应用
13
  app = Flask(__name__)
 
14
 
15
+ # 自定义日志格式
16
+ log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
17
+ logging.basicConfig(level=logging.INFO, format=log_format)
18
+
19
+ # 创建不同的日志记录器
20
+ app_logger = logging.getLogger('app')
21
+ sio_logger = logging.getLogger('socketio.client')
22
+ engineio_logger = logging.getLogger('engineio.client')
23
+
24
+ # 调整日志级别,隐藏过多的调试信息
25
+ sio_logger.setLevel(logging.WARNING)
26
+ engineio_logger.setLevel(logging.WARNING)
27
+
28
+ # 从环境变量中获取 API 密钥
29
  API_KEY = os.environ.get('PPLX_KEY')
30
 
31
  # 代理设置
32
  proxy_url = os.environ.get('PROXY_URL')
 
 
33
  if proxy_url:
34
  proxies = {
35
  'http': proxy_url,
 
40
  else:
41
  transport = None
42
 
43
+ sio = socketio.Client(http_session=transport, logger=sio_logger, engineio_logger=engineio_logger)
44
 
45
  # 连接选项
46
  connect_opts = {
 
60
 
61
  def log_request(ip, route, status):
62
  timestamp = datetime.now().isoformat()
63
+ app_logger.info(f"{timestamp} - {ip} - {route} - {status}")
64
 
65
  def validate_api_key():
66
  api_key = request.headers.get('x-api-key')
 
77
  if isinstance(content, str):
78
  return content
79
  elif isinstance(content, dict):
 
80
  return json.dumps(content, ensure_ascii=False)
81
  elif isinstance(content, list):
 
82
  return " ".join([normalize_content(item) for item in content])
83
  else:
 
84
  return ""
85
 
86
  def calculate_tokens(text):
 
89
  - 对于英文和有空格的文本,使用空格分词。
90
  - 对于中文等没有空格的文本,使用字符级分词。
91
  """
 
92
  if re.search(r'[^\x00-\x7F]', text):
 
93
  return len(text)
94
  else:
 
95
  tokens = text.split()
96
  return len(tokens)
97
 
 
125
 
126
  try:
127
  json_body = request.json
128
+ model = json_body.get('model', 'claude-3-opus-20240229')
129
+ stream = json_body.get('stream', True)
130
 
 
131
  previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
 
 
132
  input_tokens = calculate_tokens(previous_messages)
133
 
134
  msg_id = str(uuid.uuid4())
 
136
  response_text = []
137
 
138
  if not stream:
 
139
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
140
 
 
141
  log_request(request.remote_addr, request.path, 200)
142
 
143
  def generate():
 
148
  "type": "message",
149
  "role": "assistant",
150
  "content": [],
151
+ "model": model,
152
  "stop_reason": None,
153
  "stop_sequence": None,
154
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1},
155
  },
156
  })
157
  yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
158
  yield create_event("ping", {"type": "ping"})
159
 
160
  def on_connect():
161
+ app_logger.info("Connected to Perplexity AI")
162
  emit_data = {
163
  "version": "2.9",
164
  "source": "default",
 
183
  if chunk:
184
  response_text.append(chunk)
185
 
 
186
  if data.get('final', False):
187
  response_event.set()
188
 
 
 
 
189
  def on_disconnect():
190
+ app_logger.info("Disconnected from Perplexity AI")
191
  response_event.set()
192
 
193
  def on_connect_error(data):
194
+ app_logger.error(f"Connection error: {data}")
195
  response_text.append(f"Error connecting to Perplexity AI: {data}")
196
  response_event.set()
197
 
198
  sio.on('connect', on_connect)
199
  sio.on('query_progress', on_query_progress)
 
200
  sio.on('disconnect', on_disconnect)
201
  sio.on('connect_error', on_connect_error)
202
 
 
214
  })
215
 
216
  except Exception as e:
217
+ app_logger.error(f"Error during socket connection: {str(e)}")
218
  yield create_event("content_block_delta", {
219
  "type": "content_block_delta",
220
  "index": 0,
 
224
  if sio.connected:
225
  sio.disconnect()
226
 
 
227
  output_tokens = calculate_tokens(''.join(response_text))
228
 
229
  yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
230
  yield create_event("message_delta", {
231
  "type": "message_delta",
232
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
233
+ "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
234
  })
235
+ yield create_event("message_stop", {"type": "message_stop"})
236
 
237
  return Response(generate(), content_type='text/event-stream')
238
 
239
  except Exception as e:
240
+ app_logger.error(f"Request error: {str(e)}")
241
  log_request(request.remote_addr, request.path, 400)
242
  return jsonify({"error": str(e)}), 400
243
 
244
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
 
 
 
245
  try:
246
  response_event = Event()
247
  response_text = []
248
 
249
  def on_connect():
250
+ app_logger.info("Connected to Perplexity AI")
251
  emit_data = {
252
  "version": "2.9",
253
  "source": "default",
 
272
  if chunk:
273
  response_text.append(chunk)
274
 
 
275
  if data.get('final', False):
276
  response_event.set()
277
 
278
  def on_disconnect():
279
+ app_logger.info("Disconnected from Perplexity AI")
280
  response_event.set()
281
 
282
  def on_connect_error(data):
283
+ app_logger.error(f"Connection error: {data}")
284
  response_text.append(f"Error connecting to Perplexity AI: {data}")
285
  response_event.set()
286
 
 
291
 
292
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
293
 
 
294
  response_event.wait(timeout=30)
 
 
295
  output_tokens = calculate_tokens(''.join(response_text))
296
 
 
297
  full_response = {
298
+ "content": [{"text": ''.join(response_text), "type": "text"}],
299
  "id": msg_id,
300
+ "model": model,
301
  "role": "assistant",
302
  "stop_reason": "end_turn",
303
  "stop_sequence": None,
304
  "type": "message",
305
  "usage": {
306
+ "input_tokens": input_tokens,
307
+ "output_tokens": output_tokens,
308
  },
309
  }
310
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
311
 
312
  except Exception as e:
313
+ app_logger.error(f"Error during socket connection: {str(e)}")
314
  return jsonify({"error": str(e)}), 500
315
  finally:
316
  if sio.connected:
 
323
 
324
  @app.errorhandler(500)
325
  def server_error(error):
326
+ app_logger.error(f"Server error: {str(error)}")
327
  log_request(request.remote_addr, request.path, 500)
328
  return "Something broke!", 500
329
 
330
  def create_event(event, data):
331
  if isinstance(data, dict):
332
+ data = json.dumps(data, ensure_ascii=False)
333
  return f"event: {event}\ndata: {data}\n\n"
334
 
335
  if __name__ == '__main__':
336
  port = int(os.environ.get('PORT', 8081))
337
+ app_logger.info(f"Perplexity proxy listening on port {port}")
338
  if not API_KEY:
339
+ app_logger.warning("Warning: PPLX_KEY environment variable is not set. API key validation will fail.")
340
  app.run(host='0.0.0.0', port=port)