|
from transformers import Pipeline, AutoModelForSequenceClassification,AutoTokenizer |
|
import torch |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
|
|
class TBCP(Pipeline): |
|
def __init__(self,**kwargs): |
|
Pipeline.__init__(self,**kwargs) |
|
self.tokenizer = AutoTokenizer.from_pretrained(kwargs["tokenizer"]) |
|
def _sanitize_parameters(self, **kwargs): |
|
postprocess_kwargs = {} |
|
if "text_pair" in kwargs: |
|
postprocess_kwargs["top_k"] = kwargs["top_k"] |
|
return {}, {}, postprocess_kwargs |
|
|
|
def preprocess(self, text): |
|
return self.tokenizer(text, return_tensors="pt") |
|
|
|
def _forward(self, model_inputs): |
|
return self.model(**model_inputs) |
|
|
|
def postprocess(self, model_outputs,top_k = None): |
|
logits = model_outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
best_class = probabilities.argmax().item() |
|
label = f"Label_{best_class}" |
|
|
|
logits = logits.squeeze().tolist() |
|
return {"label": label, |
|
|
|
"logits": logits} |
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
"TunBERT-classifier", |
|
pipeline_class=TBCP, |
|
pt_model=AutoModelForSequenceClassification, |
|
default={"pt": ("not-lain/TunBERT", "main")}, |
|
type="text", |
|
) |