Speech / openedai.py
azils3's picture
Upload 38 files
7a3a321 verified
raw
history blame contribute delete
No virus
5.85 kB
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse, JSONResponse
from loguru import logger
class OpenAIError(Exception):
pass
class APIError(OpenAIError):
message: str
code: str = None
param: str = None
type: str = None
def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''):
super().__init__(message)
self.message = message
self.code = code
self.param = param
self.type = self.__class__.__name__,
self.internal_message = internal_message
def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__,
self.message,
self.code,
self.param,
)
class InternalServerError(APIError):
pass
class ServiceUnavailableError(APIError):
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
super().__init__(message, code, internal_message)
class APIStatusError(APIError):
status_code: int = 400
def __init__(self, message: str, param: str = None, internal_message: str = ''):
super().__init__(message, self.status_code, param, internal_message)
class BadRequestError(APIStatusError):
status_code: int = 400
class AuthenticationError(APIStatusError):
status_code: int = 401
class PermissionDeniedError(APIStatusError):
status_code: int = 403
class NotFoundError(APIStatusError):
status_code: int = 404
class ConflictError(APIStatusError):
status_code: int = 409
class UnprocessableEntityError(APIStatusError):
status_code: int = 422
class RateLimitError(APIStatusError):
status_code: int = 429
class OpenAIStub(FastAPI):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.models = {}
self.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@self.exception_handler(Exception)
def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse:
# Generic server errors
#logger.opt(exception=exc).error("Logging exception traceback")
return JSONResponse(status_code=500, content={
'message': 'InternalServerError',
'code': 500,
})
@self.exception_handler(APIError)
def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse:
# Server error
logger.opt(exception=exc).error("Logging exception traceback")
if exc.internal_message:
logger.info(exc.internal_message)
return JSONResponse(status_code = exc.code, content={
'message': exc.message,
'code': exc.code,
'type': exc.__class__.__name__,
'param': exc.param,
})
@self.exception_handler(APIStatusError)
def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse:
# Client side error
logger.info(repr(exc))
if exc.internal_message:
logger.info(exc.internal_message)
return JSONResponse(status_code = exc.code, content={
'message': exc.message,
'code': exc.code,
'type': exc.__class__.__name__,
'param': exc.param,
})
@self.middleware("http")
async def log_requests(request: Request, call_next):
logger.debug(f"Request path: {request.url.path}")
logger.debug(f"Request method: {request.method}")
logger.debug(f"Request headers: {request.headers}")
logger.debug(f"Request query params: {request.query_params}")
logger.debug(f"Request body: {await request.body()}")
response = await call_next(request)
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
return response
@self.get('/v1/billing/usage')
@self.get('/v1/dashboard/billing/usage')
async def handle_billing_usage():
return { 'total_usage': 0 }
@self.get("/", response_class=PlainTextResponse)
@self.head("/", response_class=PlainTextResponse)
@self.options("/", response_class=PlainTextResponse)
async def root():
return PlainTextResponse(content="", status_code=200 if self.models else 503)
@self.get("/health")
async def health():
return {"status": "ok" if self.models else "unk" }
@self.get("/v1/models")
async def get_model_list():
return self.model_list()
@self.get("/v1/models/{model}")
async def get_model_info(model_id: str):
return self.model_info(model_id)
def register_model(self, name: str, model: str = None) -> None:
self.models[name] = model if model else name
def deregister_model(self, name: str) -> None:
if name in self.models:
del self.models[name]
def model_info(self, model: str) -> dict:
result = {
"id": model,
"object": "model",
"created": 0,
"owned_by": "user"
}
return result
def model_list(self) -> dict:
if not self.models:
return {}
result = {
"object": "list",
"data": [ self.model_info(model) for model in list(set(self.models.keys() | self.models.values())) if model ]
}
return result