smgc commited on
Commit
511f8f6
1 Parent(s): b9e956c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -110
app.py CHANGED
@@ -8,6 +8,7 @@ import requests
8
  import logging
9
  from threading import Event
10
  import re
 
11
 
12
  # 创建 Flask 应用
13
  app = Flask(__name__)
@@ -15,37 +16,21 @@ app = Flask(__name__)
15
  # 自定义日志格式
16
  log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
17
  logging.basicConfig(level=logging.INFO, format=log_format)
18
-
19
- # 创建不同的日志记录器
20
  app_logger = logging.getLogger('app')
21
- sio_logger = logging.getLogger('socketio.client')
22
- engineio_logger = logging.getLogger('engineio.client')
23
-
24
- # 调整日志级别,隐藏过多的调试信息
25
- sio_logger.setLevel(logging.WARNING)
26
- engineio_logger.setLevel(logging.WARNING)
27
 
28
  # 从环境变量中获取 API 密钥
29
  API_KEY = os.environ.get('PPLX_KEY')
30
 
31
  # 代理设置
32
  proxy_url = os.environ.get('PROXY_URL')
 
33
  if proxy_url:
34
- proxies = {
35
- 'http': proxy_url,
36
- 'https': proxy_url
37
- }
38
- transport = requests.Session()
39
- transport.proxies.update(proxies)
40
- else:
41
- transport = None
42
 
43
- sio = socketio.Client(http_session=transport, logger=sio_logger, engineio_logger=engineio_logger)
44
 
45
  # 连接选项
46
- connect_opts = {
47
- 'transports': ['websocket', 'polling'], # 允许回退到轮询
48
- }
49
 
50
  # 其他选项
51
  sio_opts = {
@@ -58,52 +43,45 @@ sio_opts = {
58
  }
59
  }
60
 
61
- def log_request(ip, route, status):
62
- timestamp = datetime.now().isoformat()
63
- app_logger.info(f"{timestamp} - {ip} - {route} - {status}")
64
-
65
- def validate_api_key():
66
- api_key = request.headers.get('x-api-key')
67
- if api_key != API_KEY:
68
- log_request(request.remote_addr, request.path, 401)
69
- return create_json_response({"error": "Invalid API key"}), 401
70
- return None
 
 
 
 
 
 
 
 
 
71
 
72
  def normalize_content(content):
73
- """
74
- 递归处理 msg['content'],确保其为字符串。
75
- 如果 content 是字典或列表,将其转换为字符串。
76
- """
77
  if isinstance(content, str):
78
  return content
79
- elif isinstance(content, dict):
80
  return json.dumps(content, ensure_ascii=False)
81
- elif isinstance(content, list):
82
- return " ".join([normalize_content(item) for item in content])
83
- else:
84
- return ""
85
 
86
  def calculate_tokens(text):
87
- """
88
- 改进的 token 计算方法。
89
- - 对于英文和有空格的文本,使用空格分词。
90
- - 对于中文等没有空格的文本,使用字符级分词。
91
- """
92
- if re.search(r'[^\x00-\x7F]', text):
93
- return len(text)
94
- else:
95
- tokens = text.split()
96
- return len(tokens)
97
-
98
- def create_json_response(data):
99
- """创建一个JSON响应,并在发送之前记录它"""
100
- json_str = json.dumps(data, ensure_ascii=False)
101
- app_logger.info(f"Sending JSON response: {json_str}")
102
- return Response(json_str, content_type='application/json')
103
 
104
  @app.route('/')
 
105
  def root():
