File size: 4,985 Bytes
2b867e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Tokenization classes for ProteinGLM."""

import os
from typing import List, Optional, Union, Dict, Any
from torch import TensorType
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding

VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}


def load_vocab_file(vocab_file: str) -> List[str]:
    with open(vocab_file, "r") as f:
        lines = f.read().splitlines()
        return [line.strip() for line in lines]


class ProteinGLMTokenizer(PreTrainedTokenizer):
    """
    Constructs a ProteinGLM tokenizer.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    model_input_names = ["input_ids", "attention_mask", "position_ids"]
    def __init__(
        self,
        vocab_file: str,
        unk_token: str = "<unk>",
        pad_token: str = "<pad>",
        mask_token: str = "<mask>",
        eos_token: str = "<eos>",
        model_max_length: int = 2048,
        additional_special_tokens: Optional[List[str]] = None,
        **kwargs,
    ):
        self.all_tokens = load_vocab_file(vocab_file)
        self._id_to_token = dict(enumerate(self.all_tokens))
        self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}

        if additional_special_tokens is None:
            additional_special_tokens = ['<pad>', '<mask>', '<gmask>', '<smask>', '<eod>', '<sop>', '<eop>', '<eos>', '<unk>']

        super().__init__(
            unk_token=unk_token,
            pad_token=pad_token,
            mask_token=mask_token,
            eos_token=eos_token,
            model_max_length=model_max_length,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

        self.unique_no_split_tokens = self.all_tokens
        self._update_trie(self.unique_no_split_tokens)

    def _convert_id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, self.unk_token)

    def _convert_token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

    def _tokenize(self, text: str, **kwargs) -> List[str]:
        return text.split()

    def get_vocab(self) -> dict:
        base_vocab = self._token_to_id.copy()
        base_vocab.update(self.added_tokens_encoder)
        return base_vocab

    def token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

    def id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, self.unk_token)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.eos_token_id]  
        if token_ids_1 is None:
            if self.eos_token_id is None:
                return token_ids_0
            else:
                return token_ids_0 + sep
        elif self.eos_token_id is None:
            raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
        return token_ids_0 + sep + token_ids_1 + sep  # Multiple inputs always have an EOS token


    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model")
        with open(vocab_file, "w") as f:
            f.write("\n".join(self.all_tokens))
        return (vocab_file,)

    @property
    def vocab_size(self) -> int:
        return len(self.all_tokens)

    def apply_chat_template(
        self, 
        query, 
        add_generation_prompt: bool = True, 
        tokenize: bool = True, 
        padding: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_dict: bool = False,
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
        add_special_tokens: bool = True,
        **kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:

        generation_prompt = "<gmask><sop><eos>"
        if isinstance(query, str):
            query = [query]
        prompt_query = []
        if add_generation_prompt:
            for each in query:
                assert isinstance(each, str)
                prompt_query.append(generation_prompt+each)
        else:
            prompt_query = query
        if tokenize:
            output = self.batch_encode_plus(
                prompt_query,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
                is_split_into_words=True,
                add_special_tokens=False
            )
            if return_dict:
                return output
            else:
                return output["input_ids"]
        else:
            return prompt_query