File size: 5,442 Bytes
46ffa30
3f553b1
 
46ffa30
 
cdb537e
46ffa30
 
 
 
3f553b1
46ffa30
 
 
8175a61
46ffa30
 
3f553b1
 
 
 
 
 
46ffa30
3f553b1
46ffa30
 
 
 
8cd0b56
 
 
 
 
 
 
 
 
 
 
 
 
46ffa30
 
 
 
 
 
3f553b1
46ffa30
 
 
a19a543
46ffa30
 
 
3f553b1
 
 
 
 
 
 
 
 
46ffa30
 
 
 
 
3f553b1
 
 
 
 
46ffa30
 
 
 
 
8cd0b56
 
3f553b1
 
 
46ffa30
 
 
3f553b1
 
 
 
a19a543
8cd0b56
a19a543
 
 
 
 
46ffa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a19a543
46ffa30
 
bc21832
46ffa30
 
 
 
 
 
 
 
 
3f553b1
46ffa30
 
3f553b1
46ffa30
 
 
 
 
 
 
 
 
 
 
 
bc21832
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import re

import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

device = torch.cuda.device_count() - 1


def get_access_token():
    try:
        if not os.path.exists(".streamlit/secrets.toml"):
            raise FileNotFoundError
        access_token = st.secrets.get("babel")
    except FileNotFoundError:
        access_token = os.environ.get("HF_ACCESS_TOKEN", None)
    return access_token


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, from_flax=True, use_auth_token=get_access_token()
    )
    if tokenizer.pad_token is None:
        print("Adding pad_token to the tokenizer")
        tokenizer.pad_token = tokenizer.eos_token
    try:
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, use_auth_token=get_access_token()
        )
    except EnvironmentError:
        try:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name, from_flax=True, use_auth_token=get_access_token()
            )
        except EnvironmentError:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name, from_tf=True, use_auth_token=get_access_token()
            )
    if device != -1:
        model.to(f"cuda:{device}")
    return tokenizer, model


class Generator:
    def __init__(self, model_name, task, desc, split_sentences):
        self.model_name = model_name
        self.task = task
        self.desc = desc
        self.split_sentences = split_sentences
        self.tokenizer = None
        self.model = None
        self.prefix = ""
        self.gen_kwargs = {
            "max_length": 128,
            "num_beams": 6,
            "num_beam_groups": 3,
            "no_repeat_ngram_size": 0,
            "early_stopping": True,
            "num_return_sequences": 1,
            "length_penalty": 1.0,
        }
        self.load()

    def load(self):
        if not self.model:
            print(f"Loading model {self.model_name}")
            self.tokenizer, self.model = load_model(self.model_name)

            for key in self.gen_kwargs:
                if key in self.model.config.__dict__:
                    self.gen_kwargs[key] = self.model.config.__dict__[key]
            try:
                if self.task in self.model.config.task_specific_params:
                    task_specific_params = self.model.config.task_specific_params[
                        self.task
                    ]
                    if "prefix" in task_specific_params:
                        self.prefix = task_specific_params["prefix"]
                    for key in self.gen_kwargs:
                        if key in task_specific_params:
                            self.gen_kwargs[key] = task_specific_params[key]
            except TypeError:
                pass

    def generate(self, text: str, **generate_kwargs) -> (str, dict):
        # Replace two or more newlines with a single newline in text
        text = re.sub(r"\n{2,}", "\n", text)
        generate_kwargs = {**self.gen_kwargs, **generate_kwargs}

        # if there are newlines in the text, and the model needs line-splitting, split the text and recurse
        if re.search(r"\n", text) and self.split_sentences:
            lines = text.splitlines()
            translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
            return "\n".join(translated), generate_kwargs

        batch_encoded = self.tokenizer(
            self.prefix + text,
            max_length=generate_kwargs["max_length"],
            padding=False,
            truncation=False,
            return_tensors="pt",
        )
        if device != -1:
            batch_encoded.to(f"cuda:{device}")
        logits = self.model.generate(
            batch_encoded["input_ids"],
            attention_mask=batch_encoded["attention_mask"],
            **generate_kwargs,
        )
        decoded_preds = self.tokenizer.batch_decode(
            logits.cpu().numpy(), skip_special_tokens=False
        )
        decoded_preds = [
            pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
            for pred in decoded_preds
        ]
        return decoded_preds[0], generate_kwargs

    def __str__(self):
        return self.model_name


class GeneratorFactory:
    def __init__(self, generator_list):
        self.generators = []
        for g in generator_list:
            with st.spinner(text=f"Loading the model {g['desc']} ..."):
                self.add_generator(**g)

    def add_generator(self, model_name, task, desc, split_sentences):
        # If the generator is not yet present, add it
        if not self.get_generator(model_name=model_name, task=task, desc=desc):
            g = Generator(model_name, task, desc, split_sentences)
            g.load()
            self.generators.append(g)

    def get_generator(self, **kwargs):
        for g in self.generators:
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
                return g
        return None

    def __iter__(self):
        return iter(self.generators)

    def filter(self, **kwargs):
        return [
            g
            for g in self.generators
            if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
        ]