from fastapi import FastAPI, Form, HTTPException, Request, File from fastapi.concurrency import asynccontextmanager from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from typing import Annotated from wtpsplit import SaT sat_models = {} @asynccontextmanager async def lifespan(app: FastAPI): # Load the ML model sat_models["sat-3l-sm"] = SaT("sat-3l-sm") yield # Clean up the ML models and release the resources sat_models.clear() app = FastAPI(lifespan=lifespan) templates = Jinja2Templates(directory="templates") @app.get("/", response_class=HTMLResponse) def root(request: Request): return templates.TemplateResponse(request=request, name="index.html") @app.post("/split", response_class=HTMLResponse) async def split_text(request: Request, text: Annotated[str, Form()] = ""): sentences = sat_models["sat-3l-sm"].split(text) return templates.TemplateResponse( request=request, name="index.html", context={"sentences": sentences} ) @app.post("/api/split") async def split_file(file: Annotated[bytes, File()]): if len(file) > 1.44 * 1024 * 1024: # 1.44 MB raise HTTPException(status_code=413, detail="File too large") text = file.decode("utf-8") sentences = sat_models["sat-3l-sm"].split(text) return {"sentences": sentences}