p1atdev commited on
Commit
0df470a
1 Parent(s): 5af8acc

Upload tokenization_dart.py

Browse files
Files changed (1) hide show
  1. tokenization_dart.py +54 -60
tokenization_dart.py CHANGED
@@ -1,26 +1,60 @@
1
  import logging
2
- import os
3
  import json
4
- from typing import Optional, Dict, List, Tuple, Union
5
  from pydantic.dataclasses import dataclass
6
 
7
- import numpy as np
8
- from numpy.typing import NDArray
9
-
10
  from transformers import PreTrainedTokenizerFast
11
  from tokenizers.decoders import Decoder
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
- VOCAB_FILES_NAMES = {
16
- "category_config": "category_config.json",
17
- }
18
 
19
- PRETRAINED_VOCAB_FILES_MAP = {
20
- "category_config": {
21
- "p1atdev/dart-tokenizer-v1": "https://huggingface.co/p1atdev/dart-tokenizer-v1/resolve/main/tag_category.json"
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  @dataclass
@@ -71,57 +105,17 @@ class DartDecoder:
71
  class DartTokenizer(PreTrainedTokenizerFast):
72
  """Dart tokenizer"""
73
 
74
- vocab_files_names = VOCAB_FILES_NAMES
75
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
76
-
77
- def __init__(self, category_config, **kwargs):
78
  super().__init__(**kwargs)
79
 
80
  self._tokenizer.decoder = Decoder.custom( # type: ignore
81
  DartDecoder(list(self.get_added_vocab().keys()))
82
  )
83
 
84
- self.category_config = load_tag_category_config(category_config)
85
-
86
- self._id_to_category_map = np.zeros(self.vocab_size).astype("uint8")
87
- for (
88
- category_id,
89
- tokens,
90
- ) in self.category_config.category_to_token_ids.items():
91
- self._id_to_category_map[tokens] = int(category_id)
92
-
93
- def create_vocab_mask(self, value: int = 1):
94
- """Create an array of vocab size filled with specified value"""
95
- return np.full(self.vocab_size, value).astype("uint8")
96
-
97
- def get_token_ids_in_category(self, category_id: Union[int, str]):
98
- """Get token ids in the specified category"""
99
- return self.category_config.category_to_token_ids[str(category_id)]
100
-
101
- def get_category(self, category_id: Union[int, str]):
102
- """Get the specified category config"""
103
- return self.category_config.categories[str(category_id)]
104
-
105
- def convert_ids_to_category_ids(self, token_ids: Union[int, List[int]]):
106
- """Get the category ids of specified tokens"""
107
- return self._id_to_category_map[token_ids]
108
-
109
- def get_banned_tokens_mask(self, tokens: Union[str, List[str], int, List[int]]):
110
- if isinstance(tokens, str):
111
- tokens = [tokens]
112
- elif isinstance(tokens, int):
113
- tokens = [tokens]
114
- elif isinstance(tokens, list):
115
- tokens = [ # type: ignore
116
- self.convert_tokens_to_ids(token) if isinstance(token, str) else token
117
- for token in tokens
118
- ]
119
-
120
- assert isinstance(tokens, list) and all(
121
- [isinstance(token, int) for token in tokens]
122
- )
123
-
124
- mask = self.create_vocab_mask(value=1)
125
- mask[tokens] = 0
126
 
127
- return mask
 
1
  import logging
 
2
  import json
3
+ from typing import Dict, List
4
  from pydantic.dataclasses import dataclass
5
 
 
 
 
6
  from transformers import PreTrainedTokenizerFast
7
  from tokenizers.decoders import Decoder
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
11
 
12
+ # fmt: off
13
+ # https://huggingface.co/docs/transformers/main/en/chat_templating
14
+ PROMPT_TEMPLATE = (
15
+ "{{ '<|bos|>' }}"
16
+
17
+ "{{ '<rating>' }}"
18
+ "{% if 'rating' not in messages or messages['rating'] is none %}"
19
+ "{{ 'rating:sfw, rating:general' }}"
20
+ "{% else %}"
21
+ "{{ messages['rating'] }}"
22
+ "{% endif %}"
23
+ "{{ '</rating>' }}"
24
+
25
+ "{{ '<copyright>' }}"
26
+ "{% if 'copyright' not in messages or messages['copyright'] is none %}"
27
+ "{{ '' }}"
28
+ "{% else %}"
29
+ "{{ messages['copyright'] }}"
30
+ "{% endif %}"
31
+ "{{ '</copyright>' }}"
32
+
33
+ "{{ '<character>' }}"
34
+ "{% if 'character' not in messages or messages['character'] is none %}"
35
+ "{{ '' }}"
36
+ "{% else %}"
37
+ "{{ messages['character'] }}"
38
+ "{% endif %}"
39
+ "{{ '</character>' }}"
40
+
41
+ "{{ '<general>' }}"
42
+ # length token
43
+ "{% if 'length' not in messages or messages['length'] is none %}"
44
+ "{{ '<|long|>' }}"
45
+ "{% else %}"
46
+ "{{ messages['length'] }}"
47
+ "{% endif %}"
48
+
49
+ # general token
50
+ "{% if 'general' not in messages or messages['general'] is none %}"
51
+ "{{ '' }}"
52
+ "{% else %}"
53
+ "{{ messages['general'] }}"
54
+ "{% endif %}"
55
+ "{{ '<|input_end|>' }}"
56
+ ).strip()
57
+ # fmt: on
58
 
59
 
60
  @dataclass
 
105
  class DartTokenizer(PreTrainedTokenizerFast):
106
  """Dart tokenizer"""
107
 
108
+ def __init__(self, **kwargs):
 
 
 
109
  super().__init__(**kwargs)
110
 
111
  self._tokenizer.decoder = Decoder.custom( # type: ignore
112
  DartDecoder(list(self.get_added_vocab().keys()))
113
  )
114
 
115
+ @property
116
+ def default_chat_template(self):
117
+ """
118
+ Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
119
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ return PROMPT_TEMPLATE