smgc commited on
Commit
1b13226
1 Parent(s): 378dec0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -37
app.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  import socket
9
  from concurrent.futures import ThreadPoolExecutor
10
  from functools import lru_cache
 
11
 
12
  import requests
13
  import tiktoken
@@ -40,6 +41,9 @@ if not NOTDIAMOND_IP:
40
  logger.error("NOTDIAMOND_IP environment variable is not set!")
41
  raise ValueError("NOTDIAMOND_IP must be set")
42
 
 
 
 
43
  # 自定义连接函数
44
  def patched_create_connection(address, *args, **kwargs):
45
  host, port = address
@@ -66,7 +70,6 @@ def create_custom_session():
66
  session.mount('http://', adapter)
67
  return session
68
 
69
- # 修改 AuthManager 类以使用自定义 Session
70
  class AuthManager:
71
  def __init__(self, email: str, password: str):
72
  self.email = email
@@ -79,10 +82,21 @@ class AuthManager:
79
  logging.basicConfig(level=logging.INFO)
80
 
81
  self.session = create_custom_session()
82
- self.login()
83
  self.fetch_apikey()
84
  self.log_values()
 
 
 
 
 
85
 
 
 
 
 
 
 
86
  def log_values(self) -> None:
87
  """记录刷新令牌到日志中。"""
88
  self.logger.info(f"\033[92mRefresh Token: {self.refresh_token}\033[0m")
@@ -109,7 +123,7 @@ class AuthManager:
109
  js_response = self.session.get(js_url, headers=headers)
110
  js_response.raise_for_status()
111
 
112
- api_key_match = re.search(r'\("https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"\)', js_response.text)
113
  if api_key_match:
114
  self.api_key = api_key_match.group(1)
115
  return self.api_key
@@ -140,6 +154,11 @@ class AuthManager:
140
  response.raise_for_status()
141
  self.user_info = response.json()
142
  self.refresh_token = self.user_info.get('refresh_token', '')
 
 
 
 
 
143
  except requests.RequestException as e:
144
  self.logger.error(f"\033[91m登录请求错误: {e}\033[0m")
145
 
@@ -158,18 +177,22 @@ class AuthManager:
158
  response.raise_for_status()
159
  self.user_info = response.json()
160
  self.refresh_token = self.user_info.get('refresh_token', '')
 
 
 
 
 
161
  except requests.RequestException as e:
162
  self.logger.error(f"刷新令牌请求错误: {e}")
 
 
163
 
164
  def get_jwt_value(self) -> str:
165
  """返回访问令牌。"""
166
  return self.user_info.get('access_token', '')
167
 
168
- # 初始化 AuthManager
169
- auth_manager = AuthManager(
170
- os.getenv("AUTH_EMAIL", "[email protected]"),
171
- os.getenv("AUTH_PASSWORD", "default_password"),
172
- )
173
 
174
  NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', 'https://not-diamond-workers.t7-cc4.workers.dev/stream-message').split(',')
175
 
@@ -263,7 +286,6 @@ def create_openai_chunk(content, model, finish_reason=None, usage=None):
263
  chunk["usage"] = usage
264
  return chunk
265
 
266
-
267
  def count_tokens(text, model="gpt-3.5-turbo-0301"):
268
  """计算给定文本的令牌数量。"""
269
  try:
@@ -327,7 +349,6 @@ def generate_stream_response(response, model, prompt_tokens):
327
  for chunk in stream_notdiamond_response(response, model):
328
  content = chunk['choices'][0]['delta'].get('content', '')
329
  total_completion_tokens += count_tokens(content, model)
