File size: 4,413 Bytes
fcbf0b1
fed49ba
 
651fc5d
fcbf0b1
b05f070
 
 
 
 
651fc5d
fed49ba
306d21c
b05f070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fb88cc
fcbf0b1
fed49ba
 
fcbf0b1
 
 
 
 
 
 
 
 
b05f070
fcbf0b1
fed49ba
 
2c7972e
651fc5d
fed49ba
651fc5d
fed49ba
 
fcbf0b1
fed49ba
651fc5d
 
 
 
 
 
 
 
 
 
 
 
fed49ba
b05f070
fcbf0b1
651fc5d
 
fed49ba
651fc5d
 
fed49ba
651fc5d
 
 
fed49ba
651fc5d
 
 
fcbf0b1
 
 
 
 
fed49ba
 
 
fcbf0b1
 
651fc5d
 
fcbf0b1
b05f070
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from asgiref.wsgi import WsgiToAsgi
from flask import Flask, request, Response
import requests
import json
import random
import os
import jwt
import uuid
from datetime import datetime, timedelta
import string

app = Flask(__name__)

def get_github_username_zed_userid_list():
    """
    Get GitHub username and Zed user ID from the USER environment variable.

    Returns:
        list: A list containing a single tuple with the GitHub username (str) and Zed user ID (str).
    """
    user_env = os.environ.get('USER', 'default_user,123456')
    try:
        username, user_id = user_env.split(',')
        return [(username.strip(), user_id.strip())]
    except ValueError:
        print("Warning: Invalid format in USER environment variable. Using default values.")
        return [("default_user", "123456")]

def create_jwt(github_user_login: str, user_id: int) -> str:
    """
    Create a JSON Web Token (JWT) for a given GitHub user.

    Args:
        github_user_login (str): The GitHub username of the user.
        user_id (int): The user's ID.

    Returns:
        str: A JWT encoded string containing user information and authentication details.

    Note:
        The token has a lifetime of 1 hour and includes the following claims:
        - iat: Issued at time
        - exp: Expiration time
        - jti: Unique token identifier
        - userId: User's ID
        - githubUserLogin: GitHub username
        - isStaff: Boolean indicating staff status (default: False)
        - hasLlmClosedBetaFeatureFlag: Boolean for LLM closed beta feature (default: False)
        - plan: User's plan (default: "Free")
    """
    LLM_TOKEN_LIFETIME = timedelta(hours=1)
    now = datetime.utcnow()

    payload = {
        "iat": int(now.timestamp()),
        "exp": int((now + LLM_TOKEN_LIFETIME).timestamp()),
        "jti": str(uuid.uuid4()),
        "userId": user_id,
        "githubUserLogin": github_user_login,
        "isStaff": False,
        "hasLlmClosedBetaFeatureFlag": False,
        "plan": "Free"
    }

    return jwt.encode(payload, 'llm-secret', algorithm='HS256')

@app.route('/chat/completions', methods=['POST'])
async def chat():
    """
    Handle chat completion requests.

    This function processes incoming POST requests to the '/chat/completions' endpoint.
    It prepares the payload for the LLM API, generates a JWT for authentication,
    and streams the response from the LLM API back to the client.

    Returns:
        Response: A streaming response containing the LLM API's output.

    Note:
        - The function uses environment variables for proxy configuration and user data.
        - The LLM model defaults to "claude-3-5-sonnet-20240620" if not specified.
    """
    # Get the payload from the request
    payload = request.json

    # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
    model = payload.get('model', 'claude-3-5-sonnet-20240620')

    # Prepare the request for the LLM API
    url = "https://llm.zed.dev/completion?"
    
    llm_payload = {
        "provider": "anthropic",
        "model": model,
        "provider_request": {
            "model": model,
            "max_tokens": payload.get('max_tokens', 8192),
            "temperature": payload.get('temperature', 0),
            "top_p": payload.get('top_p', 0.7),
            "messages": payload['messages'],
            "system": ""
        }
    }

    github_username, zed_user_id = get_github_username_zed_userid_list()[0]
    jwt = create_jwt(github_username, zed_user_id)

    headers = {
        'Host': 'llm.zed.dev',
        'accept': '*/*',
        'content-type': 'application/json',
        'authorization': f'Bearer {jwt}',
        'user-agent': 'Zed/0.149.3 (macos; aarch64)'
    }

    # Get proxy from environment variable
    proxy = os.environ.get('HTTP_PROXY', None)
    proxies = {'http': proxy, 'https': proxy} if proxy else None

    async def generate():
        with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    yield chunk

    return Response(generate(), content_type='application/octet-stream')

# Convert the Flask app to an ASGI app
asgi_app = WsgiToAsgi(app)

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(asgi_app, host="0.0.0.0", port=8000)