File size: 4,394 Bytes
fed49ba
 
651fc5d
75615ce
ae87889
d328605
0d5e69a
651fc5d
fed49ba
306d21c
d328605
ae87889
 
 
 
 
 
 
d328605
 
3fb88cc
fed49ba
 
 
 
ae87889
 
d328605
ae87889
 
 
d328605
fed49ba
3fb88cc
ae87889
651fc5d
fed49ba
651fc5d
ae87889
651fc5d
fed49ba
 
 
d328605
fed49ba
 
 
 
d328605
fed49ba
d328605
fed49ba
 
 
d328605
ae87889
fed49ba
651fc5d
 
 
 
 
 
 
 
 
 
 
 
ae87889
fed49ba
 
d328605
651fc5d
 
fed49ba
651fc5d
 
fed49ba
651fc5d
 
ae87889
651fc5d
fed49ba
651fc5d
 
ae87889
651fc5d
3fb88cc
d328605
ae87889
3fb88cc
d328605
ae87889
d328605
 
ae87889
 
d328605
 
3fb88cc
d328605
ae87889
d328605
 
 
 
fed49ba
ae87889
fed49ba
 
88f8e77
3fb88cc
88f8e77
651fc5d
3fb88cc
 
 
651fc5d
3fb88cc
ae87889
3fb88cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from flask import Flask, request, Response
import requests
import json
import os
import sys
import logging
from helper import create_jwt

app = Flask(__name__)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

@app.route('/chat/completions', methods=['POST'])
async def chat():
    """
    Handle chat completion requests.
    """
    logger.info("Received chat completion request")
    
    # 记录请求信息
    logger.info(f"Request method: {request.method}")
    logger.info(f"Request URL: {request.url}")
    logger.info(f"Request headers: {dict(request.headers)}")
    
    # Get the payload from the request
    payload = await request.get_json()
    logger.info(f"Request payload: {json.dumps(payload, indent=2)}")

    # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
    model = payload.get('model', 'claude-3-5-sonnet-20240620')
    logger.info(f"Using model: {model}")

    # Extract GitHub username and Zed user ID from Authorization header
    auth_header = request.headers.get('Authorization')
    if not auth_header or not auth_header.startswith('Bearer '):
        logger.error("Invalid Authorization header")
        return Response('Invalid Authorization header', status=401)

    try:
        github_username, zed_user_id = auth_header[7:].split(',')
        logger.info(f"GitHub username: {github_username}, Zed user ID: {zed_user_id}")
    except ValueError:
        logger.error("Invalid Authorization header format")
        return Response('Invalid Authorization header format', status=401)

    # Prepare the request for the LLM API
    url = "https://llm.zed.dev/completion"
    logger.info(f"LLM API URL: {url}")
    
    llm_payload = {
        "provider": "anthropic",
        "model": model,
        "provider_request": {
            "model": model,
            "max_tokens": payload.get('max_tokens', 8192),
            "temperature": payload.get('temperature', 0),
            "top_p": payload.get('top_p', 0.7),
            "messages": payload['messages'],
            "system": ""
        }
    }
    logger.info(f"LLM API payload: {json.dumps(llm_payload, indent=2)}")

    jwt = create_jwt(github_username, int(zed_user_id))
    logger.info(f"Generated JWT token: {jwt}")

    headers = {
        'Host': 'llm.zed.dev',
        'accept': '*/*',
        'content-type': 'application/json',
        'authorization': f'Bearer {jwt}',
        'user-agent': 'Zed/0.149.3 (macos; aarch64)'
    }
    logger.info(f"Request headers: {json.dumps(headers, indent=2)}")

    # Get proxy from environment variable
    proxy = os.environ.get('HTTP_PROXY', None)
    proxies = {'http': proxy, 'https': proxy} if proxy else None
    logger.info(f"Using proxy: {proxy}")

    async def generate():
        try:
            logger.info("Sending request to LLM API")
            async with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies, allow_redirects=True) as response:
                logger.info(f"LLM API response status: {response.status_code}")
                logger.info(f"LLM API response headers: {dict(response.headers)}")
                
                if response.status_code == 301:
                    new_location = response.headers.get('Location')
                    logger.warning(f"Received 301 redirect. New location: {new_location}")
                    # 如果需要,可以在这里处理重定向
                
                async for chunk in response.iter_content(chunk_size=1024):
                    if chunk:
                        logger.debug(f"Received chunk of size: {len(chunk)} bytes")
                        yield chunk
        except Exception as e:
            logger.error(f"Error during API request: {str(e)}")
            yield str(e).encode()

    logger.info("Returning streaming response")
    return Response(generate(), content_type='application/octet-stream')

@app.route('/', methods=['GET'])
async def home():
    return "Welcome to the Chat Completion API", 200

# 创建 ASGI 应用
asgi_app = app.asgi_app

if __name__ == '__main__':
    import uvicorn
    logger.info("Starting the application")
    uvicorn.run("app:asgi_app", host="0.0.0.0", port=8000, log_level="info")