330
-
331
  chunk['usage'] = {
332
  "prompt_tokens": prompt_tokens,
333
  "completion_tokens": total_completion_tokens,
@@ -338,6 +359,28 @@ def generate_stream_response(response, model, prompt_tokens):
338
 
339
  yield "data: [DONE]\n\n"
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  @app.route('/', methods=['GET'])
342
  def root():
343
  return jsonify({
@@ -385,6 +428,10 @@ def proxy_models():
385
  @app.route('/ai/v1/chat/completions', methods=['POST'])
386
  def handle_request():
387
  """处理聊天完成请求。"""
 
 
 
 
388
  try:
389
  request_data = request.get_json()
390
  model_id = request_data.get('model', '')
@@ -480,39 +527,20 @@ def make_request(payload):
480
  """发送请求并处理可能的认证刷新。"""
481
  url = get_notdiamond_url()
482
  headers = get_notdiamond_headers()
483
- session = create_custom_session()
484
- logger.info(f"Sending request to URL: {url}")
485
- try:
486
- response = executor.submit(session.post, url, headers=headers, json=payload, stream=True).result()
487
- logger.info(f"Response status code: {response.status_code}")
488
- if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
489
- return response
490
- except Exception as e:
491
- logger.error(f"Error in make_request: {str(e)}")
492
- raise
493
 
494
- logger.info("Refreshing user token...")
495
  auth_manager.refresh_user_token()
496
  headers = get_notdiamond_headers()
497
- try:
498
- response = executor.submit(session.post, url, headers=headers, json=payload, stream=True).result()
499
- logger.info(f"Response status code after token refresh: {response.status_code}")
500
- if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
501
- return response
502
- except Exception as e:
503
- logger.error(f"Error in make_request after token refresh: {str(e)}")
504
- raise
505
 
506
- logger.info("Logging in again...")
507
  auth_manager.login()
508
  headers = get_notdiamond_headers()
509
- try:
510
- response = executor.submit(session.post, url, headers=headers, json=payload, stream=True).result()
511
- logger.info(f"Response status code after login: {response.status_code}")
512
- return response
513
- except Exception as e:
514
- logger.error(f"Error in make_request after login: {str(e)}")
515
- raise
516
 
517
  if __name__ == "__main__":
518
  port = int(os.environ.get("PORT", 3000))
 
8
  import socket
9
  from concurrent.futures import ThreadPoolExecutor
10
  from functools import lru_cache
11
+ from cachetools import TTLCache
12
 
13
  import requests
14
  import tiktoken
 
41
  logger.error("NOTDIAMOND_IP environment variable is not set!")
42
  raise ValueError("NOTDIAMOND_IP must be set")
43
 
44
+ # 创建一个 TTLCache 来存储 refresh_token
45
+ refresh_token_cache = TTLCache(maxsize=1000, ttl=3600)
46
+
47
  # 自定义连接函数
48
  def patched_create_connection(address, *args, **kwargs):
49
  host, port = address
 
70
  session.mount('http://', adapter)
71
  return session
72
 
 
73
  class AuthManager:
74
  def __init__(self, email: str, password: str):
75
  self.email = email
 
82
  logging.basicConfig(level=logging.INFO)
83
 
84
  self.session = create_custom_session()
85
+ self.initialize_auth()
86
  self.fetch_apikey()
87
  self.log_values()
88
+
89
+ def initialize_auth(self):
90
+ """初始化认证,优先使用缓存的 refresh_token"""
91
+ cache_key = f"{self.email}|{self.password}"
92
+ cached_token = refresh_token_cache.get(cache_key)
93
 
94
+ if cached_token:
95
+ self.refresh_token = cached_token
96
+ self.refresh_user_token()
97
+ else:
98
+ self.login()
99
+
100
  def log_values(self) -> None:
101
  """记录刷新令牌到日志中。"""
102
  self.logger.info(f"\033[92mRefresh Token: {self.refresh_token}\033[0m")
 
123
  js_response = self.session.get(js_url, headers=headers)
124
  js_response.raise_for_status()
125
 
126
+ api_key_match = re.search(r'$"https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"$', js_response.text)
127
  if api_key_match:
128
  self.api_key = api_key_match.group(1)
129
  return self.api_key
 
154
  response.raise_for_status()
155
  self.user_info = response.json()
156
  self.refresh_token = self.user_info.get('refresh_token', '')
157
+
158
+ # 缓存 refresh_token
159
+ cache_key = f"{self.email}|{self.password}"
160
+ refresh_token_cache[cache_key] = self.refresh_token
161
+
162
  except requests.RequestException as e:
163
  self.logger.error(f"\033[91m登录请求错误: {e}\033[0m")
164
 
 
177
  response.raise_for_status()
178
  self.user_info = response.json()
179
  self.refresh_token = self.user_info.get('refresh_token', '')
180
+
181
+ # 更新缓存中的 refresh_token
182
+ cache_key = f"{self.email}|{self.password}"
183
+ refresh_token_cache[cache_key] = self.refresh_token
184
+
185
  except requests.RequestException as e:
186
  self.logger.error(f"刷新令牌请求错误: {e}")
187
+ # 如果刷新失败,尝试重新登录
188
+ self.login()
189
 
190
  def get_jwt_value(self) -> str:
191
  """返回访问令牌。"""
192
  return self.user_info.get('access_token', '')
193
 
194
+ # 全局的 AuthManager 对象,将在每次请求时更新
195
+ auth_manager = None
 
 
 
196
 
197
  NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', 'https://not-diamond-workers.t7-cc4.workers.dev/stream-message').split(',')
198
 
 
286
  chunk["usage"] = usage
287
  return chunk
288
 
 
289
  def count_tokens(text, model="gpt-3.5-turbo-0301"):
290
  """计算给定文本的令牌数量。"""
291
  try:
 
349
  for chunk in stream_notdiamond_response(response, model):
350
  content = chunk['choices'][0]['delta'].get('content', '')
351
  total_completion_tokens += count_tokens(content, model)
 
352
  chunk['usage'] = {
353
  "prompt_tokens": prompt_tokens,
354
  "completion_tokens": total_completion_tokens,
 
359
 
360
  yield "data: [DONE]\n\n"
361
 
362
+ def get_auth_credentials():
363
+ """从请求头中获取认证凭据"""
364
+ auth_header = request.headers.get('Authorization')
365
+ if not auth_header or not auth_header.startswith('Bearer '):
366
+ return None, None
367
+
368
+ try:
369
+ credentials = auth_header.split('Bearer ')[1]
370
+ email, password = credentials.split('|')
371
+ return email.strip(), password.strip()
372
+ except:
373
+ return None, None
374
+
375
+ @app.before_request
376
+ def before_request():
377
+ global auth_manager
378
+ email, password = get_auth_credentials()
379
+ if email and password:
380
+ auth_manager = AuthManager(email, password)
381
+ else:
382
+ auth_manager = None
383
+
384
  @app.route('/', methods=['GET'])
385
  def root():
386
  return jsonify({
 
428
  @app.route('/ai/v1/chat/completions', methods=['POST'])
429
  def handle_request():
430
  """处理聊天完成请求。"""
431
+ global auth_manager
432
+ if not auth_manager:
433
+ return jsonify({'error': 'Unauthorized'}), 401
434
+
435
  try:
436
  request_data = request.get_json()
437
  model_id = request_data.get('model', '')
 
527
  """发送请求并处理可能的认证刷新。"""
528
  url = get_notdiamond_url()
529
  headers = get_notdiamond_headers()
530
+ response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
531
+ if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
532
+ return response
 
 
 
 
 
 
 
533
 
 
534
  auth_manager.refresh_user_token()
535
  headers = get_notdiamond_headers()
536
+ response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
537
+ if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
538
+ return response
 
 
 
 
 
539
 
 
540
  auth_manager.login()
541
  headers = get_notdiamond_headers()
542
+ response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
543
+ return response
 
 
 
 
 
544
 
545
  if __name__ == "__main__":
546
  port = int(os.environ.get("PORT", 3000))