File size: 5,911 Bytes
4605201 fcbf0b1 fed49ba 651fc5d b05f070 dedefcd 651fc5d fed49ba 306d21c 4605201 b05f070 4605201 b05f070 dedefcd 4605201 b05f070 dedefcd 4605201 b05f070 4605201 b05f070 c2cd73a b05f070 b074ef9 4605201 b074ef9 4605201 b074ef9 4605201 b074ef9 4605201 b074ef9 acd9ed4 4605201 fed49ba b074ef9 651fc5d b074ef9 651fc5d fed49ba 4605201 b074ef9 651fc5d b074ef9 651fc5d b074ef9 4605201 b074ef9 c2cd73a b074ef9 c2cd73a b074ef9 c2cd73a b074ef9 4605201 fed49ba fcbf0b1 651fc5d fcbf0b1 d160f6b 4605201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from flask import Flask, request, Response, redirect, jsonify
from asgiref.wsgi import WsgiToAsgi
import requests
import json
import os
import jwt
import uuid
from datetime import datetime, timedelta
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = Flask(__name__)
def get_github_username_zed_userid():
"""
Get GitHub username and Zed user ID from the USER environment variable.
Returns:
tuple: A tuple containing the GitHub username (str) and Zed user ID (str).
"""
user_env = os.environ.get('USER', 'default_user,123456')
try:
username, user_id = user_env.split(',')
logger.info(f"Using GitHub username: {username} and Zed user ID: {user_id}")
return username.strip(), user_id.strip()
except ValueError:
logger.warning("Invalid format in USER environment variable. Using default values.")
return "default_user", "123456"
def create_jwt(github_user_login: str, user_id: str) -> str:
"""
Create a JSON Web Token (JWT) for a given GitHub user.
"""
LLM_TOKEN_LIFETIME = timedelta(hours=1)
now = datetime.utcnow()
payload = {
"iat": int(now.timestamp()),
"exp": int((now + LLM_TOKEN_LIFETIME).timestamp()),
"jti": str(uuid.uuid4()),
"userId": user_id,
"githubUserLogin": github_user_login,
"isStaff": False,
"hasLlmClosedBetaFeatureFlag": False,
"plan": "Free"
}
token = jwt.encode(payload, 'llm-secret', algorithm='HS256')
logger.info(f"Created JWT for user: {github_user_login}")
logger.info(f"JWT Token: {token}")
return token
@app.before_request
def before_request():
logger.info(f"Received request: {request.method} {request.url}")
logger.info(f"Request headers: {dict(request.headers)}")
if request.data:
logger.info(f"Request body: {request.get_data(as_text=True)}")
# 检查 X-Forwarded-Proto 头,而不是 URL
if request.headers.get('X-Forwarded-Proto') == 'http':
url = request.url.replace('http://', 'https://', 1)
logger.info(f"Redirecting to HTTPS: {url}")
return redirect(url, code=301)
# 不需要重定向,继续处理请求
return None
@app.route('/')
def root():
logger.info("Received request to root path")
return jsonify({
"status": "ok",
"message": "Welcome to the chat completion API",
"endpoints": {
"/chat/completions": "POST - Send chat completion requests"
}
}), 200
@app.route('/ai/v1/chat/completions', methods=['POST'])
def chat():
"""
Handle chat completion requests.
"""
logger.info("Processing chat completion request")
try:
# Get the payload from the request
payload = request.json
logger.info(f"Request payload: {payload}")
# Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
model = payload.get('model', 'claude-3-5-sonnet-20240620')
logger.info(f"Using model: {model}")
# Prepare the request for the LLM API
url = "https://llm.zed.dev/completion?"
llm_payload = {
"provider": "anthropic",
"model": model,
"provider_request": {
"model": model,
"max_tokens": payload.get('max_tokens', 8192),
"temperature": payload.get('temperature', 0),
"top_p": payload.get('top_p', 0.7),
"messages": payload['messages'],
"system": ""
}
}
github_username, zed_user_id = get_github_username_zed_userid()
jwt_token = create_jwt(github_username, zed_user_id)
headers = {
'Host': 'llm.zed.dev',
'accept': '*/*',
'content-type': 'application/json',
'authorization': f'Bearer {jwt_token}',
'user-agent': 'Zed/0.149.3 (macos; aarch64)'
}
# Get proxy from environment variable
proxy = os.environ.get('HTTP_PROXY', None)
proxies = {'http': proxy, 'https': proxy} if proxy else None
logger.info(f"Using proxy: {proxy}")
def generate():
logger.info("Starting to stream response")
try:
logger.info(f"Sending request to LLM API:")
logger.info(f"URL: {url}")
logger.info(f"Headers: {headers}")
logger.info(f"Payload: {json.dumps(llm_payload, indent=2)}")
with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response:
logger.info(f"LLM API response status: {response.status_code}")
logger.info(f"LLM API response headers: {dict(response.headers)}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
logger.info(f"Received chunk: {chunk.decode('utf-8')}")
yield chunk
except requests.RequestException as e:
logger.error(f"Error during LLM API request: {e}")
yield json.dumps({"error": "Internal server error"}).encode('utf-8')
logger.info("Finished streaming response")
return Response(generate(), content_type='application/octet-stream')
except Exception as e:
logger.error(f"Error processing request: {e}")
return jsonify({"error": "Internal server error"}), 500
# Convert the Flask app to an ASGI app
asgi_app = WsgiToAsgi(app)
if __name__ == '__main__':
import uvicorn
port = int(os.environ.get("PORT", 8000))
logger.info(f"Starting the application on port {port}")
uvicorn.run(asgi_app, host="0.0.0.0", port=port)
|