|
from dataclasses import dataclass |
|
import math |
|
from operator import itemgetter |
|
import torch |
|
from torch import nn |
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
|
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast |
|
from transformers.models.bert.modeling_bert import BertOnlyMLMHead |
|
from transformers.utils import ModelOutput |
|
from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits |
|
from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking |
|
from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits |
|
|
|
import warnings |
|
|
|
@dataclass |
|
class JointParsingOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
|
|
logits: Optional[Union[SyntaxLogitsOutput, None]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
syntax_logits: Optional[SyntaxLogitsOutput] = None |
|
ner_logits: Optional[torch.FloatTensor] = None |
|
prefix_logits: Optional[torch.FloatTensor] = None |
|
lex_logits: Optional[torch.FloatTensor] = None |
|
morph_logits: Optional[MorphLogitsOutput] = None |
|
|
|
|
|
|
|
class ModuleRef: |
|
def __init__(self, module: torch.nn.Module): |
|
self.module = module |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.module.forward(*args, **kwargs) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.module(*args, **kwargs) |
|
|
|
class BertForJointParsing(BertPreTrainedModel): |
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
|
|
def __init__(self, config, do_syntax=None, do_ner=None, do_prefix=None, do_lex=None, do_morph=None, syntax_head_size=64): |
|
super().__init__(config) |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
self.syntax, self.ner, self.prefix, self.lex, self.morph = (None,)*5 |
|
|
|
if do_syntax is not None: |
|
config.do_syntax = do_syntax |
|
config.syntax_head_size = syntax_head_size |
|
if do_ner is not None: config.do_ner = do_ner |
|
if do_prefix is not None: config.do_prefix = do_prefix |
|
if do_lex is not None: config.do_lex = do_lex |
|
if do_morph is not None: config.do_morph = do_morph |
|
|
|
|
|
if config.do_syntax: |
|
self.syntax = BertSyntaxParsingHead(config) |
|
if config.do_ner: |
|
self.num_labels = config.num_labels |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
self.ner = ModuleRef(self.classifier) |
|
if config.do_prefix: |
|
self.prefix = BertPrefixMarkingHead(config) |
|
if config.do_lex: |
|
self.cls = BertOnlyMLMHead(config) |
|
self.lex = ModuleRef(self.cls) |
|
if config.do_morph: |
|
self.morph = BertMorphTaggingHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.decoder if self.lex is not None else None |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
if self.lex is not None: |
|
self.cls.predictions.decoder = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
prefix_class_id_options: Optional[torch.Tensor] = None, |
|
labels: Optional[Union[SyntaxLabels, MorphLabels, torch.Tensor]] = None, |
|
labels_type: Optional[Literal['syntax', 'ner', 'prefix', 'lex', 'morph']] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
compute_syntax_mst: Optional[bool] = None |
|
): |
|
if return_dict is False: |
|
warnings.warn("Specified `return_dict=False` but the flag is ignored and treated as always True in this model.") |
|
|
|
if labels is not None and labels_type is None: |
|
raise ValueError("Cannot specify labels without labels_type") |
|
|
|
if labels_type == 'seg' and prefix_class_id_options is None: |
|
raise ValueError('Cannot calculate prefix logits without prefix_class_id_options') |
|
|
|
if compute_syntax_mst is not None and self.syntax is None: |
|
raise ValueError("Cannot compute syntax MST when the syntax head isn't loaded") |
|
|
|
|
|
bert_outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
) |
|
|
|
|
|
extended_attention_mask = None |
|
if attention_mask is not None: |
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size()) |
|
|
|
|
|
hidden_states = self.dropout(bert_outputs[0]) |
|
|
|
logits = None |
|
syntax_logits = None |
|
ner_logits = None |
|
prefix_logits = None |
|
lex_logits = None |
|
morph_logits = None |
|
|
|
|
|
if self.syntax is not None and (labels is None or labels_type == 'syntax'): |
|
|
|
loss, syntax_logits = self.syntax(hidden_states, extended_attention_mask, labels, compute_syntax_mst) |
|
logits = syntax_logits |
|
|
|
|
|
if self.ner is not None and (labels is None or labels_type == 'ner'): |
|
ner_logits = self.ner(hidden_states) |
|
logits = ner_logits |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
if self.prefix is not None and (labels is None or labels_type == 'prefix'): |
|
loss, prefix_logits = self.prefix(hidden_states, prefix_class_id_options, labels) |
|
logits = prefix_logits |
|
|
|
|
|
if self.lex is not None and (labels is None or labels_type == 'lex'): |
|
lex_logits = self.lex(hidden_states) |
|
logits = lex_logits |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(lex_logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if self.morph is not None and (labels is None or labels_type == 'morph'): |
|
loss, morph_logits = self.morph(hidden_states, labels) |
|
logits = morph_logits |
|
|
|
|
|
if labels is None: logits = None |
|
|
|
return JointParsingOutput( |
|
loss, |
|
logits, |
|
hidden_states=bert_outputs.hidden_states, |
|
attentions=bert_outputs.attentions, |
|
|
|
syntax_logits=syntax_logits, |
|
ner_logits=ner_logits, |
|
prefix_logits=prefix_logits, |
|
lex_logits=lex_logits, |
|
morph_logits=morph_logits |
|
) |
|
|
|
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False): |
|
is_single_sentence = isinstance(sentences, str) |
|
if is_single_sentence: |
|
sentences = [sentences] |
|
|
|
|
|
if self.prefix is not None: |
|
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding) |
|
else: |
|
inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt') |
|
|
|
|
|
inputs = {k:v.to(self.device) for k,v in inputs.items()} |
|
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst) |
|
|
|
final_output = [dict(text=sentence, tokens=[dict(token=t) for t in combine_token_wordpieces(ids, tokenizer)]) for sentence, ids in zip(sentences, inputs['input_ids'])] |
|
|
|
if output.syntax_logits is not None: |
|
for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)): |
|
merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax') |
|
final_output[sent_idx]['root_idx'] = parsed['root_idx'] |
|
|
|
|
|
if output.prefix_logits is not None: |
|
for sent_idx,parsed in enumerate(prefix_parse_logits(inputs, sentences, tokenizer, output.prefix_logits)): |
|
merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg') |
|
|
|
|
|
if output.lex_logits is not None: |
|
for sent_idx, parsed in enumerate(lex_parse_logits(inputs, sentences, tokenizer, output.lex_logits)): |
|
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex') |
|
|
|
|
|
if output.morph_logits is not None: |
|
for sent_idx,parsed in enumerate(morph_parse_logits(inputs, sentences, tokenizer, output.morph_logits)): |
|
merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph') |
|
|
|
|
|
if output.ner_logits is not None: |
|
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)): |
|
if per_token_ner: |
|
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner') |
|
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed) |
|
|
|
if is_single_sentence: |
|
final_output = final_output[0] |
|
return final_output |
|
|
|
def aggregate_ner_tokens(predictions): |
|
entities = [] |
|
prev = None |
|
for word,pred in predictions: |
|
|
|
if pred == 'O': prev = None |
|
|
|
elif pred.startswith('B-') or pred[2:] != prev: |
|
prev = pred[2:] |
|
entities.append(([word], prev)) |
|
else: entities[-1][0].append(word) |
|
|
|
return [dict(phrase=' '.join(words), label=label) for words,label in entities] |
|
|
|
|
|
def merge_token_list(src, update, key): |
|
for token_src, token_update in zip(src, update): |
|
token_src[key] = token_update |
|
|
|
def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFast): |
|
ret = [] |
|
for token in tokenizer.convert_ids_to_tokens(input_ids): |
|
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue |
|
if token.startswith('##'): |
|
ret[-1] += token[2:] |
|
else: ret.append(token) |
|
return ret |
|
|
|
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]): |
|
input_ids = inputs['input_ids'] |
|
|
|
predictions = torch.argmax(logits, dim=-1) |
|
batch_ret = [] |
|
for batch_idx in range(len(sentences)): |
|
ret = [] |
|
batch_ret.append(ret) |
|
for tok_idx in range(input_ids.shape[1]): |
|
token_id = input_ids[batch_idx, tok_idx] |
|
|
|
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue |
|
|
|
token = tokenizer._convert_id_to_token(token_id) |
|
|
|
if token.startswith('##'): |
|
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1]) |
|
continue |
|
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()])) |
|
return batch_ret |
|
|
|
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor): |
|
input_ids = inputs['input_ids'] |
|
|
|
predictions = torch.argmax(logits, dim=-1) |
|
batch_ret = [] |
|
for batch_idx in range(len(sentences)): |
|
ret = [] |
|
batch_ret.append(ret) |
|
for tok_idx in range(input_ids.shape[1]): |
|
token_id = input_ids[batch_idx, tok_idx] |
|
|
|
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue |
|
|
|
token = tokenizer._convert_id_to_token(token_id) |
|
|
|
if token.startswith('##'): |
|
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1]) |
|
continue |
|
ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx]))) |
|
return batch_ret |