File size: 3,077 Bytes
306d21c
487c5c6
6ff3d3b
 
651fc5d
75615ce
651fc5d
 
306d21c
 
 
 
6ff3d3b
651fc5d
6ff3d3b
 
306d21c
 
9e063bf
 
 
306d21c
9e063bf
651fc5d
6ff3d3b
306d21c
651fc5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306d21c
651fc5d
 
 
 
 
306d21c
651fc5d
 
306d21c
651fc5d
 
 
 
306d21c
651fc5d
 
6ff3d3b
306d21c
 
 
 
 
 
 
 
 
 
 
 
 
 
651fc5d
6ff3d3b
651fc5d
306d21c
 
 
 
651fc5d
 
6ff3d3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import logging
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import httpx
import json
import os
from helper import create_jwt, generate_random_tuple

# 设置日志
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = FastAPI()

@app.post('/ai/v1/chat/completions')
async def chat(request: Request):
    logger.debug("Received request")
    
    # Generate JWT token
    github_username, user_id = generate_random_tuple()
    jwt_token = create_jwt(github_username, user_id)
    logger.debug(f"Generated JWT token: {jwt_token}")

    # Get the payload from the request
    payload = await request.json()
    logger.debug(f"Received payload: {payload}")

    # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
    model = payload.get('model', 'claude-3-5-sonnet-20240620')

    # Prepare the request for the LLM API
    url = "https://llm.zed.dev/completion?"
    
    llm_payload = {
        "provider": "anthropic",
        "model": model,
        "provider_request": {
            "model": model,
            "max_tokens": payload.get('max_tokens', 8192),
            "temperature": payload.get('temperature', 0),
            "top_p": payload.get('top_p', 0.7),
            "messages": payload['messages'],
            "system": ""
        }
    }
    logger.debug(f"LLM payload: {llm_payload}")

    headers = {
        'Host': 'llm.zed.dev',
        'accept': '*/*',
        'content-type': 'application/json',
        'authorization': f'Bearer {jwt_token}',
        'user-agent': 'Zed/0.149.3 (macos; aarch64)'
    }
    logger.debug(f"Request headers: {headers}")

    # Get proxy from environment variable
    proxy = os.environ.get('HTTP_PROXY', None)
    proxies = {'http': proxy, 'https': proxy} if proxy else None
    logger.debug(f"Using proxies: {proxies}")

    async def generate():
        async with httpx.AsyncClient(proxies=proxies) as client:
            try:
                async with client.stream('POST', url, headers=headers, json=llm_payload) as response:
                    logger.debug(f"LLM API response status: {response.status_code}")
                    logger.debug(f"LLM API response headers: {response.headers}")
                    if response.status_code != 200:
                        error_content = await response.read()
                        logger.error(f"LLM API error response: {error_content}")
                        yield f"Error: {response.status_code} - {error_content}"
                    else:
                        async for chunk in response.aiter_bytes():
                            yield chunk
            except Exception as e:
                logger.error(f"Error during LLM API request: {str(e)}")
                yield f"Error: {str(e)}"

    return StreamingResponse(generate(), media_type='application/octet-stream')

@app.get("/")
async def root():
    return {"message": "Welcome to the AI Chat Completions API"}

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)