File size: 3,402 Bytes
306d21c
0d5e69a
6ff3d3b
 
651fc5d
75615ce
324f554
0d5e69a
651fc5d
306d21c
 
 
 
6ff3d3b
651fc5d
c90075d
 
6ff3d3b
 
306d21c
 
0d5e69a
 
 
 
 
 
 
 
 
 
 
 
 
 
9e063bf
306d21c
9e063bf
0d5e69a
6ff3d3b
306d21c
651fc5d
0d5e69a
651fc5d
 
0d5e69a
651fc5d
 
 
 
 
 
 
 
 
 
 
 
306d21c
651fc5d
 
 
 
324f554
651fc5d
 
306d21c
651fc5d
0d5e69a
651fc5d
 
306d21c
651fc5d
 
6ff3d3b
306d21c
c90075d
306d21c
 
 
0d5e69a
306d21c
0d5e69a
306d21c
 
 
 
 
 
651fc5d
6ff3d3b
651fc5d
 
 
6ff3d3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import logging
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
import httpx
import json
import os

from helper import create_jwt

# 设置日志
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = FastAPI()

LLM_API_URL = "https://llm.zed.dev/completion?"

@app.post('/ai/v1/chat/completions')
async def chat(request: Request):
    logger.debug("Received request")
    
    # 获取客户端的 authorization 头
    auth_header = request.headers.get('authorization')
    if not auth_header or not auth_header.startswith('Bearer '):
        raise HTTPException(status_code=401, detail="Invalid authorization header")
    
    # 提取 github_username 和 user_id
    try:
        _, auth_data = auth_header.split('Bearer ', 1)
        github_username, user_id_str = auth_data.split(',')
        user_id = int(user_id_str)
    except ValueError:
        raise HTTPException(status_code=401, detail="Invalid authorization format")
    
    # 生成 JWT token
    jwt_token = create_jwt(github_username, user_id)
    logger.debug(f"Generated JWT token: {jwt_token}")

    # 获取请求 payload
    payload = await request.json()
    logger.debug(f"Received payload: {payload}")

    # 获取模型,默认为 "claude-3-5-sonnet-20240620"
    model = payload.get('model', 'claude-3-5-sonnet-20240620')

    # 准备 LLM API 请求
    llm_payload = {
        "provider": "anthropic",
        "model": model,
        "provider_request": {
            "model": model,
            "max_tokens": payload.get('max_tokens', 8192),
            "temperature": payload.get('temperature', 0),
            "top_p": payload.get('top_p', 0.7),
            "messages": payload['messages'],
            "system": ""
        }
    }
    logger.debug(f"LLM payload: {llm_payload}")

    headers = {
        'accept': '*/*',
        'content-type': 'application/json',
        'authorization': f'Bearer {jwt_token}',
        'user-agent': 'Zed/0.149.3 (macos; aarch64)'
    }
    logger.debug(f"Request headers: {headers}")

    # 获取代理设置
    proxy = os.environ.get('HTTP_PROXY', None)
    proxies = {'http': proxy, 'https': proxy} if proxy else None
    logger.debug(f"Using proxies: {proxies}")

    async def generate():
        async with httpx.AsyncClient(proxies=proxies) as client:
            try:
                async with client.stream('POST', LLM_API_URL, headers=headers, json=llm_payload) as response:
                    logger.debug(f"LLM API response status: {response.status_code}")
                    logger.debug(f"LLM API response headers: {response.headers}")
                    if response.status_code != 200:
                        error_content = await response.aread()
                        logger.error(f"LLM API error response: {error_content}")
                        yield f"Error: {response.status_code} - {error_content.decode()}"
                    else:
                        async for chunk in response.aiter_bytes():
                            yield chunk
            except Exception as e:
                logger.error(f"Error during LLM API request: {str(e)}")
                yield f"Error: {str(e)}"

    return StreamingResponse(generate(), media_type='application/octet-stream')

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)