smgc commited on
Commit
12993a7
1 Parent(s): c99e95e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -123
app.py CHANGED
@@ -7,22 +7,25 @@ import uuid
7
  import re
8
  import socket
9
  from concurrent.futures import ThreadPoolExecutor
10
- from functools import lru_cache
11
- from cachetools import TTLCache
12
 
13
  import requests
14
  import tiktoken
15
  from flask import Flask, Response, jsonify, request, stream_with_context
16
  from flask_cors import CORS
17
- from typing import Dict, Any
18
  from requests.adapters import HTTPAdapter
19
  from urllib3.util.connection import create_connection
20
  import urllib3
 
21
 
22
  # Constants
23
  CHAT_COMPLETION_CHUNK = 'chat.completion.chunk'
24
  CHAT_COMPLETION = 'chat.completion'
25
  CONTENT_TYPE_EVENT_STREAM = 'text/event-stream'
 
 
 
26
 
27
  app = Flask(__name__)
28
  logging.basicConfig(level=logging.INFO)
@@ -72,134 +75,144 @@ def create_custom_session():
72
 
73
  class AuthManager:
74
  def __init__(self, email: str, password: str):
75
- self.email = email
76
- self.password = password
77
- self.api_key: str = ""
78
- self.user_info: Dict[str, Any] = {}
79
- self.refresh_token: str = ""
 
 
 
80
 
81
- self.logger = logging.getLogger(__name__)
82
  logging.basicConfig(level=logging.INFO)
83
 
84
- self.session = create_custom_session()
85
- self.initialize_auth()
86
- self.fetch_apikey()
87
- self.log_values()
88
-
89
- def initialize_auth(self):
90
- """初始化认证,优先使用缓存的 refresh_token"""
91
- cache_key = f"{self.email}|{self.password}"
92
- cached_token = refresh_token_cache.get(cache_key)
93
-
94
- if cached_token:
95
- self.refresh_token = cached_token
96
- self.refresh_user_token()
97
- else:
98
- self.login()
99
-
100
- def log_values(self) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  """记录刷新令牌到日志中。"""
102
- self.logger.info(f"\033[92mRefresh Token: {self.refresh_token}\033[0m")
 
103
 
104
- def fetch_apikey(self) -> str:
105
  """获取API密钥。"""
106
- if self.api_key:
107
- return self.api_key
108
-
109
  try:
110
- url = "https://chat.notdiamond.ai/login"
111
- headers = {
112
- 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36'
113
- }
114
- response = self.session.get(url, headers=headers)
115
- response.raise_for_status()
116
-
117
- # 匹配 <script> 标签中的 JS 文件路径
118
  match = re.search(r'<script src="(/_next/static/chunks/app/layout-[^"]+\.js)"', response.text)
119
  if not match:
120
- self.logger.warning("未找到匹配的脚本标签")
121
- return ""
122
-
123
- js_url = f"https://chat.notdiamond.ai{match.group(1)}"
124
- js_response = self.session.get(js_url, headers=headers)
125
- js_response.raise_for_status()
126
-
127
- # 匹配 API key
128
- api_key_match = re.search(r'\("https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"\)', js_response.text)
129
- if api_key_match:
130
- self.api_key = api_key_match.group(1)
131
- logger.info(f"Extracted API key: {self.api_key}")
132
- return self.api_key
133
- else:
134
- self.logger.error("未能匹配API key")
135
- return ""
136
-
137
- except requests.RequestException as e:
138
- self.logger.error(f"请求JS文件时发生错误: {e}")
139
- return ""
140
 
141
- def login(self) -> None:
142
- """使用电子邮件和密码进行用户登录,并获取用户信息。"""
143
- api_key = self.fetch_apikey()
144
- if not api_key:
145
- self.logger.error("API key is missing, cannot proceed with login.")
146
- return
147
-
148
- url = "https://spuckhogycrxcbomznwo.supabase.co/auth/v1/token?grant_type=password"
149
- headers = {
150
- 'apikey': api_key,
151
- 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
152
- 'Content-Type': 'application/json'
153
- }
154
- data = {
155
- "email": self.email,
156
- "password": self.password,
157
- "gotrue_meta_security": {}
158
- }
159
-
160
- try:
161
- response = self.session.post(url, headers=headers, json=data)
162
- response.raise_for_status()
163
- self.user_info = response.json()
164
- self.refresh_token = self.user_info.get('refresh_token', '')
165
 