106
- log_request(request.remote_addr, request.path, 200)
107
  return create_json_response({
108
  "message": "Welcome to the Perplexity AI Proxy API",
109
  "endpoints": {
@@ -124,11 +102,9 @@ def root():
124
  })
125
 
126
  @app.route('/ai/v1/messages', methods=['POST'])
 
 
127
  def messages():
128
- auth_error = validate_api_key()
129
- if auth_error:
130
- return auth_error
131
-
132
  try:
133
  json_body = request.json
134
  model = json_body.get('model', 'claude-3-opus-20240229')
@@ -144,11 +120,9 @@ def messages():
144
  if not stream:
145
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
146
 
147
- log_request(request.remote_addr, request.path, 200)
148
-
149
  def generate():
150
  try:
151
- yield create_event("message_start", {
152
  "type": "message_start",
153
  "message": {
154
  "id": msg_id,
@@ -161,9 +135,12 @@ def messages():
161
  "usage": {"input_tokens": input_tokens, "output_tokens": 1},
162
  },
163
  })
164
- yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
165
- yield create_event("ping", {"type": "ping"})
166
 
 
 
 
167
  def on_connect():
168
  app_logger.info("Connected to Perplexity AI")
169
  emit_data = {
@@ -182,73 +159,67 @@ def messages():
182
  }
183
  sio.emit('perplexity_ask', (previous_messages, emit_data))
184
 
 
185
  def on_query_progress(data):
186
- nonlocal response_text
187
  if 'text' in data:
188
  text = json.loads(data['text'])
189
  chunk = text['chunks'][-1] if text['chunks'] else None
190
  if chunk:
191
  response_text.append(chunk)
 
 
 
 
 
192
 
193
  if data.get('final', False):
194
  response_event.set()
195
 
 
196
  def on_disconnect():
197
  app_logger.info("Disconnected from Perplexity AI")
198
  response_event.set()
199
 
 
200
  def on_connect_error(data):
201
  app_logger.error(f"Connection error: {data}")
202
- response_text.append(f"Error connecting to Perplexity AI: {data}")
203
  response_event.set()
204
 
205
- sio.on('connect', on_connect)
206
- sio.on('query_progress', on_query_progress)
207
- sio.on('disconnect', on_disconnect)
208
- sio.on('connect_error', on_connect_error)
209
-
210
- sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
211
-
212
  while not response_event.is_set():
213
  sio.sleep(0.1)
214
- while response_text:
215
- chunk = response_text.pop(0)
216
- app_logger.info(f"Returning chunk to client: {chunk}")
217
- yield create_event("content_block_delta", {
218
- "type": "content_block_delta",
219
- "index": 0,
220
- "delta": {"type": "text_delta", "text": chunk},
221
- })
222
-
223
  output_tokens = calculate_tokens(''.join(response_text))
224
 
225
- yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
226
- yield create_event("message_delta", {
227
  "type": "message_delta",
228
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
229
  "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
230
  })
231
- yield create_event("message_stop", {"type": "message_stop"})
232
 
233
  except Exception as e:
234
  app_logger.error(f"Error in generate function: {str(e)}")
235
- yield create_event("error", {"type": "error", "message": str(e)})
236
  finally:
237
  if sio.connected:
238
  sio.disconnect()
239
 
240
- return Response(generate(), content_type='text/event-stream')
241
 
242
  except Exception as e:
243
  app_logger.error(f"Request error: {str(e)}")
244
- log_request(request.remote_addr, request.path, 400)
245
- return create_json_response({"error": str(e)}), 400
246
 
247
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
248
  try:
249
  response_event = Event()
250
  response_text = []
251
 
 
 
 
252
  def on_connect():
253
  app_logger.info("Connected to Perplexity AI")
254
  emit_data = {
@@ -267,8 +238,8 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
267
  }
268
  sio.emit('perplexity_ask', (previous_messages, emit_data))
269
 
 
270
  def on_query_progress(data):
271
- nonlocal response_text
272
  if 'text' in data:
273
  text = json.loads(data['text'])
274
  chunk = text['chunks'][-1] if text['chunks'] else None
@@ -278,22 +249,17 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
278
  if data.get('final', False):
279
  response_event.set()
280
 
 
281
  def on_disconnect():
282
  app_logger.info("Disconnected from Perplexity AI")
283
  response_event.set()
284
 
 
285
  def on_connect_error(data):
286
  app_logger.error(f"Connection error: {data}")
287
  response_text.append(f"Error connecting to Perplexity AI: {data}")
288
  response_event.set()
289
 
290
- sio.on('connect', on_connect)
291
- sio.on('query_progress', on_query_progress)
292
- sio.on('disconnect', on_disconnect)
293
- sio.on('connect_error', on_connect_error)
294
-
295
- sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
296
-
297
  response_event.wait(timeout=30)
298
  output_tokens = calculate_tokens(''.join(response_text))
299
 
@@ -315,27 +281,24 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
315
 
316
  except Exception as e:
317
  app_logger.error(f"Error during socket connection: {str(e)}")
318
- return create_json_response({"error": str(e)}), 500
319
  finally:
320
  if sio.connected:
321
  sio.disconnect()
322
 
323
  @app.errorhandler(404)
324
  def not_found(error):
325
- log_request(request.remote_addr, request.path, 404)
326
- return create_json_response({"error": "Not Found"}), 404
327
 
328
  @app.errorhandler(500)
329
  def server_error(error):
330
  app_logger.error(f"Server error: {str(error)}")
331
- log_request(request.remote_addr, request.path, 500)
332
- return create_json_response({"error": "Internal Server Error"}), 500
333
-
334
- def create_event(event, data):
335
- if isinstance(data, dict):
336
- data = json.dumps(data, ensure_ascii=False)
337
- event_str = f"event: {event}\ndata: {data}\n\n"
338
- app_logger.info(f"Sending SSE event: {event_str}")
339
  return event_str
340
 
341
  if __name__ == '__main__':
@@ -343,4 +306,4 @@ if __name__ == '__main__':
343
  app_logger.info(f"Perplexity proxy listening on port {port}")
344
  if not API_KEY:
345
  app_logger.warning("Warning: PPLX_KEY environment variable is not set. API key validation will fail.")
346
- app.run(host='0.0.0.0', port=port)
 
8
  import logging
9
  from threading import Event
10
  import re
11
+ from functools import wraps
12
 
13
  # 创建 Flask 应用
14
  app = Flask(__name__)
 
16
  # 自定义日志格式
17
  log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
18
  logging.basicConfig(level=logging.INFO, format=log_format)
 
 
19
  app_logger = logging.getLogger('app')
 
 
 
 
 
 
20
 
21
  # 从环境变量中获取 API 密钥
22
  API_KEY = os.environ.get('PPLX_KEY')
23
 
24
  # 代理设置
25
  proxy_url = os.environ.get('PROXY_URL')
26
+ transport = requests.Session()
27
  if proxy_url:
28
+ transport.proxies.update({'http': proxy_url, 'https': proxy_url})
 
 
 
 
 
 
 
29
 
30
+ sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True)
31
 
