File size: 5,911 Bytes
4605201
fcbf0b1
fed49ba
651fc5d
b05f070
 
 
 
dedefcd
 
 
 
 
651fc5d
fed49ba
306d21c
4605201
b05f070
 
 
 
4605201
b05f070
 
 
 
dedefcd
4605201
b05f070
dedefcd
4605201
b05f070
4605201
b05f070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2cd73a
 
 
 
b05f070
b074ef9
 
 
4605201
b074ef9
 
 
4605201
 
b074ef9
 
4605201
 
 
 
b074ef9
 
 
 
4605201
 
 
 
 
 
 
b074ef9
acd9ed4
4605201
fed49ba
 
 
b074ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651fc5d
b074ef9
 
 
 
 
 
 
 
651fc5d
fed49ba
4605201
b074ef9
651fc5d
b074ef9
 
 
 
 
 
 
651fc5d
b074ef9
 
 
 
 
4605201
b074ef9
 
c2cd73a
 
 
 
 
b074ef9
 
c2cd73a
 
b074ef9
 
c2cd73a
b074ef9
 
 
 
 
 
 
 
 
 
4605201
fed49ba
fcbf0b1
 
651fc5d
 
fcbf0b1
d160f6b
4605201
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from flask import Flask, request, Response, redirect, jsonify
from asgiref.wsgi import WsgiToAsgi
import requests
import json
import os
import jwt
import uuid
from datetime import datetime, timedelta
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

app = Flask(__name__)

def get_github_username_zed_userid():
    """
    Get GitHub username and Zed user ID from the USER environment variable.

    Returns:
        tuple: A tuple containing the GitHub username (str) and Zed user ID (str).
    """
    user_env = os.environ.get('USER', 'default_user,123456')
    try:
        username, user_id = user_env.split(',')
        logger.info(f"Using GitHub username: {username} and Zed user ID: {user_id}")
        return username.strip(), user_id.strip()
    except ValueError:
        logger.warning("Invalid format in USER environment variable. Using default values.")
        return "default_user", "123456"

def create_jwt(github_user_login: str, user_id: str) -> str:
    """
    Create a JSON Web Token (JWT) for a given GitHub user.
    """
    LLM_TOKEN_LIFETIME = timedelta(hours=1)
    now = datetime.utcnow()

    payload = {
        "iat": int(now.timestamp()),
        "exp": int((now + LLM_TOKEN_LIFETIME).timestamp()),
        "jti": str(uuid.uuid4()),
        "userId": user_id,
        "githubUserLogin": github_user_login,
        "isStaff": False,
        "hasLlmClosedBetaFeatureFlag": False,
        "plan": "Free"
    }

    token = jwt.encode(payload, 'llm-secret', algorithm='HS256')
    logger.info(f"Created JWT for user: {github_user_login}")
    logger.info(f"JWT Token: {token}")
    return token

@app.before_request
def before_request():
    logger.info(f"Received request: {request.method} {request.url}")
    logger.info(f"Request headers: {dict(request.headers)}")
    if request.data:
        logger.info(f"Request body: {request.get_data(as_text=True)}")

    # 检查 X-Forwarded-Proto 头,而不是 URL
    if request.headers.get('X-Forwarded-Proto') == 'http':
        url = request.url.replace('http://', 'https://', 1)
        logger.info(f"Redirecting to HTTPS: {url}")
        return redirect(url, code=301)
    
    # 不需要重定向,继续处理请求
    return None

@app.route('/')
def root():
    logger.info("Received request to root path")
    return jsonify({
        "status": "ok",
        "message": "Welcome to the chat completion API",
        "endpoints": {
            "/chat/completions": "POST - Send chat completion requests"
        }
    }), 200

@app.route('/ai/v1/chat/completions', methods=['POST'])
def chat():
    """
    Handle chat completion requests.
    """
    logger.info("Processing chat completion request")

    try:
        # Get the payload from the request
        payload = request.json
        logger.info(f"Request payload: {payload}")

        # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
        model = payload.get('model', 'claude-3-5-sonnet-20240620')
        logger.info(f"Using model: {model}")

        # Prepare the request for the LLM API
        url = "https://llm.zed.dev/completion?"
        
        llm_payload = {
            "provider": "anthropic",
            "model": model,
            "provider_request": {
                "model": model,
                "max_tokens": payload.get('max_tokens', 8192),
                "temperature": payload.get('temperature', 0),
                "top_p": payload.get('top_p', 0.7),
                "messages": payload['messages'],
                "system": ""
            }
        }

        github_username, zed_user_id = get_github_username_zed_userid()
        jwt_token = create_jwt(github_username, zed_user_id)

        headers = {
            'Host': 'llm.zed.dev',
            'accept': '*/*',
            'content-type': 'application/json',
            'authorization': f'Bearer {jwt_token}',
            'user-agent': 'Zed/0.149.3 (macos; aarch64)'
        }

        # Get proxy from environment variable
        proxy = os.environ.get('HTTP_PROXY', None)
        proxies = {'http': proxy, 'https': proxy} if proxy else None
        logger.info(f"Using proxy: {proxy}")

        def generate():
            logger.info("Starting to stream response")
            try:
                logger.info(f"Sending request to LLM API:")
                logger.info(f"URL: {url}")
                logger.info(f"Headers: {headers}")
                logger.info(f"Payload: {json.dumps(llm_payload, indent=2)}")
                
                with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response:
                    logger.info(f"LLM API response status: {response.status_code}")
                    logger.info(f"LLM API response headers: {dict(response.headers)}")
                    
                    for chunk in response.iter_content(chunk_size=1024):
                        if chunk:
                            logger.info(f"Received chunk: {chunk.decode('utf-8')}")
                            yield chunk
            except requests.RequestException as e:
                logger.error(f"Error during LLM API request: {e}")
                yield json.dumps({"error": "Internal server error"}).encode('utf-8')
            logger.info("Finished streaming response")

        return Response(generate(), content_type='application/octet-stream')

    except Exception as e:
        logger.error(f"Error processing request: {e}")
        return jsonify({"error": "Internal server error"}), 500

# Convert the Flask app to an ASGI app
asgi_app = WsgiToAsgi(app)

if __name__ == '__main__':
    import uvicorn
    port = int(os.environ.get("PORT", 8000))
    logger.info(f"Starting the application on port {port}")
    uvicorn.run(asgi_app, host="0.0.0.0", port=port)