166
- # 缓存 refresh_token
167
- cache_key = f"{self.email}|{self.password}"
168
- refresh_token_cache[cache_key] = self.refresh_token
169
 
170
- logger.info(f"Login successful for email: {self.email}")
171
-
172
- except requests.RequestException as e:
173
- self.logger.error(f"登录请求错误: {e}")
174
 
175
- def refresh_user_token(self) -> None:
176
- """使用刷新令牌来请求一个新的访问令牌并更新实例变量。"""
177
- url = "https://spuckhogycrxcbomznwo.supabase.co/auth/v1/token?grant_type=refresh_token"
 
 
 
178
  headers = {
179
- 'apikey': self.fetch_apikey(),
180
- 'content-type': 'application/json;charset=UTF-8',
181
- 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36'
182
  }
183
- data = {"refresh_token": self.refresh_token}
184
-
 
 
 
 
 
 
185
  try:
186
- response = self.session.post(url, headers=headers, json=data)
187
  response.raise_for_status()
188
- self.user_info = response.json()
189
- self.refresh_token = self.user_info.get('refresh_token', '')
190
-
191
- # 更新缓存中的 refresh_token
192
- cache_key = f"{self.email}|{self.password}"
193
- refresh_token_cache[cache_key] = self.refresh_token
194
-
195
  except requests.RequestException as e:
196
- self.logger.error(f"刷新令牌请求错误: {e}")
197
- # 如果刷新失败,尝试重新登录
198
- self.login()
199
-
200
- def get_jwt_value(self) -> str:
201
- """返回访问令牌。"""
202
- return self.user_info.get('access_token', '')
 
 
 
 
203
 
204
  # 全局的 AuthManager 对象,将在每次请求时更新
205
  auth_manager = None
@@ -217,9 +230,7 @@ def get_notdiamond_headers():
217
  'accept': 'text/event-stream',
218
  'accept-language': 'zh-CN,zh;q=0.9',
219
  'content-type': 'application/json',
220
- 'user-agent': ('Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
221
- 'AppleWebKit/537.36 (KHTML, like Gecko) '
222
- 'Chrome/128.0.0.0 Safari/537.36'),
223
  'authorization': f'Bearer {auth_manager.get_jwt_value()}'
224
  }
225
 
@@ -334,7 +345,7 @@ def handle_non_stream_response(response, model, prompt_tokens):
334
  "object": "chat.completion",
335
  "created": int(time.time()),
336
  "model": model,
