File size: 5,854 Bytes
7a3a321 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse, JSONResponse
from loguru import logger
class OpenAIError(Exception):
pass
class APIError(OpenAIError):
message: str
code: str = None
param: str = None
type: str = None
def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''):
super().__init__(message)
self.message = message
self.code = code
self.param = param
self.type = self.__class__.__name__,
self.internal_message = internal_message
def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__,
self.message,
self.code,
self.param,
)
class InternalServerError(APIError):
pass
class ServiceUnavailableError(APIError):
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
super().__init__(message, code, internal_message)
class APIStatusError(APIError):
status_code: int = 400
def __init__(self, message: str, param: str = None, internal_message: str = ''):
super().__init__(message, self.status_code, param, internal_message)
class BadRequestError(APIStatusError):
status_code: int = 400
class AuthenticationError(APIStatusError):
status_code: int = 401
class PermissionDeniedError(APIStatusError):
status_code: int = 403
class NotFoundError(APIStatusError):
status_code: int = 404
class ConflictError(APIStatusError):
status_code: int = 409
class UnprocessableEntityError(APIStatusError):
status_code: int = 422
class RateLimitError(APIStatusError):
status_code: int = 429
class OpenAIStub(FastAPI):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.models = {}
self.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@self.exception_handler(Exception)
def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse:
# Generic server errors
#logger.opt(exception=exc).error("Logging exception traceback")
return JSONResponse(status_code=500, content={
'message': 'InternalServerError',
'code': 500,
})
@self.exception_handler(APIError)
def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse:
# Server error
logger.opt(exception=exc).error("Logging exception traceback")
if exc.internal_message:
logger.info(exc.internal_message)
return JSONResponse(status_code = exc.code, content={
'message': exc.message,
'code': exc.code,
'type': exc.__class__.__name__,
'param': exc.param,
})
@self.exception_handler(APIStatusError)
def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse:
# Client side error
logger.info(repr(exc))
if exc.internal_message:
logger.info(exc.internal_message)
return JSONResponse(status_code = exc.code, content={
'message': exc.message,
'code': exc.code,
'type': exc.__class__.__name__,
'param': exc.param,
})
@self.middleware("http")
async def log_requests(request: Request, call_next):
logger.debug(f"Request path: {request.url.path}")
logger.debug(f"Request method: {request.method}")
logger.debug(f"Request headers: {request.headers}")
logger.debug(f"Request query params: {request.query_params}")
logger.debug(f"Request body: {await request.body()}")
response = await call_next(request)
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
return response
@self.get('/v1/billing/usage')
@self.get('/v1/dashboard/billing/usage')
async def handle_billing_usage():
return { 'total_usage': 0 }
@self.get("/", response_class=PlainTextResponse)
@self.head("/", response_class=PlainTextResponse)
@self.options("/", response_class=PlainTextResponse)
async def root():
return PlainTextResponse(content="", status_code=200 if self.models else 503)
@self.get("/health")
async def health():
return {"status": "ok" if self.models else "unk" }
@self.get("/v1/models")
async def get_model_list():
return self.model_list()
@self.get("/v1/models/{model}")
async def get_model_info(model_id: str):
return self.model_info(model_id)
def register_model(self, name: str, model: str = None) -> None:
self.models[name] = model if model else name
def deregister_model(self, name: str) -> None:
if name in self.models:
del self.models[name]
def model_info(self, model: str) -> dict:
result = {
"id": model,
"object": "model",
"created": 0,
"owned_by": "user"
}
return result
def model_list(self) -> dict:
if not self.models:
return {}
result = {
"object": "list",
"data": [ self.model_info(model) for model in list(set(self.models.keys() | self.models.values())) if model ]
}
return result
|