smgc commited on
Commit
02829ce
1 Parent(s): d8375a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -101
app.py CHANGED
@@ -8,6 +8,7 @@ import requests
8
  import logging
9
  from threading import Event, Timer
10
  import re
 
11
 
12
  app = Flask(__name__)
13
 
@@ -98,9 +99,9 @@ def normalize_content(content):
98
  elif isinstance(content, dict):
99
  return json.dumps(content, ensure_ascii=False)
100
  elif isinstance(content, list):
101
- return json.dumps(content, ensure_ascii=False)
102
  else:
103
- return str(content)
104
 
105
  def calculate_tokens(text):
106
  if re.search(r'[^\x00-\x7F]', text):
@@ -136,6 +137,11 @@ def root():
136
  }
137
  })
138
 
 
 
 
 
 
139
  @app.route('/ai/v1/messages', methods=['POST'])
140
  def messages():
141
  auth_error = validate_api_key()
@@ -144,21 +150,14 @@ def messages():
144
 
145
  try:
146
  json_body = request.json
147
- model = json_body.get('model', 'claude-3-5-sonnet-20240620')
148
- stream = json_body.get('stream', False)
149
- messages = json_body.get('messages', [])
150
-
151
- # 规范化所有消息
152
- normalized_messages = [
153
- {**msg, 'content': normalize_content(msg.get('content', ''))}
154
- for msg in messages
155
- ]
156
 
157
- # 计算总输入 tokens
158
- total_input_tokens = sum(calculate_tokens(msg['content']) for msg in normalized_messages)
159
 
160
- # 准备完整的输入文本
161
- full_input = "\n\n".join([msg['content'] for msg in normalized_messages])
162
 
163
  msg_id = str(uuid.uuid4())
164
  response_event = Event()
@@ -166,21 +165,32 @@ def messages():
166
  response_text = []
167
  total_output_tokens = 0
168
 