32
  # 连接选项
33
+ connect_opts = {'transports': ['websocket', 'polling']}
 
 
34
 
35
  # 其他选项
36
  sio_opts = {
 
43
  }
44
  }
45
 
46
+ def log_request(func):
47
+ @wraps(func)
48
+ def wrapper(*args, **kwargs):
49
+ start_time = datetime.now()
50
+ response = func(*args, **kwargs)
51
+ duration = (datetime.now() - start_time).total_seconds()
52
+ app_logger.info(f"{request.remote_addr} - {request.method} {request.path} - {response.status_code} - {duration:.2f}s")
53
+ return response
54
+ return wrapper
55
+
56
+ def validate_api_key(func):
57
+ @wraps(func)
58
+ def wrapper(*args, **kwargs):
59
+ api_key = request.headers.get('x-api-key')
60
+ if api_key != API_KEY:
61
+ app_logger.warning(f"Invalid API key attempt from {request.remote_addr}")
62
+ return jsonify({"error": "Invalid API key"}), 401
63
+ return func(*args, **kwargs)
64
+ return wrapper
65
 
66
  def normalize_content(content):
 
 
 
 
67
  if isinstance(content, str):
68
  return content
69
+ elif isinstance(content, (dict, list)):
70
  return json.dumps(content, ensure_ascii=False)
71
+ return str(content)
 
 
 
72
 
73
  def calculate_tokens(text):
74
+ return len(re.findall(r'\w+|[^\w\s]', text, re.UNICODE))
75
+
76
+ def create_json_response(data, status_code=200):
77
+ response = jsonify(data)
78
+ response.status_code = status_code
79
+ app_logger.debug(f"Sending JSON response: {json.dumps(data, ensure_ascii=False)}")
80
+ return response
 
 
 
 
 
 
 
 
 
81
 
82
  @app.route('/')
83
+ @log_request
84
  def root():
 
