import logging import torch from tqdm.auto import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer def load_model_and_tokenizer(model_name): """ load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface Args: model_name (str): the name of the model to load Returns: AutoModelForSeq2SeqLM: the model AutoTokenizer: the tokenizer """ model = AutoModelForSeq2SeqLM.from_pretrained( model_name, # low_cpu_mem_usage=True, # use_cache=False, ) tokenizer = AutoTokenizer.from_pretrained(model_name) model = model.to("cuda") if torch.cuda.is_available() else model logging.info(f"Loaded model {model_name}") return model, tokenizer def summarize(ids, mask, model, tokenizer, **kwargs): """ summarize - given a batch of ids and a mask, returns a summary and the token length of the output summary Args: ids (): the batch of ids mask (): the attention mask for the batch model (): the model to use for summarization tokenizer (): the tokenizer to use for summarization Returns: str: the summary of the batch """ ids = ids[None, :] mask = mask[None, :] input_ids = ids.to("cuda") if torch.cuda.is_available() else ids attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask #global_attention_mask = torch.zeros_like(attention_mask) # put global attention on token #global_attention_mask[:, 0] = 1 summary_pred_ids = model.generate( input_ids, attention_mask=attention_mask, #global_attention_mask=global_attention_mask, return_dict_in_generate=True, **kwargs, ) summary = tokenizer.batch_decode( summary_pred_ids.sequences, skip_special_tokens=True, remove_invalid_values=True, ) len_res = len(summary_pred_ids.sequences.cpu().numpy()[0]) return summary, len_res def summarize_via_tokenbatches( input_text: str, model, tokenizer, batch_length=2048, batch_stride=16, **kwargs, ): """ summarize_via_tokenbatches - a function that takes a string and returns a summary Args: input_text (str): the text to summarize model (): the model to use for summarization tokenizer (): the tokenizer to use for summarization batch_length (int, optional): the length of each batch. Defaults to 2048. batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches. Returns: str: the summary """ # log all input parameters if batch_length < 512: batch_length = 512 print("WARNING: batch_length was set to 512") print( f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}" ) encoded_input = tokenizer( input_text, padding="max_length", truncation=True, max_length=batch_length, stride=batch_stride, return_overflowing_tokens=True, add_special_tokens=False, return_tensors="pt", ) in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask gen_summaries = [] pbar = tqdm(total=len(in_id_arr)) for _id, _mask in zip(in_id_arr, att_arr): result, l = summarize( ids=_id, mask=_mask, model=model, tokenizer=tokenizer, **kwargs, ) rate = round(float((len(_id)-l)/len(_id)),3) _sum = { "input_tokens": _id, "summary": result, "compression_rate": rate, } gen_summaries.append(_sum) print(f"\t{result[0]}\nCompression:\t{rate}") pbar.update() pbar.close() return gen_summaries