import os import re import streamlit as st import torch from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, ) device = torch.cuda.device_count() - 1 def get_access_token(): try: if not os.path.exists(".streamlit/secrets.toml"): raise FileNotFoundError access_token = st.secrets.get("babel") except FileNotFoundError: access_token = os.environ.get("HF_ACCESS_TOKEN", None) return access_token @st.cache(suppress_st_warning=True, allow_output_mutation=True) def load_model(model_name): os.environ["TOKENIZERS_PARALLELISM"] = "false" tokenizer = AutoTokenizer.from_pretrained( model_name, from_flax=True, use_auth_token=get_access_token() ) if tokenizer.pad_token is None: print("Adding pad_token to the tokenizer") tokenizer.pad_token = tokenizer.eos_token try: model = AutoModelForSeq2SeqLM.from_pretrained( model_name, use_auth_token=get_access_token() ) except EnvironmentError: try: model = AutoModelForSeq2SeqLM.from_pretrained( model_name, from_flax=True, use_auth_token=get_access_token() ) except EnvironmentError: model = AutoModelForSeq2SeqLM.from_pretrained( model_name, from_tf=True, use_auth_token=get_access_token() ) if device != -1: model.to(f"cuda:{device}") return tokenizer, model class Generator: def __init__(self, model_name, task, desc, split_sentences): self.model_name = model_name self.task = task self.desc = desc self.split_sentences = split_sentences self.tokenizer = None self.model = None self.prefix = "" self.gen_kwargs = { "max_length": 128, "num_beams": 6, "num_beam_groups": 3, "no_repeat_ngram_size": 0, "early_stopping": True, "num_return_sequences": 1, "length_penalty": 1.0, } self.load() def load(self): if not self.model: print(f"Loading model {self.model_name}") self.tokenizer, self.model = load_model(self.model_name) for key in self.gen_kwargs: if key in self.model.config.__dict__: self.gen_kwargs[key] = self.model.config.__dict__[key] try: if self.task in self.model.config.task_specific_params: task_specific_params = self.model.config.task_specific_params[ self.task ] if "prefix" in task_specific_params: self.prefix = task_specific_params["prefix"] for key in self.gen_kwargs: if key in task_specific_params: self.gen_kwargs[key] = task_specific_params[key] except TypeError: pass def generate(self, text: str, **generate_kwargs) -> (str, dict): # Replace two or more newlines with a single newline in text text = re.sub(r"\n{2,}", "\n", text) generate_kwargs = {**self.gen_kwargs, **generate_kwargs} # if there are newlines in the text, and the model needs line-splitting, split the text and recurse if re.search(r"\n", text) and self.split_sentences: lines = text.splitlines() translated = [self.generate(line, **generate_kwargs)[0] for line in lines] return "\n".join(translated), generate_kwargs batch_encoded = self.tokenizer( self.prefix + text, max_length=generate_kwargs["max_length"], padding=False, truncation=False, return_tensors="pt", ) if device != -1: batch_encoded.to(f"cuda:{device}") logits = self.model.generate( batch_encoded["input_ids"], attention_mask=batch_encoded["attention_mask"], **generate_kwargs, ) decoded_preds = self.tokenizer.batch_decode( logits.cpu().numpy(), skip_special_tokens=False ) decoded_preds = [ pred.replace(" ", "").replace("", "").replace("", "") for pred in decoded_preds ] return decoded_preds[0], generate_kwargs def __str__(self): return self.model_name class GeneratorFactory: def __init__(self, generator_list): self.generators = [] for g in generator_list: with st.spinner(text=f"Loading the model {g['desc']} ..."): self.add_generator(**g) def add_generator(self, model_name, task, desc, split_sentences): # If the generator is not yet present, add it if not self.get_generator(model_name=model_name, task=task, desc=desc): g = Generator(model_name, task, desc, split_sentences) g.load() self.generators.append(g) def get_generator(self, **kwargs): for g in self.generators: if all([g.__dict__.get(k) == v for k, v in kwargs.items()]): return g return None def __iter__(self): return iter(self.generators) def filter(self, **kwargs): return [ g for g in self.generators if all([g.__dict__.get(k) == v for k, v in kwargs.items()]) ]