smgc commited on
Commit
0d5e69a
1 Parent(s): bb03080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import logging
2
- from fastapi import FastAPI, Request
3
  from fastapi.responses import StreamingResponse
4
  import httpx
5
  import json
6
  import os
7
- from helper import create_jwt, generate_random_tuple
8
 
9
  # 设置日志
10
  logging.basicConfig(level=logging.DEBUG)
@@ -16,19 +16,31 @@ app = FastAPI()
16
  async def chat(request: Request):
17
  logger.debug("Received request")
18
 
19
- # Generate JWT token
20
- github_username, user_id = generate_random_tuple()
 
 
 
 
 
 
 
 
 
 
 
 
21
  jwt_token = create_jwt(github_username, user_id)
22
  logger.debug(f"Generated JWT token: {jwt_token}")
23
 
24
- # Get the payload from the request
25
  payload = await request.json()
26
  logger.debug(f"Received payload: {payload}")
27
 
28
- # Get the model from the payload, defaulting to "claude-3-5-sonnet-20240620"
29
  model = payload.get('model', 'claude-3-5-sonnet-20240620')
30
 
31
- # Prepare the request for the LLM API
32
  url = "https://llm.zed.dev/completion?"
33
 
34
  llm_payload = {
@@ -49,12 +61,12 @@ async def chat(request: Request):
49
  'Host': 'llm.zed.dev',
50
  'accept': '*/*',
51
  'content-type': 'application/json',
52
- 'authorization': f'Bearer {jwt_token}',
53
  'user-agent': 'Zed/0.149.3 (macos; aarch64)'
54
  }
55
  logger.debug(f"Request headers: {headers}")
56
 
57
- # Get proxy from environment variable
58
  proxy = os.environ.get('HTTP_PROXY', None)
59
  proxies = {'http': proxy, 'https': proxy} if proxy else None
60
  logger.debug(f"Using proxies: {proxies}")
@@ -66,9 +78,9 @@ async def chat(request: Request):
66
  logger.debug(f"LLM API response status: {response.status_code}")
67
  logger.debug(f"LLM API response headers: {response.headers}")
68
  if response.status_code != 200:
69
- error_content = await response.read()
70
  logger.error(f"LLM API error response: {error_content}")
71
- yield f"Error: {response.status_code} - {error_content}"
72
  else:
73
  async for chunk in response.aiter_bytes():
74
  yield chunk
@@ -78,10 +90,6 @@ async def chat(request: Request):
78
 
79
  return StreamingResponse(generate(), media_type='application/octet-stream')
80
 
81
- @app.get("/")
82
- async def root():
83
- return {"message": "Welcome to the AI Chat Completions API"}
84
-
85
  if __name__ == '__main__':
86
  import uvicorn
87
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import logging
2
+ from fastapi import FastAPI, Request, HTTPException
3
  from fastapi.responses import StreamingResponse
4
  import httpx
5
  import json
6
  import os
7
+ from helper import create_jwt
8
 
9
  # 设置日志
10
  logging.basicConfig(level=logging.DEBUG)
 
16
  async def chat(request: Request):
17
  logger.debug("Received request")
18
 
19
+ # 获取客户端的 authorization
20
+ auth_header = request.headers.get('authorization')
21
+ if not auth_header or not auth_header.startswith('Bearer '):
22
+ raise HTTPException(status_code=401, detail="Invalid authorization header")
23
+
24
+ # 提取 github_username 和 user_id
25
+ try:
26
+ _, auth_data = auth_header.split('Bearer ', 1)
27
+ github_username, user_id_str = auth_data.split(',')
28
+ user_id = int(user_id_str)
29
+ except ValueError:
30
+ raise HTTPException(status_code=401, detail="Invalid authorization format")
31
+
32
+ # 生成 JWT token
33
  jwt_token = create_jwt(github_username, user_id)
34
  logger.debug(f"Generated JWT token: {jwt_token}")
35
 
36
+ # 获取请求 payload
37
  payload = await request.json()
38
  logger.debug(f"Received payload: {payload}")
39
 
40
+ # 获取模型,默认为 "claude-3-5-sonnet-20240620"
41
  model = payload.get('model', 'claude-3-5-sonnet-20240620')
42
 
43
+ # 准备 LLM API 请求
44
  url = "https://llm.zed.dev/completion?"
45
 
46
  llm_payload = {
 
61
  'Host': 'llm.zed.dev',
62
  'accept': '*/*',
63
  'content-type': 'application/json',
64
+ 'authorization': f'Bearer {jwt_token}', # 使用新生成的 JWT token
65
  'user-agent': 'Zed/0.149.3 (macos; aarch64)'
66
  }
67
  logger.debug(f"Request headers: {headers}")
68
 
69
+ # 获取代理设置
70
  proxy = os.environ.get('HTTP_PROXY', None)
71
  proxies = {'http': proxy, 'https': proxy} if proxy else None
72
  logger.debug(f"Using proxies: {proxies}")
 
78
  logger.debug(f"LLM API response status: {response.status_code}")
79
  logger.debug(f"LLM API response headers: {response.headers}")
80
  if response.status_code != 200:
81
+ error_content = await response.aread()
82
  logger.error(f"LLM API error response: {error_content}")
83
+ yield f"Error: {response.status_code} - {error_content.decode()}"
84
  else:
85
  async for chunk in response.aiter_bytes():
86
  yield chunk
 
90
 
91
  return StreamingResponse(generate(), media_type='application/octet-stream')
92
 
 
 
 
 
93
  if __name__ == '__main__':
94
  import uvicorn
95
  uvicorn.run(app, host="0.0.0.0", port=8000)