import logging from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse import httpx import json import os from helper import create_jwt # 设置日志 logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) app = FastAPI() @app.post('/ai/v1/chat/completions') async def chat(request: Request): logger.debug("Received request") # 获取客户端的 authorization 头 auth_header = request.headers.get('authorization') if not auth_header or not auth_header.startswith('Bearer '): raise HTTPException(status_code=401, detail="Invalid authorization header") # 提取 github_username 和 user_id try: _, auth_data = auth_header.split('Bearer ', 1) github_username, user_id_str = auth_data.split(',') user_id = int(user_id_str) except ValueError: raise HTTPException(status_code=401, detail="Invalid authorization format") # 生成 JWT token jwt_token = create_jwt(github_username, user_id) logger.debug(f"Generated JWT token: {jwt_token}") # 获取请求 payload payload = await request.json() logger.debug(f"Received payload: {payload}") # 获取模型,默认为 "claude-3-5-sonnet-20240620" model = payload.get('model', 'claude-3-5-sonnet-20240620') # 准备 LLM API 请求 url = "https://llm.zed.dev/completion?" llm_payload = { "provider": "anthropic", "model": model, "provider_request": { "model": model, "max_tokens": payload.get('max_tokens', 8192), "temperature": payload.get('temperature', 0), "top_p": payload.get('top_p', 0.7), "messages": payload['messages'], "system": "" } } logger.debug(f"LLM payload: {llm_payload}") headers = { 'Host': 'llm.zed.dev', 'accept': '*/*', 'content-type': 'application/json', 'authorization': f'Bearer {jwt_token}', # 使用新生成的 JWT token 'user-agent': 'Zed/0.149.3 (macos; aarch64)' } logger.debug(f"Request headers: {headers}") # 获取代理设置 proxy = os.environ.get('HTTP_PROXY', None) proxies = {'http': proxy, 'https': proxy} if proxy else None logger.debug(f"Using proxies: {proxies}") async def generate(): async with httpx.AsyncClient(proxies=proxies) as client: try: async with client.stream('POST', url, headers=headers, json=llm_payload) as response: logger.debug(f"LLM API response status: {response.status_code}") logger.debug(f"LLM API response headers: {response.headers}") if response.status_code != 200: error_content = await response.aread() logger.error(f"LLM API error response: {error_content}") yield f"Error: {response.status_code} - {error_content.decode()}" else: async for chunk in response.aiter_bytes(): yield chunk except Exception as e: logger.error(f"Error during LLM API request: {str(e)}") yield f"Error: {str(e)}" return StreamingResponse(generate(), media_type='application/octet-stream') if __name__ == '__main__': import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)