File size: 2,929 Bytes
fed49ba
 
 
651fc5d
75615ce
0d5e69a
651fc5d
fed49ba
306d21c
fed49ba
 
 
 
651fc5d
fed49ba
 
 
c90075d
fed49ba
 
9e063bf
fed49ba
 
 
 
 
 
 
651fc5d
fed49ba
651fc5d
 
fed49ba
 
 
 
 
 
 
 
 
 
 
 
 
651fc5d
 
 
 
 
 
 
 
 
 
 
 
fed49ba
 
651fc5d
 
fed49ba
651fc5d
 
fed49ba
651fc5d
 
 
fed49ba
651fc5d
 
 
 
fed49ba
 
 
 
 
 
 
 
 
651fc5d
 
 
fed49ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from asgiref.wsgi import WsgiToAsgi
from flask import Flask, request, Response
import requests
import json
import os
from helper import create_jwt

app = Flask(__name__)

@app.route('/ai/v1/chat/completions', methods=['POST'])
async def chat():
    """
    Handle chat completion requests.

    This function processes incoming POST requests to the '/chat/completions' endpoint.
    It prepares the payload for the LLM API, generates a JWT for authentication,
    and streams the response from the LLM API back to the client.

    Returns:
        Response: A streaming response containing the LLM API's output.

    Note:
        - The function uses environment variables for proxy configuration.
        - It extracts GitHub username and Zed user ID from the Authorization header.
        - The LLM model defaults to "claude-3-5-sonnet-20240620" if not specified.
    """
    # Get the payload from the request
    payload = request.json

    # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
    model = payload.get('model', 'claude-3-5-sonnet-20240620')

    # Extract GitHub username and Zed user ID from Authorization header
    auth_header = request.headers.get('Authorization')
    if not auth_header or not auth_header.startswith('Bearer '):
        return Response('Invalid Authorization header', status=401)

    try:
        github_username, zed_user_id = auth_header[7:].split(',')
    except ValueError:
        return Response('Invalid Authorization header format', status=401)

    # Prepare the request for the LLM API
    url = "https://llm.zed.dev/completion?"
    
    llm_payload = {
        "provider": "anthropic",
        "model": model,
        "provider_request": {
            "model": model,
            "max_tokens": payload.get('max_tokens', 8192),
            "temperature": payload.get('temperature', 0),
            "top_p": payload.get('top_p', 0.7),
            "messages": payload['messages'],
            "system": ""
        }
    }

    jwt = create_jwt(github_username, int(zed_user_id))

    headers = {
        'Host': 'llm.zed.dev',
        'accept': '*/*',
        'content-type': 'application/json',
        'authorization': f'Bearer {jwt}',
        'user-agent': 'Zed/0.149.3 (macos; aarch64)'
    }

    # Get proxy from environment variable
    proxy = os.environ.get('HTTP_PROXY', None)
    proxies = {'http': proxy, 'https': proxy} if proxy else None

    async def generate():
        with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    yield chunk

    return Response(generate(), content_type='application/octet-stream')

# Convert the Flask app to an ASGI app
asgi_app = WsgiToAsgi(app)

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(asgi_app, host="0.0.0.0", port=8000)