169
- # 动态调整超时时间
170
- timeout_seconds = max(60, min(600, total_input_tokens // 1000 * 30)) # 每1000 tokens至少60秒,最多10分钟
 
 
 
 
 
171
 
172
- def send_event(event_type, data):
173
- event = create_event(event_type, data)
174
- logger.info(f"Sending {event_type} event", extra={
175
- 'event_type': event_type,
176
- 'data': {'content': event}
177
- })
178
- return event
179
 
180
  def generate():
181
  nonlocal total_output_tokens
182
 
183
- yield send_event("message_start", {
 
 
 
 
 
 
 
 
 
184
  "type": "message_start",
185
  "message": {
186
  "id": msg_id,
@@ -190,127 +200,202 @@ def messages():
190
  "content": [],
191
  "stop_reason": None,
192
  "stop_sequence": None,
193
- "usage": {"input_tokens": total_input_tokens, "output_tokens": 0},
194
  },
195
  })
196
- yield send_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
 
197
 
198
  def on_query_progress(data):
199
- nonlocal response_text, total_output_tokens
200
  if 'text' in data:
201
  text = json.loads(data['text'])
202
- new_chunk = text['chunks'][-1] if text['chunks'] else None
203
- if new_chunk:
204
- response_text.append(new_chunk)
205
- chunk_tokens = calculate_tokens(new_chunk)
206
  total_output_tokens += chunk_tokens
207
- yield send_event("content_block_delta", {
208
- "type": "content_block_delta",
209
- "index": 0,
210
- "delta": {"type": "text_delta", "text": new_chunk},
 
 
 
 
 
 
 
 
 
211
  })
212
 
213
  if data.get('final', False):
 
 
 
 
 
 
214
  response_event.set()
215
 
216
  def on_connect():
217
  logger.info("Connected to Perplexity AI", extra={'event_type': 'connection_established'})
218
- sio.emit('perplexity_ask', (full_input, get_emit_data()))
219
-
220
- def timeout_handler():
221
- logger.warning(f"Request timed out after {timeout_seconds} seconds", extra={
222
- 'event_type': 'request_timeout',
223
- 'data': {'timeout_seconds': timeout_seconds}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  })
225
- timeout_event.set()
226
- response_event.set()
227
 
228
  sio.on('connect', on_connect)
229
  sio.on('query_progress', on_query_progress)
230
 
231
- timer = Timer(timeout_seconds, timeout_handler)
 
 
 
 
 
232
  timer.start()
233
 
234
  try:
235
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
236
-
237
  while not response_event.is_set() and not timeout_event.is_set():
238
  sio.sleep(0.1)
239
- yield from on_query_progress({'text': json.dumps({'chunks': response_text})})
240
- response_text.clear()
241
-
 
 
 
 
 
242
  if timeout_event.is_set():
243
- yield send_event("content_block_delta", {
244
  "type": "content_block_delta",
245
  "index": 0,
246
- "delta": {"type": "text_delta", "text": f"Request timed out after {timeout_seconds} seconds"},
247
  })
248
-
249
  except Exception as e:
250
- logger.error(f"Error during processing: {str(e)}", exc_info=True)
251
- yield send_event("content_block_delta", {
252
  "type": "content_block_delta",
253
  "index": 0,
254
- "delta": {"type": "text_delta", "text": f"Error during processing: {str(e)}"},
255
  })
256
  finally:
257
  timer.cancel()
258
  if sio.connected:
259
  sio.disconnect()
260
 
261
- yield send_event("content_block_stop", {"type": "content_block_stop", "index": 0})
262
- yield send_event("message_delta", {
 
263
  "type": "message_delta",
264
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
265
- "usage": {"input_tokens": total_input_tokens, "output_tokens": total_output_tokens},
266
- })
267
- yield send_event("message_stop", {"type": "message_stop"})
268
-
269
- if stream:
270
- return Response(generate(), content_type='text/event-stream')
271
- else:
272
- # 非流式处理
273
- full_response = []
274
- for event in generate():
275
- if 'content_block_delta' in event:
276
- data = json.loads(event.split('data: ')[1])
277
- full_response.append(data['delta']['text'])
278
-
279
- return jsonify({
280
- "content": [{"text": ''.join(full_response), "type": "text"}],
281
- "id": msg_id,
282
- "model": model,
283
- "role": "assistant",
284
- "stop_reason": "end_turn",
285
- "stop_sequence": None,
286
- "type": "message",
287
- "usage": {
288
- "input_tokens": total_input_tokens,
289
- "output_tokens": total_output_tokens,
290
- },
291
  })
 
 
 
292
 
293
  except Exception as e:
294
  logger.error(f"Request error: {str(e)}", exc_info=True)
295
- logger.error(f"Request body: {json.dumps(request.json, default=str)}")
296
- log_request(request.remote_addr, request.path, 500)
297
- return jsonify({"error": str(e)}), 500
298
 
299
- def get_emit_data():
300
- return {
301
- "version": "2.9",
302
- "source": "default",
303
- "attachments": [],
304
- "language": "en-GB",
305
- "timezone": "Europe/London",
306
- "mode": "concise",
307
- "is_related_query": False,
308
- "is_default_related_query": False,
309
- "visitor_id": str(uuid.uuid4()),
310
- "frontend_context_uuid": str(uuid.uuid4()),
311
- "prompt_source": "user",
312
- "query_source": "home"
313
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  @app.errorhandler(404)
316
  def not_found(error):
 
8
  import logging
9
  from threading import Event, Timer
10
  import re
11
+ import asyncio
12
 
13
  app = Flask(__name__)
14
 
 
99
  elif isinstance(content, dict):
100
  return json.dumps(content, ensure_ascii=False)
101
  elif isinstance(content, list):
102
+ return " ".join([normalize_content(item) for item in content])
103
  else:
104
+ return ""
105
 
106
  def calculate_tokens(text):
107
  if re.search(r'[^\x00-\x7F]', text):
 
137
  }
138
  })
139
 
140
+ async def process_large_request(previous_messages, model, input_tokens):
141
+ # 这个函数用于异步处理大型请求
142
+ # 实际实现时,你可能需要将这个过程放到后台任务队列中
143
+ pass
144
+
145
  @app.route('/ai/v1/messages', methods=['POST'])
146
  def messages():
147
  auth_error = validate_api_key()
 
150
 
151
  try:
152
  json_body = request.json
153
+ model = json_body.get('model', 'claude-3-opus-20240229')
154
+ stream = json_body.get('stream', True)
 
 
 
 
 
 
 
155
 
156
+ previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
157
+ input_tokens = calculate_tokens(previous_messages)
158
 
