File size: 3,402 Bytes
306d21c 0d5e69a 6ff3d3b 651fc5d 75615ce 324f554 0d5e69a 651fc5d 306d21c 6ff3d3b 651fc5d c90075d 6ff3d3b 306d21c 0d5e69a 9e063bf 306d21c 9e063bf 0d5e69a 6ff3d3b 306d21c 651fc5d 0d5e69a 651fc5d 0d5e69a 651fc5d 306d21c 651fc5d 324f554 651fc5d 306d21c 651fc5d 0d5e69a 651fc5d 306d21c 651fc5d 6ff3d3b 306d21c c90075d 306d21c 0d5e69a 306d21c 0d5e69a 306d21c 651fc5d 6ff3d3b 651fc5d 6ff3d3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import logging
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
import httpx
import json
import os
from helper import create_jwt
# 设置日志
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
app = FastAPI()
LLM_API_URL = "https://llm.zed.dev/completion?"
@app.post('/ai/v1/chat/completions')
async def chat(request: Request):
logger.debug("Received request")
# 获取客户端的 authorization 头
auth_header = request.headers.get('authorization')
if not auth_header or not auth_header.startswith('Bearer '):
raise HTTPException(status_code=401, detail="Invalid authorization header")
# 提取 github_username 和 user_id
try:
_, auth_data = auth_header.split('Bearer ', 1)
github_username, user_id_str = auth_data.split(',')
user_id = int(user_id_str)
except ValueError:
raise HTTPException(status_code=401, detail="Invalid authorization format")
# 生成 JWT token
jwt_token = create_jwt(github_username, user_id)
logger.debug(f"Generated JWT token: {jwt_token}")
# 获取请求 payload
payload = await request.json()
logger.debug(f"Received payload: {payload}")
# 获取模型,默认为 "claude-3-5-sonnet-20240620"
model = payload.get('model', 'claude-3-5-sonnet-20240620')
# 准备 LLM API 请求
llm_payload = {
"provider": "anthropic",
"model": model,
"provider_request": {
"model": model,
"max_tokens": payload.get('max_tokens', 8192),
"temperature": payload.get('temperature', 0),
"top_p": payload.get('top_p', 0.7),
"messages": payload['messages'],
"system": ""
}
}
logger.debug(f"LLM payload: {llm_payload}")
headers = {
'accept': '*/*',
'content-type': 'application/json',
'authorization': f'Bearer {jwt_token}',
'user-agent': 'Zed/0.149.3 (macos; aarch64)'
}
logger.debug(f"Request headers: {headers}")
# 获取代理设置
proxy = os.environ.get('HTTP_PROXY', None)
proxies = {'http': proxy, 'https': proxy} if proxy else None
logger.debug(f"Using proxies: {proxies}")
async def generate():
async with httpx.AsyncClient(proxies=proxies) as client:
try:
async with client.stream('POST', LLM_API_URL, headers=headers, json=llm_payload) as response:
logger.debug(f"LLM API response status: {response.status_code}")
logger.debug(f"LLM API response headers: {response.headers}")
if response.status_code != 200:
error_content = await response.aread()
logger.error(f"LLM API error response: {error_content}")
yield f"Error: {response.status_code} - {error_content.decode()}"
else:
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
logger.error(f"Error during LLM API request: {str(e)}")
yield f"Error: {str(e)}"
return StreamingResponse(generate(), media_type='application/octet-stream')
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|