zed2api / app.py
smgc's picture
Update app.py
0d5e69a verified
raw
history blame
3.46 kB
import logging
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
import httpx
import json
import os
from helper import create_jwt
# 设置日志
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
app = FastAPI()
@app.post('/ai/v1/chat/completions')
async def chat(request: Request):
logger.debug("Received request")
# 获取客户端的 authorization 头
auth_header = request.headers.get('authorization')
if not auth_header or not auth_header.startswith('Bearer '):
raise HTTPException(status_code=401, detail="Invalid authorization header")
# 提取 github_username 和 user_id
try:
_, auth_data = auth_header.split('Bearer ', 1)
github_username, user_id_str = auth_data.split(',')
user_id = int(user_id_str)
except ValueError:
raise HTTPException(status_code=401, detail="Invalid authorization format")
# 生成 JWT token
jwt_token = create_jwt(github_username, user_id)
logger.debug(f"Generated JWT token: {jwt_token}")
# 获取请求 payload
payload = await request.json()
logger.debug(f"Received payload: {payload}")
# 获取模型,默认为 "claude-3-5-sonnet-20240620"
model = payload.get('model', 'claude-3-5-sonnet-20240620')
# 准备 LLM API 请求
url = "https://llm.zed.dev/completion?"
llm_payload = {
"provider": "anthropic",
"model": model,
"provider_request": {
"model": model,
"max_tokens": payload.get('max_tokens', 8192),
"temperature": payload.get('temperature', 0),
"top_p": payload.get('top_p', 0.7),
"messages": payload['messages'],
"system": ""
}
}
logger.debug(f"LLM payload: {llm_payload}")
headers = {
'Host': 'llm.zed.dev',
'accept': '*/*',
'content-type': 'application/json',
'authorization': f'Bearer {jwt_token}', # 使用新生成的 JWT token
'user-agent': 'Zed/0.149.3 (macos; aarch64)'
}
logger.debug(f"Request headers: {headers}")
# 获取代理设置
proxy = os.environ.get('HTTP_PROXY', None)
proxies = {'http': proxy, 'https': proxy} if proxy else None
logger.debug(f"Using proxies: {proxies}")
async def generate():
async with httpx.AsyncClient(proxies=proxies) as client:
try:
async with client.stream('POST', url, headers=headers, json=llm_payload) as response:
logger.debug(f"LLM API response status: {response.status_code}")
logger.debug(f"LLM API response headers: {response.headers}")
if response.status_code != 200:
error_content = await response.aread()
logger.error(f"LLM API error response: {error_content}")
yield f"Error: {response.status_code} - {error_content.decode()}"
else:
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
logger.error(f"Error during LLM API request: {str(e)}")
yield f"Error: {str(e)}"
return StreamingResponse(generate(), media_type='application/octet-stream')
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)