159
+ # 根据 input_tokens 动态调整超时时间
160
+ timeout = max(30, min(300, input_tokens // 1000 * 30)) # 最少30秒,最多300秒
161
 
162
  msg_id = str(uuid.uuid4())
163
  response_event = Event()
 
165
  response_text = []
166
  total_output_tokens = 0
167
 
168
+ if input_tokens > 100000: # 如果 tokens 数量特别大,使用异步处理
169
+ task_id = str(uuid.uuid4())
170
+ asyncio.create_task(process_large_request(previous_messages, model, input_tokens))
171
+ return jsonify({
172
+ "message": "Request is being processed asynchronously",
173
+ "task_id": task_id
174
+ }), 202
175
 
176
+ if not stream:
177
+ return handle_non_stream(previous_messages, msg_id, model, input_tokens, timeout)
178
+
179
+ log_request(request.remote_addr, request.path, 200)
 
 
 
180
 
181
  def generate():
182
  nonlocal total_output_tokens
183
 
184
+ def send_event(event_type, data):
185
+ event = create_event(event_type, data)
186
+ logger.info(f"Sending {event_type} event", extra={
187
+ 'event_type': event_type,
188
+ 'data': {'content': event}
189
+ })
190
+ yield event
191
+
192
+ # Send initial events
193
+ yield from send_event("message_start", {
194
  "type": "message_start",
195
  "message": {
196
  "id": msg_id,
 
200
  "content": [],
201
  "stop_reason": None,
202
  "stop_sequence": None,
203
+ "usage": {"input_tokens": input_tokens, "output_tokens": total_output_tokens},
204
  },
205
  })
206
+ yield from send_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})
207
+ yield from send_event("ping", {"type": "ping"})
208
 
209
  def on_query_progress(data):
210
+ nonlocal total_output_tokens, response_text
211
  if 'text' in data:
212
  text = json.loads(data['text'])
213
+ chunk = text['chunks'][-1] if text['chunks'] else None
214
+ if chunk:
215
+ response_text.append(chunk)
216
+ chunk_tokens = calculate_tokens(chunk)
217
  total_output_tokens += chunk_tokens
218
+ logger.info("Received chunk", extra={
219
+ 'event_type': 'chunk_received',
220
+ 'data': {
221
+ 'chunk': chunk,
222
+ 'tokens': chunk_tokens,
223
+ 'total_tokens': total_output_tokens
224
+ }
225
+ })
226
+ # 发送进度更新
227
+ yield from send_event("progress", {
228
+ "type": "progress",
229
+ "processed_tokens": total_output_tokens,
230
+ "total_tokens": input_tokens
231
  })
232
 
233
  if data.get('final', False):
234
+ logger.info("Final response received", extra={
235
+ 'event_type': 'response_complete',
236
+ 'data': {
237
+ 'total_tokens': total_output_tokens
238
+ }
239
+ })
240
  response_event.set()
241
 
242
  def on_connect():
243
  logger.info("Connected to Perplexity AI", extra={'event_type': 'connection_established'})
244
+ emit_data = {
245
+ "version": "2.9",
246
+ "source": "default",
247
+ "attachments": [],
248
+ "language": "en-GB",
249
+ "timezone": "Europe/London",
250
+ "mode": "concise",
251
+ "is_related_query": False,
252
+ "is_default_related_query": False,
253
+ "visitor_id": str(uuid.uuid4()),
254
+ "frontend_context_uuid": str(uuid.uuid4()),
255
+ "prompt_source": "user",
256
+ "query_source": "home"
257
+ }
258
+ sio.emit('perplexity_ask', (previous_messages, emit_data))
259
+ logger.info("Sent query to Perplexity AI", extra={
260
+ 'event_type': 'query_sent',
261
+ 'data': {
262
+ 'message': previous_messages[:100] + '...' if len(previous_messages) > 100 else previous_messages
263
+ }
264
  })
 
 
265
 
266
  sio.on('connect', on_connect)
267
  sio.on('query_progress', on_query_progress)
268
 
269
+ def timeout_handler():
270
+ logger.warning("Request timed out", extra={'event_type': 'request_timeout'})
271
+ timeout_event.set()
272
+ response_event.set()
273
+
274
+ timer = Timer(timeout, timeout_handler) # 使用动态超时时间
275
  timer.start()
276
 
277
  try:
278
  sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
279
+
280
  while not response_event.is_set() and not timeout_event.is_set():
281
  sio.sleep(0.1)
282
+ while response_text:
283
+ chunk = response_text.pop(0)
284
+ yield from send_event("content_block_delta", {
285
+ "type": "content_block_delta",
286
+ "index": 0,
287
+ "delta": {"type": "text_delta", "text": chunk},
288
+ })
289
+
290
  if timeout_event.is_set():
