|
import re |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Formatter(): |
|
|
|
def __init__(self, replace_tokens, unused_type='[unusedi]'): |
|
self.dict_token_replace = {k: ' ' + unused_type.replace('i', str(i + 1)) + ' ' for i, k in |
|
enumerate(replace_tokens)} |
|
|
|
def format(self, path, pattern): |
|
lines = [] |
|
|
|
re_line = re.compile(pattern) |
|
with open(path, 'r') as f: |
|
for match in re_line.finditer(''.join(f.readlines())): |
|
line = match[0] |
|
|
|
|
|
for k, v in self.dict_token_replace.items(): |
|
line = line.replace(k, v) |
|
|
|
lines.append(line) |
|
|
|
return lines |
|
|
|
def unformat(self, sentences): |
|
unformatted_sentences = [] |
|
for sent in sentences: |
|
|
|
for k, v in self.dict_token_replace.items(): |
|
sent = sent.replace(v.strip(), k) |
|
|
|
unformatted_sentences.append(sent) |
|
|
|
return unformatted_sentences |
|
|
|
class Encoder(): |
|
def __init__(self, tokenizer): |
|
self.set_tokenizer(tokenizer) |
|
|
|
def set_tokenizer(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
|
|
def set_model(self, model): |
|
self.model = model |
|
|
|
def encode(self, lines): |
|
|
|
encoded_dict = self.tokenizer.batch_encode_plus( |
|
lines, |
|
padding=True, |
|
add_special_tokens=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
return encoded_dict |
|
|
|
|
|
|
|
|
|
class LabelEncoder(Encoder): |
|
def __init__(self, model, tokenizer, classes=[], num_tokens_per_class=3): |
|
|
|
super().__init__(tokenizer) |
|
self.set_model(model) |
|
|
|
|
|
|
|
|
|
self.num_tokens_per_class = num_tokens_per_class |
|
self.label_special_tokens_dict = {cls: [f'[{cls}-{i}]' for i in range(num_tokens_per_class)] for cls in classes} |
|
self.label_special_tokens_list = np.concatenate([list(x) for x in self.label_special_tokens_dict.values()]).tolist() |
|
|
|
|
|
|
|
self.tokenizer.add_special_tokens({'additional_special_tokens': self.label_special_tokens_list}) |
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
|
|
def encode(self, lines, labels): |
|
labeled_lines = [' '.join(self.label_special_tokens_dict[label]) + ' ' + line for line, label in zip(lines, labels)] |
|
return super().encode(labeled_lines) |
|
|