import re from sentencepiece import SentencePieceProcessor def replace_spaces_with_blank(match: re.Match[str]): return f"<|blank_{len(match.group())}|>" def replace_blank_with_spaces(match: re.Match[str]): return " " * int(match.group(1)) class ChatGLMTokenizer: def __init__(self, vocab_file): assert vocab_file is not None self.vocab_file = vocab_file self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] self.text_tokenizer = SentencePieceProcessor(str(vocab_file)) def __len__(self): return len(self.text_tokenizer) def __getitem__(self, key: str): return self.text_tokenizer[key] def preprocess(self, text: str, linebreak=True, whitespaces=True): if linebreak: text = text.replace("\n", "") if whitespaces: text = text.replace("\t", "<|tab|>") text = re.sub(r" {2,80}", replace_spaces_with_blank, text) return text def encode( self, text: str, text_pair: str = None, linebreak=True, whitespaces=True, add_dummy_prefix=True, special_tokens=True, ) -> tuple[list[int], list[int]]: """ text: Text to encode. Bidirectional part with a [gMASK] and an for causal LM. text_pair: causal LM part. linebreak: Whether to encode newline (\n) in text. whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. add_dummy_prefix: Whether to add dummy blank space in the beginning. """ text = self.preprocess(text, linebreak, whitespaces) if not add_dummy_prefix: text = "" + text tokens = self.text_tokenizer.encode(text) prefix_mask = [1] * len(tokens) if special_tokens: tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer[""]] prefix_mask += [1, 0] if text_pair is not None: text_pair = self.preprocess(text_pair, linebreak, whitespaces) pair_tokens = self.text_tokenizer.encode(text_pair) tokens += pair_tokens prefix_mask += [0] * len(pair_tokens) if special_tokens: tokens += [self.text_tokenizer[""]] prefix_mask += [0] return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask def decode(self, text_ids: list[int]) -> str: text = self.text_tokenizer.decode(text_ids) text = text.replace("", "\n") text = text.replace("<|tab|>", "\t") text = re.sub(r"<\|blank_(\d\d?)\|>", replace_blank_with_spaces, text) return text