import os import streamlit as st import torch from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, ) device = torch.cuda.device_count() - 1 TRANSLATION_NL_TO_EN = "translation_en_to_nl" @st.cache(suppress_st_warning=True, allow_output_mutation=True) def load_model(model_name, task): os.environ["TOKENIZERS_PARALLELISM"] = "false" try: if not os.path.exists(".streamlit/secrets.toml"): raise FileNotFoundError access_token = st.secrets.get("babel") except FileNotFoundError: access_token = os.environ.get("HF_ACCESS_TOKEN", None) tokenizer = AutoTokenizer.from_pretrained( model_name, from_flax=True, use_auth_token=access_token ) if tokenizer.pad_token is None: print("Adding pad_token to the tokenizer") tokenizer.pad_token = tokenizer.eos_token auto_model_class = ( AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM ) model = auto_model_class.from_pretrained( model_name, from_flax=True, use_auth_token=access_token ) if device != -1: model.to(f"cuda:{device}") return tokenizer, model class Generator: def __init__(self, model_name, task, desc): self.model_name = model_name self.task = task self.desc = desc self.tokenizer = None self.model = None self.prefix = "" self.load() def load(self): if not self.model: print(f"Loading model {self.model_name}") self.tokenizer, self.model = load_model(self.model_name, self.task) try: if self.task in self.model.config.task_specific_params: task_specific_params = self.model.config.task_specific_params[ self.task ] if "prefix" in task_specific_params: self.prefix = task_specific_params["prefix"] except TypeError: pass def generate(self, text: str, **generate_kwargs) -> str: # # import pydevd_pycharm # pydevd_pycharm.settrace('10.1.0.144', port=12345, stdoutToServer=True, stderrToServer=True) # batch_encoded = self.tokenizer( self.prefix + text, max_length=generate_kwargs["max_length"], padding=False, truncation=False, return_tensors="pt", ) if device != -1: batch_encoded.to(f"cuda:{device}") logits = self.model.generate( batch_encoded["input_ids"], attention_mask=batch_encoded["attention_mask"], **generate_kwargs, ) decoded_preds = self.tokenizer.batch_decode( logits.cpu().numpy(), skip_special_tokens=False ) decoded_preds = [ pred.replace(" ", "").replace("", "").replace("", "") for pred in decoded_preds ] return decoded_preds # return self.pipeline(text, **generate_kwargs) def __str__(self): return self.desc class GeneratorFactory: def __init__(self, generator_list): self.generators = [] for g in generator_list: with st.spinner(text=f"Loading the model {g['desc']} ..."): self.add_generator(**g) def add_generator(self, model_name, task, desc): # If the generator is not yet present, add it if not self.get_generator(model_name=model_name, task=task, desc=desc): g = Generator(model_name, task, desc) g.load() self.generators.append(g) def get_generator(self, **kwargs): for g in self.generators: if all([g.__dict__.get(k) == v for k, v in kwargs.items()]): return g return None def __iter__(self): return iter(self.generators) def gpt_descs(self): return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN]