import re import torch from transformers import DonutProcessor from transformers.utils import add_start_docstrings from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING # Inspired on https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706 class NoRepeatNGramLogitsProcessor(LogitsProcessor): def __init__(self, ngram_size: int, skip_tokens = None): if not isinstance(ngram_size, int) or ngram_size <= 0: raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") self.ngram_size = ngram_size self.skip_tokens = skip_tokens @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: num_batch_hypotheses = scores.shape[0] cur_len = input_ids.shape[-1] return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens) def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None): if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) for batch_idx in range(batch_size): if skip_tokens is not None: logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf") else: logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") return logits def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): # Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < no_repeat_ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet return [[] for _ in range(num_hypos)] generated_ngrams = [{} for _ in range(num_hypos)] for idx in range(num_hypos): gen_tokens = prev_input_ids[idx].tolist() generated_ngram = generated_ngrams[idx] for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): prev_ngram_tuple = tuple(ngram[:-1]) generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ ngram[-1] ] def _get_generated_ngrams(hypo_idx): # Before decoding the next token, prevent decoding of ngrams that have already appeared start_idx = cur_len + 1 - no_repeat_ngram_size ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) return generated_ngrams[hypo_idx].get(ngram_idx, []) banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] return banned_tokens def get_table_token_ids(processor): return {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("