YaTharThShaRma999
commited on
Commit
•
b1bbf8b
1
Parent(s):
1f583fd
Create utilf.py
Browse files
utilf.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL.Image import Image
|
6 |
+
from transformers import LlamaTokenizerFast
|
7 |
+
from transformers.processing_utils import ProcessorMixin
|
8 |
+
|
9 |
+
from deepseek_vl.models.image_processing_vlm import VLMImageProcessor
|
10 |
+
from deepseek_vl.utils.conversation import get_conv_template
|
11 |
+
|
12 |
+
|
13 |
+
class DictOutput(object):
|
14 |
+
def keys(self):
|
15 |
+
return self.__dict__.keys()
|
16 |
+
|
17 |
+
def __getitem__(self, item):
|
18 |
+
return self.__dict__[item]
|
19 |
+
|
20 |
+
def __setitem__(self, key, value):
|
21 |
+
self.__dict__[key] = value
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class VLChatProcessorOutput(DictOutput):
|
26 |
+
sft_format: str
|
27 |
+
input_ids: torch.Tensor
|
28 |
+
pixel_values: torch.Tensor
|
29 |
+
num_image_tokens: torch.IntTensor
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.input_ids)
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class BatchedVLChatProcessorOutput(DictOutput):
|
37 |
+
sft_format: List[str]
|
38 |
+
input_ids: torch.Tensor
|
39 |
+
pixel_values: torch.Tensor
|
40 |
+
attention_mask: torch.Tensor
|
41 |
+
images_seq_mask: torch.BoolTensor
|
42 |
+
images_emb_mask: torch.BoolTensor
|
43 |
+
|
44 |
+
def to(self, device, dtype=torch.bfloat16):
|
45 |
+
self.input_ids = self.input_ids.to(device)
|
46 |
+
self.attention_mask = self.attention_mask.to(device)
|
47 |
+
self.images_seq_mask = self.images_seq_mask.to(device)
|
48 |
+
self.images_emb_mask = self.images_emb_mask.to(device)
|
49 |
+
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
|
50 |
+
return self
|
51 |
+
|
52 |
+
|
53 |
+
class VLChatProcessor(ProcessorMixin):
|
54 |
+
image_processor_class = "AutoImageProcessor"
|
55 |
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
56 |
+
|
57 |
+
attributes = ["image_processor", "tokenizer"]
|
58 |
+
|
59 |
+
system_prompt = (
|
60 |
+
"You are a helpful language and vision assistant. "
|
61 |
+
"You are able to understand the visual content that the user provides, "
|
62 |
+
"and assist the user with a variety of tasks using natural language."
|
63 |
+
)
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
image_processor: VLMImageProcessor,
|
68 |
+
tokenizer: LlamaTokenizerFast,
|
69 |
+
image_tag: str = "<image_placeholder>",
|
70 |
+
num_image_tokens: int = 576,
|
71 |
+
add_special_token: bool = False,
|
72 |
+
sft_format: str = "deepseek",
|
73 |
+
mask_prompt: bool = True,
|
74 |
+
ignore_id: int = -100,
|
75 |
+
system="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
76 |
+
**kwargs,
|
77 |
+
):
|
78 |
+
self.system_prompt = system
|
79 |
+
self.image_processor = image_processor
|
80 |
+
self.tokenizer = tokenizer
|
81 |
+
|
82 |
+
image_id = self.tokenizer.vocab.get(image_tag)
|
83 |
+
if image_id is None:
|
84 |
+
special_tokens = [image_tag]
|
85 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
86 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
87 |
+
print(f"Add image tag = {image_tag} to the tokenizer")
|
88 |
+
|
89 |
+
self.image_tag = image_tag
|
90 |
+
self.num_image_tokens = num_image_tokens
|
91 |
+
self.add_special_token = add_special_token
|
92 |
+
self.sft_format = sft_format
|
93 |
+
self.mask_prompt = mask_prompt
|
94 |
+
self.ignore_id = ignore_id
|
95 |
+
|
96 |
+
super().__init__(
|
97 |
+
image_processor,
|
98 |
+
tokenizer,
|
99 |
+
image_tag,
|
100 |
+
num_image_tokens,
|
101 |
+
add_special_token,
|
102 |
+
sft_format,
|
103 |
+
mask_prompt,
|
104 |
+
ignore_id,
|
105 |
+
**kwargs,
|
106 |
+
)
|
107 |
+
|
108 |
+
def new_chat_template(self):
|
109 |
+
conv = get_conv_template(self.sft_format)
|
110 |
+
conv.set_system_message(self.system_prompt)
|
111 |
+
return conv
|
112 |
+
|
113 |
+
def apply_sft_template_for_multi_turn_prompts(
|
114 |
+
self,
|
115 |
+
conversations: List[Dict[str, str]],
|
116 |
+
sft_format: str = "deepseek",
|
117 |
+
system_prompt: str = "",
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
Applies the SFT template to conversation.
|
121 |
+
|
122 |
+
An example of conversation:
|
123 |
+
conversation = [
|
124 |
+
{
|
125 |
+
"role": "User",
|
126 |
+
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
|
127 |
+
"images": [
|
128 |
+
"./multi-images/attribute_comparison_1.png",
|
129 |
+
"./multi-images/attribute_comparison_2.png"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"role": "Assistant",
|
134 |
+
"content": ""
|
135 |
+
}
|
136 |
+
]
|
137 |
+
|
138 |
+
Args:
|
139 |
+
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
|
140 |
+
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
141 |
+
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
sft_prompt (str): The formatted text.
|
145 |
+
"""
|
146 |
+
|
147 |
+
conv = get_conv_template(sft_format)
|
148 |
+
conv.set_system_message(system_prompt)
|
149 |
+
for message in conversations:
|
150 |
+
conv.append_message(message["role"], message["content"].strip())
|
151 |
+
sft_prompt = conv.get_prompt().strip()
|
152 |
+
|
153 |
+
return sft_prompt
|
154 |
+
|
155 |
+
@property
|
156 |
+
def image_token(self):
|
157 |
+
return self.image_tag
|
158 |
+
|
159 |
+
@property
|
160 |
+
def image_id(self):
|
161 |
+
image_id = self.tokenizer.vocab.get(self.image_tag)
|
162 |
+
return image_id
|
163 |
+
|
164 |
+
@property
|
165 |
+
def pad_id(self):
|
166 |
+
pad_id = self.tokenizer.pad_token_id
|
167 |
+
if pad_id is None:
|
168 |
+
pad_id = self.tokenizer.eos_token_id
|
169 |
+
|
170 |
+
return pad_id
|
171 |
+
|
172 |
+
def add_image_token(
|
173 |
+
self,
|
174 |
+
image_indices: List[int],
|
175 |
+
input_ids: torch.LongTensor,
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
|
179 |
+
Args:
|
180 |
+
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
181 |
+
input_ids (torch.LongTensor): [N]
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
input_ids (torch.LongTensor): [N + image tokens]
|
185 |
+
num_image_tokens (torch.IntTensor): [n_images]
|
186 |
+
"""
|
187 |
+
|
188 |
+
input_slices = []
|
189 |
+
|
190 |
+
start = 0
|
191 |
+
for index in image_indices:
|
192 |
+
if self.add_special_token:
|
193 |
+
end = index + 1
|
194 |
+
else:
|
195 |
+
end = index
|
196 |
+
|
197 |
+
# original text tokens
|
198 |
+
input_slices.append(input_ids[start:end])
|
199 |
+
|
200 |
+
# add image tokens, and set the mask as False
|
201 |
+
input_slices.append(
|
202 |
+
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
203 |
+
)
|
204 |
+
start = index + 1
|
205 |
+
|
206 |
+
# the left part
|
207 |
+
input_slices.append(input_ids[start:])
|
208 |
+
|
209 |
+
# concat all slices
|
210 |
+
input_ids = torch.cat(input_slices, dim=0)
|
211 |
+
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
212 |
+
|
213 |
+
return input_ids, num_image_tokens
|
214 |
+
|
215 |
+
def process_one(
|
216 |
+
self,
|
217 |
+
prompt: str = None,
|
218 |
+
conversations: List[Dict[str, str]] = None,
|
219 |
+
images: List[Image] = None,
|
220 |
+
**kwargs,
|
221 |
+
):
|
222 |
+
"""
|
223 |
+
|
224 |
+
Args:
|
225 |
+
prompt (str): the formatted prompt;
|
226 |
+
conversations (List[Dict]): conversations with a list of messages;
|
227 |
+
images (List[ImageType]): the list of images;
|
228 |
+
**kwargs:
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
232 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
233 |
+
- target_ids (torch.LongTensor): [N + image tokens]
|
234 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
235 |
+
- image_id (int): the id of the image token
|
236 |
+
- num_image_tokens (List[int]): the number of image tokens
|
237 |
+
"""
|
238 |
+
|
239 |
+
assert (
|
240 |
+
prompt is None or conversations is None
|
241 |
+
), "prompt and conversations cannot be used at the same time."
|
242 |
+
|
243 |
+
if prompt is None:
|
244 |
+
# apply sft format
|
245 |
+
sft_format = self.apply_sft_template_for_multi_turn_prompts(
|
246 |
+
conversations=conversations,
|
247 |
+
sft_format=self.sft_format,
|
248 |
+
system_prompt=self.system_prompt,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
sft_format = prompt
|
252 |
+
|
253 |
+
# tokenize
|
254 |
+
input_ids = self.tokenizer.encode(sft_format)
|
255 |
+
input_ids = torch.LongTensor(input_ids)
|
256 |
+
|
257 |
+
# add image tokens to the input_ids
|
258 |
+
image_token_mask: torch.BoolTensor = input_ids == self.image_id
|
259 |
+
image_indices = image_token_mask.nonzero()
|
260 |
+
input_ids, num_image_tokens = self.add_image_token(
|
261 |
+
image_indices=image_indices,
|
262 |
+
input_ids=input_ids,
|
263 |
+
)
|
264 |
+
|
265 |
+
# load images
|
266 |
+
images_outputs = self.image_processor(images, return_tensors="pt")
|
267 |
+
|
268 |
+
prepare = VLChatProcessorOutput(
|
269 |
+
sft_format=sft_format,
|
270 |
+
input_ids=input_ids,
|
271 |
+
pixel_values=images_outputs.pixel_values,
|
272 |
+
num_image_tokens=num_image_tokens,
|
273 |
+
)
|
274 |
+
|
275 |
+
return prepare
|
276 |
+
|
277 |
+
def __call__(
|
278 |
+
self,
|
279 |
+
*,
|
280 |
+
prompt: str = None,
|
281 |
+
conversations: List[Dict[str, str]] = None,
|
282 |
+
images: List[Image] = None,
|
283 |
+
force_batchify: bool = True,
|
284 |
+
**kwargs,
|
285 |
+
):
|
286 |
+
"""
|
287 |
+
|
288 |
+
Args:
|
289 |
+
prompt (str): the formatted prompt;
|
290 |
+
conversations (List[Dict]): conversations with a list of messages;
|
291 |
+
images (List[ImageType]): the list of images;
|
292 |
+
force_batchify (bool): force batchify the inputs;
|
293 |
+
**kwargs:
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
297 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
298 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
299 |
+
- image_id (int): the id of the image token
|
300 |
+
- num_image_tokens (List[int]): the number of image tokens
|
301 |
+
"""
|
302 |
+
|
303 |
+
prepare = self.process_one(
|
304 |
+
prompt=prompt, conversations=conversations, images=images
|
305 |
+
)
|
306 |
+
|
307 |
+
if force_batchify:
|
308 |
+
prepare = self.batchify([prepare])
|
309 |
+
|
310 |
+
return prepare
|
311 |
+
|
312 |
+
def batchify(
|
313 |
+
self, prepare_list: List[VLChatProcessorOutput]
|
314 |
+
) -> BatchedVLChatProcessorOutput:
|
315 |
+
"""
|
316 |
+
Preprocesses the inputs for multimodal inference.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
323 |
+
"""
|
324 |
+
|
325 |
+
batch_size = len(prepare_list)
|
326 |
+
sft_format = []
|
327 |
+
n_images = []
|
328 |
+
seq_lens = []
|
329 |
+
for prepare in prepare_list:
|
330 |
+
n_images.append(len(prepare.num_image_tokens))
|
331 |
+
seq_lens.append(len(prepare))
|
332 |
+
|
333 |
+
input_token_max_len = max(seq_lens)
|
334 |
+
max_n_images = max(1, max(n_images))
|
335 |
+
|
336 |
+
batched_input_ids = torch.full(
|
337 |
+
(batch_size, input_token_max_len), self.pad_id
|
338 |
+
).long() # FIXME
|
339 |
+
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
340 |
+
batched_pixel_values = torch.zeros(
|
341 |
+
(batch_size, max_n_images, *self.image_processor.default_shape)
|
342 |
+
).float()
|
343 |
+
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
344 |
+
batched_images_emb_mask = torch.zeros(
|
345 |
+
(batch_size, max_n_images, self.num_image_tokens)
|
346 |
+
).bool()
|
347 |
+
|
348 |
+
for i, prepare in enumerate(prepare_list):
|
349 |
+
input_ids = prepare.input_ids
|
350 |
+
seq_len = len(prepare)
|
351 |
+
n_image = len(prepare.num_image_tokens)
|
352 |
+
# left-padding
|
353 |
+
batched_attention_mask[i, -seq_len:] = 1
|
354 |
+
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
355 |
+
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
356 |
+
|
357 |
+
if n_image > 0:
|
358 |
+
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
359 |
+
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
360 |
+
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
361 |
+
|
362 |
+
sft_format.append(prepare.sft_format)
|
363 |
+
|
364 |
+
batched_prepares = BatchedVLChatProcessorOutput(
|
365 |
+
input_ids=batched_input_ids,
|
366 |
+
attention_mask=batched_attention_mask,
|
367 |
+
pixel_values=batched_pixel_values,
|
368 |
+
images_seq_mask=batched_images_seq_mask,
|
369 |
+
images_emb_mask=batched_images_emb_mask,
|
370 |
+
sft_format=sft_format,
|
371 |
+
)
|
372 |
+
|
373 |
+
return batched_prepares
|