Completion stuck in a loop
Hi, @ParkerBurchett . In the inference widget it doesn't seem to work too fine. I didn't tested it there because it is slowly. Can you test in a Colab? I will add the recommend generation params
Sure I got it up in Colab. It's still in a loop
https://colab.research.google.com/drive/1YcH40MhJfYiFd39Q361SC6cauq2YFx9v?usp=sharing
prompt = f"McDonald's hamburger promotion on a red billboard, white lettering, "
input_ids = tokenizer(prompt, return_tensors="pt").to('cuda')
sample = model.generate(**input_ids, max_new_tokens=50)
tokenizer.decode(sample[0])
McDonald's hamburger promotion on a red billboard, white lettering, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital bill
@mrm8488 which are the generation params you are using?
I had the same issue using the default parameters but I put a high repetition_penalty
to fix it.
In Colab it still gets stuck in a loop
import torch
from transformers import BloomTokenizerFast, BloomForCausalLM
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = 'mrm8488/bloom-560m-finetuned-sd-prompts'
tokenizer = BloomTokenizerFast.from_pretrained(ckpt)
model = BloomForCausalLM.from_pretrained(ckpt).to(device)
def generate_prompt(text):
torch.cuda.empty_cache()
inputs = tokenizer(text, return_tensors='pt')
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output = model.generate(input_ids, attention_mask=attention_mask, max_length=512, eos_token_id=tokenizer.eos_token_id)
return tokenizer.decode(output[0], skip_special_tokens=False)
text = "<s>Prompt: pikachu dinning in the eiffel tower"
text2 = f"<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering, "
generate_prompt(text2)
Returns
<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital bil
@undefined2 changing the repetition_penalty to 1.05 worked for me too.
def generate_prompt(text):
inputs = tokenizer(text, return_tensors='pt')
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output = model.generate(input_ids, attention_mask=attention_mask, repetition_penalty=1.05, max_length=512, eos_token_id=tokenizer.eos_token_id)
return tokenizer.decode(output[0], skip_special_tokens=False)
text2 = f"<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering,"
generate_prompt(text2)
<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering, advertisement with posters and flyrets in the style of artgerm.</s>
@mrm8488
Maybe change the default repetition_penalty
to slightly over 1 in the example code?
I got the problem when running the code, it generates the exactly same sentence every time. What should I do?