diff --git "a/custom_bart/bart_generation_mixin.py" "b/custom_bart/bart_generation_mixin.py" new file mode 100644--- /dev/null +++ "b/custom_bart/bart_generation_mixin.py" @@ -0,0 +1,3272 @@ +import inspect +import warnings +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch import nn + +from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint +from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from transformers.generation_logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, + LogitNormalization, + LogitsProcessorList, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) +from transformers.generation_stopping_criteria import ( + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) +from transformers.pytorch_utils import torch_int_div +from transformers.utils import ModelOutput + +from transformers.generation_utils import ( + SampleOutput, + BeamSearchOutput, + BeamSampleOutput, + GreedySearchOutput, GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, + BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSampleDecoderOnlyOutput, + BeamSampleEncoderDecoderOutput, SampleEncoderDecoderOutput, +) +from utils import get_jump_chunks +from torch.nn.utils.rnn import pad_sequence + +class GenerationMixin: + """ + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + + The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], + if `constraints!=None` or `force_words_ids!=None`. + """ + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if ( + self.config.is_encoder_decoder + and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name + ): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside " + f"{input_name} which is not allowed." + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. models with `input_ids` can also make use of `inputs_embeds` + if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. Only encoder-decoder models can have non `input_ids` input format + if not self.config.is_encoder_decoder and input_name != "input_ids": + raise ValueError( + f"If {input_name} is passed as model-specific keyword " + "input then model has to be an encoder-decoder and not a " + f"{self.__class__.__name__}." + ) + + # 5. if `inputs` is still None, try to create `input_ids` from BOS token + if inputs is None: + inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) + + return inputs, input_name, model_kwargs + + def _can_retrieve_inputs_from_name( + self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] + ) -> torch.Tensor: + """ + If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved + from name + """ + can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( + inspect.signature(self.forward).parameters.keys() + ) + + if can_retrieve_inputs and inputs is not None: + raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") + + return can_retrieve_inputs + + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + """ + Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. + """ + return {"input_ids": input_ids} + + def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + """ + Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. + """ + return logits + + def _prepare_input_ids_for_generation( + self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] + ) -> torch.LongTensor: + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.last_hidden_state.size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_attention_mask_for_generation( + self, + inputs: torch.Tensor, + pad_token_id: int, + eos_token_id: int, + ) -> torch.LongTensor: + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + # Check if input is input_ids and padded -> only then is attention_mask defined + if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: + return inputs.ne(pad_token_id).long() + else: + return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None + ) -> Dict[str, Any]: + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + print('encoder_kwargs:', encoder_kwargs) + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) + + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + device: torch.device = None, + ) -> torch.LongTensor: + + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + return model_kwargs.pop("decoder_input_ids") + else: + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None + ): + return self.config.decoder.decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + return self.config.decoder.bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) + ) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + @staticmethod + def _update_model_kwargs_for_generation( + outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + ) -> Dict[str, Any]: + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + def _reorder_cache(self, past, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" + ) + + def _get_logits_warper( + self, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + temperature: Optional[float] = None, + num_beams: Optional[int] = None, + renormalize_logits: Optional[bool] = None, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # init warp parameters + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + typical_p = typical_p if typical_p is not None else self.config.typical_p + temperature = temperature if temperature is not None else self.config.temperature + # instantiate warpers list + warpers = LogitsProcessorList() + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if temperature is not None and temperature != 1.0: + warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + # `LogitNormalization` should always be the last logit processor, when present + if renormalize_logits is True: + warpers.append(LogitNormalization()) + return warpers + + def _get_logits_processor( + self, + repetition_penalty: float, + no_repeat_ngram_size: int, + encoder_no_repeat_ngram_size: int, + input_ids_seq_length: int, + encoder_input_ids: torch.LongTensor, + bad_words_ids: List[List[int]], + min_length: int, + max_length: int, + eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], + num_beams: int, + num_beam_groups: int, + diversity_penalty: float, + remove_invalid_values: bool, + exponential_decay_length_penalty: Tuple, + logits_processor: Optional[LogitsProcessorList], + renormalize_logits: Optional[bool], + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + processors = LogitsProcessorList() + + # init warp parameters + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + encoder_no_repeat_ngram_size = ( + encoder_no_repeat_ngram_size + if encoder_no_repeat_ngram_size is not None + else self.config.encoder_no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty + forced_bos_token_id = ( + forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id + ) + remove_invalid_values = ( + remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values + ) + exponential_decay_length_penalty = ( + exponential_decay_length_penalty + if exponential_decay_length_penalty is not None + else self.config.exponential_decay_length_penalty + ) + # instantiate processors list + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if diversity_penalty is not None and diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + ) + if repetition_penalty is not None and repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) + if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: + if self.config.is_encoder_decoder: + processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) + else: + raise ValueError( + "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" + ) + if bad_words_ids is not None: + processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) + if min_length is not None and eos_token_id is not None and min_length > 0: + processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) + if prefix_allowed_tokens_fn is not None: + processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) + if forced_bos_token_id is not None: + processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + if forced_eos_token_id is not None: + processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) + if remove_invalid_values is True: + processors.append(InfNanRemoveLogitsProcessor()) + if exponential_decay_length_penalty is not None: + processors.append( + ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) + ) + processors = self._merge_criteria_processor_list(processors, logits_processor) + # `LogitNormalization` should always be the last logit processor, when present + if renormalize_logits is True: + processors.append(LogitNormalization()) + return processors + + def _get_stopping_criteria( + self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if max_length is not None: + criteria.append(MaxLengthCriteria(max_length=max_length)) + if max_time is not None: + criteria.append(MaxTimeCriteria(max_time=max_time)) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _merge_criteria_processor_list( + self, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " + f"but it has already been created with the values {default}. {default} has been created by passing the " + "corresponding arguments to generate or by the model's config default values. " + f"If you just want to change the default values of {object_type} consider passing them as arguments " + f"to `generate` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + + def compute_transition_beam_scores( + self, + sequences: torch.Tensor, + scores: Tuple[torch.Tensor], + beam_indices: torch.Tensor, + eos_token_id: int = None, + ): + """compute the transition probabilities of sequences given generation + scores and beam indices""" + + # reshape scores as [vocab_size * batch_size, # generation steps] + # with batch_size being 2 * vocab_size and # generation steps being + # seq_len - input_length + scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) + + # start of generated tokens + cut_idx = sequences.shape[-1] - scores.shape[-1] + # adjust for beam indices + beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size + # compute real indices + indices = sequences[:, cut_idx:] + beam_sequence_indices + # gather scores and run + transition_scores = scores.gather(0, indices) + # make sure that if EOS token was used before length of sequence `sequence.shape[-1]` + # get first occurence of EOS token + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + if eos_token_id is not None: + is_eos_token_id = sequences[:, cut_idx:] == eos_token_id + # make sure first eos token still contributes to transition probs + is_eos_token_id[:, -1] = False + is_eos_token_id = is_eos_token_id.roll(1, -1) + # all indices after eos shoud be masked + zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool() + # zero out padded probs + transition_scores.masked_fill_(zero_transition_prob_mask, 0.0) + + return transition_scores + + # ADDED FRED + def remove_subsets(self, l): + #l = [[1, 2, 4, 8], [1, 2, 4, 5, 6], [1, 2, 3], [2, 3, 21], [1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7]] + l2 = l[:] + for m in l: + for n in l: + if set(m).issubset(set(n)) and m != n: + l2.remove(m) + break + return l2 + + # ADDED FRED + @torch.no_grad() + def cs_generate( + self, + inputs: Optional[torch.Tensor] = None, + contexts:List[str]=None, #input data + model_input:Dict=None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + renormalize_logits: Optional[bool] = None, + stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, + use_kg:bool=False, #added + relation_mapper_builder=None, + tokenizer=None, + max_neig_per_concept=1, #it slows down quite a lot + **model_kwargs, + ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + # print(model_input) + input_ids = model_input['input_ids'] + if "input_commonsense_relations" in model_input: + # print(model_input['input_commonsense_relations'].sum()) + model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(input_ids.device) + if use_kg: + all_constraints = [] + print('contexts:', contexts[:3]) + for context in contexts: + constraints = [] + print('+++++++') + concepts_from_context = relation_mapper_builder.get_concepts_from_context(context=context, + clear_common_wds=True, alignment=1) + print('concepts_from_context:', concepts_from_context) + useful_concepts = [relation_mapper_builder.swow_knowledge.get_related_concepts(concept) for concept in + concepts_from_context] + if not useful_concepts: + useful_concepts = [relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] + useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces + # useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts] + # useful_concepts = list(itertools.chain.from_iterable(useful_concepts)) + # print('useful_concepts:', useful_concepts[:5]) + print('-------') + print('useful_concepts:', useful_concepts) + if concepts_from_context and useful_concepts: + for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): + print('neighbour:', neighbour_concepts[:5]) + # flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound + # flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts + flexible_words = [word for word in neighbour_concepts if + word not in context_concept] # remove input concepts + print('flexible_words:', flexible_words[:5]) + if not flexible_words: + continue + flexible_words_ids: List[List[int]] = tokenizer(flexible_words, add_special_tokens=False).input_ids #add_prefix_space=True, + flexible_words_ids = self.remove_subsets(flexible_words_ids) + # add_prefix_space=True + # flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets + flexible_words_ids = flexible_words_ids[:max_neig_per_concept] + #print('flexible_words_ids:', flexible_words_ids[:3]) + constraint = DisjunctiveConstraint(flexible_words_ids) + constraints.append(constraint) + all_constraints.extend(constraints) + else: + all_constraints = None + + generated_answers_encoded = self.generate(input_ids=input_ids, + #attention_mask=model_input["attention_mask"].to(input_ids.device), + constraints=all_constraints, + min_length=min_length, + #max_length=max_length, + do_sample=do_sample, + early_stopping=early_stopping, + #num_beams=num_beams, + temperature=temperature, + top_k=top_k, + top_p=top_p, + # eos_token_id=tokenizer.eos_token_id, + no_repeat_ngram_size=no_repeat_ngram_size, + num_return_sequences=num_return_sequences, + return_dict_in_generate=return_dict_in_generate, + output_attentions=output_attentions, + output_scores=output_scores, + **model_kwargs, + ) + return generated_answers_encoded + + # ADDED FRED + @torch.no_grad() + def cs_simple_generate( + self, + inputs: Optional[torch.Tensor] = None, + neighbours_contexts:List[List[str]]=None, #input data + model_input:Dict=None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + renormalize_logits: Optional[bool] = None, + stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, + use_kg:bool=False, #added + relation_mapper_builder=None, + tokenizer=None, + max_concepts=2, #it slows down quite a lot + **model_kwargs, + ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + # print(model_input) + input_ids = model_input['input_ids'] + if use_kg: + all_constraints = [] + for context_neighbours in neighbours_contexts: + # context_neighbours is a collection of concepts + # lets create sub collections of concepts + context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] + n_size_chuncks = len(context_neighbours) // max_concepts + n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 + sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) + constraints = [] + for sub_concepts in sub_concepts_collection[:max_concepts]: + flexible_words_ids: List[List[int]] = tokenizer(sub_concepts, add_special_tokens=False).input_ids #add_prefix_space=True, + #flexible_words_ids = self.remove_subsets(flexible_words_ids) + flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] + disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) + + # add_prefix_space=True + # flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets + #flexible_words_ids = flexible_words_ids[:max_neig_per_concept] + #print('flexible_words_ids:', flexible_words_ids[:3]) + if not any(disjunctive_set): + continue + constraint = DisjunctiveConstraint(disjunctive_set) + constraints.append(constraint) + if not any(constraints): + constraints=None + all_constraints.append(constraints) + else: + all_constraints = None + if not all_constraints: + all_constraints = None + + generated_answers_encoded = [] + #print('all_constraints:', all_constraints) + for i, contraints in enumerate(all_constraints): + #print('contraints.token_ids:', [x.token_ids for x in contraints]) + if "input_commonsense_relations" in model_input: + # print(model_input['input_commonsense_relations'].sum()) + model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations")[i].unsqueeze(0).to(input_ids.device) + #print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask")) + model_kwargs["attention_mask"] = model_input.get("attention_mask")[i].unsqueeze(0).to(input_ids.device) + gen = self.generate(input_ids=input_ids[i].unsqueeze(0), + constraints=contraints, + min_length=min_length, + #max_length=max_length, + do_sample=do_sample, + early_stopping=early_stopping, + #num_beams=num_beams, + temperature=temperature, + top_k=top_k, + top_p=top_p, + # eos_token_id=tokenizer.eos_token_id, + no_repeat_ngram_size=no_repeat_ngram_size, + num_return_sequences=num_return_sequences, + return_dict_in_generate=return_dict_in_generate, + output_attentions=output_attentions, + output_scores=output_scores, + **model_kwargs) + #print('[gen]:', gen) + #print(tokenizer.batch_decode(gen)) + generated_answers_encoded.append(gen[0].detach().cpu()) + #torch.LongTensor(generated_answers_encoded) + #print('generated_answers_encoded:', generated_answers_encoded) + return torch.LongTensor(pad_sequence(generated_answers_encoded, batch_first=True, padding_value=tokenizer.pad_token_id)).to(input_ids.device) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + renormalize_logits: Optional[bool] = None, + stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, + **model_kwargs, + ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling + [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or + `force_words_ids!=None`. + + + + Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as + defined in the model's config (`config.json`) which in turn defaults to the + [`~modeling_utils.PretrainedConfig`] of the model. + + + + Most of these parameters are explained in more detail in [this blog + post](https://huggingface.co/blog/how-to-generate). + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + max_length (`int`, *optional*, defaults to `model.config.max_length`): + The maximum length of the sequence to be generated. + max_new_tokens (`int`, *optional*, defaults to None): + The maximum numbers of tokens to generate, ignore the current number of tokens. Use either + `max_new_tokens` or `max_length` but not both, they serve the same purpose. + min_length (`int`, *optional*, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + num_beams (`int`, *optional*, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (`float`, *optional*, defaults to 1.0): + The value used to module the next token probabilities. + top_k (`int`, *optional*, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher + are kept for generation. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + bos_token_id (`int`, *optional*): + The id of the *beginning-of-sequence* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. + 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer + sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences. + no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the + `decoder_input_ids`. + bad_words_ids(`List[List[int]]`, *optional*): + List of token ids that are not allowed to be generated. In order to get the token ids of the words that + should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, + add_special_tokens=False).input_ids`. + force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): + List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple + list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, + this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), + where one can allow different forms of each word. + num_return_sequences(`int`, *optional*, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + max_time(`float`, *optional*, defaults to None): + The maximum amount of time you allow the computation to run for in seconds. generation will still + finish the current pass after allocated time has been passed. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens + that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape + as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) + decoder_start_token_id (`int`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + use_cache: (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + num_beam_groups (`int`, *optional*, defaults to 1): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of + beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + diversity_penalty (`float`, *optional*, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is + enabled. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and a + model's config. If a logit processor is passed that is already created with the arguments or a model's + config an error is thrown. This feature is intended for advanced users. + renormalize_logits: (`bool`, *optional*, defaults to `False`): + Whether to renormalize the logits after applying all the logits processors or warpers (including the + custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the + score logits are normalized but some logit processors or warpers break the normalization. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + model's config. If a stopping criteria is passed that is already created with the arguments or a + model's config an error is thrown. This feature is intended for advanced users. + constraints (`List[Constraint]`, *optional*): + Custom constraints that can be added to the generation to ensure that the output will contain the use + of certain tokens as defined by `Constraint` objects, in the most sensible way possible. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + forced_bos_token_id (`int`, *optional*): + The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful + for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be + the target language token. + forced_eos_token_id (`int`, *optional*): + The id of the token to force as the last generated token when `max_length` is reached. + remove_invalid_values (`bool`, *optional*): + Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to + crash. Note that using `remove_invalid_values` can slow down generation. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + exponential_decay_length_penalty (`tuple(int, float)`, *optional*): + This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been + generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates + where penalty starts and `decay_factor` represents the factor of exponential decay + + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model + is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs + should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation_utils.GreedySearchDecoderOnlyOutput`], + - [`~generation_utils.SampleDecoderOnlyOutput`], + - [`~generation_utils.BeamSearchDecoderOnlyOutput`], + - [`~generation_utils.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation_utils.GreedySearchEncoderDecoderOutput`], + - [`~generation_utils.SampleEncoderDecoderOutput`], + - [`~generation_utils.BeamSearchEncoderDecoderOutput`], + - [`~generation_utils.BeamSampleEncoderDecoderOutput`] + + Examples: + + Greedy Decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` + + Multinomial Sampling: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # sample up to 30 tokens + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] + ``` + + Beam-search decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids + + >>> outputs = model.generate(input_ids) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] + ```""" + # 1. Set generation parameters if not already defined + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + num_beams = num_beams if num_beams is not None else self.config.num_beams + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups + do_sample = do_sample if do_sample is not None else self.config.do_sample + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + if eos_token_id is None and hasattr(self.config, "decoder"): + eos_token_id = self.config.decoder.eos_token_id + + if pad_token_id is None and eos_token_id is not None: + # special case if pad_token_id is not defined + print(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # 2. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + batch_size = inputs_tensor.shape[0] + + # 3. Define other model kwargs + model_kwargs["output_attentions"] = output_attentions + model_kwargs["output_hidden_states"] = output_hidden_states + model_kwargs["use_cache"] = use_cache + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, pad_token_id, eos_token_id + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 4. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + input_ids_seq_length = input_ids.shape[-1] + + # 5. Prepare `max_length` depending on other stopping criteria + # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` + if max_length is None and max_new_tokens is not None: + max_length = max_new_tokens + input_ids_seq_length + elif max_length is not None and max_new_tokens is not None: + # Both are set, this is odd, raise a warning + warnings.warn( + "Both `max_length` and `max_new_tokens` have been set " + f"but they serve the same purpose. `max_length` {max_length} " + f"will take priority over `max_new_tokens` {max_new_tokens}.", + UserWarning, + ) + # default to config if still None + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + + if min_length is not None and min_length > max_length: + raise ValueError( + f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " + f"length ({max_length})" + ) + if input_ids_seq_length >= max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + print( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " + "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." + ) + + # 6. determine generation mode + is_constraint_gen_mode = constraints is not None or force_words_ids is not None + is_greedy_gen_mode = ( + (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + ) + is_sample_gen_mode = ( + (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + ) + is_beam_gen_mode = ( + (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + ) + is_beam_sample_gen_mode = ( + (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + ) + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode + + if num_beam_groups > num_beams: + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") + if is_group_beam_gen_mode and do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + # 7. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + bad_words_ids=bad_words_ids, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, + remove_invalid_values=remove_invalid_values, + exponential_decay_length_penalty=exponential_decay_length_penalty, + logits_processor=logits_processor, + renormalize_logits=renormalize_logits, + ) + + # 8. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria + ) + + # 9. go into different generation modes + if is_greedy_gen_mode: + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + ) + + # 10. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 10. prepare logits warper + logits_warper = self._get_logits_warper( + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + num_beams=num_beams, + renormalize_logits=renormalize_logits, + ) + + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 12. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_gen_mode: + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + # 10. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=inputs_tensor.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 10. prepare logits warper + logits_warper = self._get_logits_warper( + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + num_beams=num_beams, + renormalize_logits=renormalize_logits, + ) + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * num_return_sequences, + num_beams=num_beams, + device=inputs_tensor.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + ) + + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_beams * num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if num_beams % num_beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + # 10. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=num_beam_groups, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + if num_beams <= 1: + raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.") + + if do_sample: + raise ValueError("`do_sample` needs to be false for constrained generation.") + + if num_beam_groups is not None and num_beam_groups > 1: + raise ValueError("`num_beam_groups` not supported yet for constrained generation.") + + final_constraints = [] + if constraints is not None: + final_constraints = constraints + + if force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {force_words_ids}." + ) + + if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: + typeerror() + + for word_ids in force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 10. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=num_beams, + device=inputs_tensor.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + # 12. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[GreedySearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be + used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> outputs = model.greedy_search( + ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + cur_len = input_ids.shape[-1] + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return GreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + def sample( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + cur_len = input_ids.shape[-1] + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return SampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + def beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + #Normal execution + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores_processed,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + else: + num_return_sequences = beam_scorer.num_beam_hyps_to_keep + # return only as many indices as sequences + beam_indices = tuple( + (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) + ) + beam_indices = sum(beam_indices, ()) + + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + def beam_sample( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[BeamSampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search multinomial + sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> outputs = model.beam_sample( + ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (logits_warper(input_ids, next_token_scores_processed),) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + probs = nn.functional.softmax(next_token_scores, dim=-1) + + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + else: + num_return_sequences = beam_scorer.num_beam_hyps_to_keep + # return only as many indices as sequences + beam_indices = tuple( + (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) + ) + beam_indices = sum(beam_indices, ()) + + if self.config.is_encoder_decoder: + return BeamSampleEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSampleDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + def group_beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **diverse beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + + model_kwargs: + Additional model specific kwargs that will be forwarded to the `forward` function of the model. If + model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if + `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a + [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... HammingDiversityLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run diverse beam search using 6 beams + >>> num_beams = 6 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... num_beam_groups=3, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.group_beam_search( + ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + device = input_ids.device + + batch_beam_size, cur_len = input_ids.shape + + if return_dict_in_generate and output_scores: + beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] + else: + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + next_token_logits = outputs.logits[batch_group_indices, -1, :] + + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + if output_scores: + processed_score[batch_group_indices] = next_token_scores_processed + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + if return_dict_in_generate and output_scores: + beam_indices[beam_group_idx] = tuple( + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + ) + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size) + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (processed_score,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + else: + beam_indices = sum(beam_indices, ()) + num_return_sequences = beam_scorer.num_beam_hyps_to_keep + # return only as many indices as sequences + beam_indices = tuple( + (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) + ) + beam_indices = sum(beam_indices, ()) + + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=beam_indices, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + def constrained_beam_search( + self, + input_ids: torch.LongTensor, + constrained_beam_scorer: ConstrainedBeamSearchScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = None, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: + + r""" + Generates sequences of token ids for models with a language modeling head using **constrained beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + constrained_beam_scorer (`ConstrainedBeamSearchScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation, while satisfying a list of positive constraints. For more information, the + documentation of [`ConstrainedBeamSearchScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... ConstrainedBeamSearchScorer, + ... PhrasalConstraint, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> constraint_str = "Sie" + >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token + >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] + + + >>> # instantiate beam scorer + >>> beam_scorer = ConstrainedBeamSearchScorer( + ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model.constrained_beam_search( + ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt sind Sie?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + batch_size = len(constrained_beam_scorer._beam_hyps) + num_beams = constrained_beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `nn.functional.log_softmax` operation. + next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + + scores_for_all_vocab = next_token_scores_processed.clone() + + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = (next_tokens / vocab_size).long() + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = constrained_beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + scores_for_all_vocab, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + + # increase cur_len + cur_len = cur_len + 1 + + if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = constrained_beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] + + +def top_k_top_p_filtering( + logits: torch.FloatTensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + + Args: + logits: logits distribution shape (batch size, vocabulary size) + top_k (`int`, *optional*, defaults to 0): + If > 0, only keep the top k tokens with highest probability (top-k filtering) + top_p (`float`, *optional*, defaults to 1.0): + If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus + filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimumber of tokens we keep per batch example in the output. + + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) + + if 0 <= top_p <= 1.0: + logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) + + return logits