exbert / server /model_api.py
bhoov's picture
whoops
4a48eda
from typing import List, Union, Tuple
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelWithLMHead, AutoModel
from transformer_formatter import TransformerOutputFormatter
from utils.f import delegates, pick, memoize
@memoize
def get_details(mname):
return ModelDetails(mname)
def get_model_tok(mname):
conf = AutoConfig.from_pretrained(mname, output_attentions=True, output_past=False)
tok = AutoTokenizer.from_pretrained(mname, config=conf)
model = AutoModelWithLMHead.from_pretrained(mname, config=conf)
return model, tok
class ModelDetails:
"""Wraps a transformer model and tokenizer to prepare inputs to the frontend visualization"""
def __init__(self, mname):
self.mname = mname
self.model, self.tok = get_model_tok(self.mname)
self.model.eval()
self.config = self.model.config
def from_sentence(self, sentence: str) -> TransformerOutputFormatter:
"""Get attentions and word probabilities from a sentence. Special tokens are automatically added if a sentence is passed.
Args:
sentence: The input sentence to tokenize and analyze.
"""
tokens = self.tok.tokenize(sentence)
return self.from_tokens(tokens, sentence, add_special_tokens=True)
def from_tokens(
self, tokens: List[str], orig_sentence:str, add_special_tokens:bool=False, mask_attentions:bool=False, topk:int=5
) -> TransformerOutputFormatter:
"""Get formatted attention and predictions from a list of tokens.
Args:
tokens: Tokens to analyze
orig_sentence: The sentence the tokens came from (needed to help organize the output)
add_special_tokens: Whether to add special tokens like CLS / <|endoftext|> to the tokens.
If False, assume the tokens already have the special tokens
mask_attentions: If True, do not pay attention to attention patterns to special tokens through the model.
topk: How many top predictions to report
"""
ids = self.tok.convert_tokens_to_ids(tokens)
# For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM
bost = self.tok.bos_token_id
clst = self.tok.cls_token_id
sept = self.tok.sep_token_id
if (bost is not None) and (bost != clst)and add_special_tokens:
ids.insert(0, bost)
inputs = self.tok.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt")
parsed_input = self.parse_inputs(inputs, mask_attentions=mask_attentions)
output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask'])
logits, atts = self.choose_logits_att(output)
words, probs = self.logits2words(logits, topk)
tokens = self.view_ids(inputs["input_ids"])
formatted_output = TransformerOutputFormatter(
orig_sentence,
tokens,
inputs["special_tokens_mask"],
atts,
words,
probs.tolist(),
self.config
)
return formatted_output
def choose_logits_att(self, out:Tuple) -> Tuple:
"""Select from the model's output the logits and the attentions, switching on model name
Args:
out: Output from the model's forward pass
Returns:
(logits: tensor((bs, N)), attentions: Tuple[tensor(())])
"""
if 't5' in self.mname:
logits, _, atts = out
else:
logits, atts = out
return logits, atts
def logits2words(self, logits, topk):
"""Convert logit probabilities into words from the tokenizer's vocabulary.
"""
probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk)
words = [self.tok.convert_ids_to_tokens(i) for i in idxs]
return words, probs
def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]:
"""View what the tokenizer thinks certain ids are for a single input"""
if type(ids) == torch.Tensor:
# Remove batch dimension
ids = ids.squeeze(0).tolist()
out = self.tok.convert_ids_to_tokens(ids)
return out
def parse_inputs(self, inputs, mask_attentions=False):
"""Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens
Args:
- inputs: The output of `tokenizer.prepare_for_model`.
A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'}
- mask_attentions: Flag indicating whether to mask the attentions or not
Returns:
Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
Usage:
```
s = "test sentence"
# from raw sentence to tokens
tokens = tokenizer.tokenize(s)
# From tokens to ids
ids = tokenizer.convert_tokens_to_ids(tokens)
# From ids to input
inputs = tokenizer.prepare_for_model(ids, return_tensors='pt')
# Parse the input. Optionally mask the special tokens from the analysis.
parsed_input = parse_inputs(inputs)
# Run the model, pick from this output whatever inputs you want
from utils.f import pick
out = model(**pick(['input_ids'], parse_inputs(inputs)))
```
"""
out = inputs.copy()
# DEFINE SPECIAL TOKENS MASK
if "special_tokens_mask" not in inputs.keys():
special_tokens = set([self.tok.unk_token_id, self.tok.cls_token_id, self.tok.sep_token_id, self.tok.bos_token_id, self.tok.eos_token_id, self.tok.pad_token_id])
in_ids = inputs['input_ids'][0]
special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids]
inputs['special_tokens_mask'] = special_tok_mask
if mask_attentions:
out["attention_mask"] = torch.tensor(
[int(not i) for i in inputs.get("special_tokens_mask")]
).unsqueeze(0)
else:
out["attention_mask"] = torch.tensor(
[1 for i in inputs.get("special_tokens_mask")]
).unsqueeze(0)
return out