|
import logging |
|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
import httpx |
|
import json |
|
import os |
|
|
|
from helper import create_jwt |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
@app.post('/ai/v1/chat/completions') |
|
async def chat(request: Request): |
|
logger.debug("Received request") |
|
|
|
|
|
auth_header = request.headers.get('authorization') |
|
if not auth_header or not auth_header.startswith('Bearer '): |
|
raise HTTPException(status_code=401, detail="Invalid authorization header") |
|
|
|
|
|
try: |
|
_, auth_data = auth_header.split('Bearer ', 1) |
|
github_username, user_id_str = auth_data.split(',') |
|
user_id = int(user_id_str) |
|
except ValueError: |
|
raise HTTPException(status_code=401, detail="Invalid authorization format") |
|
|
|
|
|
jwt_token = create_jwt(github_username, user_id) |
|
logger.debug(f"Generated JWT token: {jwt_token}") |
|
|
|
|
|
payload = await request.json() |
|
logger.debug(f"Received payload: {payload}") |
|
|
|
|
|
model = payload.get('model', 'claude-3-5-sonnet-20240620') |
|
|
|
|
|
llm_payload = { |
|
"provider": "anthropic", |
|
"model": model, |
|
"provider_request": { |
|
"model": model, |
|
"max_tokens": payload.get('max_tokens', 8192), |
|
"temperature": payload.get('temperature', 0), |
|
"top_p": payload.get('top_p', 0.7), |
|
"messages": payload['messages'], |
|
"system": "" |
|
} |
|
} |
|
logger.debug(f"LLM payload: {llm_payload}") |
|
|
|
headers = { |
|
'accept': '*/*', |
|
'content-type': 'application/json', |
|
'authorization': f'Bearer {jwt_token}', |
|
'user-agent': 'Zed/0.149.3 (macos; aarch64)' |
|
} |
|
logger.debug(f"Request headers: {headers}") |
|
|
|
|
|
proxy = os.environ.get('HTTP_PROXY', None) |
|
proxies = {'http': proxy, 'https': proxy} if proxy else None |
|
logger.debug(f"Using proxies: {proxies}") |
|
|
|
async def generate(): |
|
async with httpx.AsyncClient(proxies=proxies) as client: |
|
try: |
|
async with client.stream('POST', 'https://llm.zed.dev/completion?', headers=headers, json=llm_payload) as response: |
|
logger.debug(f"LLM API response status: {response.status_code}") |
|
logger.debug(f"LLM API response headers: {response.headers}") |
|
if response.status_code != 200: |
|
error_content = await response.aread() |
|
logger.error(f"LLM API error response: {error_content}") |
|
yield f"Error: {response.status_code} - {error_content.decode()}" |
|
else: |
|
async for chunk in response.aiter_bytes(): |
|
yield chunk |
|
except Exception as e: |
|
logger.error(f"Error during LLM API request: {str(e)}") |
|
yield f"Error: {str(e)}" |
|
|
|
return StreamingResponse(generate(), media_type='application/octet-stream') |
|
|
|
if __name__ == '__main__': |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|