File size: 861 Bytes
70c8b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from typing import Dict, List, Any
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX
class PreTrainedPipeline():
def __init__(self, path: str):
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
directory=path,
spe_filename="sp.model",
model_filename="model.onnx",
config_filename="config.yaml",
)
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)
def __call__(self, data: str) -> List[Dict]:
# Use list to generate a batch of size 1
pred_texts: List[List[str]] = self._punctuator.infer([data])
# Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
return outputs
|