File size: 5,442 Bytes
46ffa30 3f553b1 46ffa30 cdb537e 46ffa30 3f553b1 46ffa30 8175a61 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 8cd0b56 46ffa30 3f553b1 46ffa30 a19a543 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 8cd0b56 3f553b1 46ffa30 3f553b1 a19a543 8cd0b56 a19a543 46ffa30 a19a543 46ffa30 bc21832 46ffa30 3f553b1 46ffa30 3f553b1 46ffa30 bc21832 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import re
import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
device = torch.cuda.device_count() - 1
def get_access_token():
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("babel")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
return access_token
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(
model_name, from_flax=True, use_auth_token=get_access_token()
)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
try:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, use_auth_token=get_access_token()
)
except EnvironmentError:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, from_flax=True, use_auth_token=get_access_token()
)
except EnvironmentError:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, from_tf=True, use_auth_token=get_access_token()
)
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class Generator:
def __init__(self, model_name, task, desc, split_sentences):
self.model_name = model_name
self.task = task
self.desc = desc
self.split_sentences = split_sentences
self.tokenizer = None
self.model = None
self.prefix = ""
self.gen_kwargs = {
"max_length": 128,
"num_beams": 6,
"num_beam_groups": 3,
"no_repeat_ngram_size": 0,
"early_stopping": True,
"num_return_sequences": 1,
"length_penalty": 1.0,
}
self.load()
def load(self):
if not self.model:
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name)
for key in self.gen_kwargs:
if key in self.model.config.__dict__:
self.gen_kwargs[key] = self.model.config.__dict__[key]
try:
if self.task in self.model.config.task_specific_params:
task_specific_params = self.model.config.task_specific_params[
self.task
]
if "prefix" in task_specific_params:
self.prefix = task_specific_params["prefix"]
for key in self.gen_kwargs:
if key in task_specific_params:
self.gen_kwargs[key] = task_specific_params[key]
except TypeError:
pass
def generate(self, text: str, **generate_kwargs) -> (str, dict):
# Replace two or more newlines with a single newline in text
text = re.sub(r"\n{2,}", "\n", text)
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
# if there are newlines in the text, and the model needs line-splitting, split the text and recurse
if re.search(r"\n", text) and self.split_sentences:
lines = text.splitlines()
translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
return "\n".join(translated), generate_kwargs
batch_encoded = self.tokenizer(
self.prefix + text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
decoded_preds = [
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
for pred in decoded_preds
]
return decoded_preds[0], generate_kwargs
def __str__(self):
return self.model_name
class GeneratorFactory:
def __init__(self, generator_list):
self.generators = []
for g in generator_list:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
def add_generator(self, model_name, task, desc, split_sentences):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc, split_sentences)
g.load()
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def __iter__(self):
return iter(self.generators)
def filter(self, **kwargs):
return [
g
for g in self.generators
if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
]
|