|
from asgiref.wsgi import WsgiToAsgi |
|
from flask import Flask, request, Response |
|
import requests |
|
import json |
|
import os |
|
from helper import create_jwt |
|
|
|
app = Flask(__name__) |
|
|
|
@app.route('/ai/v1/chat/completions', methods=['POST']) |
|
async def chat(): |
|
""" |
|
Handle chat completion requests. |
|
|
|
This function processes incoming POST requests to the '/chat/completions' endpoint. |
|
It prepares the payload for the LLM API, generates a JWT for authentication, |
|
and streams the response from the LLM API back to the client. |
|
|
|
Returns: |
|
Response: A streaming response containing the LLM API's output. |
|
|
|
Note: |
|
- The function uses environment variables for proxy configuration. |
|
- It extracts GitHub username and Zed user ID from the Authorization header. |
|
- The LLM model defaults to "claude-3-5-sonnet-20240620" if not specified. |
|
""" |
|
|
|
payload = request.json |
|
|
|
|
|
model = payload.get('model', 'claude-3-5-sonnet-20240620') |
|
|
|
|
|
auth_header = request.headers.get('Authorization') |
|
if not auth_header or not auth_header.startswith('Bearer '): |
|
return Response('Invalid Authorization header', status=401) |
|
|
|
try: |
|
github_username, zed_user_id = auth_header[7:].split(',') |
|
except ValueError: |
|
return Response('Invalid Authorization header format', status=401) |
|
|
|
|
|
url = "https://llm.zed.dev/completion?" |
|
|
|
llm_payload = { |
|
"provider": "anthropic", |
|
"model": model, |
|
"provider_request": { |
|
"model": model, |
|
"max_tokens": payload.get('max_tokens', 8192), |
|
"temperature": payload.get('temperature', 0), |
|
"top_p": payload.get('top_p', 0.7), |
|
"messages": payload['messages'], |
|
"system": "" |
|
} |
|
} |
|
|
|
jwt = create_jwt(github_username, int(zed_user_id)) |
|
|
|
headers = { |
|
'Host': 'llm.zed.dev', |
|
'accept': '*/*', |
|
'content-type': 'application/json', |
|
'authorization': f'Bearer {jwt}', |
|
'user-agent': 'Zed/0.149.3 (macos; aarch64)' |
|
} |
|
|
|
|
|
proxy = os.environ.get('HTTP_PROXY', None) |
|
proxies = {'http': proxy, 'https': proxy} if proxy else None |
|
|
|
async def generate(): |
|
with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response: |
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
yield chunk |
|
|
|
return Response(generate(), content_type='application/octet-stream') |
|
|
|
|
|
asgi_app = WsgiToAsgi(app) |
|
|
|
if __name__ == '__main__': |
|
import uvicorn |
|
uvicorn.run(asgi_app, host="0.0.0.0", port=8000) |
|
|