smgc commited on
Commit
fed49ba
1 Parent(s): c90075d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -1,49 +1,48 @@
1
- import logging
2
- from fastapi import FastAPI, Request, HTTPException
3
- from fastapi.responses import StreamingResponse
4
- import httpx
5
  import json
6
  import os
7
-
8
  from helper import create_jwt
9
 
10
- # 设置日志
11
- logging.basicConfig(level=logging.DEBUG)
12
- logger = logging.getLogger(__name__)
13
 
14
- app = FastAPI()
 
 
 
15
 
16
- LLM_API_URL = "https://llm.zed.dev/completion?"
 
 
17
 
18
- @app.post('/ai/v1/chat/completions')
19
- async def chat(request: Request):
20
- logger.debug("Received request")
21
-
22
- # 获取客户端的 authorization 头
23
- auth_header = request.headers.get('authorization')
24
- if not auth_header or not auth_header.startswith('Bearer '):
25
- raise HTTPException(status_code=401, detail="Invalid authorization header")
26
-
27
- # 提取 github_username 和 user_id
28
- try:
29
- _, auth_data = auth_header.split('Bearer ', 1)
30
- github_username, user_id_str = auth_data.split(',')
31
- user_id = int(user_id_str)
32
- except ValueError:
33
- raise HTTPException(status_code=401, detail="Invalid authorization format")
34
-
35
- # 生成 JWT token
36
- jwt_token = create_jwt(github_username, user_id)
37
- logger.debug(f"Generated JWT token: {jwt_token}")
38
 
39
- # 获取请求 payload
40
- payload = await request.json()
41
- logger.debug(f"Received payload: {payload}")
 
 
 
 
42
 
43
- # 获取模型,默认为 "claude-3-5-sonnet-20240620"
44
  model = payload.get('model', 'claude-3-5-sonnet-20240620')
45
 
46
- # 准备 LLM API 请求
 
 
 
 
 
 
 
 
 
 
 
 
47
  llm_payload = {
48
  "provider": "anthropic",
49
  "model": model,
@@ -56,40 +55,32 @@ async def chat(request: Request):
56
  "system": ""
57
  }
58
  }
59
- logger.debug(f"LLM payload: {llm_payload}")
 
60
 
61
  headers = {
 
62
  'accept': '*/*',
63
  'content-type': 'application/json',
64
- 'authorization': f'Bearer {jwt_token}',
65
  'user-agent': 'Zed/0.149.3 (macos; aarch64)'
66
  }
67
- logger.debug(f"Request headers: {headers}")
68
 
69
- # 获取代理设置
70
  proxy = os.environ.get('HTTP_PROXY', None)
71
  proxies = {'http': proxy, 'https': proxy} if proxy else None
72
- logger.debug(f"Using proxies: {proxies}")
73
 
74
  async def generate():
75
- async with httpx.AsyncClient(proxies=proxies) as client:
76
- try:
77
- async with client.stream('POST', LLM_API_URL, headers=headers, json=llm_payload) as response:
78
- logger.debug(f"LLM API response status: {response.status_code}")
79
- logger.debug(f"LLM API response headers: {response.headers}")
80
- if response.status_code != 200:
81
- error_content = await response.aread()
82
- logger.error(f"LLM API error response: {error_content}")
83
- yield f"Error: {response.status_code} - {error_content.decode()}"
84
- else:
85
- async for chunk in response.aiter_bytes():
86
- yield chunk
87
- except Exception as e:
88
- logger.error(f"Error during LLM API request: {str(e)}")
89
- yield f"Error: {str(e)}"
90
-
91
- return StreamingResponse(generate(), media_type='application/octet-stream')
92
 
93
  if __name__ == '__main__':
94
  import uvicorn
95
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from asgiref.wsgi import WsgiToAsgi
2
+ from flask import Flask, request, Response
3
+ import requests
 
4
  import json
5
  import os
 
6
  from helper import create_jwt
7
 
8
+ app = Flask(__name__)
 
 
9
 
10
+ @app.route('/ai/v1/chat/completions', methods=['POST'])
11
+ async def chat():
12
+ """
13
+ Handle chat completion requests.
14
 
15
+ This function processes incoming POST requests to the '/chat/completions' endpoint.
16
+ It prepares the payload for the LLM API, generates a JWT for authentication,
17
+ and streams the response from the LLM API back to the client.
18
 
19
+ Returns:
20
+ Response: A streaming response containing the LLM API's output.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ Note:
23
+ - The function uses environment variables for proxy configuration.
24
+ - It extracts GitHub username and Zed user ID from the Authorization header.
25
+ - The LLM model defaults to "claude-3-5-sonnet-20240620" if not specified.
26
+ """
27
+ # Get the payload from the request
28
+ payload = request.json
29
 
30
+ # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
31
  model = payload.get('model', 'claude-3-5-sonnet-20240620')
32
 
33
+ # Extract GitHub username and Zed user ID from Authorization header
34
+ auth_header = request.headers.get('Authorization')
35
+ if not auth_header or not auth_header.startswith('Bearer '):
36
+ return Response('Invalid Authorization header', status=401)
37
+
38
+ try:
39
+ github_username, zed_user_id = auth_header[7:].split(',')
40
+ except ValueError:
41
+ return Response('Invalid Authorization header format', status=401)
42
+
43
+ # Prepare the request for the LLM API
44
+ url = "https://llm.zed.dev/completion?"
45
+
46
  llm_payload = {
47
  "provider": "anthropic",
48
  "model": model,
 
55
  "system": ""
56
  }
57
  }
58
+
59
+ jwt = create_jwt(github_username, int(zed_user_id))
60
 
61
  headers = {
62
+ 'Host': 'llm.zed.dev',
63
  'accept': '*/*',
64
  'content-type': 'application/json',
65
+ 'authorization': f'Bearer {jwt}',
66
  'user-agent': 'Zed/0.149.3 (macos; aarch64)'
67
  }
 
68
 
69
+ # Get proxy from environment variable
70
  proxy = os.environ.get('HTTP_PROXY', None)
71
  proxies = {'http': proxy, 'https': proxy} if proxy else None
 
72
 
73
  async def generate():
74
+ with requests.post(url, headers=headers, json=llm_payload, stream=True, proxies=proxies) as response:
75
+ for chunk in response.iter_content(chunk_size=1024):
76
+ if chunk:
77
+ yield chunk
78
+
79
+ return Response(generate(), content_type='application/octet-stream')
80
+
81
+ # Convert the Flask app to an ASGI app
82
+ asgi_app = WsgiToAsgi(app)
 
 
 
 
 
 
 
 
83
 
84
  if __name__ == '__main__':
85
  import uvicorn
86
+ uvicorn.run(asgi_app, host="0.0.0.0", port=8000)