85
  return create_json_response({
86
  "message": "Welcome to the Perplexity AI Proxy API",
87
  "endpoints": {
 
102
  })
103
 
104
  @app.route('/ai/v1/messages', methods=['POST'])
105
+ @log_request
106
+ @validate_api_key
107
  def messages():
 
 
 
 
108
  try:
109
  json_body = request.json
110
  model = json_body.get('model', 'claude-3-opus-20240229')
 
120
  if not stream:
121
  return handle_non_stream(previous_messages, msg_id, model, input_tokens)
122
 
 
 
123
  def generate():
124
  try:
125
+ yield create_sse_event("message_start", {
126
  "type": "message_start",
127
  "message": {
128
  "id": msg_id,
 
135
  "usage": {"input_tokens": input_tokens, "output_tokens": 1},
136
  },
137
  })
138
+ yield create_sse_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
139
+ yield create_sse_event("ping", {"type": "ping"})
140
 
141
+ sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
142
+
143
+ @sio.on('connect')
144
  def on_connect():
145
  app_logger.info("Connected to Perplexity AI")
146
  emit_data = {
 
159
  }
160
  sio.emit('perplexity_ask', (previous_messages, emit_data))
161
 
162
+ @sio.on('query_progress')
163
  def on_query_progress(data):
 
164
  if 'text' in data:
165
  text = json.loads(data['text'])
166
  chunk = text['chunks'][-1] if text['chunks'] else None
167
  if chunk:
168
  response_text.append(chunk)
169
+ yield create_sse_event("content_block_delta", {
170
+ "type": "content_block_delta",
171
+ "index": 0,
172
+ "delta": {"type": "text_delta", "text": chunk},
173
+ })
174
 
175
  if data.get('final', False):
176
  response_event.set()
177
 
178
+ @sio.on('disconnect')
179
  def on_disconnect():
180
  app_logger.info("Disconnected from Perplexity AI")
181
  response_event.set()
182
 
183
+ @sio.on('connect_error')
184
  def on_connect_error(data):
185
  app_logger.error(f"Connection error: {data}")
186
+ yield create_sse_event("error", {"type": "error", "message": f"Error connecting to Perplexity AI: {data}"})
187
  response_event.set()
188
 
 
 
 
 
 
 
 
189
  while not response_event.is_set():
190
  sio.sleep(0.1)
191
+
 
 
 
 
 
 
 
 
192
  output_tokens = calculate_tokens(''.join(response_text))
193
 
194
+ yield create_sse_event("content_block_stop", {"type": "content_block_stop", "index": 0})
195
+ yield create_sse_event("message_delta", {
196
  "type": "message_delta",
197
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
198
  "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
199
  })
200
+ yield create_sse_event("message_stop", {"type": "message_stop"})
201
 
202
  except Exception as e:
203
  app_logger.error(f"Error in generate function: {str(e)}")
204
+ yield create_sse_event("error", {"type": "error", "message": str(e)})
205
  finally:
206
  if sio.connected:
207
  sio.disconnect()
208
 
209
+ return Response(generate(), content_type='text/event-stream; charset=utf-8')
210
 
211
  except Exception as e:
212
  app_logger.error(f"Request error: {str(e)}")
213
+ return create_json_response({"error": str(e)}, 400)
 
214
 
215
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
216
  try:
217
  response_event = Event()
218
  response_text = []
219
 
220
+ sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
221
+
222
+ @sio.on('connect')
223
  def on_connect():
224
  app_logger.info("Connected to Perplexity AI")
225
  emit_data = {
 
238
  }
239
  sio.emit('perplexity_ask', (previous_messages, emit_data))
240
 
241
+ @sio.on('query_progress')
242
  def on_query_progress(data):
 
243
  if 'text' in data:
244
  text = json.loads(data['text'])
245
  chunk = text['chunks'][-1] if text['chunks'] else None
 
249
  if data.get('final', False):
250
  response_event.set()
251
 
252
+ @sio.on('disconnect')
253
  def on_disconnect():
254
  app_logger.info("Disconnected from Perplexity AI")
255
  response_event.set()
256
 
257
+ @sio.on('connect_error')
258
  def on_connect_error(data):
259
  app_logger.error(f"Connection error: {data}")
260
  response_text.append(f"Error connecting to Perplexity AI: {data}")
261
  response_event.set()
262
 
 
 
 
 
 
 
 
263
  response_event.wait(timeout=30)
264
  output_tokens = calculate_tokens(''.join(response_text))
265
 
 
281
 
282
  except Exception as e:
283
  app_logger.error(f"Error during socket connection: {str(e)}")
284
+ return create_json_response({"error": str(e)}, 500)
285
  finally:
286
  if sio.connected:
287
  sio.disconnect()
288
 
289
  @app.errorhandler(404)
290
  def not_found(error):
291
+ return create_json_response({"error": "Not Found"}, 404)
 
292
 
293
  @app.errorhandler(500)
294
  def server_error(error):
295
  app_logger.error(f"Server error: {str(error)}")
296
+ return create_json_response({"error": "Internal Server Error"}, 500)
297
+
298
+ def create_sse_event(event, data):
299
+ json_data = json.dumps(data, ensure_ascii=False)
300
+ event_str = f"event: {event}\ndata: {json_data}\n\n"
301
+ app_logger.debug(f"Sending SSE event: {event_str}")
 
 
302
  return event_str
303
 
304
  if __name__ == '__main__':
 
306
  app_logger.info(f"Perplexity proxy listening on port {port}")
307
  if not API_KEY:
308
  app_logger.warning("Warning: PPLX_KEY environment variable is not set. API key validation will fail.")
309
+ app.run(host='0.0.0.0', port=port, debug=False)