|
from asgiref.wsgi import WsgiToAsgi |
|
from flask import Flask, request, Response, redirect |
|
import requests |
|
import json |
|
import random |
|
import os |
|
import jwt |
|
import uuid |
|
from datetime import datetime, timedelta |
|
import string |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
app = Flask(__name__) |
|
|
|
def get_github_username_zed_userid_list(): |
|
""" |
|
Get GitHub username and Zed user ID from the USER environment variable. |
|
|
|
Returns: |
|
list: A list containing a single tuple with the GitHub username (str) and Zed user ID (str). |
|
""" |
|
user_env = os.environ.get('USER', 'default_user,123456') |
|
try: |
|
username, user_id = user_env.split(',') |
|
logger.info(f"Using GitHub username: {username} and Zed user ID: {user_id}") |
|
return [(username.strip(), user_id.strip())] |
|
except ValueError: |
|
logger.warning("Invalid format in USER environment variable. Using default values.") |
|
return [("default_user", "123456")] |
|
|
|
def create_jwt(github_user_login: str, user_id: int) -> str: |
|
""" |
|
Create a JSON Web Token (JWT) for a given GitHub user. |
|
""" |
|
LLM_TOKEN_LIFETIME = timedelta(hours=1) |
|
now = datetime.utcnow() |
|
|
|
payload = { |
|
"iat": int(now.timestamp()), |
|
"exp": int((now + LLM_TOKEN_LIFETIME).timestamp()), |
|
"jti": str(uuid.uuid4()), |
|
"userId": user_id, |
|
"githubUserLogin": github_user_login, |
|
"isStaff": False, |
|
"hasLlmClosedBetaFeatureFlag": False, |
|
"plan": "Free" |
|
} |
|
|
|
logger.info(f"Creating JWT for user: {github_user_login}") |
|
return jwt.encode(payload, 'llm-secret', algorithm='HS256') |
|
|
|
@app.before_request |
|
def before_request(): |
|
logger.info(f"Received request: {request.method} {request.url}") |
|
logger.info(f"Request headers: {request.headers}") |
|
if request.data: |
|
logger.info(f"Request body: {request.get_data(as_text=True)}") |
|
|
|
if not request.url.startswith('https') and not request.url.startswith('http://localhost'): |
|
url = request.url.replace('http://', 'https://', 1) |
|
code = 301 |
|
logger.info(f"Redirecting to HTTPS: {url}") |
|
return redirect(url, code=code) |
|
|
|
@app.route('/') |
|
def root(): |
|
logger.info("Received request to root path") |
|
return "Welcome to the chat completion API", 200 |
|
|
|
@app.route('/chat/completions', methods=['POST']) |
|
async def chat(): |
|
""" |
|
Handle chat completion requests. |
|
""" |
|
logger.info("Processing chat completion request") |
|
|
|
try: |
|
|
|
payload = request.json |
|
logger.info(f"Request payload: {payload}") |
|
|
|
|
|
model = payload.get('model', 'claude-3-5-sonnet-20240620') |
|
logger.info(f"Using model: {model}") |
|
|
|
|
|
url = "https://llm.zed.dev/completion?" |
|
|
|
llm_payload = { |
|
"provider": "anthropic", |
|
"model": model, |
|
"provider_request": { |
|
"model": model, |
|
"max_tokens": payload.get('max_tokens', 8192), |
|
"temperature": payload.get('temperature', 0), |
|
"top_p": payload.get('top_p', 0.7), |
|
"messages": payload['messages'], |
|
"system": "" |
|
} |
|
} |
|
|
|
github_username, zed_user_id = get_github_username_zed_userid_list()[0] |
|
jwt_token = create_jwt(github_username, zed_user_id) |
|
|
|
headers = { |
|
'Host': 'llm.zed.dev', |
|
'accept': '*/*', |
|
'content-type': 'application/json', |
|
'authorization': f'Bearer {jwt_token}', |
|
'user-agent': 'Zed/0.149.3 (macos; aarch64)' |
|
} |
|
|
|
|
|
proxy = os.environ.get('HTTP_PROXY', None) |
|
proxies = {'http': proxy, 'https': proxy} if proxy else None |
|
logger.info(f"Using proxy: {proxy}") |
|
|
|
async def generate(): |
|
logger.info("Starting to stream response") |
|
try: |
|
with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response: |
|
logger.info(f"LLM API response status: {response.status_code}") |
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
yield chunk |
|
except requests.RequestException as e: |
|
logger.error(f"Error during LLM API request: {e}") |
|
yield json.dumps({"error": "Internal server error"}).encode('utf-8') |
|
logger.info("Finished streaming response") |
|
|
|
return Response(generate(), content_type='application/octet-stream') |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing request: {e}") |
|
return json.dumps({"error": "Internal server error"}), 500, {'Content-Type': 'application/json'} |
|
|
|
|
|
asgi_app = WsgiToAsgi(app) |
|
|
|
if __name__ == '__main__': |
|
import uvicorn |
|
logger.info("Starting the application") |
|
uvicorn.run(asgi_app, host="0.0.0.0", port=8000) |
|
|