Problems with sample when using left padding and enable sampling

#40
by RaccoonOnion - opened

My code structure is:

        generated_ids = model.generate(
            model_inputs.input_ids,
            attention_mask=model_inputs.attention_mask,
            pad_token_id=pad_token_id,
            max_length=200, 
            repetition_penalty=1.0,  
            do_sample=True,
            temperature=1.0, 
            top_k=50, 
            top_p=1.0,  
        )

I got errors like the following only when batch_size > 1:

File */lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)

File */lib/python3.10/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2016 input_ids, model_kwargs = self._expand_inputs_for_generation(
2017 input_ids=input_ids,
2018 expand_size=generation_config.num_return_sequences,
2019 is_encoder_decoder=self.config.is_encoder_decoder,
2020 **model_kwargs,
2021 )
2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False)
-> 2024 result = self._sample(
2025 input_ids,
2026 logits_processor=prepared_logits_processor,
2027 logits_warper=prepared_logits_warper,
2028 stopping_criteria=prepared_stopping_criteria,
2029 generation_config=generation_config,
2030 synced_gpus=synced_gpus,
2031 streamer=streamer,
2032 **model_kwargs,
2033 )
2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2036 # 11. prepare logits warper
2037 prepared_logits_warper = (
2038 self._get_logits_warper(generation_config, device=input_ids.device)
2039 if generation_config.do_sample
2040 else None
2041 )

File */lib/python3.10/site-packages/transformers/generation/utils.py:3020, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
3018 probs = nn.functional.softmax(next_token_scores, dim=-1)
3019 # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
-> 3020 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
3021 else:
3022 next_tokens = torch.argmax(next_token_scores, dim=-1)

RuntimeError: probability tensor contains either inf, nan or element < 0

I didn't have this problem with other Gemma models. The result is fine when batch size = 1.

My transformers version = 4.44.0, torch=2.4.0, device=H100*2. I am loading models using device_map='auto'

@GopiUppari Could you take a look at this issue and provide some feedback? Thank you very much!

Sign up or log in to comment