Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import re
|
|
8 |
import socket
|
9 |
from concurrent.futures import ThreadPoolExecutor
|
10 |
from functools import lru_cache
|
|
|
11 |
|
12 |
import requests
|
13 |
import tiktoken
|
@@ -40,6 +41,9 @@ if not NOTDIAMOND_IP:
|
|
40 |
logger.error("NOTDIAMOND_IP environment variable is not set!")
|
41 |
raise ValueError("NOTDIAMOND_IP must be set")
|
42 |
|
|
|
|
|
|
|
43 |
# 自定义连接函数
|
44 |
def patched_create_connection(address, *args, **kwargs):
|
45 |
host, port = address
|
@@ -66,7 +70,6 @@ def create_custom_session():
|
|
66 |
session.mount('http://', adapter)
|
67 |
return session
|
68 |
|
69 |
-
# 修改 AuthManager 类以使用自定义 Session
|
70 |
class AuthManager:
|
71 |
def __init__(self, email: str, password: str):
|
72 |
self.email = email
|
@@ -79,10 +82,21 @@ class AuthManager:
|
|
79 |
logging.basicConfig(level=logging.INFO)
|
80 |
|
81 |
self.session = create_custom_session()
|
82 |
-
self.
|
83 |
self.fetch_apikey()
|
84 |
self.log_values()
|
|
|
|
|
|
|
|
|
|
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def log_values(self) -> None:
|
87 |
"""记录刷新令牌到日志中。"""
|
88 |
self.logger.info(f"\033[92mRefresh Token: {self.refresh_token}\033[0m")
|
@@ -109,7 +123,7 @@ class AuthManager:
|
|
109 |
js_response = self.session.get(js_url, headers=headers)
|
110 |
js_response.raise_for_status()
|
111 |
|
112 |
-
api_key_match = re.search(r'
|
113 |
if api_key_match:
|
114 |
self.api_key = api_key_match.group(1)
|
115 |
return self.api_key
|
@@ -140,6 +154,11 @@ class AuthManager:
|
|
140 |
response.raise_for_status()
|
141 |
self.user_info = response.json()
|
142 |
self.refresh_token = self.user_info.get('refresh_token', '')
|
|
|
|
|
|
|
|
|
|
|
143 |
except requests.RequestException as e:
|
144 |
self.logger.error(f"\033[91m登录请求错误: {e}\033[0m")
|
145 |
|
@@ -158,18 +177,22 @@ class AuthManager:
|
|
158 |
response.raise_for_status()
|
159 |
self.user_info = response.json()
|
160 |
self.refresh_token = self.user_info.get('refresh_token', '')
|
|
|
|
|
|
|
|
|
|
|
161 |
except requests.RequestException as e:
|
162 |
self.logger.error(f"刷新令牌请求错误: {e}")
|
|
|
|
|
163 |
|
164 |
def get_jwt_value(self) -> str:
|
165 |
"""返回访问令牌。"""
|
166 |
return self.user_info.get('access_token', '')
|
167 |
|
168 |
-
#
|
169 |
-
auth_manager =
|
170 |
-
os.getenv("AUTH_EMAIL", "[email protected]"),
|
171 |
-
os.getenv("AUTH_PASSWORD", "default_password"),
|
172 |
-
)
|
173 |
|
174 |
NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', 'https://not-diamond-workers.t7-cc4.workers.dev/stream-message').split(',')
|
175 |
|
@@ -263,7 +286,6 @@ def create_openai_chunk(content, model, finish_reason=None, usage=None):
|
|
263 |
chunk["usage"] = usage
|
264 |
return chunk
|
265 |
|
266 |
-
|
267 |
def count_tokens(text, model="gpt-3.5-turbo-0301"):
|
268 |
"""计算给定文本的令牌数量。"""
|
269 |
try:
|
@@ -327,7 +349,6 @@ def generate_stream_response(response, model, prompt_tokens):
|
|
327 |
for chunk in stream_notdiamond_response(response, model):
|
328 |
content = chunk['choices'][0]['delta'].get('content', '')
|
329 |
total_completion_tokens += count_tokens(content, model)
|
330 |
-
|
331 |
chunk['usage'] = {
|
332 |
"prompt_tokens": prompt_tokens,
|
333 |
"completion_tokens": total_completion_tokens,
|
@@ -338,6 +359,28 @@ def generate_stream_response(response, model, prompt_tokens):
|
|
338 |
|
339 |
yield "data: [DONE]\n\n"
|
340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
@app.route('/', methods=['GET'])
|
342 |
def root():
|
343 |
return jsonify({
|
@@ -385,6 +428,10 @@ def proxy_models():
|
|
385 |
@app.route('/ai/v1/chat/completions', methods=['POST'])
|
386 |
def handle_request():
|
387 |
"""处理聊天完成请求。"""
|
|
|
|
|
|
|
|
|
388 |
try:
|
389 |
request_data = request.get_json()
|
390 |
model_id = request_data.get('model', '')
|
@@ -480,39 +527,20 @@ def make_request(payload):
|
|
480 |
"""发送请求并处理可能的认证刷新。"""
|
481 |
url = get_notdiamond_url()
|
482 |
headers = get_notdiamond_headers()
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
response = executor.submit(session.post, url, headers=headers, json=payload, stream=True).result()
|
487 |
-
logger.info(f"Response status code: {response.status_code}")
|
488 |
-
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
|
489 |
-
return response
|
490 |
-
except Exception as e:
|
491 |
-
logger.error(f"Error in make_request: {str(e)}")
|
492 |
-
raise
|
493 |
|
494 |
-
logger.info("Refreshing user token...")
|
495 |
auth_manager.refresh_user_token()
|
496 |
headers = get_notdiamond_headers()
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
|
501 |
-
return response
|
502 |
-
except Exception as e:
|
503 |
-
logger.error(f"Error in make_request after token refresh: {str(e)}")
|
504 |
-
raise
|
505 |
|
506 |
-
logger.info("Logging in again...")
|
507 |
auth_manager.login()
|
508 |
headers = get_notdiamond_headers()
|
509 |
-
|
510 |
-
|
511 |
-
logger.info(f"Response status code after login: {response.status_code}")
|
512 |
-
return response
|
513 |
-
except Exception as e:
|
514 |
-
logger.error(f"Error in make_request after login: {str(e)}")
|
515 |
-
raise
|
516 |
|
517 |
if __name__ == "__main__":
|
518 |
port = int(os.environ.get("PORT", 3000))
|
|
|
8 |
import socket
|
9 |
from concurrent.futures import ThreadPoolExecutor
|
10 |
from functools import lru_cache
|
11 |
+
from cachetools import TTLCache
|
12 |
|
13 |
import requests
|
14 |
import tiktoken
|
|
|
41 |
logger.error("NOTDIAMOND_IP environment variable is not set!")
|
42 |
raise ValueError("NOTDIAMOND_IP must be set")
|
43 |
|
44 |
+
# 创建一个 TTLCache 来存储 refresh_token
|
45 |
+
refresh_token_cache = TTLCache(maxsize=1000, ttl=3600)
|
46 |
+
|
47 |
# 自定义连接函数
|
48 |
def patched_create_connection(address, *args, **kwargs):
|
49 |
host, port = address
|
|
|
70 |
session.mount('http://', adapter)
|
71 |
return session
|
72 |
|
|
|
73 |
class AuthManager:
|
74 |
def __init__(self, email: str, password: str):
|
75 |
self.email = email
|
|
|
82 |
logging.basicConfig(level=logging.INFO)
|
83 |
|
84 |
self.session = create_custom_session()
|
85 |
+
self.initialize_auth()
|
86 |
self.fetch_apikey()
|
87 |
self.log_values()
|
88 |
+
|
89 |
+
def initialize_auth(self):
|
90 |
+
"""初始化认证,优先使用缓存的 refresh_token"""
|
91 |
+
cache_key = f"{self.email}|{self.password}"
|
92 |
+
cached_token = refresh_token_cache.get(cache_key)
|
93 |
|
94 |
+
if cached_token:
|
95 |
+
self.refresh_token = cached_token
|
96 |
+
self.refresh_user_token()
|
97 |
+
else:
|
98 |
+
self.login()
|
99 |
+
|
100 |
def log_values(self) -> None:
|
101 |
"""记录刷新令牌到日志中。"""
|
102 |
self.logger.info(f"\033[92mRefresh Token: {self.refresh_token}\033[0m")
|
|
|
123 |
js_response = self.session.get(js_url, headers=headers)
|
124 |
js_response.raise_for_status()
|
125 |
|
126 |
+
api_key_match = re.search(r'$"https://spuckhogycrxcbomznwo\.supabase\.co","([^"]+)"$', js_response.text)
|
127 |
if api_key_match:
|
128 |
self.api_key = api_key_match.group(1)
|
129 |
return self.api_key
|
|
|
154 |
response.raise_for_status()
|
155 |
self.user_info = response.json()
|
156 |
self.refresh_token = self.user_info.get('refresh_token', '')
|
157 |
+
|
158 |
+
# 缓存 refresh_token
|
159 |
+
cache_key = f"{self.email}|{self.password}"
|
160 |
+
refresh_token_cache[cache_key] = self.refresh_token
|
161 |
+
|
162 |
except requests.RequestException as e:
|
163 |
self.logger.error(f"\033[91m登录请求错误: {e}\033[0m")
|
164 |
|
|
|
177 |
response.raise_for_status()
|
178 |
self.user_info = response.json()
|
179 |
self.refresh_token = self.user_info.get('refresh_token', '')
|
180 |
+
|
181 |
+
# 更新缓存中的 refresh_token
|
182 |
+
cache_key = f"{self.email}|{self.password}"
|
183 |
+
refresh_token_cache[cache_key] = self.refresh_token
|
184 |
+
|
185 |
except requests.RequestException as e:
|
186 |
self.logger.error(f"刷新令牌请求错误: {e}")
|
187 |
+
# 如果刷新失败,尝试重新登录
|
188 |
+
self.login()
|
189 |
|
190 |
def get_jwt_value(self) -> str:
|
191 |
"""返回访问令牌。"""
|
192 |
return self.user_info.get('access_token', '')
|
193 |
|
194 |
+
# 全局的 AuthManager 对象,将在每次请求时更新
|
195 |
+
auth_manager = None
|
|
|
|
|
|
|
196 |
|
197 |
NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', 'https://not-diamond-workers.t7-cc4.workers.dev/stream-message').split(',')
|
198 |
|
|
|
286 |
chunk["usage"] = usage
|
287 |
return chunk
|
288 |
|
|
|
289 |
def count_tokens(text, model="gpt-3.5-turbo-0301"):
|
290 |
"""计算给定文本的令牌数量。"""
|
291 |
try:
|
|
|
349 |
for chunk in stream_notdiamond_response(response, model):
|
350 |
content = chunk['choices'][0]['delta'].get('content', '')
|
351 |
total_completion_tokens += count_tokens(content, model)
|
|
|
352 |
chunk['usage'] = {
|
353 |
"prompt_tokens": prompt_tokens,
|
354 |
"completion_tokens": total_completion_tokens,
|
|
|
359 |
|
360 |
yield "data: [DONE]\n\n"
|
361 |
|
362 |
+
def get_auth_credentials():
|
363 |
+
"""从请求头中获取认证凭据"""
|
364 |
+
auth_header = request.headers.get('Authorization')
|
365 |
+
if not auth_header or not auth_header.startswith('Bearer '):
|
366 |
+
return None, None
|
367 |
+
|
368 |
+
try:
|
369 |
+
credentials = auth_header.split('Bearer ')[1]
|
370 |
+
email, password = credentials.split('|')
|
371 |
+
return email.strip(), password.strip()
|
372 |
+
except:
|
373 |
+
return None, None
|
374 |
+
|
375 |
+
@app.before_request
|
376 |
+
def before_request():
|
377 |
+
global auth_manager
|
378 |
+
email, password = get_auth_credentials()
|
379 |
+
if email and password:
|
380 |
+
auth_manager = AuthManager(email, password)
|
381 |
+
else:
|
382 |
+
auth_manager = None
|
383 |
+
|
384 |
@app.route('/', methods=['GET'])
|
385 |
def root():
|
386 |
return jsonify({
|
|
|
428 |
@app.route('/ai/v1/chat/completions', methods=['POST'])
|
429 |
def handle_request():
|
430 |
"""处理聊天完成请求。"""
|
431 |
+
global auth_manager
|
432 |
+
if not auth_manager:
|
433 |
+
return jsonify({'error': 'Unauthorized'}), 401
|
434 |
+
|
435 |
try:
|
436 |
request_data = request.get_json()
|
437 |
model_id = request_data.get('model', '')
|
|
|
527 |
"""发送请求并处理可能的认证刷新。"""
|
528 |
url = get_notdiamond_url()
|
529 |
headers = get_notdiamond_headers()
|
530 |
+
response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
|
531 |
+
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
|
532 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
|
|
|
534 |
auth_manager.refresh_user_token()
|
535 |
headers = get_notdiamond_headers()
|
536 |
+
response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
|
537 |
+
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
|
538 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
539 |
|
|
|
540 |
auth_manager.login()
|
541 |
headers = get_notdiamond_headers()
|
542 |
+
response = executor.submit(requests.post, url, headers=headers, json=payload, stream=True).result()
|
543 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
if __name__ == "__main__":
|
546 |
port = int(os.environ.get("PORT", 3000))
|