File size: 4,966 Bytes
20fdad1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
"""Character tokenizer for Hugging Face.
"""
from typing import List, Optional, Dict, Sequence, Tuple
from transformers import PreTrainedTokenizer
class CaduceusTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids"]
def __init__(self,
model_max_length: int,
characters: Sequence[str] = ("A", "C", "G", "T", "N"),
complement_map=None,
bos_token="[BOS]",
eos_token="[SEP]",
sep_token="[SEP]",
cls_token="[CLS]",
pad_token="[PAD]",
mask_token="[MASK]",
unk_token="[UNK]",
**kwargs):
"""Character tokenizer for Hugging Face transformers.
Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py
Args:
model_max_length (int): Model maximum sequence length.
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following is a list of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character.
"""
if complement_map is None:
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"}
self.characters = characters
self.model_max_length = model_max_length
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(self.characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
add_prefix_space = kwargs.pop("add_prefix_space", False)
padding_side = kwargs.pop("padding_side", "left")
self._complement_map = {}
for k, v in self._vocab_str_to_int.items():
complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
self._complement_map[self._vocab_str_to_int[k]] = complement_id
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=add_prefix_space,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
@property
def complement_map(self) -> Dict[int, int]:
return self._complement_map
def _tokenize(self, text: str, **kwargs) -> List[str]:
return list(text.upper()) # Convert all base pairs to uppercase
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
# cls = [self.cls_token_id]
result = token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
# Fixed vocabulary with no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
return ()
|