291
+ yield from send_event("content_block_delta", {
292
  "type": "content_block_delta",
293
  "index": 0,
294
+ "delta": {"type": "text_delta", "text": "Request timed out. Partial response: " + ''.join(response_text)},
295
  })
296
+
297
  except Exception as e:
298
+ logger.error(f"Error during socket connection: {str(e)}", exc_info=True)
299
+ yield from send_event("content_block_delta", {
300
  "type": "content_block_delta",
301
  "index": 0,
302
+ "delta": {"type": "text_delta", "text": f"Error during socket connection: {str(e)}"},
303
  })
304
  finally:
305
  timer.cancel()
306
  if sio.connected:
307
  sio.disconnect()
308
 
309
+ # Send final events
310
+ yield from send_event("content_block_stop", {"type": "content_block_stop", "index": 0})
311
+ yield from send_event("message_delta", {
312
  "type": "message_delta",
313
  "delta": {"stop_reason": "end_turn", "stop_sequence": None},
314
+ "usage": {"output_tokens": total_output_tokens},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  })
316
+ yield from send_event("message_stop", {"type": "message_stop"})
317
+
318
+ return Response(generate(), content_type='text/event-stream')
319
 
320
  except Exception as e:
321
  logger.error(f"Request error: {str(e)}", exc_info=True)
322
+ log_request(request.remote_addr, request.path, 400)
323
+ return jsonify({"error": str(e)}), 400
 
324
 
325
+ def handle_non_stream(previous_messages, msg_id, model, input_tokens, timeout):
326
+ try:
327
+ response_event = Event()
328
+ response_text = []
329
+ total_output_tokens = 0
330
+
331
+ def on_query_progress(data):
332
+ nonlocal total_output_tokens, response_text
333
+ if 'text' in data:
334
+ text = json.loads(data['text'])
335
+ chunk = text['chunks'][-1] if text['chunks'] else None
336
+ if chunk:
337
+ response_text.append(chunk)
338
+ chunk_tokens = calculate_tokens(chunk)
339
+ total_output_tokens += chunk_tokens
340
+
341
+ if data.get('final', False):
342
+ response_event.set()
343
+
344
+ def on_connect():
345
+ logger.info("Connected to Perplexity AI (non-stream)", extra={'event_type': 'connection_established_non_stream'})
346
+ emit_data = {
347
+ "version": "2.9",
348
+ "source": "default",
349
+ "attachments": [],
350
+ "language": "en-GB",
351
+ "timezone": "Europe/London",
352
+ "mode": "concise",
353
+ "is_related_query": False,
354
+ "is_default_related_query": False,
355
+ "visitor_id": str(uuid.uuid4()),
356
+ "frontend_context_uuid": str(uuid.uuid4()),
357
+ "prompt_source": "user",
358
+ "query_source": "home"
359
+ }
360
+ sio.emit('perplexity_ask', (previous_messages, emit_data))
361
+
362
+ sio.on('connect', on_connect)
363
+ sio.on('query_progress', on_query_progress)
364
+
365
+ sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders'])
366
+
367
+ # Wait for response with timeout
368
+ response_event.wait(timeout=timeout)
369
+
370
+ if not response_text:
371
+ logger.warning("No response received (non-stream)", extra={'event_type': 'no_response_non_stream'})
372
+ return jsonify({"error": "No response received"}), 504
373
+
374
+ full_response = {
375
+ "content": [{"text": ''.join(response_text), "type": "text"}],
376
+ "id": msg_id,
377
+ "model": model,
378
+ "role": "assistant",
379
+ "stop_reason": "end_turn",
380
+ "stop_sequence": None,
381
+ "type": "message",
382
+ "usage": {
383
+ "input_tokens": input_tokens,
384
+ "output_tokens": total_output_tokens,
385
+ },
386
+ }
387
+ logger.info("Sending non-stream response", extra={
388
+ 'event_type': 'non_stream_response',
389
+ 'data': {'content': full_response}
390
+ })
391
+ return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json')
392
+
393
+ except Exception as e:
394
+ logger.error(f"Error during non-stream socket connection: {str(e)}", exc_info=True)
395
+ return jsonify({"error": str(e)}), 500
396
+ finally:
397
+ if sio.connected:
398
+ sio.disconnect()
399
 
400
  @app.errorhandler(404)
401
  def not_found(error):