337
- "system_fingerprint": generate_system_fingerprint(),
338
  "choices": [
339
  {
340
  "index": 0,
@@ -359,6 +370,7 @@ def generate_stream_response(response, model, prompt_tokens):
359
  for chunk in stream_notdiamond_response(response, model):
360
  content = chunk['choices'][0]['delta'].get('content', '')
361
  total_completion_tokens += count_tokens(content, model)
 
362
  chunk['usage'] = {
363
  "prompt_tokens": prompt_tokens,
364
  "completion_tokens": total_completion_tokens,
@@ -372,12 +384,8 @@ def generate_stream_response(response, model, prompt_tokens):
372
  def get_auth_credentials():
373
  """从请求头中获取认证凭据"""
374
  auth_header = request.headers.get('Authorization')
375
- if not auth_header:
376
- logger.error("Authorization header is missing")
377
- return None, None
378
-
379
- if not auth_header.startswith('Bearer '):
380
- logger.error(f"Authorization header format is incorrect: {auth_header}")
381
  return None, None
382
 
383
  try:
@@ -395,6 +403,8 @@ def before_request():
395
  email, password = get_auth_credentials()
396
  if email and password:
397
  auth_manager = AuthManager(email, password)
 
 
398
  else:
399
  auth_manager = None
400
 
@@ -542,6 +552,7 @@ def build_payload(request_data, model_id):
542
 
543
  def make_request(payload):
544
  """发送请求并处理可能的认证刷新。"""
 
545
  url = get_notdiamond_url()
546
  headers = get_notdiamond_headers()
547
  response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
@@ -562,3 +573,4 @@ def make_request(payload):
562
  if __name__ == "__main__":
563
  port = int(os.environ.get("PORT", 3000))
564
  app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
 
 
7
  import re
8
  import socket
9
  from concurrent.futures import ThreadPoolExecutor
10
+ from functools import lru_cache, wraps
11
+ from typing import Dict, Any, Callable
12
 
13
  import requests
14
  import tiktoken
15
  from flask import Flask, Response, jsonify, request, stream_with_context
16
  from flask_cors import CORS
 
17
  from requests.adapters import HTTPAdapter
18
  from urllib3.util.connection import create_connection
19
  import urllib3
20
+ from cachetools import TTLCache
21
 
22
  # Constants
23
  CHAT_COMPLETION_CHUNK = 'chat.completion.chunk'
24
  CHAT_COMPLETION = 'chat.completion'
25
  CONTENT_TYPE_EVENT_STREAM = 'text/event-stream'
26
+ _BASE_URL = "https://chat.notdiamond.ai"
27
+ _API_BASE_URL = "https://spuckhogycrxcbomznwo.supabase.co"
28
+ _USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36'
29
 
30
  app = Flask(__name__)
31
  logging.basicConfig(level=logging.INFO)
 
75
 
76
  class AuthManager:
77
  def __init__(self, email: str, password: str):
78
+ self._email: str = email
79
+ self._password: str = password
80
+ self._api_key: str = ""
81
+ self._user_info: Dict[str, Any] = {}
82
+ self._refresh_token: str = ""
83
+ self._access_token: str = ""
84
+ self._token_expiry: float = 0
85
+ self._session: requests.Session = create_custom_session()
86
 
87
+ self._logger: logging.Logger = logging.getLogger(__name__)
88
  logging.basicConfig(level=logging.INFO)
89
 
90
+ def login(self) -> bool:
91
+ """使用电子邮件和密码进行用户登录,并获取用户信息。"""
92
+ url = f"{_API_BASE_URL}/auth/v1/token?grant_type=password"
93
+ headers = self._get_headers(with_content_type=True)
94
+ data = {
95
+ "email": self._email,
96
+ "password": self._password,
97
+ "gotrue_meta_security": {}
98
+ }
99
+
100
+ try:
101
+ response = self._make_request('POST', url, headers=headers, json=data)
102
+ self._user_info = response.json()
103
+ self._refresh_token = self._user_info.get('refresh_token', '')
104
+ self._access_token = self._user_info.get('access_token', '')
105
+ self._token_expiry = time.time() + self._user_info.get('expires_in', 3600)
106
+ self._log_values()
107
+ return True
108
+ except requests.RequestException as e:
109
+ self._logger.error(f"\033[91m登录请求错误: {e}\033[0m")
110
+ return False
111
+
112
+ def refresh_user_token(self) -> bool:
113
+ """使用刷新令牌来请求一个新的访问令牌并更新实例变量。"""
114
+ url = f"{_API_BASE_URL}/auth/v1/token?grant_type=refresh_token"
115
+ headers = self._get_headers(with_content_type=True)
116
+ data = {"refresh_token": self._refresh_token}
117
+
118
+ try:
119
+ response = self._make_request('POST', url, headers=headers, json=data)
120
+ self._user_info = response.json()
121
+ self._refresh_token = self._user_info.get('refresh_token', '')
122
+ self._access_token = self._user_info.get('access_token', '')
123
+ self._token_expiry = time.time() + self._user_info.get('expires_in', 3600)
124
+ self._log_values()
125
+ return True
126
+ except requests.RequestException as e:
127
+ self._logger.error(f"刷新令牌请求错误: {e}")
128
+ return False
129
+
130
+ def get_jwt_value(self) -> str:
131
+ """返回访问令牌。"""
132
+ return self._access_token
133
+
134
+ def is_token_valid(self) -> bool:
135
+ """检查当前的访问令牌是否有效。"""
136
+ return bool(self._access_token) and time.time() < self._token_expiry
137
+
138
+ def ensure_valid_token(self) -> bool:
139
+ """确保有一个有效的访问令牌,如果需要则刷新或重新登录。"""
140
+ if self.is_token_valid():
141
+ return True
142
+ if self._refresh_token:
143
+ if self.refresh_user_token():
144
+ return True
145
+ return self.login()
146
+
147
+ def clear_auth(self) -> None:
148
+ """清除当前的授权信息。"""
149
+ self._user_info = {}
150
+ self._refresh_token = ""
151
+ self._access_token = ""
152
+ self._token_expiry = 0
153
+
154
+ def _log_values(self) -> None:
155
  """记录刷新令牌到日志中。"""
156
+ self._logger.info(f"\033[92mRefresh Token: {self._refresh_token}\033[0m")
157
+ self._logger.info(f"\033[92mAccess Token: {self._access_token}\033[0m")
158
 
159
+ def _fetch_apikey(self) -> str:
160
  """获取API密钥。"""
161
+ if self._api_key:
162
+ return self._api_key
163
+
164
  try:
165
+ login_url = f"{_BASE_URL}/login"
166
+ response = self._make_request('GET', login_url)
167
+
 
 
 
 
 
168
  match = re.search(r'<script src="(/_next/static/chunks/app/layout-[^"]+\.js)"', response.text)
169
  if not match:
170
+ raise ValueError("未找到匹配的脚本标签")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ js_url = f"{_BASE_URL}{match.group(1)}"
173
+ js_response = self._make_request('GET', js_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ api_key_match = re.search(r'$"https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"$', js_response.text)
176
+ if not api_key_match:
177
+ raise ValueError("未能匹配API key")
178
 
179
+ self._api_key = api_key_match.group(1)
180
+ return self._api_key
 
 
181
 
182
+ except (requests.RequestException, ValueError) as e:
183
+ self._logger.error(f"获取API密钥时发生错误: {e}")
184
+ return ""
185
+
186
+ def _get_headers(self, with_content_type: bool = False) -> Dict[str, str]:
187
+ """生成请求头。"""
188
  headers = {
189
+ 'apikey': self._fetch_apikey(),
190
+ 'user-agent': _USER_AGENT
 
191
  }
192
+ if with_content_type:
193
+ headers['Content-Type'] = 'application/json'
194
+ if self._access_token:
195
+ headers['Authorization'] = f'Bearer {self._access_token}'
196
+ return headers
197
+
198
+ def _make_request(self, method: str, url: str, **kwargs) -> requests.Response:
199
+ """发送HTTP请求并处理异常。"""
200
  try:
201
+ response = self._session.request(method, url, **kwargs)
202
  response.raise_for_status()
203
+ return response
 
 
 
 
 
 
204
  except requests.RequestException as e:
205
+ self._logger.error(f"请求错误 ({method} {url}): {e}")
206
+ raise
207
+
208
+ def require_auth(func: Callable) -> Callable:
209
+ """装饰器,确保在调用API之前有有效的token。"""
210
+ @wraps(func)
211
+ def wrapper(self, *args, **kwargs):
212
+ if not self.ensure_valid_token():
213
+ raise Exception("无法获取有效的授权token")
214
+ return func(self, *args, **kwargs)
215
+ return wrapper
216
 
217
  # 全局的 AuthManager 对象,将在每次请求时更新
218
  auth_manager = None
 
230
  'accept': 'text/event-stream',
231
  'accept-language': 'zh-CN,zh;q=0.9',
232
  'content-type': 'application/json',
233
+ 'user-agent': _USER_AGENT,
 
 
234
  'authorization': f'Bearer {auth_manager.get_jwt_value()}'
235
  }
236
 
 
345
  "object": "chat.completion",
346
  "created": int(time.time()),
347
  "model": model,
348
+ "system_fingerprint": generate_system_fingerprint(),
349
  "choices": [
350
  {
351
  "index": 0,
 
370
  for chunk in stream_notdiamond_response(response, model):
371
  content = chunk['choices'][0]['delta'].get('content', '')
372
  total_completion_tokens += count_tokens(content, model)
373
+
374
  chunk['usage'] = {
375
  "prompt_tokens": prompt_tokens,
376
  "completion_tokens": total_completion_tokens,
 
384
  def get_auth_credentials():
385
  """从请求头中获取认证凭据"""
386
  auth_header = request.headers.get('Authorization')
387
+ if not auth_header or not auth_header.startswith('Bearer '):
388
+ logger.error("Authorization header is missing or invalid")
 
 
 
 
389
  return None, None
390
 
391
  try:
 
403
  email, password = get_auth_credentials()
404
  if email and password:
405
  auth_manager = AuthManager(email, password)
406
+ if not auth_manager.ensure_valid_token():
407
+ auth_manager = None
408
  else:
409
  auth_manager = None
410
 
 
552
 
553
  def make_request(payload):
554
  """发送请求并处理可能的认证刷新。"""
555
+ global auth_manager
556
  url = get_notdiamond_url()
557
  headers = get_notdiamond_headers()
558
  response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
 
573
  if __name__ == "__main__":
574
  port = int(os.environ.get("PORT", 3000))
575
  app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
576
+