Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,22 +7,25 @@ import uuid
|
|
7 |
import re
|
8 |
import socket
|
9 |
from concurrent.futures import ThreadPoolExecutor
|
10 |
-
from functools import lru_cache
|
11 |
-
from
|
12 |
|
13 |
import requests
|
14 |
import tiktoken
|
15 |
from flask import Flask, Response, jsonify, request, stream_with_context
|
16 |
from flask_cors import CORS
|
17 |
-
from typing import Dict, Any
|
18 |
from requests.adapters import HTTPAdapter
|
19 |
from urllib3.util.connection import create_connection
|
20 |
import urllib3
|
|
|
21 |
|
22 |
# Constants
|
23 |
CHAT_COMPLETION_CHUNK = 'chat.completion.chunk'
|
24 |
CHAT_COMPLETION = 'chat.completion'
|
25 |
CONTENT_TYPE_EVENT_STREAM = 'text/event-stream'
|
|
|
|
|
|
|
26 |
|
27 |
app = Flask(__name__)
|
28 |
logging.basicConfig(level=logging.INFO)
|
@@ -72,134 +75,144 @@ def create_custom_session():
|
|
72 |
|
73 |
class AuthManager:
|
74 |
def __init__(self, email: str, password: str):
|
75 |
-
self.
|
76 |
-
self.
|
77 |
-
self.
|
78 |
-
self.
|
79 |
-
self.
|
|
|
|
|
|
|
80 |
|
81 |
-
self.
|
82 |
logging.basicConfig(level=logging.INFO)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
self.
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
self.
|
96 |
-
self.
|
97 |
-
|
98 |
-
self.
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
"""记录刷新令牌到日志中。"""
|
102 |
-
self.
|
|
|
103 |
|
104 |
-
def
|
105 |
"""获取API密钥。"""
|
106 |
-
if self.
|
107 |
-
return self.
|
108 |
-
|
109 |
try:
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
}
|
114 |
-
response = self.session.get(url, headers=headers)
|
115 |
-
response.raise_for_status()
|
116 |
-
|
117 |
-
# 匹配 <script> 标签中的 JS 文件路径
|
118 |
match = re.search(r'<script src="(/_next/static/chunks/app/layout-[^"]+\.js)"', response.text)
|
119 |
if not match:
|
120 |
-
|
121 |
-
return ""
|
122 |
-
|
123 |
-
js_url = f"https://chat.notdiamond.ai{match.group(1)}"
|
124 |
-
js_response = self.session.get(js_url, headers=headers)
|
125 |
-
js_response.raise_for_status()
|
126 |
-
|
127 |
-
# 匹配 API key
|
128 |
-
api_key_match = re.search(r'\("https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"\)', js_response.text)
|
129 |
-
if api_key_match:
|
130 |
-
self.api_key = api_key_match.group(1)
|
131 |
-
logger.info(f"Extracted API key: {self.api_key}")
|
132 |
-
return self.api_key
|
133 |
-
else:
|
134 |
-
self.logger.error("未能匹配API key")
|
135 |
-
return ""
|
136 |
-
|
137 |
-
except requests.RequestException as e:
|
138 |
-
self.logger.error(f"请求JS文件时发生错误: {e}")
|
139 |
-
return ""
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
api_key = self.fetch_apikey()
|
144 |
-
if not api_key:
|
145 |
-
self.logger.error("API key is missing, cannot proceed with login.")
|
146 |
-
return
|
147 |
-
|
148 |
-
url = "https://spuckhogycrxcbomznwo.supabase.co/auth/v1/token?grant_type=password"
|
149 |
-
headers = {
|
150 |
-
'apikey': api_key,
|
151 |
-
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
|
152 |
-
'Content-Type': 'application/json'
|
153 |
-
}
|
154 |
-
data = {
|
155 |
-
"email": self.email,
|
156 |
-
"password": self.password,
|
157 |
-
"gotrue_meta_security": {}
|
158 |
-
}
|
159 |
-
|
160 |
-
try:
|
161 |
-
response = self.session.post(url, headers=headers, json=data)
|
162 |
-
response.raise_for_status()
|
163 |
-
self.user_info = response.json()
|
164 |
-
self.refresh_token = self.user_info.get('refresh_token', '')
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
except requests.RequestException as e:
|
173 |
-
self.logger.error(f"登录请求错误: {e}")
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
178 |
headers = {
|
179 |
-
'apikey': self.
|
180 |
-
'
|
181 |
-
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36'
|
182 |
}
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
try:
|
186 |
-
response = self.
|
187 |
response.raise_for_status()
|
188 |
-
|
189 |
-
self.refresh_token = self.user_info.get('refresh_token', '')
|
190 |
-
|
191 |
-
# 更新缓存中的 refresh_token
|
192 |
-
cache_key = f"{self.email}|{self.password}"
|
193 |
-
refresh_token_cache[cache_key] = self.refresh_token
|
194 |
-
|
195 |
except requests.RequestException as e:
|
196 |
-
self.
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
203 |
|
204 |
# 全局的 AuthManager 对象,将在每次请求时更新
|
205 |
auth_manager = None
|
@@ -217,9 +230,7 @@ def get_notdiamond_headers():
|
|
217 |
'accept': 'text/event-stream',
|
218 |
'accept-language': 'zh-CN,zh;q=0.9',
|
219 |
'content-type': 'application/json',
|
220 |
-
'user-agent':
|
221 |
-
'AppleWebKit/537.36 (KHTML, like Gecko) '
|
222 |
-
'Chrome/128.0.0.0 Safari/537.36'),
|
223 |
'authorization': f'Bearer {auth_manager.get_jwt_value()}'
|
224 |
}
|
225 |
|
@@ -334,7 +345,7 @@ def handle_non_stream_response(response, model, prompt_tokens):
|
|
334 |
"object": "chat.completion",
|
335 |
"created": int(time.time()),
|
336 |
"model": model,
|
337 |
-
|
338 |
"choices": [
|
339 |
{
|
340 |
"index": 0,
|
@@ -359,6 +370,7 @@ def generate_stream_response(response, model, prompt_tokens):
|
|
359 |
for chunk in stream_notdiamond_response(response, model):
|
360 |
content = chunk['choices'][0]['delta'].get('content', '')
|
361 |
total_completion_tokens += count_tokens(content, model)
|
|
|
362 |
chunk['usage'] = {
|
363 |
"prompt_tokens": prompt_tokens,
|
364 |
"completion_tokens": total_completion_tokens,
|
@@ -372,12 +384,8 @@ def generate_stream_response(response, model, prompt_tokens):
|
|
372 |
def get_auth_credentials():
|
373 |
"""从请求头中获取认证凭据"""
|
374 |
auth_header = request.headers.get('Authorization')
|
375 |
-
if not auth_header:
|
376 |
-
logger.error("Authorization header is missing")
|
377 |
-
return None, None
|
378 |
-
|
379 |
-
if not auth_header.startswith('Bearer '):
|
380 |
-
logger.error(f"Authorization header format is incorrect: {auth_header}")
|
381 |
return None, None
|
382 |
|
383 |
try:
|
@@ -395,6 +403,8 @@ def before_request():
|
|
395 |
email, password = get_auth_credentials()
|
396 |
if email and password:
|
397 |
auth_manager = AuthManager(email, password)
|
|
|
|
|
398 |
else:
|
399 |
auth_manager = None
|
400 |
|
@@ -542,6 +552,7 @@ def build_payload(request_data, model_id):
|
|
542 |
|
543 |
def make_request(payload):
|
544 |
"""发送请求并处理可能的认证刷新。"""
|
|
|
545 |
url = get_notdiamond_url()
|
546 |
headers = get_notdiamond_headers()
|
547 |
response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
|
@@ -562,3 +573,4 @@ def make_request(payload):
|
|
562 |
if __name__ == "__main__":
|
563 |
port = int(os.environ.get("PORT", 3000))
|
564 |
app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
|
|
|
|
7 |
import re
|
8 |
import socket
|
9 |
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
from functools import lru_cache, wraps
|
11 |
+
from typing import Dict, Any, Callable
|
12 |
|
13 |
import requests
|
14 |
import tiktoken
|
15 |
from flask import Flask, Response, jsonify, request, stream_with_context
|
16 |
from flask_cors import CORS
|
|
|
17 |
from requests.adapters import HTTPAdapter
|
18 |
from urllib3.util.connection import create_connection
|
19 |
import urllib3
|
20 |
+
from cachetools import TTLCache
|
21 |
|
22 |
# Constants
|
23 |
CHAT_COMPLETION_CHUNK = 'chat.completion.chunk'
|
24 |
CHAT_COMPLETION = 'chat.completion'
|
25 |
CONTENT_TYPE_EVENT_STREAM = 'text/event-stream'
|
26 |
+
_BASE_URL = "https://chat.notdiamond.ai"
|
27 |
+
_API_BASE_URL = "https://spuckhogycrxcbomznwo.supabase.co"
|
28 |
+
_USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36'
|
29 |
|
30 |
app = Flask(__name__)
|
31 |
logging.basicConfig(level=logging.INFO)
|
|
|
75 |
|
76 |
class AuthManager:
|
77 |
def __init__(self, email: str, password: str):
|
78 |
+
self._email: str = email
|
79 |
+
self._password: str = password
|
80 |
+
self._api_key: str = ""
|
81 |
+
self._user_info: Dict[str, Any] = {}
|
82 |
+
self._refresh_token: str = ""
|
83 |
+
self._access_token: str = ""
|
84 |
+
self._token_expiry: float = 0
|
85 |
+
self._session: requests.Session = create_custom_session()
|
86 |
|
87 |
+
self._logger: logging.Logger = logging.getLogger(__name__)
|
88 |
logging.basicConfig(level=logging.INFO)
|
89 |
|
90 |
+
def login(self) -> bool:
|
91 |
+
"""使用电子邮件和密码进行用户登录,并获取用户信息。"""
|
92 |
+
url = f"{_API_BASE_URL}/auth/v1/token?grant_type=password"
|
93 |
+
headers = self._get_headers(with_content_type=True)
|
94 |
+
data = {
|
95 |
+
"email": self._email,
|
96 |
+
"password": self._password,
|
97 |
+
"gotrue_meta_security": {}
|
98 |
+
}
|
99 |
+
|
100 |
+
try:
|
101 |
+
response = self._make_request('POST', url, headers=headers, json=data)
|
102 |
+
self._user_info = response.json()
|
103 |
+
self._refresh_token = self._user_info.get('refresh_token', '')
|
104 |
+
self._access_token = self._user_info.get('access_token', '')
|
105 |
+
self._token_expiry = time.time() + self._user_info.get('expires_in', 3600)
|
106 |
+
self._log_values()
|
107 |
+
return True
|
108 |
+
except requests.RequestException as e:
|
109 |
+
self._logger.error(f"\033[91m登录请求错误: {e}\033[0m")
|
110 |
+
return False
|
111 |
+
|
112 |
+
def refresh_user_token(self) -> bool:
|
113 |
+
"""使用刷新令牌来请求一个新的访问令牌并更新实例变量。"""
|
114 |
+
url = f"{_API_BASE_URL}/auth/v1/token?grant_type=refresh_token"
|
115 |
+
headers = self._get_headers(with_content_type=True)
|
116 |
+
data = {"refresh_token": self._refresh_token}
|
117 |
+
|
118 |
+
try:
|
119 |
+
response = self._make_request('POST', url, headers=headers, json=data)
|
120 |
+
self._user_info = response.json()
|
121 |
+
self._refresh_token = self._user_info.get('refresh_token', '')
|
122 |
+
self._access_token = self._user_info.get('access_token', '')
|
123 |
+
self._token_expiry = time.time() + self._user_info.get('expires_in', 3600)
|
124 |
+
self._log_values()
|
125 |
+
return True
|
126 |
+
except requests.RequestException as e:
|
127 |
+
self._logger.error(f"刷新令牌请求错误: {e}")
|
128 |
+
return False
|
129 |
+
|
130 |
+
def get_jwt_value(self) -> str:
|
131 |
+
"""返回访问令牌。"""
|
132 |
+
return self._access_token
|
133 |
+
|
134 |
+
def is_token_valid(self) -> bool:
|
135 |
+
"""检查当前的访问令牌是否有效。"""
|
136 |
+
return bool(self._access_token) and time.time() < self._token_expiry
|
137 |
+
|
138 |
+
def ensure_valid_token(self) -> bool:
|
139 |
+
"""确保有一个有效的访问令牌,如果需要则刷新或重新登录。"""
|
140 |
+
if self.is_token_valid():
|
141 |
+
return True
|
142 |
+
if self._refresh_token:
|
143 |
+
if self.refresh_user_token():
|
144 |
+
return True
|
145 |
+
return self.login()
|
146 |
+
|
147 |
+
def clear_auth(self) -> None:
|
148 |
+
"""清除当前的授权信息。"""
|
149 |
+
self._user_info = {}
|
150 |
+
self._refresh_token = ""
|
151 |
+
self._access_token = ""
|
152 |
+
self._token_expiry = 0
|
153 |
+
|
154 |
+
def _log_values(self) -> None:
|
155 |
"""记录刷新令牌到日志中。"""
|
156 |
+
self._logger.info(f"\033[92mRefresh Token: {self._refresh_token}\033[0m")
|
157 |
+
self._logger.info(f"\033[92mAccess Token: {self._access_token}\033[0m")
|
158 |
|
159 |
+
def _fetch_apikey(self) -> str:
|
160 |
"""获取API密钥。"""
|
161 |
+
if self._api_key:
|
162 |
+
return self._api_key
|
163 |
+
|
164 |
try:
|
165 |
+
login_url = f"{_BASE_URL}/login"
|
166 |
+
response = self._make_request('GET', login_url)
|
167 |
+
|
|
|
|
|
|
|
|
|
|
|
168 |
match = re.search(r'<script src="(/_next/static/chunks/app/layout-[^"]+\.js)"', response.text)
|
169 |
if not match:
|
170 |
+
raise ValueError("未找到匹配的脚本标签")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
js_url = f"{_BASE_URL}{match.group(1)}"
|
173 |
+
js_response = self._make_request('GET', js_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
api_key_match = re.search(r'$"https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"$', js_response.text)
|
176 |
+
if not api_key_match:
|
177 |
+
raise ValueError("未能匹配API key")
|
178 |
|
179 |
+
self._api_key = api_key_match.group(1)
|
180 |
+
return self._api_key
|
|
|
|
|
181 |
|
182 |
+
except (requests.RequestException, ValueError) as e:
|
183 |
+
self._logger.error(f"获取API密钥时发生错误: {e}")
|
184 |
+
return ""
|
185 |
+
|
186 |
+
def _get_headers(self, with_content_type: bool = False) -> Dict[str, str]:
|
187 |
+
"""生成请求头。"""
|
188 |
headers = {
|
189 |
+
'apikey': self._fetch_apikey(),
|
190 |
+
'user-agent': _USER_AGENT
|
|
|
191 |
}
|
192 |
+
if with_content_type:
|
193 |
+
headers['Content-Type'] = 'application/json'
|
194 |
+
if self._access_token:
|
195 |
+
headers['Authorization'] = f'Bearer {self._access_token}'
|
196 |
+
return headers
|
197 |
+
|
198 |
+
def _make_request(self, method: str, url: str, **kwargs) -> requests.Response:
|
199 |
+
"""发送HTTP请求并处理异常。"""
|
200 |
try:
|
201 |
+
response = self._session.request(method, url, **kwargs)
|
202 |
response.raise_for_status()
|
203 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
except requests.RequestException as e:
|
205 |
+
self._logger.error(f"请求错误 ({method} {url}): {e}")
|
206 |
+
raise
|
207 |
+
|
208 |
+
def require_auth(func: Callable) -> Callable:
|
209 |
+
"""装饰器,确保在调用API之前有有效的token。"""
|
210 |
+
@wraps(func)
|
211 |
+
def wrapper(self, *args, **kwargs):
|
212 |
+
if not self.ensure_valid_token():
|
213 |
+
raise Exception("无法获取有效的授权token")
|
214 |
+
return func(self, *args, **kwargs)
|
215 |
+
return wrapper
|
216 |
|
217 |
# 全局的 AuthManager 对象,将在每次请求时更新
|
218 |
auth_manager = None
|
|
|
230 |
'accept': 'text/event-stream',
|
231 |
'accept-language': 'zh-CN,zh;q=0.9',
|
232 |
'content-type': 'application/json',
|
233 |
+
'user-agent': _USER_AGENT,
|
|
|
|
|
234 |
'authorization': f'Bearer {auth_manager.get_jwt_value()}'
|
235 |
}
|
236 |
|
|
|
345 |
"object": "chat.completion",
|
346 |
"created": int(time.time()),
|
347 |
"model": model,
|
348 |
+
"system_fingerprint": generate_system_fingerprint(),
|
349 |
"choices": [
|
350 |
{
|
351 |
"index": 0,
|
|
|
370 |
for chunk in stream_notdiamond_response(response, model):
|
371 |
content = chunk['choices'][0]['delta'].get('content', '')
|
372 |
total_completion_tokens += count_tokens(content, model)
|
373 |
+
|
374 |
chunk['usage'] = {
|
375 |
"prompt_tokens": prompt_tokens,
|
376 |
"completion_tokens": total_completion_tokens,
|
|
|
384 |
def get_auth_credentials():
|
385 |
"""从请求头中获取认证凭据"""
|
386 |
auth_header = request.headers.get('Authorization')
|
387 |
+
if not auth_header or not auth_header.startswith('Bearer '):
|
388 |
+
logger.error("Authorization header is missing or invalid")
|
|
|
|
|
|
|
|
|
389 |
return None, None
|
390 |
|
391 |
try:
|
|
|
403 |
email, password = get_auth_credentials()
|
404 |
if email and password:
|
405 |
auth_manager = AuthManager(email, password)
|
406 |
+
if not auth_manager.ensure_valid_token():
|
407 |
+
auth_manager = None
|
408 |
else:
|
409 |
auth_manager = None
|
410 |
|
|
|
552 |
|
553 |
def make_request(payload):
|
554 |
"""发送请求并处理可能的认证刷新。"""
|
555 |
+
global auth_manager
|
556 |
url = get_notdiamond_url()
|
557 |
headers = get_notdiamond_headers()
|
558 |
response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
|
|
|
573 |
if __name__ == "__main__":
|
574 |
port = int(os.environ.get("PORT", 3000))
|
575 |
app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
|
576 |
+
|