smgc commited on
Commit
edd867f
1 Parent(s): da3156c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -197
app.py CHANGED
@@ -1,18 +1,18 @@
1
  import os
2
  import json
3
  import uuid
4
- import time
5
  from datetime import datetime
6
  from flask import Flask, request, Response, jsonify
7
  import socketio
8
  import requests
9
  import logging
 
10
  import re
11
- import asyncio
12
- from functools import partial
13
 
14
  app = Flask(__name__)
15
 
 
16
  class CustomFormatter(logging.Formatter):
17
  def format(self, record):
18
  log_data = {
@@ -39,20 +39,31 @@ def setup_logging():
39
 
40
  logger = logging.getLogger(__name__)
41
 
 
42
  API_KEY = os.environ.get('PPLX_KEY')
 
 
43
  proxy_url = os.environ.get('PROXY_URL')
44
 
 
45
  if proxy_url:
46
- proxies = {'http': proxy_url, 'https': proxy_url}
 
 
 
47
  transport = requests.Session()
48
  transport.proxies.update(proxies)
49
  else:
50
  transport = None
51
 
52
- sio = socketio.AsyncClient(http_session=transport, logger=False, engineio_logger=False)
53
 
54
- connect_opts = {'transports': ['websocket', 'polling']}
 
 
 
55
 
 
56
  sio_opts = {
57
  'extraHeaders': {
58
  'Cookie': os.environ.get('PPLX_COOKIE'),
@@ -96,7 +107,8 @@ def calculate_tokens(text):
96
  if re.search(r'[^\x00-\x7F]', text):
97
  return len(text)
98
  else:
99
- return len(text.split())
 
100
 
101
  def create_event(event, data):
102
  if isinstance(data, dict):
@@ -118,13 +130,34 @@ def root():
118
  },
119
  "body": {
120
  "messages": "Array of message objects",
121
- "stream": "Boolean (optional, defaults to false)",
122
- "model": "Model to be used (optional, defaults to claude-3-5-sonnet-20240620)"
123
  }
124
  }
125
  }
126
  })
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  @app.route('/ai/v1/messages', methods=['POST'])
129
  def messages():
130
  auth_error = validate_api_key()
@@ -135,34 +168,25 @@ def messages():
135
  json_body = request.json
136
  model = json_body.get('model', 'claude-3-5-sonnet-20240620')
137
  stream = json_body.get('stream', False)
138
-
139
- previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
140
- input_tokens = calculate_tokens(previous_messages)
 
 
 
 
 
 
141
 
142
  msg_id = str(uuid.uuid4())
143
-
144
- if not stream:
145
- return handle_non_stream(previous_messages, msg_id, model, input_tokens)
146
-
147
- log_request(request.remote_addr, request.path, 200)
148
-
149
- async def run_socket_io():
150
- response_event = asyncio.Event()
151
- response_text = []
152
- total_output_tokens = 0
153
- start_time = time.time()
154
- last_activity_time = start_time
155
- timeout = max(300, input_tokens / 100) # 动态设置超时时间,最少300秒
156
-
157
- def send_event(event_type, data):
158
- event = create_event(event_type, data)
159
- logger.info(f"Sending {event_type} event", extra={
160
- 'event_type': event_type,
161
- 'data': {'content': event}
162
- })
163
- return event
164
 
165
- yield send_event("message_start", {
166
  "type": "message_start",
167
  "message": {
168
  "id": msg_id,
@@ -172,148 +196,142 @@ def messages():
172
  "content": [],
173
  "stop_reason": None,
174
  "stop_sequence": None,
175
- "usage": {"input_tokens": input_tokens, "output_tokens": total_output_tokens},
176
  },
177
  })
178
- yield send_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
179
- yield send_event("ping", {"type": "ping"})
180
-
181
- @sio.event
182
- async def query_progress(data):
183
- nonlocal total_output_tokens, response_text, last_activity_time
184
- last_activity_time = time.time()
185
- if 'text' in data:
186
- text = json.loads(data['text'])
187
- chunk = text['chunks'][-1] if text['chunks'] else None
188
- if chunk:
189
- response_text.append(chunk)
190
- chunk_tokens = calculate_tokens(chunk)
191
- total_output_tokens += chunk_tokens
192
- logger.info("Received chunk", extra={
193
- 'event_type': 'chunk_received',
194
- 'data': {
195
- 'chunk': chunk,
196
- 'tokens': chunk_tokens,
197
- 'total_tokens': total_output_tokens
198
- }
199
- })
200
-
201
- if data.get('final', False):
202
- logger.info("Final response received", extra={
203
- 'event_type': 'response_complete',
204
- 'data': {
205
- 'total_tokens': total_output_tokens
206
- }
207
- })
208
- response_event.set()
209
 
