|
from fastapi import FastAPI, Form, HTTPException, Request, File |
|
from fastapi.concurrency import asynccontextmanager |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.templating import Jinja2Templates |
|
from typing import Annotated |
|
from wtpsplit import SaT |
|
|
|
sat_models = {} |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
sat_models["sat-3l-sm"] = SaT("sat-3l-sm") |
|
yield |
|
|
|
sat_models.clear() |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def root(request: Request): |
|
return templates.TemplateResponse(request=request, name="index.html") |
|
|
|
|
|
@app.post("/split", response_class=HTMLResponse) |
|
async def split_text(request: Request, text: Annotated[str, Form()] = ""): |
|
sentences = sat_models["sat-3l-sm"].split(text) |
|
return templates.TemplateResponse( |
|
request=request, name="index.html", context={"sentences": sentences} |
|
) |
|
|
|
@app.post("/api/split") |
|
async def split_file(file: Annotated[bytes, File()]): |
|
if len(file) > 1.44 * 1024 * 1024: |
|
raise HTTPException(status_code=413, detail="File too large") |
|
text = file.decode("utf-8") |
|
sentences = sat_models["sat-3l-sm"].split(text) |
|
return {"sentences": sentences} |