File size: 3,472 Bytes
8b35dd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import re
import unicodedata
TRANS_TABLE = dict([(ord(x), ord(y)) for x, y in zip(u"‘’´“”—–-", u"'''\"\"---")])
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def _handle_single_quote(tokens):
line = ' '.join(tokens)
line = re.sub(r"' ([smdSMDtT])\b", r"'\1", line)
line = re.sub(r"' ll\b", "'ll", line)
line = re.sub(r"' re\b", "'re", line)
line = re.sub(r"' ve\b", "'ve", line)
line = re.sub(r"' LL\b", "'LL ", line)
line = re.sub(r"' RE\b", "'RE ", line)
line = re.sub(r"' VE\b", "'VE ", line)
return line.split()
def _split_on_cont_punc(tokens):
new_tokens = []
for token in tokens:
if len(token) > 1:
last_j = 0
pre_is_punc = _is_punctuation(token[0])
for j, ch in enumerate(token):
is_punc = _is_punctuation(ch)
if is_punc != pre_is_punc:
new_tokens.append(token[last_j: j])
last_j = j
pre_is_punc = is_punc
if last_j < len(token):
new_tokens.append(token[last_j:])
else:
new_tokens.append(token)
return new_tokens
def _split_pre_and_post_punc(tokens):
def pre_punc(token):
last_j = 0
for j in range(1, len(token)):
if not _is_punctuation(token[j]):
last_j = j
break
return token[:last_j], token[last_j:]
def post_punc(token):
last_j = len(token)
for j in range(len(token) - 2, -1, -1):
if not _is_punctuation(token[j]):
last_j = j + 1
break
return token[:last_j], token[last_j:]
new_tokens = []
for token in tokens:
if len(token) > 1 and _is_punctuation(token[0]):
a, b = pre_punc(token)
if a:
new_tokens.append(a)
if b:
if _is_punctuation(b[-1]):
c, d = post_punc(b)
if c:
new_tokens.append(c)
if d:
new_tokens.append(d)
else:
new_tokens.append(b)
elif len(token) > 1 and _is_punctuation(token[-1]):
a, b = post_punc(token)
if a:
new_tokens.append(a)
if b:
new_tokens.append(b)
else:
new_tokens.append(token)
return new_tokens
class GuokeTokenizer(object):
def __init__(self, cfg):
self.cfg = cfg
def encode(self, x: str) -> str:
x = x.strip()
x = x.replace("``", '"').replace("''", '"')
x = x.translate(TRANS_TABLE)
tokens = x.split()
tokens = _split_pre_and_post_punc(tokens)
tokens = _handle_single_quote(tokens)
x = " ".join(tokens)
if self.cfg.lower:
x = x.lower()
return x
def decode(self, x: str) -> str:
raise NotImplementedError()
|