210
- @sio.event
211
- async def connect():
212
- logger.info("Connected to Perplexity AI", extra={'event_type': 'connection_established'})
213
- emit_data = {
214
- "version": "2.9",
215
- "source": "default",
216
- "attachments": [],
217
- "language": "en-GB",
218
- "timezone": "Europe/London",
219
- "mode": "concise",
220
- "is_related_query": False,
221
- "is_default_related_query": False,
222
- "visitor_id": str(uuid.uuid4()),
223
- "frontend_context_uuid": str(uuid.uuid4()),
224
- "prompt_source": "user",
225
- "query_source": "home"
226
- }
227
- await sio.emit('perplexity_ask', (previous_messages, emit_data))
228
- logger.info("Sent query to Perplexity AI", extra={
229
- 'event_type': 'query_sent',
230
- 'data': {
231
- 'message': previous_messages[:100] + '...' if len(previous_messages) > 100 else previous_messages
232
- }
233
  })
234
 
235
- async def heartbeat():
236
- while not response_event.is_set():
237
- await sio.emit('ping')
238
- await asyncio.sleep(25)
 
 
 
239
 
240
- try:
241
- await sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
242
- heartbeat_task = asyncio.create_task(heartbeat())
243
-
244
- while not response_event.is_set() and (time.time() - start_time) < timeout:
245
- current_time = time.time()
246
- if current_time - last_activity_time > 60: # 如果60秒内没有活动,记录警告
247
- logger.warning("No activity for 60 seconds", extra={'event_type': 'inactivity_warning'})
248
- await asyncio.sleep(0.1)
249
- while response_text:
250
- chunk = response_text.pop(0)
251
- yield send_event("content_block_delta", {
252
- "type": "content_block_delta",
253
- "index": 0,
254
- "delta": {"type": "text_delta", "text": chunk},
255
- })
256
-
257
- if not response_event.is_set():
258
- logger.warning(f"Request timed out after {timeout} seconds", extra={
259
- 'event_type': 'request_timeout',
260
- 'data': {
261
- 'timeout': timeout,
262
- 'input_tokens': input_tokens,
263
- 'output_tokens': total_output_tokens,
264
- 'elapsed_time': time.time() - start_time
265
- }
266
  })
267
- yield send_event("content_block_delta", {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  "type": "content_block_delta",
269
  "index": 0,
270
- "delta": {"type": "text_delta", "text": f"Request timed out after {timeout} seconds"},
271
  })
272
-
273
- except Exception as e:
274
- logger.error(f"Error during socket connection: {str(e)}", exc_info=True)
275
- yield send_event("content_block_delta", {
276
- "type": "content_block_delta",
277
- "index": 0,
278
- "delta": {"type": "text_delta", "text": f"Error during socket connection: {str(e)}"},
 
 
279
  })
280
- finally:
281
- heartbeat_task.cancel()
282
- if sio.connected:
283
- await sio.disconnect()
284
 
285
- yield send_event("content_block_stop", {"type": "content_block_stop", "index": 0})
286
- yield send_event("message_delta", {
287
  "type": "message_delta",
288
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
289
- "usage": {"output_tokens": total_output_tokens},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  })
291
- yield send_event("message_stop", {"type": "message_stop"})
292
-
293
- def run_async():
294
- loop = asyncio.new_event_loop()
295
- asyncio.set_event_loop(loop)
296
- try:
297
- return loop.run_until_complete(run_socket_io())
298
- finally:
299
- loop.close()
300
-
301
- return Response(run_async(), content_type='text/event-stream')
302
 
303
  except Exception as e:
304
  logger.error(f"Request error: {str(e)}", exc_info=True)
305
  log_request(request.remote_addr, request.path, 400)
306
  return jsonify({"error": str(e)}), 400
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
309
- async def run_non_stream():
 
310
  response_text = []
311
  total_output_tokens = 0
312
- start_time = time.time()
313
- timeout = max(300, input_tokens / 100) # 动态设置超时时间,最少300秒
314
 
315
- @sio.event
316
- async def query_progress(data):
317
  nonlocal total_output_tokens, response_text
318
  if 'text' in data:
319
  text = json.loads(data['text'])
@@ -322,20 +340,11 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
322
  response_text.append(chunk)
323
  chunk_tokens = calculate_tokens(chunk)
324
  total_output_tokens += chunk_tokens
