import json import logging import os import random import time import uuid import re import socket from concurrent.futures import ThreadPoolExecutor from functools import lru_cache, wraps from typing import Dict, Any, Callable import requests import tiktoken from flask import Flask, Response, jsonify, request, stream_with_context from flask_cors import CORS from requests.adapters import HTTPAdapter from urllib3.util.connection import create_connection import urllib3 from cachetools import TTLCache # Constants CHAT_COMPLETION_CHUNK = 'chat.completion.chunk' CHAT_COMPLETION = 'chat.completion' CONTENT_TYPE_EVENT_STREAM = 'text/event-stream' _BASE_URL = "https://chat.notdiamond.ai" _API_BASE_URL = "https://spuckhogycrxcbomznwo.supabase.co" _USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36' app = Flask(__name__) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) executor = ThreadPoolExecutor(max_workers=10) proxy_url = os.getenv('PROXY_URL') # 获取环境变量中指定的 IP 地址 NOTDIAMOND_IP = os.getenv('NOTDIAMOND_IP') NOTDIAMOND_DOMAIN = 'not-diamond-workers.t7-cc4.workers.dev' if not NOTDIAMOND_IP: logger.error("NOTDIAMOND_IP environment variable is not set!") raise ValueError("NOTDIAMOND_IP must be set") # 创建一个 TTLCache 来存储 refresh_token refresh_token_cache = TTLCache(maxsize=1000, ttl=3600) # 自定义连接函数 def patched_create_connection(address, *args, **kwargs): host, port = address if host == NOTDIAMOND_DOMAIN: logger.info(f"Connecting to {NOTDIAMOND_DOMAIN} using IP: {NOTDIAMOND_IP}") return create_connection((NOTDIAMOND_IP, port), *args, **kwargs) return create_connection(address, *args, **kwargs) # 替换 urllib3 的默认连接函数 urllib3.util.connection.create_connection = patched_create_connection # 自定义 HTTPAdapter class CustomHTTPAdapter(HTTPAdapter): def init_poolmanager(self, *args, **kwargs): kwargs['socket_options'] = kwargs.get('socket_options', []) kwargs['socket_options'] += [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] return super(CustomHTTPAdapter, self).init_poolmanager(*args, **kwargs) # 创建自定义的 Session def create_custom_session(): session = requests.Session() adapter = CustomHTTPAdapter() session.mount('https://', adapter) session.mount('http://', adapter) return session class AuthManager: def __init__(self, email: str, password: str): self._email: str = email self._password: str = password self._api_key: str = "" self._user_info: Dict[str, Any] = {} self._refresh_token: str = "" self._access_token: str = "" self._token_expiry: float = 0 self._session: requests.Session = create_custom_session() self._logger: logging.Logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def login(self) -> bool: """使用电子邮件和密码进行用户登录,并获取用户信息。""" url = f"{_API_BASE_URL}/auth/v1/token?grant_type=password" headers = self._get_headers(with_content_type=True) data = { "email": self._email, "password": self._password, "gotrue_meta_security": {} } try: response = self._make_request('POST', url, headers=headers, json=data) self._user_info = response.json() self._refresh_token = self._user_info.get('refresh_token', '') self._access_token = self._user_info.get('access_token', '') self._token_expiry = time.time() + self._user_info.get('expires_in', 3600) self._log_values() return True except requests.RequestException as e: self._logger.error(f"\033[91m登录请求错误: {e}\033[0m") return False def refresh_user_token(self) -> bool: """使用刷新令牌来请求一个新的访问令牌并更新实例变量。""" url = f"{_API_BASE_URL}/auth/v1/token?grant_type=refresh_token" headers = self._get_headers(with_content_type=True) data = {"refresh_token": self._refresh_token} try: response = self._make_request('POST', url, headers=headers, json=data) self._user_info = response.json() self._refresh_token = self._user_info.get('refresh_token', '') self._access_token = self._user_info.get('access_token', '') self._token_expiry = time.time() + self._user_info.get('expires_in', 3600) self._log_values() return True except requests.RequestException as e: self._logger.error(f"刷新令牌请求错误: {e}") return False def get_jwt_value(self) -> str: """返回访问令牌。""" return self._access_token def is_token_valid(self) -> bool: """检查当前的访问令牌是否有效。""" return bool(self._access_token) and time.time() < self._token_expiry def ensure_valid_token(self) -> bool: """确保有一个有效的访问令牌,如果需要则刷新或重新登录。""" if self.is_token_valid(): return True if self._refresh_token: if self.refresh_user_token(): return True return self.login() def clear_auth(self) -> None: """清除当前的授权信息。""" self._user_info = {} self._refresh_token = "" self._access_token = "" self._token_expiry = 0 def _log_values(self) -> None: """记录刷新令牌到日志中。""" self._logger.info(f"\033[92mRefresh Token: {self._refresh_token}\033[0m") self._logger.info(f"\033[92mAccess Token: {self._access_token}\033[0m") def _fetch_apikey(self) -> str: """获取API密钥。""" if self._api_key: return self._api_key try: login_url = f"{_BASE_URL}/login" response = self._make_request('GET', login_url) match = re.search(r'