stop sequence
#2
by
LavGadewar
- opened
there is a parameter like "stop sequence " which is present in GPT- 3 is there are similar parameter in GPT_neo -2.7B model to stop the generation of tokens . and will not contain that sequence ?
LavGadewar
changed discussion status to
closed
LavGadewar
changed discussion status to
open
Yes, modify the code as needed:
import torch; device = torch.device("cuda")
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
return True
return False
sequence = ['\n','\n\n', '.\n', '. ', '. \n', '?', '!']
output = tokenizer.decode(model.generate(
**tokenizer( prompt, return_tensors='pt' ).to(device),
top_p=1,
top_k=0,
temperature=0.2,
max_new_tokens=18,
pad_token_id=50256,
no_repeat_ngram_size = 2,
stopping_criteria=StoppingCriteriaList([KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in sequence])]),
early_stopping=True,
do_sample=True,
)[0],
skip_special_tokens=True
)