325
- logger.info("Received chunk (non-stream)", extra={
326
- 'event_type': 'chunk_received_non_stream',
327
- 'data': {
328
- 'chunk': chunk,
329
- 'tokens': chunk_tokens,
330
- 'total_tokens': total_output_tokens
331
- }
332
- })
333
 
334
  if data.get('final', False):
335
- return True
336
 
337
- @sio.event
338
- async def connect():
339
  logger.info("Connected to Perplexity AI (non-stream)", extra={'event_type': 'connection_established_non_stream'})
340
  emit_data = {
341
  "version": "2.9",
@@ -351,34 +360,19 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
351
  "prompt_source": "user",
352
  "query_source": "home"
353
  }
354
- await sio.emit('perplexity_ask', (previous_messages, emit_data))
355
-
356
- try:
357
- await sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
358
- await asyncio.wait_for(sio.wait(), timeout=timeout)
359
- except asyncio.TimeoutError:
360
- logger.warning(f"Request timed out after {timeout} seconds (non-stream)", extra={
361
- 'event_type': 'request_timeout_non_stream',
362
- 'data': {
363
- 'timeout': timeout,
364
- 'input_tokens': input_tokens,
365
- 'elapsed_time': time.time() - start_time
366
- }
367
- })
368
- finally:
369
- if sio.connected:
370
- await sio.disconnect()
371
 
372
  if not response_text:
373
- logger.warning(f"No response received (non-stream) after {timeout} seconds", extra={
374
- 'event_type': 'no_response_non_stream',
375
- 'data': {
376
- 'timeout': timeout,
377
- 'input_tokens': input_tokens,
378
- 'elapsed_time': time.time() - start_time
379
- }
380
- })
381
- return jsonify({"error": f"No response received after {timeout} seconds"}), 504
382
 
383
  full_response = {
384
  "content": [{"text": ''.join(response_text), "type": "text"}],
@@ -395,22 +389,16 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
395
  }
396
  logger.info("Sending non-stream response", extra={
397
  'event_type': 'non_stream_response',
398
- 'data': {
399
- 'content': full_response,
400
- 'elapsed_time': time.time() - start_time
401
- }
402
  })
403
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
404
 
405
- loop = asyncio.new_event_loop()
406
- asyncio.set_event_loop(loop)
407
- try:
408
- return loop.run_until_complete(run_non_stream())
409
  except Exception as e:
410
  logger.error(f"Error during non-stream socket connection: {str(e)}", exc_info=True)
411
  return jsonify({"error": str(e)}), 500
412
  finally:
413
- loop.close()
 
414
 
415
  @app.errorhandler(404)
416
  def not_found(error):
 
1
  import os
2
  import json
3
  import uuid
 
4
  from datetime import datetime
5
  from flask import Flask, request, Response, jsonify
6
  import socketio
7
  import requests
8
  import logging
9
+ from threading import Event, Timer
10
  import re
11
+ import math
 
12
 
13
  app = Flask(__name__)
14
 
15
+ # 自定义日志格式化器
16
  class CustomFormatter(logging.Formatter):
17
  def format(self, record):
18
  log_data = {
 
39
 
40
  logger = logging.getLogger(__name__)
41
 
42
+ # 从环境变量中获取API密钥
43
  API_KEY = os.environ.get('PPLX_KEY')
44
+
45
+ # 代理设置
46
  proxy_url = os.environ.get('PROXY_URL')
47
 
48
+ # 设置代理
49
  if proxy_url:
50
+ proxies = {
51
+ 'http': proxy_url,
52
+ 'https': proxy_url
53
+ }
54
  transport = requests.Session()
55
  transport.proxies.update(proxies)
56
  else:
57
  transport = None
58
 
59
+ sio = socketio.Client(http_session=transport, logger=False, engineio_logger=False)
60
 
61
+ # 连接选项
62
+ connect_opts = {
63
+ 'transports': ['websocket', 'polling'],
64
+ }
65
 
66
+ # 其他选项
67
  sio_opts = {
68
  'extraHeaders': {
69
  'Cookie': os.environ.get('PPLX_COOKIE'),
 
107
  if re.search(r'[^\x00-\x7F]', text):
108
  return len(text)
109
  else:
110
+ tokens = text.split()
111
+ return len(tokens)
112
 
113
  def create_event(event, data):
114
  if isinstance(data, dict):
 
130
  },
131
  "body": {
132
  "messages": "Array of message objects",
133
+ "stream": "Boolean (true for streaming response)",
134
+ "model": "Model to be used (optional, defaults to claude-3-opus-20240229)"
135
  }
136
  }
137
  }
138
  })
139
 
