KoQuestionBART / pipeline.py
zzrng76's picture
Create pipeline.py
7ac259c
raw
history blame
2.19 kB
from transformers import Pipeline
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "max_length" in kwargs:
preprocess_kwargs["max_length"] = kwargs["max_length"]
if "num_beams" in kwargs:
preprocess_kwargs["num_beams"] = kwargs["num_beams"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, **kwargs):
inputs = re.sub(r'[^A-Za-z가-힣,<>0-9:&# ]', '', inputs)
inputs = "질문 생성: <unused0>"+inputs
input_ids = [tokenizer.bos_token_id] + tokenizer.encode(inputs) + [tokenizer.eos_token_id]
return {"inputs":torch.tensor([input_ids]),'max_length':kwargs['max_length'],'num_beams':kwargs['num_beams'] }
def _forward(self, model_inputs):
res_ids = model.generate(
model_inputs['inputs'],
max_length=model_inputs['max_length'],
num_beams=model_inputs['num_beams'],
eos_token_id=tokenizer.eos_token_id,
bad_words_ids=[[tokenizer.unk_token_id]]
)
return {"logits": res_ids}
def postprocess(self, model_outputs):
a = tokenizer.batch_decode(model_outputs["logits"].tolist())[0]
out_question = a.replace('<s>', '').replace('</s>', '')
return out_question
def _inference(self,paragraph,**kwargs):
input_ids = self.preprocess(paragraph,**kwargs)
reds_ids = self._forward(input_ids)
out_question = self.postprocess(reds_ids)
return out_question
def make_question(self, text, **kwargs):
words = text.split(" ")
frame_size = kwargs['frame_size']
hop_length = kwargs['hop_length']
steps = round((len(words)-frame_size)/hop_length) + 1
outs = []
for step in range(steps):
try:
script = " ".join(words[step*hop_length:step*hop_length+frame_size])
except:
script = " ".join(words[(1+step)*hop_length:])
outs.append(self._inference(script,**kwargs))
#if step>4:
# break
return outs