cointegrated
commited on
Commit
•
ee388d7
1
Parent(s):
22a24c6
Create char_tokenizer.py
Browse files- char_tokenizer.py +162 -0
char_tokenizer.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copypasted from
|
3 |
+
https://huggingface.co/IlyaGusev/ru-word-stress-transformer/blob/main/char_tokenizer.py
|
4 |
+
with Apache 2.0 license
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from typing import Optional, Tuple, List
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from transformers import PreTrainedTokenizer, AutoTokenizer
|
13 |
+
|
14 |
+
|
15 |
+
def load_vocab(vocab_file):
|
16 |
+
vocab = OrderedDict()
|
17 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
18 |
+
tokens = reader.readlines()
|
19 |
+
for index, token in enumerate(tokens):
|
20 |
+
token = token.rstrip("\n")
|
21 |
+
vocab[token] = index
|
22 |
+
return vocab
|
23 |
+
|
24 |
+
|
25 |
+
class CharTokenizer(PreTrainedTokenizer):
|
26 |
+
vocab_files_names = {"vocab_file": "vocab.txt"}
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
vocab_file=None,
|
31 |
+
pad_token="[pad]",
|
32 |
+
unk_token="[unk]",
|
33 |
+
bos_token="[bos]",
|
34 |
+
eos_token="[eos]",
|
35 |
+
cls_token="[cls]",
|
36 |
+
sep_token="[sep]",
|
37 |
+
mask_token="[mask]",
|
38 |
+
space_token="▁",
|
39 |
+
do_lower_case=False,
|
40 |
+
*args,
|
41 |
+
**kwargs
|
42 |
+
):
|
43 |
+
if not vocab_file or not os.path.isfile(vocab_file):
|
44 |
+
self.vocab = OrderedDict()
|
45 |
+
self.ids_to_tokens = OrderedDict()
|
46 |
+
else:
|
47 |
+
self.vocab = load_vocab(vocab_file)
|
48 |
+
self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
49 |
+
|
50 |
+
super().__init__(
|
51 |
+
pad_token=pad_token,
|
52 |
+
unk_token=unk_token,
|
53 |
+
bos_token=bos_token,
|
54 |
+
eos_token=eos_token,
|
55 |
+
cls_token=cls_token,
|
56 |
+
mask_token=mask_token,
|
57 |
+
do_lower_case=do_lower_case,
|
58 |
+
**kwargs
|
59 |
+
)
|
60 |
+
self.do_lower_case = do_lower_case
|
61 |
+
self.space_token = space_token
|
62 |
+
|
63 |
+
def train(self, file_path):
|
64 |
+
vocab = set()
|
65 |
+
with open(file_path) as r:
|
66 |
+
for line in r:
|
67 |
+
word = line.strip()
|
68 |
+
if self.do_lower_case:
|
69 |
+
word = word.lower()
|
70 |
+
vocab |= set(word)
|
71 |
+
vocab = list(vocab)
|
72 |
+
vocab.sort()
|
73 |
+
special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
|
74 |
+
vocab = special_tokens + vocab
|
75 |
+
|
76 |
+
for i, ch in enumerate(vocab):
|
77 |
+
self.vocab[ch] = i
|
78 |
+
self.ids_to_tokens = vocab
|
79 |
+
|
80 |
+
@property
|
81 |
+
def vocab_size(self):
|
82 |
+
return len(self.vocab)
|
83 |
+
|
84 |
+
def get_vocab(self):
|
85 |
+
return self.vocab
|
86 |
+
|
87 |
+
def _convert_token_to_id(self, token):
|
88 |
+
if self.do_lower_case:
|
89 |
+
token = token.lower()
|
90 |
+
return self.vocab.get(token, self.vocab[self.unk_token])
|
91 |
+
|
92 |
+
def _convert_id_to_token(self, index):
|
93 |
+
return self.ids_to_tokens[index]
|
94 |
+
|
95 |
+
def prepare_for_tokenization(
|
96 |
+
self, text, is_split_into_words: bool = False, spaces=0, **kwargs
|
97 |
+
):
|
98 |
+
if spaces:
|
99 |
+
pad = self.space_token * spaces
|
100 |
+
text = pad + pad.join(text) + pad
|
101 |
+
return (text, kwargs)
|
102 |
+
|
103 |
+
def _tokenize(self, text, spaces=0):
|
104 |
+
if self.do_lower_case:
|
105 |
+
text = text.lower()
|
106 |
+
return list(text)
|
107 |
+
|
108 |
+
def convert_tokens_to_string(self, tokens):
|
109 |
+
return "".join(tokens)
|
110 |
+
|
111 |
+
def build_inputs_with_special_tokens(
|
112 |
+
self,
|
113 |
+
token_ids_0: List[int],
|
114 |
+
token_ids_1: Optional[List[int]] = None
|
115 |
+
) -> List[int]:
|
116 |
+
bos = [self.bos_token_id]
|
117 |
+
eos = [self.eos_token_id]
|
118 |
+
return bos + token_ids_0 + eos
|
119 |
+
|
120 |
+
def get_special_tokens_mask(
|
121 |
+
self,
|
122 |
+
token_ids_0: List[int],
|
123 |
+
token_ids_1: Optional[List[int]] = None
|
124 |
+
) -> List[int]:
|
125 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
126 |
+
|
127 |
+
def create_token_type_ids_from_sequences(
|
128 |
+
self,
|
129 |
+
token_ids_0: List[int],
|
130 |
+
token_ids_1: Optional[List[int]] = None
|
131 |
+
) -> List[int]:
|
132 |
+
return (len(token_ids_0) + 2) * [0]
|
133 |
+
|
134 |
+
def save_vocabulary(
|
135 |
+
self,
|
136 |
+
save_directory: str,
|
137 |
+
filename_prefix: Optional[str] = None
|
138 |
+
) -> Tuple[str]:
|
139 |
+
assert os.path.isdir(save_directory)
|
140 |
+
vocab_file = os.path.join(
|
141 |
+
save_directory,
|
142 |
+
(filename_prefix + "-" if filename_prefix else "") +
|
143 |
+
self.vocab_files_names["vocab_file"]
|
144 |
+
)
|
145 |
+
index = 0
|
146 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
147 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
148 |
+
assert index == token_index
|
149 |
+
writer.write(token + "\n")
|
150 |
+
index += 1
|
151 |
+
return (vocab_file,)
|
152 |
+
|
153 |
+
def clean_up_tokenization(self, text, space='▁'):
|
154 |
+
res = []
|
155 |
+
prev = space
|
156 |
+
for c in text:
|
157 |
+
if c != prev and c != space:
|
158 |
+
res.append(c)
|
159 |
+
prev = c
|
160 |
+
return ''.join(res)
|
161 |
+
|
162 |
+
AutoTokenizer.register("char_tokenizer", CharTokenizer)
|