140
+ # 在文件开头添加这个函数
141
+ def split_messages(messages, max_tokens_per_chunk=8000):
142
+ chunks = []
143
+ current_chunk = []
144
+ current_chunk_tokens = 0
145
+
146
+ for message in messages:
147
+ message_tokens = calculate_tokens(message['content'])
148
+ if current_chunk_tokens + message_tokens > max_tokens_per_chunk and current_chunk:
149
+ chunks.append(current_chunk)
150
+ current_chunk = []
151
+ current_chunk_tokens = 0
152
+
153
+ current_chunk.append(message)
154
+ current_chunk_tokens += message_tokens
155
+
156
+ if current_chunk:
157
+ chunks.append(current_chunk)
158
+
159
+ return chunks
160
+
161
  @app.route('/ai/v1/messages', methods=['POST'])
162
  def messages():
163
  auth_error = validate_api_key()
 
168
  json_body = request.json
169
  model = json_body.get('model', 'claude-3-5-sonnet-20240620')
170
  stream = json_body.get('stream', False)
171
+ messages = json_body.get('messages', [])
172
+
173
+ # 分块处理
174
+ chunks = split_messages(messages)
175
+ total_chunks = len(chunks)
176
+ logger.info(f"Input split into {total_chunks} chunks", extra={
177
+ 'event_type': 'input_split',
178
+ 'data': {'total_chunks': total_chunks}
179
+ })
180
 
181
  msg_id = str(uuid.uuid4())
