File size: 2,749 Bytes
66a3123 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import re
import numpy as np
class Formatter():
def __init__(self, replace_tokens, unused_type='[unusedi]'):
self.dict_token_replace = {k: ' ' + unused_type.replace('i', str(i + 1)) + ' ' for i, k in
enumerate(replace_tokens)}
def format(self, path, pattern):
lines = []
re_line = re.compile(pattern)
with open(path, 'r') as f:
for match in re_line.finditer(''.join(f.readlines())):
line = match[0]
# Replace
for k, v in self.dict_token_replace.items():
line = line.replace(k, v)
lines.append(line)
return lines
def unformat(self, sentences):
unformatted_sentences = []
for sent in sentences:
# Replace
for k, v in self.dict_token_replace.items():
sent = sent.replace(v.strip(), k)
unformatted_sentences.append(sent)
return unformatted_sentences
class Encoder():
def __init__(self, tokenizer):
self.set_tokenizer(tokenizer)
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
def set_model(self, model):
self.model = model
def encode(self, lines):
encoded_dict = self.tokenizer.batch_encode_plus(
lines, # Sentence to encode.
padding=True,
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
return_attention_mask=True, # Construct attn. masks.
return_tensors='pt', # Return pytorch tensors.
)
return encoded_dict
class LabelEncoder(Encoder):
def __init__(self, model, tokenizer, classes=[], num_tokens_per_class=3):
super().__init__(tokenizer)
self.set_model(model)
# Preparing special tokens related to labels.
# Each token is of the type 'cls-k' where cls is a class in classes and
# k is an integer value in range(0, num_tokens_per_class)
self.num_tokens_per_class = num_tokens_per_class
self.label_special_tokens_dict = {cls: [f'[{cls}-{i}]' for i in range(num_tokens_per_class)] for cls in classes}
self.label_special_tokens_list = np.concatenate([list(x) for x in self.label_special_tokens_dict.values()]).tolist()
# Addd special tokens and replace vocabulary
self.tokenizer.add_special_tokens({'additional_special_tokens': self.label_special_tokens_list})
self.model.resize_token_embeddings(len(self.tokenizer))
def encode(self, lines, labels):
labeled_lines = [' '.join(self.label_special_tokens_dict[label]) + ' ' + line for line, label in zip(lines, labels)]
return super().encode(labeled_lines)
|