File size: 4,394 Bytes
fed49ba 651fc5d 75615ce ae87889 d328605 0d5e69a 651fc5d fed49ba 306d21c d328605 ae87889 d328605 3fb88cc fed49ba ae87889 d328605 ae87889 d328605 fed49ba 3fb88cc ae87889 651fc5d fed49ba 651fc5d ae87889 651fc5d fed49ba d328605 fed49ba d328605 fed49ba d328605 fed49ba d328605 ae87889 fed49ba 651fc5d ae87889 fed49ba d328605 651fc5d fed49ba 651fc5d fed49ba 651fc5d ae87889 651fc5d fed49ba 651fc5d ae87889 651fc5d 3fb88cc d328605 ae87889 3fb88cc d328605 ae87889 d328605 ae87889 d328605 3fb88cc d328605 ae87889 d328605 fed49ba ae87889 fed49ba 88f8e77 3fb88cc 88f8e77 651fc5d 3fb88cc 651fc5d 3fb88cc ae87889 3fb88cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from flask import Flask, request, Response
import requests
import json
import os
import sys
import logging
from helper import create_jwt
app = Flask(__name__)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
@app.route('/chat/completions', methods=['POST'])
async def chat():
"""
Handle chat completion requests.
"""
logger.info("Received chat completion request")
# 记录请求信息
logger.info(f"Request method: {request.method}")
logger.info(f"Request URL: {request.url}")
logger.info(f"Request headers: {dict(request.headers)}")
# Get the payload from the request
payload = await request.get_json()
logger.info(f"Request payload: {json.dumps(payload, indent=2)}")
# Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
model = payload.get('model', 'claude-3-5-sonnet-20240620')
logger.info(f"Using model: {model}")
# Extract GitHub username and Zed user ID from Authorization header
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
logger.error("Invalid Authorization header")
return Response('Invalid Authorization header', status=401)
try:
github_username, zed_user_id = auth_header[7:].split(',')
logger.info(f"GitHub username: {github_username}, Zed user ID: {zed_user_id}")
except ValueError:
logger.error("Invalid Authorization header format")
return Response('Invalid Authorization header format', status=401)
# Prepare the request for the LLM API
url = "https://llm.zed.dev/completion"
logger.info(f"LLM API URL: {url}")
llm_payload = {
"provider": "anthropic",
"model": model,
"provider_request": {
"model": model,
"max_tokens": payload.get('max_tokens', 8192),
"temperature": payload.get('temperature', 0),
"top_p": payload.get('top_p', 0.7),
"messages": payload['messages'],
"system": ""
}
}
logger.info(f"LLM API payload: {json.dumps(llm_payload, indent=2)}")
jwt = create_jwt(github_username, int(zed_user_id))
logger.info(f"Generated JWT token: {jwt}")
headers = {
'Host': 'llm.zed.dev',
'accept': '*/*',
'content-type': 'application/json',
'authorization': f'Bearer {jwt}',
'user-agent': 'Zed/0.149.3 (macos; aarch64)'
}
logger.info(f"Request headers: {json.dumps(headers, indent=2)}")
# Get proxy from environment variable
proxy = os.environ.get('HTTP_PROXY', None)
proxies = {'http': proxy, 'https': proxy} if proxy else None
logger.info(f"Using proxy: {proxy}")
async def generate():
try:
logger.info("Sending request to LLM API")
async with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies, allow_redirects=True) as response:
logger.info(f"LLM API response status: {response.status_code}")
logger.info(f"LLM API response headers: {dict(response.headers)}")
if response.status_code == 301:
new_location = response.headers.get('Location')
logger.warning(f"Received 301 redirect. New location: {new_location}")
# 如果需要,可以在这里处理重定向
async for chunk in response.iter_content(chunk_size=1024):
if chunk:
logger.debug(f"Received chunk of size: {len(chunk)} bytes")
yield chunk
except Exception as e:
logger.error(f"Error during API request: {str(e)}")
yield str(e).encode()
logger.info("Returning streaming response")
return Response(generate(), content_type='application/octet-stream')
@app.route('/', methods=['GET'])
async def home():
return "Welcome to the Chat Completion API", 200
# 创建 ASGI 应用
asgi_app = app.asgi_app
if __name__ == '__main__':
import uvicorn
logger.info("Starting the application")
uvicorn.run("app:asgi_app", host="0.0.0.0", port=8000, log_level="info")
|