182
+ total_input_tokens = sum(calculate_tokens(msg['content']) for msg in messages)
183
+ total_output_tokens = 0
184
+ full_response = []
185
+
186
+ def generate():
187
+ nonlocal total_output_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ yield from send_event("message_start", {
190
  "type": "message_start",
191
  "message": {
192
  "id": msg_id,
 
196
  "content": [],
197
  "stop_reason": None,
198
  "stop_sequence": None,
199
+ "usage": {"input_tokens": total_input_tokens, "output_tokens": 0},
200
  },
201
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ for chunk_index, chunk in enumerate(chunks):
204
+ chunk_input = "\n\n".join([normalize_content(msg['content']) for msg in chunk])
205
+ chunk_input_tokens = calculate_tokens(chunk_input)
206
+
207
+ response_event = Event()
208
+ timeout_event = Event()
209
+ response_text = []
210
+
211
+ # 动态调整超时时间
212
+ timeout_seconds = max(30, min(300, chunk_input_tokens // 1000 * 30))
213
+
214
+ yield from send_event("chunk_start", {
215
+ "type": "chunk_start",
216
+ "chunk_index": chunk_index,
217
+ "total_chunks": total_chunks,
 
 
 
 
 
 
 
 
218
  })
219
 
220
+ def on_query_progress(data):
221
+ nonlocal response_text
222
+ if 'text' in data:
223
+ text = json.loads(data['text'])
224
+ new_chunk = text['chunks'][-1] if text['chunks'] else None
225
+ if new_chunk:
226
+ response_text.append(new_chunk)
227
 
228
+ if data.get('final', False):
229
+ response_event.set()
230
+
231
+ sio.on('query_progress', on_query_progress)
232
+
233
+ def timeout_handler():
234
+ logger.warning(f"Chunk {chunk_index + 1}/{total_chunks} timed out after {timeout_seconds} seconds", extra={
235
+ 'event_type': 'chunk_timeout',
236
+ 'data': {'chunk_index': chunk_index, 'total_chunks': total_chunks, 'timeout_seconds': timeout_seconds}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  })
238
+ timeout_event.set()
239
+ response_event.set()
240
+
241
+ timer = Timer(timeout_seconds, timeout_handler)
242
+ timer.start()
243
+
244
+ try:
245
+ sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
246
+ sio.emit('perplexity_ask', (chunk_input, get_emit_data()))
247
+
248
+ while not response_event.is_set() and not timeout_event.is_set():
249
+ sio.sleep(0.1)
250
+ while response_text:
251
+ new_chunk = response_text.pop(0)
252
+ full_response.append(new_chunk)
253
+ chunk_tokens = calculate_tokens(new_chunk)
254
+ total_output_tokens += chunk_tokens
255
+ yield from send_event("content_block_delta", {
256
+ "type": "content_block_delta",
257
+ "index": 0,
258
+ "delta": {"type": "text_delta", "text": new_chunk},
259
+ })
260
+
261
+ except Exception as e:
262
+ logger.error(f"Error during chunk {chunk_index + 1}/{total_chunks} processing: {str(e)}", exc_info=True)
263
+ yield from send_event("content_block_delta", {
264
  "type": "content_block_delta",
265
  "index": 0,
266
+ "delta": {"type": "text_delta", "text": f"Error processing chunk {chunk_index + 1}/{total_chunks}: {str(e)}"},
267
  })
268
+ finally:
269
+ timer.cancel()
270
+ if sio.connected:
271
+ sio.disconnect()
272
+
273
+ yield from send_event("chunk_end", {
274
+ "type": "chunk_end",
275
+ "chunk_index": chunk_index,
276
+ "total_chunks": total_chunks,
277
  })
 
 
 
 
278
 
279
+ yield from send_event("content_block_stop", {"type": "content_block_stop", "index": 0})
280
+ yield from send_event("message_delta", {
281
  "type": "message_delta",
282
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
283
+ "usage": {"input_tokens": total_input_tokens, "output_tokens": total_output_tokens},
284
+ })
285
+ yield from send_event("message_stop", {"type": "message_stop"})
286
+
287
+ if stream:
288
+ return Response(generate(), content_type='text/event-stream')
289
+ else:
290
+ # 非流式处理
291
+ for _ in generate():
292
+ pass # 处理所有生成的事件,但不发送
293
+ return jsonify({
294
+ "content": [{"text": ''.join(full_response), "type": "text"}],
295
+ "id": msg_id,
296
+ "model": model,
297
+ "role": "assistant",
298
+ "stop_reason": "end_turn",
299
+ "stop_sequence": None,
300
+ "type": "message",
301
+ "usage": {
302
+ "input_tokens": total_input_tokens,
303
+ "output_tokens": total_output_tokens,
304
+ },
305
  })
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  except Exception as e:
308
  logger.error(f"Request error: {str(e)}", exc_info=True)
309
  log_request(request.remote_addr, request.path, 400)
310
  return jsonify({"error": str(e)}), 400
311
 
312
+ def get_emit_data():
313
+ return {
314
+ "version": "2.9",
315
+ "source": "default",
316
+ "attachments": [],
317
+ "language": "en-GB",
318
+ "timezone": "Europe/London",
319
+ "mode": "concise",
320
+ "is_related_query": False,
321
+ "is_default_related_query": False,
322
+ "visitor_id": str(uuid.uuid4()),
323
+ "frontend_context_uuid": str(uuid.uuid4()),
324
+ "prompt_source": "user",
325
+ "query_source": "home"
326
+ }
327
+
328
  def handle_non_stream(previous_messages, msg_id, model, input_tokens):
329
+ try:
330
+ response_event = Event()
331
  response_text = []
332
  total_output_tokens = 0
 
 
333
 
334
+ def on_query_progress(data):
 
335
  nonlocal total_output_tokens, response_text
336
  if 'text' in data:
337
  text = json.loads(data['text'])
 
340
  response_text.append(chunk)
341
  chunk_tokens = calculate_tokens(chunk)
342
  total_output_tokens += chunk_tokens
 
 
 
 
 
 
 
 
343
 
344
  if data.get('final', False):
345
+ response_event.set()
346
 
347
+ def on_connect():
 
348
  logger.info("Connected to Perplexity AI (non-stream)", extra={'event_type': 'connection_established_non_stream'})
349
  emit_data = {
350
  "version": "2.9",
 
360
  "prompt_source": "user",
361
  "query_source": "home"
362
  }
363
+ sio.emit('perplexity_ask', (previous_messages, emit_data))
364
+
365
+ sio.on('connect', on_connect)
366
+ sio.on('query_progress', on_query_progress)
367
+
368
+ sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
369
+
370
+ # Wait for response with timeout
371
+ response_event.wait(timeout=30)
 
 
 
 
 
 
 
 
372
 
373
  if not response_text:
374
+ logger.warning("No response received (non-stream)", extra={'event_type': 'no_response_non_stream'})
375
+ return jsonify({"error": "No response received"}), 504
 
 
 
 
 
 
 
376
 
377
  full_response = {
378
  "content": [{"text": ''.join(response_text), "type": "text"}],
 
389
  }
390
  logger.info("Sending non-stream response", extra={
391
  'event_type': 'non_stream_response',
392
+ 'data': {'content': full_response}
 
 
 
393
  })
394
  return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
395
 
 
 
 
 
396
  except Exception as e:
397
  logger.error(f"Error during non-stream socket connection: {str(e)}", exc_info=True)
398
  return jsonify({"error": str(e)}), 500
399
  finally:
400
+ if sio.connected:
401
+ sio.disconnect()
402
 
403
  @app.errorhandler(404)
404
  def not_found(error):