Delete preprocesssors.py
Browse files- preprocesssors.py +0 -2472
preprocesssors.py
DELETED
@@ -1,2472 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import json
|
3 |
-
import math
|
4 |
-
from functools import reduce
|
5 |
-
from typing import Mapping, Optional, Sequence
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import tensorflow as tf
|
9 |
-
import seqio
|
10 |
-
import gin
|
11 |
-
|
12 |
-
from .data_utils import flatten_parts, stateless_permutation, stateless_shuffle
|
13 |
-
from .. import config
|
14 |
-
|
15 |
-
|
16 |
-
def get_from_dict(data, keys):
|
17 |
-
"""Iterate nested dictionary"""
|
18 |
-
return reduce(dict.get, keys, data)
|
19 |
-
|
20 |
-
def get_blank_image():
|
21 |
-
image = tf.zeros([224, 224, 3], dtype=tf.uint8)
|
22 |
-
image = tf.expand_dims(image, 0)[:1]
|
23 |
-
return image
|
24 |
-
|
25 |
-
|
26 |
-
@seqio.utils.map_over_dataset
|
27 |
-
def rekey(x, key_map=None):
|
28 |
-
"""Replace the feature keys according to the mapping in `key_map`.
|
29 |
-
For example, if the dataset returns examples of the format:
|
30 |
-
{'foo': 'something', 'bar': 'something else'}
|
31 |
-
and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return
|
32 |
-
examples with the format
|
33 |
-
{'boo': 'something', 'spar': 'something else'}
|
34 |
-
If a mapping is to an empty key or None, set the new key to an empty string.
|
35 |
-
Args:
|
36 |
-
x: an example to process.
|
37 |
-
key_map: dictionary mapping new keys to original keys
|
38 |
-
Returns:
|
39 |
-
A preprocessed example with the format listed above.
|
40 |
-
"""
|
41 |
-
if key_map:
|
42 |
-
out = {}
|
43 |
-
for new_key, old_key in key_map.items():
|
44 |
-
if isinstance(old_key, list):
|
45 |
-
out[new_key] = get_from_dict(x, old_key)
|
46 |
-
else:
|
47 |
-
out[new_key] = x[old_key]
|
48 |
-
return out
|
49 |
-
return x
|
50 |
-
|
51 |
-
|
52 |
-
def rename(**kwargs):
|
53 |
-
@seqio.map_over_dataset
|
54 |
-
def _fn(x):
|
55 |
-
updates = {}
|
56 |
-
for new_key, old_key in kwargs.items():
|
57 |
-
if isinstance(old_key, list):
|
58 |
-
val = x[old_key[0]]
|
59 |
-
for k in old_key[1:-1]:
|
60 |
-
val = val[k]
|
61 |
-
updates[new_key] = val.pop(old_key[-1])
|
62 |
-
else:
|
63 |
-
updates[new_key] = x.pop(old_key)
|
64 |
-
x.update(updates)
|
65 |
-
return x
|
66 |
-
return _fn
|
67 |
-
|
68 |
-
|
69 |
-
def extract_transcripts(ds):
|
70 |
-
ds = flatten_parts(ds, ["transcripts"])
|
71 |
-
def _map(ex):
|
72 |
-
return dict(
|
73 |
-
image=ex["image"],
|
74 |
-
text=ex["transcripts"],
|
75 |
-
url=ex["url"]
|
76 |
-
)
|
77 |
-
return ds.map(_map)
|
78 |
-
|
79 |
-
|
80 |
-
@seqio.map_over_dataset
|
81 |
-
def extract_caption_and_all_transcripts(ex):
|
82 |
-
transcripts = tf.random.shuffle(ex["transcripts"])[:3]
|
83 |
-
weight = 1.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
|
84 |
-
return dict(
|
85 |
-
image=ex["image"],
|
86 |
-
text=tf.concat([tf.expand_dims(ex["caption"], 0), transcripts], 0),
|
87 |
-
url=ex["url"],
|
88 |
-
text_weights=tf.pad(
|
89 |
-
tf.ones((1,), dtype=tf.float32), [[0, tf.shape(transcripts)[0]]],
|
90 |
-
constant_values=weight),
|
91 |
-
)
|
92 |
-
|
93 |
-
|
94 |
-
@seqio.map_over_dataset
|
95 |
-
def extract_all_transcripts(ex):
|
96 |
-
transcripts = tf.random.shuffle(ex["transcripts"])[:3]
|
97 |
-
weight = 3.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
|
98 |
-
return dict(
|
99 |
-
image=ex["image"],
|
100 |
-
text=transcripts,
|
101 |
-
url=ex["url"],
|
102 |
-
text_weights=tf.fill((tf.shape(transcripts)[0],), weight),
|
103 |
-
)
|
104 |
-
|
105 |
-
|
106 |
-
@seqio.map_over_dataset
|
107 |
-
def extract_transcript(ex):
|
108 |
-
transcripts = tf.random.shuffle(ex["transcripts"])
|
109 |
-
return dict(
|
110 |
-
image=ex["image"],
|
111 |
-
text=transcripts[0],
|
112 |
-
url=ex["url"],
|
113 |
-
)
|
114 |
-
|
115 |
-
|
116 |
-
@seqio.map_over_dataset
|
117 |
-
def extract_caption(ex):
|
118 |
-
caption = ex["caption"]
|
119 |
-
if len(caption.shape) > 0:
|
120 |
-
ex["text"] = caption[0]
|
121 |
-
else:
|
122 |
-
ex["text"] = caption
|
123 |
-
return ex
|
124 |
-
|
125 |
-
|
126 |
-
@seqio.map_over_dataset
|
127 |
-
def extract_joint_captions(ex):
|
128 |
-
caption = ex["caption"]
|
129 |
-
if len(caption.shape) > 0:
|
130 |
-
caption = caption[0]
|
131 |
-
_ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
|
132 |
-
_ix = _ix % tf.shape(ex["transcripts"])[0]
|
133 |
-
return dict(
|
134 |
-
image=ex["image"],
|
135 |
-
text=tf.stack([caption, ex["mistral_caption"], ex["transcripts"][_ix]], 0),
|
136 |
-
url=ex["url"]
|
137 |
-
)
|
138 |
-
|
139 |
-
|
140 |
-
@seqio.map_over_dataset(num_seeds=1)
|
141 |
-
def extract_caption_and_transcript(ex, seed):
|
142 |
-
caption = ex["caption"]
|
143 |
-
if len(caption.shape) > 0:
|
144 |
-
caption = caption[0]
|
145 |
-
_ix = tf.random.stateless_uniform((), seed, 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
|
146 |
-
return dict(
|
147 |
-
image=ex["image"],
|
148 |
-
text=tf.stack([caption, ex["transcripts"][_ix]], 0),
|
149 |
-
url=ex["url"]
|
150 |
-
)
|
151 |
-
|
152 |
-
|
153 |
-
@seqio.map_over_dataset
|
154 |
-
def caption_transcript_augmented(ex, sequence_length):
|
155 |
-
caption = ex["caption"]
|
156 |
-
if len(caption.shape) > 0:
|
157 |
-
caption = caption[0]
|
158 |
-
image = ex["image"]
|
159 |
-
properties = []
|
160 |
-
|
161 |
-
do_augmentation = sequence_length["is_training"]
|
162 |
-
# do_augmentation = False
|
163 |
-
|
164 |
-
# Keep this off, it screws up OCR
|
165 |
-
# do_hflip = (tf.random.uniform(()) > 0.2 and do_augmentation)
|
166 |
-
do_hflip = False
|
167 |
-
if do_hflip:
|
168 |
-
image = image[:, ::-1]
|
169 |
-
|
170 |
-
# Mild color jitter
|
171 |
-
do_color = (tf.random.uniform(()) > 0.5 and do_augmentation)
|
172 |
-
if do_color:
|
173 |
-
image = tf.image.random_hue(image, max_delta=0.05)
|
174 |
-
image = tf.image.random_brightness(image, max_delta=0.2)
|
175 |
-
image = tf.image.random_saturation(image, 0.7, 1.3)
|
176 |
-
image = tf.image.random_contrast(image, 0.7, 1.3)
|
177 |
-
|
178 |
-
# Mild affine transformation
|
179 |
-
do_affine = (tf.random.uniform(()) > 0.5 and do_augmentation)
|
180 |
-
if do_affine and do_augmentation:
|
181 |
-
shift_x = tf.random.uniform((), -10, 10) * 0
|
182 |
-
shift_y = tf.random.uniform((), -10, 10) * 0
|
183 |
-
shear_x = tf.random.uniform((), -2, 2)
|
184 |
-
shear_y = tf.random.uniform((), -2, 2)
|
185 |
-
rotation = tf.random.uniform((), -6, 6)
|
186 |
-
max_scale = 1.1
|
187 |
-
scale = tf.random.uniform((), 0.8, max_scale)
|
188 |
-
center = tf.cast(tf.shape(image), tf.float32)/2
|
189 |
-
|
190 |
-
image = tf.keras.ops.image.affine_transform(
|
191 |
-
image,
|
192 |
-
tf.stack(get_affine_matrix(
|
193 |
-
[center[0], center[1]],
|
194 |
-
rotation,
|
195 |
-
[shift_x, shift_y],
|
196 |
-
1/scale,
|
197 |
-
[shear_x, shear_y]
|
198 |
-
) + [0., 0.]),
|
199 |
-
interpolation='bilinear',
|
200 |
-
fill_mode='constant',
|
201 |
-
fill_value=1.,
|
202 |
-
data_format='channels_last'
|
203 |
-
)
|
204 |
-
|
205 |
-
properties = tf.stack([
|
206 |
-
("[hflip]" if do_hflip else ""),
|
207 |
-
("[color]" if do_color else ""),
|
208 |
-
("[affine]" if do_affine else "")
|
209 |
-
])
|
210 |
-
properties = tf.boolean_mask(properties, tf.strings.length(properties) > 0)
|
211 |
-
prompt = tf.strings.reduce_join(properties, separator=" ")
|
212 |
-
ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
|
213 |
-
out = dict(
|
214 |
-
image=image,
|
215 |
-
text=tf.stack([caption, ex["transcripts"][ix]], 0),
|
216 |
-
url=ex["url"],
|
217 |
-
prompt=prompt,
|
218 |
-
)
|
219 |
-
# out["metadata/unaugmented_image"] = image
|
220 |
-
return out
|
221 |
-
|
222 |
-
|
223 |
-
def extract_caption_and_transcript_hflip(ds):
|
224 |
-
|
225 |
-
# Just in case they are ordered somehow in Matt's data
|
226 |
-
@seqio.map_over_dataset
|
227 |
-
def _shuffle_transcripts(_ex):
|
228 |
-
_ex["transcripts"] = tf.random.shuffle(_ex["transcripts"])
|
229 |
-
_ex["hflip"] = tf.random.uniform((), 0, 3, dtype=tf.int32)
|
230 |
-
return _ex
|
231 |
-
|
232 |
-
ds = _shuffle_transcripts(ds)
|
233 |
-
|
234 |
-
# Build a 3x long dataset with each individual transcript so we iterate through
|
235 |
-
# each transcript
|
236 |
-
@seqio.map_over_dataset
|
237 |
-
def _with_transcript(ex, _ix):
|
238 |
-
caption = ex["caption"]
|
239 |
-
if len(caption.shape) > 0:
|
240 |
-
caption = caption[0]
|
241 |
-
hflip = ex["hflip"] == _ix
|
242 |
-
if hflip:
|
243 |
-
ex["image"] = ex["image"][:, ::-1]
|
244 |
-
style = ["long_caption_flipped", "transcript_flipped"]
|
245 |
-
else:
|
246 |
-
style = ["long_caption", "transcript"]
|
247 |
-
return dict(
|
248 |
-
image=ex["image"],
|
249 |
-
text=tf.stack([caption, ex["transcripts"][_ix]], 0),
|
250 |
-
url=ex["url"],
|
251 |
-
style=style
|
252 |
-
)
|
253 |
-
|
254 |
-
joint_ds = _with_transcript(ds, 0)
|
255 |
-
for i in range(1, 3):
|
256 |
-
joint_ds = joint_ds.concatenate(_with_transcript(ds, i))
|
257 |
-
|
258 |
-
return joint_ds
|
259 |
-
|
260 |
-
|
261 |
-
@seqio.map_over_dataset
|
262 |
-
def extract_llava(ex, sequence_length, output_features):
|
263 |
-
tf.assert_equal(tf.shape(ex['conversations']['value'])[0], 2)
|
264 |
-
prompt = ex['conversations']['value'][0]
|
265 |
-
text = ex['conversations']['value'][1]
|
266 |
-
ex.pop('conversations')
|
267 |
-
ex["text"] = text
|
268 |
-
ex["prompt"] = prompt
|
269 |
-
return ex
|
270 |
-
|
271 |
-
|
272 |
-
def extract_localized_narrative(ds):
|
273 |
-
ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
|
274 |
-
def _map(ex):
|
275 |
-
return dict(
|
276 |
-
image=ex["image"],
|
277 |
-
text=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
|
278 |
-
)
|
279 |
-
return ds.map(_map)
|
280 |
-
|
281 |
-
|
282 |
-
def float_to_text(val):
|
283 |
-
return tf.strings.as_string(tf.cast(val * 100, tf.int32))
|
284 |
-
|
285 |
-
|
286 |
-
@seqio.map_over_dataset
|
287 |
-
def extract_vqa(ex):
|
288 |
-
questions = ex["vqa"]["questions"]
|
289 |
-
answers = ex["vqa"]["answers"]
|
290 |
-
answers = tf.strings.reduce_join(answers, 1, separator="; ")
|
291 |
-
qas = tf.strings.reduce_join(tf.stack([questions, answers], 1), separator=" ")
|
292 |
-
return dict(
|
293 |
-
image=ex["image"],
|
294 |
-
text=tf.strings.reduce_join(qas, separator="\n")
|
295 |
-
)
|
296 |
-
|
297 |
-
|
298 |
-
@seqio.map_over_dataset
|
299 |
-
def coco_image_id_from_path(ex):
|
300 |
-
image_id = tf.strings.substr(ex["image/filename"], 0, tf.strings.length(ex["image/filename"])-4)
|
301 |
-
ex["image_id"] = tf.strings.to_number(image_id)
|
302 |
-
return ex
|
303 |
-
|
304 |
-
|
305 |
-
@seqio.map_over_dataset
|
306 |
-
def add_coco_url(ex):
|
307 |
-
"""Turns a COCO path into a URL, which can then be used in visualizations"""
|
308 |
-
path = ex["image/filename"]
|
309 |
-
if not tf.strings.regex_full_match(path, ".*/.*"):
|
310 |
-
prefix = tf.strings.regex_replace(path, "COCO_", "")
|
311 |
-
prefix = tf.strings.regex_replace(prefix, "_[0-9]+.jpg", "")
|
312 |
-
path = tf.strings.join([prefix, path], separator="/")
|
313 |
-
|
314 |
-
# images are hosted by the COCO website here
|
315 |
-
url = tf.strings.join(["https://s3.us-east-1.amazonaws.com/images.cocodataset.org/", path])
|
316 |
-
ex["metadata/image_url"] = url
|
317 |
-
return ex
|
318 |
-
|
319 |
-
|
320 |
-
def flatten_vqa(ds):
|
321 |
-
parts = ["questions", "answers"]
|
322 |
-
for k in ["id", "question_id"]:
|
323 |
-
if k in ds.element_spec:
|
324 |
-
parts.append(k)
|
325 |
-
return flatten_parts(ds, parts)
|
326 |
-
|
327 |
-
|
328 |
-
def format_gqa(ds, is_balanced=True, flatten=True):
|
329 |
-
if is_balanced:
|
330 |
-
ds = ds.filter(lambda x: tf.reduce_any(x["questions"]["is_balanced"]))
|
331 |
-
def _filter_qs(ex):
|
332 |
-
qs = ex["questions"]
|
333 |
-
mask = qs["is_balanced"]
|
334 |
-
qs = {k: tf.boolean_mask(v, mask) for k, v in qs.items()}
|
335 |
-
ex["questions"] = qs
|
336 |
-
return ex
|
337 |
-
ds = ds.map(_filter_qs)
|
338 |
-
|
339 |
-
if flatten:
|
340 |
-
ds = flatten_parts(ds, ["questions"])
|
341 |
-
|
342 |
-
def _rename(ex):
|
343 |
-
out = ex["questions"]
|
344 |
-
out["image"] = ex["image"]
|
345 |
-
out["image_id"] = ex["image_id"]
|
346 |
-
return out
|
347 |
-
return ds.map(_rename)
|
348 |
-
|
349 |
-
|
350 |
-
@seqio.map_over_dataset
|
351 |
-
def fix_doqa_url(x):
|
352 |
-
x["image_url"] = tf.strings.regex_replace(x["image_url"], "gs://", "")
|
353 |
-
return x
|
354 |
-
|
355 |
-
|
356 |
-
def _add_metadata(ex):
|
357 |
-
out = {}
|
358 |
-
if "id" in ex:
|
359 |
-
out["metadata/example_id"] = ex["id"]
|
360 |
-
elif "example_id" in ex:
|
361 |
-
out["metadata/example_id"] = ex["example_id"]
|
362 |
-
elif "question_id" in ex:
|
363 |
-
out["metadata/example_id"] = ex["question_id"]
|
364 |
-
if "image_url" in ex:
|
365 |
-
out["metadata/image_url"] = ex["image_url"]
|
366 |
-
for k, v in ex.items():
|
367 |
-
if k.startswith("metadata/"):
|
368 |
-
out[k] = v
|
369 |
-
return out
|
370 |
-
|
371 |
-
|
372 |
-
def image_only(ds):
|
373 |
-
return ds.filter(lambda x: x["has_image"])
|
374 |
-
|
375 |
-
|
376 |
-
def filter_difficult_direct_answer(ds):
|
377 |
-
return ds.filter(lambda x: not x["difficult_direct_answer"])
|
378 |
-
|
379 |
-
|
380 |
-
@seqio.map_over_dataset()
|
381 |
-
def format_ai2d(ex, variable_style=True):
|
382 |
-
abc = tf.constant(list("abcdefg".upper()))
|
383 |
-
out = dict(image=ex["image"])
|
384 |
-
out.update(_add_metadata(ex))
|
385 |
-
|
386 |
-
options = ex["choices"]
|
387 |
-
# >= 3 in case of none of the above like answers
|
388 |
-
n_options = tf.shape(ex["option_is_abc"])[0]
|
389 |
-
if ex["abc_label"] and tf.reduce_sum(tf.cast(ex["option_is_abc"], tf.int32)) >= (n_options - 1):
|
390 |
-
# The image labels are always upper, so use upper in the answer ptions
|
391 |
-
options = tf.where(
|
392 |
-
ex["option_is_abc"],
|
393 |
-
tf.strings.upper(options),
|
394 |
-
options
|
395 |
-
)
|
396 |
-
short_options = options
|
397 |
-
style = "ai2_diagram_no_letter"
|
398 |
-
else:
|
399 |
-
short_options = abc[:tf.shape(options)[0]]
|
400 |
-
options = tf.stack([short_options, options,], 1)
|
401 |
-
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
402 |
-
style = "ai2_diagram"
|
403 |
-
|
404 |
-
options = tf.strings.reduce_join(options, separator="\n")
|
405 |
-
out["question"] = ex["question"]
|
406 |
-
out["options"] = options
|
407 |
-
if variable_style:
|
408 |
-
out["style"] = style
|
409 |
-
if ex["answer_idx"] < 0:
|
410 |
-
out["text"] = "?"
|
411 |
-
else:
|
412 |
-
out["text"] = short_options[ex["answer_idx"]]
|
413 |
-
out["metadata/answer_idx"] = ex["answer_idx"]
|
414 |
-
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
415 |
-
out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
|
416 |
-
out["metadata/has_transparent_box"] = ex.get("has_transparent_box", tf.constant(False))
|
417 |
-
out["metadata/abc_label"] = ex["abc_label"]
|
418 |
-
return out
|
419 |
-
|
420 |
-
|
421 |
-
@gin.configurable()
|
422 |
-
@seqio.map_over_dataset()
|
423 |
-
def format_multiple_choice_qa(ex, option_format="abc"):
|
424 |
-
assert option_format == "abc"
|
425 |
-
abc = tf.constant(list("abcdefg".upper()))
|
426 |
-
out = dict(image=ex["image"])
|
427 |
-
out.update(_add_metadata(ex))
|
428 |
-
options = ex["choices"]
|
429 |
-
short_options = abc[:tf.shape(options)[0]]
|
430 |
-
options = tf.stack([short_options, options,], 1)
|
431 |
-
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
432 |
-
options = tf.strings.reduce_join(options, separator="\n")
|
433 |
-
out["question"] = ex["question"]
|
434 |
-
out["options"] = options
|
435 |
-
if ex["answer_idx"] < 0:
|
436 |
-
out["text"] = "?"
|
437 |
-
else:
|
438 |
-
out["text"] = short_options[ex["answer_idx"]]
|
439 |
-
out["metadata/answer_idx"] = ex["answer_idx"]
|
440 |
-
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
441 |
-
out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
|
442 |
-
# out["metadata/option_names"] = tf.RaggedTensor.from_row_lengths(short_options, tf.shape(short_options))
|
443 |
-
# out["metadata/option_names"] = short_options
|
444 |
-
return out
|
445 |
-
|
446 |
-
|
447 |
-
@seqio.map_over_dataset()
|
448 |
-
def output_options(ex):
|
449 |
-
ex["metadata/options"] = ex["options"]
|
450 |
-
return ex
|
451 |
-
|
452 |
-
|
453 |
-
@seqio.map_over_dataset()
|
454 |
-
def extract_tally_qa(ex):
|
455 |
-
questions = ex.pop("questions")
|
456 |
-
ex["questions"] = questions["question"]
|
457 |
-
ex["answers"] = tf.strings.as_string(questions["answer"])
|
458 |
-
ex["question_id"] = questions["question_id"]
|
459 |
-
return ex
|
460 |
-
|
461 |
-
|
462 |
-
@seqio.map_over_dataset()
|
463 |
-
def count_bench_preprocessor(ex):
|
464 |
-
return {
|
465 |
-
"image": ex["image"],
|
466 |
-
"text": tf.strings.as_string(ex["number"]),
|
467 |
-
"object": ex["noun"],
|
468 |
-
"question": tf.strings.join([
|
469 |
-
"How many ", ex["noun"], " are there?"
|
470 |
-
]),
|
471 |
-
"metadata/count": ex["number"],
|
472 |
-
}
|
473 |
-
|
474 |
-
|
475 |
-
def filter_human(ds):
|
476 |
-
return ds.filter(lambda x: x["is_human"])
|
477 |
-
|
478 |
-
|
479 |
-
def filter_aug(ds):
|
480 |
-
return ds.filter(lambda x: not x["is_human"])
|
481 |
-
|
482 |
-
|
483 |
-
@seqio.map_over_dataset()
|
484 |
-
def reweight_chartqa(ex, human, aug):
|
485 |
-
is_human = ex["metadata/is_human"]
|
486 |
-
ex["text_weights"] = human if is_human else aug
|
487 |
-
return ex
|
488 |
-
|
489 |
-
|
490 |
-
@seqio.map_over_dataset()
|
491 |
-
def chartqa_prompting(ex):
|
492 |
-
question = tf.strings.join([ex["question"], " Answer:"])
|
493 |
-
return dict(
|
494 |
-
image=ex["image"],
|
495 |
-
question=question,
|
496 |
-
answer=ex["answer"]
|
497 |
-
)
|
498 |
-
|
499 |
-
|
500 |
-
@seqio.map_over_dataset()
|
501 |
-
def chartqa_explanation(ex):
|
502 |
-
question = tf.strings.join([ex["question"], " Explanation:"])
|
503 |
-
out = {
|
504 |
-
"image": ex["image"],
|
505 |
-
"question": question,
|
506 |
-
"answer": ex["answer"],
|
507 |
-
}
|
508 |
-
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
|
509 |
-
return out
|
510 |
-
|
511 |
-
|
512 |
-
@seqio.map_over_dataset(num_seeds=1)
|
513 |
-
def _preprocess_scifi(ex, seed):
|
514 |
-
if "qa_pairs" in ex:
|
515 |
-
q = ex["qa_pairs"]
|
516 |
-
else:
|
517 |
-
q = ex["qa"]
|
518 |
-
ix = stateless_permutation(tf.shape(q["question"])[0], seed)
|
519 |
-
return dict(
|
520 |
-
image=ex["image"],
|
521 |
-
question=tf.gather(q["question"], ix),
|
522 |
-
explanation=tf.gather(q["explanation"], ix),
|
523 |
-
answer=tf.gather(q["answer"], ix),
|
524 |
-
)
|
525 |
-
|
526 |
-
@seqio.map_over_dataset
|
527 |
-
def scifi_explanation_only(ex):
|
528 |
-
return dict(
|
529 |
-
image=ex["image"],
|
530 |
-
question=ex["question"],
|
531 |
-
answer=ex["explanation"],
|
532 |
-
)
|
533 |
-
|
534 |
-
|
535 |
-
def filter_named_entity(ds):
|
536 |
-
@seqio.map_over_dataset
|
537 |
-
def _load_image(ex):
|
538 |
-
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
539 |
-
return ex
|
540 |
-
|
541 |
-
ds = _load_image(ds)
|
542 |
-
return ds.filter(lambda x: tf.reduce_min(tf.shape(x["image"])[:2]) >= 32)
|
543 |
-
|
544 |
-
|
545 |
-
@seqio.map_over_dataset()
|
546 |
-
def extract_named_entity(ex):
|
547 |
-
qs = ex["questions"]
|
548 |
-
return {
|
549 |
-
"image": ex["image"],
|
550 |
-
"metadata/image_url": ex["url"],
|
551 |
-
"metadata/entity": ex["entity"],
|
552 |
-
"questions": qs["question"],
|
553 |
-
"answers": qs["answer"],
|
554 |
-
}
|
555 |
-
|
556 |
-
@gin.configurable()
|
557 |
-
def extract_individual_vqa(ds, test=False, answer_mode="best"):
|
558 |
-
|
559 |
-
@seqio.map_over_dataset(num_seeds=1)
|
560 |
-
def _extract(ex, seed):
|
561 |
-
if "questions" in ex:
|
562 |
-
question = ex["questions"]
|
563 |
-
else:
|
564 |
-
question = ex["question"]
|
565 |
-
out = dict(
|
566 |
-
image=ex["image"],
|
567 |
-
question=question,
|
568 |
-
)
|
569 |
-
out.update(_add_metadata(ex))
|
570 |
-
out["metadata/question"] = question
|
571 |
-
if ex.get("answers") is not None:
|
572 |
-
out["metadata/references"] = tf.strings.reduce_join(ex["answers"], separator="\n")
|
573 |
-
elif ex.get("answer") is not None:
|
574 |
-
out["metadata/references"] = ex["answer"]
|
575 |
-
|
576 |
-
if not test:
|
577 |
-
if "answer" in ex:
|
578 |
-
answer = ex["answer"]
|
579 |
-
else:
|
580 |
-
answer = ex["answers"]
|
581 |
-
if answer.dtype in [tf.int32, tf.int64]:
|
582 |
-
answer = tf.strings.as_string(answer)
|
583 |
-
if len(answer.shape) == 1 and tf.shape(answer)[0] == 0:
|
584 |
-
answer = tf.expand_dims("", 0)
|
585 |
-
if len(answer.shape) == len(question.shape):
|
586 |
-
pass
|
587 |
-
# Handle questions with multiple answers
|
588 |
-
elif answer_mode == "random":
|
589 |
-
assert len(answer.shape) == 1
|
590 |
-
answer = answer[tf.random.stateless_uniform((), seed, 0, tf.shape(answer)[0], dtype=tf.int32)]
|
591 |
-
elif answer_mode == "best":
|
592 |
-
def _get_best(_answer):
|
593 |
-
vals, _, counts = tf.unique_with_counts(_answer)
|
594 |
-
count_thresh = tf.reduce_max(counts)
|
595 |
-
vals = tf.boolean_mask(vals, counts >= count_thresh)
|
596 |
-
return vals[tf.random.stateless_uniform((), seed, 0, tf.shape(vals)[0], dtype=tf.int32)]
|
597 |
-
if len(answer.shape) == 1:
|
598 |
-
answer = _get_best(answer)
|
599 |
-
elif isinstance(answer, tf.RaggedTensor):
|
600 |
-
n = tf.shape(answer)[0]
|
601 |
-
answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
|
602 |
-
for i in range(n):
|
603 |
-
answer_arr = answer_arr.write(i, _get_best(answer[i]))
|
604 |
-
answer = answer_arr.stack()
|
605 |
-
else:
|
606 |
-
answer = tf.map_fn(_get_best, answer)
|
607 |
-
elif answer_mode == "all_segments":
|
608 |
-
out["text"] = answer
|
609 |
-
elif answer_mode == "all_segments_weighted":
|
610 |
-
out["text"] = answer
|
611 |
-
out["text_weights"] = 1.0 / tf.cast(tf.shape(answer)[-1], tf.float32)
|
612 |
-
elif answer_mode == "all":
|
613 |
-
if len(answer.shape) == 1:
|
614 |
-
answer = stateless_shuffle(answer, seed)
|
615 |
-
answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
|
616 |
-
elif isinstance(answer, tf.RaggedTensor):
|
617 |
-
n = tf.shape(answer)[0]
|
618 |
-
answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
|
619 |
-
for i in range(n):
|
620 |
-
answer_arr = answer_arr.write(i, tf.strings.reduce_join(tf.random.shuffle(answer[i]), separator="\n", axis=-1))
|
621 |
-
answer = answer_arr.stack()
|
622 |
-
else:
|
623 |
-
answer = tf.map_fn(tf.random.shuffle, answer)
|
624 |
-
answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
|
625 |
-
else:
|
626 |
-
raise NotImplementedError()
|
627 |
-
out["text"] = answer
|
628 |
-
return out
|
629 |
-
return _extract(ds)
|
630 |
-
|
631 |
-
|
632 |
-
@seqio.map_over_dataset()
|
633 |
-
def extract_khan_academy(ex):
|
634 |
-
return dict(
|
635 |
-
image=ex["image"],
|
636 |
-
image_url=ex["image_url"],
|
637 |
-
prompt="Answer this question",
|
638 |
-
text=ex["gptResponse"]
|
639 |
-
)
|
640 |
-
|
641 |
-
@seqio.map_over_dataset()
|
642 |
-
def extract_vaia_qa_latex_image(ex, add_short_answer=False, set_short_answer_first=False):
|
643 |
-
if ex["has_image"]:
|
644 |
-
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
645 |
-
image = tf.expand_dims(image, 0)[:1]
|
646 |
-
else:
|
647 |
-
# image = get_blank_image() # blank image
|
648 |
-
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
649 |
-
image = tf.expand_dims(image, 0)[:0]
|
650 |
-
img_h = tf.shape(image)[1]
|
651 |
-
img_w = tf.shape(image)[2]
|
652 |
-
|
653 |
-
if add_short_answer:
|
654 |
-
if set_short_answer_first:
|
655 |
-
answer = tf.strings.join(["Answer: ", ex["short_answer"], "\n\n", ex["answer"]])
|
656 |
-
else:
|
657 |
-
answer = tf.strings.join([ex["answer"], "\n\n", "Answer: ", ex["short_answer"]])
|
658 |
-
else:
|
659 |
-
answer = ex["answer"]
|
660 |
-
out = dict(
|
661 |
-
image=image, # 4-d tensor
|
662 |
-
text=answer,
|
663 |
-
prompt=tf.strings.join([ex["latex_question"], "\n"]),
|
664 |
-
)
|
665 |
-
out["metadata/images"] = image
|
666 |
-
out.update(_add_metadata(ex))
|
667 |
-
out["metadata/batch_id"] = ex["batch_id"]
|
668 |
-
out["metadata/image_size"] = [img_w, img_h]
|
669 |
-
return out
|
670 |
-
|
671 |
-
@seqio.map_over_dataset()
|
672 |
-
def extract_vqa_online(ex):
|
673 |
-
out = dict(
|
674 |
-
image=ex["image"],
|
675 |
-
prompt=tf.strings.join([ex["question"], "\n"]),
|
676 |
-
text=ex["answer"]
|
677 |
-
)
|
678 |
-
out.update(_add_metadata(ex))
|
679 |
-
out["metadata/row_id"] = ex["row_id"]
|
680 |
-
return out
|
681 |
-
|
682 |
-
|
683 |
-
@seqio.map_over_dataset()
|
684 |
-
def extract_scifi_joint(ex):
|
685 |
-
if "qa_pairs" in ex:
|
686 |
-
q = ex["qa_pairs"]
|
687 |
-
else:
|
688 |
-
q = ex["qa"]
|
689 |
-
prompts = tf.concat([["Describe this image in detail."], q["question"]], 0)
|
690 |
-
responses = tf.concat([ex["summary"][None], q["answer"]], 0)
|
691 |
-
return dict(
|
692 |
-
image=ex["image"],
|
693 |
-
prompt=prompts,
|
694 |
-
text=responses,
|
695 |
-
)
|
696 |
-
|
697 |
-
|
698 |
-
def remove_no_qa(ds):
|
699 |
-
def _filter(ex):
|
700 |
-
if "qa_pairs" in ex:
|
701 |
-
q = ex["qa_pairs"]
|
702 |
-
else:
|
703 |
-
q = ex["qa"]
|
704 |
-
return tf.shape(q["question"])[0] > 0
|
705 |
-
return ds.filter(_filter)
|
706 |
-
|
707 |
-
|
708 |
-
@seqio.map_over_dataset()
|
709 |
-
def extract_scifi_qa_exp(ex):
|
710 |
-
return dict(
|
711 |
-
image=ex["image"],
|
712 |
-
question=ex["question"], # Array of questions
|
713 |
-
answer=tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]]),
|
714 |
-
)
|
715 |
-
|
716 |
-
|
717 |
-
@seqio.map_over_dataset(num_seeds=1)
|
718 |
-
def extract_scifi_qa_demo(ex, seed):
|
719 |
-
# if tf.random.stateless_uniform((), 0, 1) > 0.5:
|
720 |
-
answer = tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]])
|
721 |
-
# else:
|
722 |
-
# answer = ex["explanation"]
|
723 |
-
return dict(
|
724 |
-
image=ex["image"],
|
725 |
-
question=ex["question"], # Array of questions
|
726 |
-
answer=answer,
|
727 |
-
)
|
728 |
-
|
729 |
-
|
730 |
-
@seqio.map_over_dataset()
|
731 |
-
def clock_bench_preprocessor(ex):
|
732 |
-
out = dict(
|
733 |
-
image=ex["image"],
|
734 |
-
prompt="What time is being shown?",
|
735 |
-
)
|
736 |
-
for k in ["hour", "minute", "second", "answerable"]:
|
737 |
-
out[f"metadata/{k}"] = ex[k]
|
738 |
-
return out
|
739 |
-
|
740 |
-
|
741 |
-
def deg2rad(x):
|
742 |
-
return x*math.pi/180.0
|
743 |
-
|
744 |
-
|
745 |
-
def get_affine_matrix(center, angle, translate, scale, shear):
|
746 |
-
# From https://github.com/pytorch/vision/blob/f96c42fca53230057b16941b078a0a9eee06e20f/torchvision/transforms/functional.py#L1006
|
747 |
-
rot = deg2rad(angle)
|
748 |
-
sx = deg2rad(shear[0])
|
749 |
-
sy = deg2rad(shear[1])
|
750 |
-
|
751 |
-
cx, cy = center
|
752 |
-
tx, ty = translate
|
753 |
-
|
754 |
-
# RSS without scaling
|
755 |
-
a = tf.cos(rot - sy) / tf.cos(sy)
|
756 |
-
b = -tf.cos(rot - sy) * tf.tan(sx) / tf.cos(sy) - tf.sin(rot)
|
757 |
-
c = tf.sin(rot - sy) / tf.cos(sy)
|
758 |
-
d = -tf.sin(rot - sy) * tf.tan(sx) / tf.cos(sy) + tf.cos(rot)
|
759 |
-
|
760 |
-
matrix = [a, b, 0.0, c, d, 0.0]
|
761 |
-
matrix = [x * scale for x in matrix]
|
762 |
-
# Apply inverse of center translation: RSS * C^-1
|
763 |
-
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
|
764 |
-
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
|
765 |
-
# Apply translation and center : T * C * RSS * C^-1
|
766 |
-
matrix[2] += cx + tx
|
767 |
-
matrix[5] += cy + ty
|
768 |
-
return matrix
|
769 |
-
|
770 |
-
|
771 |
-
def quantize_point(coor, max_dim, mode="percent-precision-1"):
|
772 |
-
max_dim = tf.cast(max_dim, tf.float32)
|
773 |
-
coor = tf.cast(coor, tf.float32)
|
774 |
-
x = (coor / max_dim)
|
775 |
-
if mode == "percent-precision-1":
|
776 |
-
return tf.strings.as_string(x*100, precision=1)
|
777 |
-
elif mode == "zero_to_one":
|
778 |
-
return tf.strings.as_string(x, precision=3)
|
779 |
-
elif mode == "1k":
|
780 |
-
return tf.strings.as_string(x*1000, precision=0)
|
781 |
-
else:
|
782 |
-
raise NotImplementedError(mode)
|
783 |
-
|
784 |
-
|
785 |
-
def construct_pointing_format(label_text, alt_text, x_str, y_str):
|
786 |
-
if alt_text is None:
|
787 |
-
alt_text = label_text
|
788 |
-
np = tf.shape(x_str)[0]
|
789 |
-
if np == 0:
|
790 |
-
output = ""
|
791 |
-
elif np == 1:
|
792 |
-
output = tf.strings.join([
|
793 |
-
'<point x="', x_str[0], '" y="', y_str[0], '" alt="',
|
794 |
-
alt_text, '">', label_text, '</point>'
|
795 |
-
])
|
796 |
-
else:
|
797 |
-
ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
|
798 |
-
xs = tf.strings.join(["x", ids, '="', x_str, '"'])
|
799 |
-
ys = tf.strings.join(["y", ids, '="', y_str, '"'])
|
800 |
-
points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
|
801 |
-
output = tf.strings.join(
|
802 |
-
["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
|
803 |
-
return output
|
804 |
-
|
805 |
-
|
806 |
-
def order_points(x, y, seed, point_order):
|
807 |
-
if point_order == "natural":
|
808 |
-
return x, y
|
809 |
-
|
810 |
-
if point_order == "random":
|
811 |
-
ix = stateless_permutation(tf.shape(x)[0], seed)
|
812 |
-
elif point_order == "xy":
|
813 |
-
x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
|
814 |
-
ix = tf.argsort(x_float*100000 + y_float)
|
815 |
-
elif point_order == "yx":
|
816 |
-
x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
|
817 |
-
ix = tf.argsort(y_float*100000 + x_float)
|
818 |
-
else:
|
819 |
-
raise NotImplementedError(point_order)
|
820 |
-
return tf.gather(x, ix), tf.gather(y, ix)
|
821 |
-
|
822 |
-
|
823 |
-
@gin.configurable()
|
824 |
-
def points_to_text(x, y, w, h, seed, label=None, alt_text=None, point_mode="percent-precision-1",
|
825 |
-
point_order="xy", point_list_mode="tag"):
|
826 |
-
"""Returns a string encoding of a list of points"""
|
827 |
-
x = quantize_point(x, w, point_mode)
|
828 |
-
y = quantize_point(y, h, point_mode)
|
829 |
-
# Order the quantized points to make the order matches what was generated, this can matter
|
830 |
-
# when points have the same quantized value e.g, (10.001, 20) (10.002, 10) should be
|
831 |
-
# represented (10, 10), (10, 20), but if we sort before quantization we get (10, 20), (10, 10)
|
832 |
-
x, y = order_points(x, y, seed, point_order)
|
833 |
-
if point_list_mode == "tag":
|
834 |
-
return construct_pointing_format(label, alt_text, x, y)
|
835 |
-
elif point_list_mode == "paren":
|
836 |
-
n = tf.shape(x)[0]
|
837 |
-
return tf.strings.reduce_join(tf.strings.join([
|
838 |
-
"(", x, ", ", y, ")"
|
839 |
-
]), separator=", ")
|
840 |
-
# if n == 0:
|
841 |
-
# output = ""
|
842 |
-
# else:
|
843 |
-
# ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
|
844 |
-
# xs = tf.strings.join(["x", ids, '="', x_str, '"'])
|
845 |
-
# ys = tf.strings.join(["y", ids, '="', y_str, '"'])
|
846 |
-
# points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
|
847 |
-
# output = tf.strings.join(
|
848 |
-
# ["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
|
849 |
-
# return output
|
850 |
-
else:
|
851 |
-
raise NotImplementedError(point_list_mode)
|
852 |
-
|
853 |
-
|
854 |
-
def points_to_answer(x, y, w, h, seed, label, is_counting, alt_text=None):
|
855 |
-
count = tf.shape(x)[0]
|
856 |
-
if is_counting:
|
857 |
-
if count == 0:
|
858 |
-
return "There are none."
|
859 |
-
else:
|
860 |
-
point_text = points_to_text(x, y, w, h, seed, label, alt_text)
|
861 |
-
return tf.strings.join([
|
862 |
-
"Counting the ", point_text,
|
863 |
-
" shows a total of ",
|
864 |
-
tf.strings.as_string(count),
|
865 |
-
"."
|
866 |
-
])
|
867 |
-
else:
|
868 |
-
if count == 0:
|
869 |
-
return "There are none."
|
870 |
-
else:
|
871 |
-
return points_to_text(x, y, w, h, seed, label, alt_text)
|
872 |
-
|
873 |
-
|
874 |
-
@seqio.map_over_dataset(num_seeds=2)
|
875 |
-
def extract_point_qa(ex, seeds, answer_type="y_major"):
|
876 |
-
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
877 |
-
img_h = tf.shape(ex["image"])[0]
|
878 |
-
img_w = tf.shape(ex["image"])[1]
|
879 |
-
|
880 |
-
questions = ex["questions"]
|
881 |
-
question = questions["question"]
|
882 |
-
n = tf.shape(question)[0]
|
883 |
-
answers = tf.TensorArray(tf.string, size=n, element_shape=())
|
884 |
-
point_text = questions["annotations"]["point_text"]
|
885 |
-
point_seeds = tf.RaggedTensor.from_row_splits(
|
886 |
-
row_splits=point_text.row_splits,
|
887 |
-
values=tf.random.split(seeds[0], num=tf.shape(point_text.values)[0])
|
888 |
-
)
|
889 |
-
for question_ix in range(n):
|
890 |
-
anno = questions["annotations"]
|
891 |
-
answer = questions["answer_with_placeholders"][question_ix]
|
892 |
-
n_anno = tf.shape(anno["point_text"][question_ix])[0]
|
893 |
-
for anno_ix in range(n_anno):
|
894 |
-
points = anno["points"][question_ix, anno_ix]
|
895 |
-
point_text = points_to_answer(
|
896 |
-
points[:, 0], points[:, 1], 100, 100,
|
897 |
-
point_seeds[question_ix, anno_ix],
|
898 |
-
anno["point_text"][question_ix, anno_ix],
|
899 |
-
False,
|
900 |
-
alt_text=anno["alt_text"][question_ix, anno_ix],
|
901 |
-
)
|
902 |
-
answer_split = tf.strings.split(answer, sep="<|POINT|>", maxsplit=1)
|
903 |
-
answer = tf.strings.join([answer_split[0], point_text, answer_split[1]])
|
904 |
-
# Make sure all placeholders where used
|
905 |
-
tf.debugging.assert_equal(tf.shape(tf.strings.split(answer, sep="<|POINT|>"))[0], 1)
|
906 |
-
answers = answers.write(question_ix, answer)
|
907 |
-
|
908 |
-
messages = tf.stack([question, answers.stack()], axis=1)
|
909 |
-
messages = tf.reshape(messages, [-1])
|
910 |
-
conversation_ids = tf.range(tf.shape(messages)[0] // 2, dtype=tf.int32)
|
911 |
-
conversation_ids = tf.repeat(conversation_ids, 2)
|
912 |
-
out = dict(
|
913 |
-
image=ex["image"],
|
914 |
-
messages=tf.RaggedTensor.from_value_rowids(messages, conversation_ids)
|
915 |
-
)
|
916 |
-
ix = stateless_permutation(tf.shape(messages)[0], seeds[1])
|
917 |
-
messages = tf.gather(messages, ix)
|
918 |
-
out.update(_add_metadata(ex))
|
919 |
-
out["metadata/image_size"] = [img_w, img_h]
|
920 |
-
return out
|
921 |
-
|
922 |
-
|
923 |
-
def select_point(mask):
|
924 |
-
bs = tf.shape(mask)[0]
|
925 |
-
valid = tf.cast(mask, tf.float32)
|
926 |
-
h, w = tf.shape(mask)[1], tf.shape(mask)[2]
|
927 |
-
ys = tf.range(h, dtype=tf.int32)
|
928 |
-
xs = tf.range(w, dtype=tf.int32)
|
929 |
-
|
930 |
-
n = tf.reduce_sum(valid, [1, 2])
|
931 |
-
cy = tf.reduce_sum(tf.cast(ys[None, :, None], tf.float32) * valid, [1, 2]) / n # [bs]
|
932 |
-
cx = tf.reduce_sum(tf.cast(xs[None, None, :], tf.float32) * valid, [1, 2]) / n # [bs]
|
933 |
-
|
934 |
-
dist_y = tf.square(tf.range(h, dtype=tf.float32)[None, :] - cy[:, None]) # [bs, h]
|
935 |
-
dist_x = tf.square(tf.range(w, dtype=tf.float32)[None, :] - cx[:, None]) # [bs, w]
|
936 |
-
dist = dist_y[:, :, None] + dist_x[:, None, :] # [batch, h, w]
|
937 |
-
dist = dist + (1 - valid) * 1e12
|
938 |
-
min_dist = tf.argmin(tf.reshape(dist, [bs, -1]), axis=-1) # [batch]
|
939 |
-
w = tf.cast(w, min_dist.dtype)
|
940 |
-
cy = tf.cast(min_dist // w, tf.float32)
|
941 |
-
cx = tf.cast(min_dist % w, tf.float32)
|
942 |
-
return cx, cy
|
943 |
-
|
944 |
-
|
945 |
-
@seqio.map_over_dataset
|
946 |
-
def refexp_pointing(ex):
|
947 |
-
img_h = tf.shape(ex["image"])[0]
|
948 |
-
img_w = tf.shape(ex["image"])[1]
|
949 |
-
objects = ex["objects"]
|
950 |
-
|
951 |
-
# Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
|
952 |
-
refexps = objects['refexp']['raw']
|
953 |
-
bbox = objects["bbox"]
|
954 |
-
mask = tf.squeeze(objects["mask"], -1)
|
955 |
-
|
956 |
-
ix = tf.range(0, tf.shape(refexps)[0], dtype=tf.int32)
|
957 |
-
ix = tf.random.shuffle(ix)
|
958 |
-
refexps = tf.gather(refexps, ix)
|
959 |
-
bbox = tf.gather(bbox, ix)
|
960 |
-
mask = tf.gather(mask, ix)
|
961 |
-
|
962 |
-
cx, cy = select_point(mask)
|
963 |
-
answers = points_to_text(img_h, img_w, cx, cy)
|
964 |
-
|
965 |
-
out = {
|
966 |
-
"image": ex["image"],
|
967 |
-
"refexp": refexps.values,
|
968 |
-
"metadata/image_size": tf.stack([img_w, img_h,]),
|
969 |
-
"text": tf.repeat(answers, refexps.row_lengths()),
|
970 |
-
}
|
971 |
-
if "image_url" in ex:
|
972 |
-
out["metadata/image_url"] = ex["image_url"]
|
973 |
-
return out
|
974 |
-
|
975 |
-
|
976 |
-
@seqio.map_over_dataset
|
977 |
-
def refexp_pointing_inf(ex):
|
978 |
-
img_h = tf.shape(ex["image"])[0]
|
979 |
-
img_w = tf.shape(ex["image"])[1]
|
980 |
-
|
981 |
-
objects = ex["objects"]
|
982 |
-
mask = tf.squeeze(objects["mask"], -1)
|
983 |
-
cx, cy = select_point(mask)
|
984 |
-
answers = points_to_text(img_h, img_w, cx, cy)
|
985 |
-
|
986 |
-
refexps = objects["refexp"]["raw"]
|
987 |
-
|
988 |
-
# We can't use `mask` directly since it is variable size, and thus it
|
989 |
-
# will break batching. Here we serialize it instead
|
990 |
-
serialized_masks = tf.map_fn(tf.io.serialize_tensor, mask, fn_output_signature=tf.string)
|
991 |
-
out = {
|
992 |
-
"image": ex["image"],
|
993 |
-
"refexp": refexps,
|
994 |
-
"metadata/bbox": objects["bbox"],
|
995 |
-
"metadata/answer": answers,
|
996 |
-
"metadata/mask": serialized_masks,
|
997 |
-
"metadata/image_size": tf.stack([img_w, img_h]),
|
998 |
-
}
|
999 |
-
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
|
1000 |
-
return out
|
1001 |
-
|
1002 |
-
@seqio.map_over_dataset
|
1003 |
-
def extract_andriod_control_inf(ex, mode):
|
1004 |
-
if mode == "ll":
|
1005 |
-
prompt = tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]])
|
1006 |
-
elif mode == "hl_ll":
|
1007 |
-
prompt = tf.strings.join([
|
1008 |
-
"high_level: ", ex["metadata/hl_instruction"],
|
1009 |
-
" low_level: ", ex["metadata/ll_instruction"]
|
1010 |
-
])
|
1011 |
-
elif mode == "hl":
|
1012 |
-
prompt = tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]])
|
1013 |
-
elif mode == "hl_cot":
|
1014 |
-
prompt = tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]])
|
1015 |
-
else:
|
1016 |
-
raise NotImplementedError()
|
1017 |
-
|
1018 |
-
out = dict(
|
1019 |
-
image=ex["image"],
|
1020 |
-
prompt=prompt,
|
1021 |
-
text=ex["metadata/target_action"]
|
1022 |
-
)
|
1023 |
-
out.update(_add_metadata(ex))
|
1024 |
-
return out
|
1025 |
-
|
1026 |
-
@seqio.map_over_dataset
|
1027 |
-
def extract_android_control(ex):
|
1028 |
-
# Each image has three tasks:
|
1029 |
-
# low level -> action
|
1030 |
-
# high+low level -> action
|
1031 |
-
# high level -> action
|
1032 |
-
# high level -> low level + action (CoT)
|
1033 |
-
out = dict(
|
1034 |
-
image=ex["image"],
|
1035 |
-
prompt=tf.stack([
|
1036 |
-
tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]]),
|
1037 |
-
tf.strings.join([
|
1038 |
-
"high_level: ", ex["metadata/hl_instruction"],
|
1039 |
-
" low_level: ", ex["metadata/ll_instruction"]
|
1040 |
-
]),
|
1041 |
-
tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]]),
|
1042 |
-
tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]]),
|
1043 |
-
]),
|
1044 |
-
text=tf.stack([
|
1045 |
-
ex["metadata/target_action"],
|
1046 |
-
ex["metadata/target_action"],
|
1047 |
-
ex["metadata/target_action"],
|
1048 |
-
tf.strings.join(["Plan: ", ex["metadata/ll_instruction"], " Action: ", ex["metadata/target_action"]]),
|
1049 |
-
])
|
1050 |
-
)
|
1051 |
-
# Only needed if visualizing
|
1052 |
-
# ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1053 |
-
# img_h = tf.shape(ex["image"])[0]
|
1054 |
-
# img_w = tf.shape(ex["image"])[1]
|
1055 |
-
# out["metadata/image_size"] = tf.stack([img_w, img_h,])
|
1056 |
-
out.update(_add_metadata(ex))
|
1057 |
-
return out
|
1058 |
-
|
1059 |
-
|
1060 |
-
@seqio.map_over_dataset(num_seeds=1)
|
1061 |
-
def refexp(ex, seed):
|
1062 |
-
img_h = tf.shape(ex["image"])[0]
|
1063 |
-
img_w = tf.shape(ex["image"])[1]
|
1064 |
-
objects = ex["objects"]
|
1065 |
-
|
1066 |
-
# Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
|
1067 |
-
refexps = objects['refexp']['raw']
|
1068 |
-
bbox = objects["bbox"]
|
1069 |
-
ix = stateless_permutation(tf.shape(refexps)[0], seed)
|
1070 |
-
refexps = tf.gather(refexps, ix)
|
1071 |
-
bbox = tf.gather(bbox, ix)
|
1072 |
-
|
1073 |
-
x2 = bbox[:, 0] + bbox[:, 2]
|
1074 |
-
y2 = bbox[:, 1] + bbox[:, 3]
|
1075 |
-
with tf.control_dependencies([
|
1076 |
-
tf.debugging.assert_equal(tf.reduce_any(x2 <= tf.cast(img_w, tf.float32)), True),
|
1077 |
-
tf.debugging.assert_equal(tf.reduce_any(y2 <= tf.cast(img_h, tf.float32)), True)
|
1078 |
-
]):
|
1079 |
-
answers = points_to_text(
|
1080 |
-
img_h, img_w,
|
1081 |
-
tf.reshape(tf.stack([bbox[:, 0], x2], 1), [-1]),
|
1082 |
-
tf.reshape(tf.stack([bbox[:, 1], y2], 1), [-1]))
|
1083 |
-
answers = tf.strings.reduce_join(tf.reshape(answers, [-1, 2]), separator=" ", axis=1)
|
1084 |
-
|
1085 |
-
out = {
|
1086 |
-
"image": ex["image"],
|
1087 |
-
"refexp": refexps.values,
|
1088 |
-
"metadata/bbox": bbox,
|
1089 |
-
"metadata/image_size": tf.stack([img_w, img_h,]),
|
1090 |
-
"text": tf.repeat(answers, refexps.row_lengths()),
|
1091 |
-
}
|
1092 |
-
|
1093 |
-
if "image_url" in ex:
|
1094 |
-
out["image_url"] = ex["image_url"]
|
1095 |
-
return out
|
1096 |
-
|
1097 |
-
|
1098 |
-
@seqio.map_over_dataset
|
1099 |
-
def refexp_inf(ex):
|
1100 |
-
img_h = tf.shape(ex["image"])[0]
|
1101 |
-
img_w = tf.shape(ex["image"])[1]
|
1102 |
-
out = {
|
1103 |
-
"image": ex["image"],
|
1104 |
-
"refexp": ex["objects"]["refexp"]["raw"],
|
1105 |
-
"metadata/bbox": ex["objects"]["bbox"],
|
1106 |
-
"metadata/image_size": tf.stack([img_w, img_h,]),
|
1107 |
-
}
|
1108 |
-
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
|
1109 |
-
return out
|
1110 |
-
|
1111 |
-
|
1112 |
-
def point_text_interleaved(*args):
|
1113 |
-
raise NotImplementedError()
|
1114 |
-
|
1115 |
-
|
1116 |
-
@seqio.map_over_dataset
|
1117 |
-
def web_pointing_preprocessor(ex):
|
1118 |
-
img_h = tf.shape(ex["image"])[0]
|
1119 |
-
img_w = tf.shape(ex["image"])[1]
|
1120 |
-
|
1121 |
-
question = point_text_interleaved(
|
1122 |
-
img_h, img_w, ex["question"], ex["question_points"]["x"], ex["question_points"]["y"])
|
1123 |
-
answer = point_text_interleaved(
|
1124 |
-
img_h, img_w, ex["answer"], ex["answer_points"]["x"], ex["answer_points"]["y"])
|
1125 |
-
answer_points = tf.stack([ex["answer_points"]["x"], ex["answer_points"]["y"]], axis=1)
|
1126 |
-
return {
|
1127 |
-
"question": question,
|
1128 |
-
"answer": answer,
|
1129 |
-
"image": ex["image"],
|
1130 |
-
"metadata/image_size": [img_w, img_h],
|
1131 |
-
"metadata/question_type": ex["question_type"],
|
1132 |
-
"metadata/answer_points": tf.io.serialize_tensor(answer_points),
|
1133 |
-
"metadata/answer": answer,
|
1134 |
-
}
|
1135 |
-
|
1136 |
-
|
1137 |
-
def filter_pointing(ds):
|
1138 |
-
return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] >= 1)
|
1139 |
-
|
1140 |
-
|
1141 |
-
def filter_qa(ds):
|
1142 |
-
return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] == 0)
|
1143 |
-
|
1144 |
-
# vaia filtering
|
1145 |
-
def filter_image_only(ds):
|
1146 |
-
return ds.filter(lambda ex: ex["has_image"])
|
1147 |
-
|
1148 |
-
def filter_mc(ds):
|
1149 |
-
return ds.filter(lambda ex: ex["is_mc"])
|
1150 |
-
|
1151 |
-
def remove_is_long(ds):
|
1152 |
-
return ds.filter(lambda ex: not ex["is_long"])
|
1153 |
-
|
1154 |
-
def remove_has_multiple_parts(ds):
|
1155 |
-
return ds.filter(lambda ex: not ex["has_multiple_parts"])
|
1156 |
-
|
1157 |
-
|
1158 |
-
def _split(ds: tf.data.Dataset, keys, n_splits=2):
|
1159 |
-
def _map(ex):
|
1160 |
-
n = tf.shape(ex[keys[0]])[0]
|
1161 |
-
if n < n_splits:
|
1162 |
-
return tf.data.Dataset.from_tensors(ex)
|
1163 |
-
else:
|
1164 |
-
# import pdb; pdb.set_trace()
|
1165 |
-
bs = n // n_splits
|
1166 |
-
remainder = n - bs*n_splits
|
1167 |
-
lens = tf.concat([
|
1168 |
-
tf.ones([remainder], dtype=tf.int32),
|
1169 |
-
tf.zeros([n_splits-remainder], dtype=tf.int32),
|
1170 |
-
], axis=0) + bs
|
1171 |
-
tf.debugging.assert_equal(tf.reduce_sum(lens), n)
|
1172 |
-
ends = tf.cumsum(lens)
|
1173 |
-
|
1174 |
-
parts = []
|
1175 |
-
for split_ix in range(n_splits):
|
1176 |
-
part_ex = dict(ex)
|
1177 |
-
e = ends[split_ix]
|
1178 |
-
s = e - lens[split_ix]
|
1179 |
-
for k in keys:
|
1180 |
-
if isinstance(k, tuple):
|
1181 |
-
assert len(k) == 2
|
1182 |
-
part_ex[k[0]][k[1]] = ex[k[0]][k[1]][s:e]
|
1183 |
-
else:
|
1184 |
-
part_ex[k] = ex[k][s:e]
|
1185 |
-
parts.append(part_ex)
|
1186 |
-
|
1187 |
-
ds = tf.data.Dataset.from_tensors(parts[0])
|
1188 |
-
for sub_ds in parts[1:]:
|
1189 |
-
sub_ds = tf.data.Dataset.from_tensors(sub_ds)
|
1190 |
-
ds = ds.concatenate(sub_ds)
|
1191 |
-
return ds
|
1192 |
-
|
1193 |
-
return ds.flat_map(_map)
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
def split(ds, n=2):
|
1198 |
-
# return ds
|
1199 |
-
return _split(ds, [k for k in [
|
1200 |
-
"question",
|
1201 |
-
"label",
|
1202 |
-
"text",
|
1203 |
-
"entity",
|
1204 |
-
"messages"
|
1205 |
-
] if k in ds.element_spec], n_splits=n)
|
1206 |
-
|
1207 |
-
|
1208 |
-
def split_points(ds, max_points=50):
|
1209 |
-
label = "question" if "question" in ds.element_spec else "label"
|
1210 |
-
return _split(ds, [
|
1211 |
-
"question", label, "notInImage",
|
1212 |
-
("answer_points", "x"),
|
1213 |
-
("answer_points", "y"),
|
1214 |
-
])
|
1215 |
-
|
1216 |
-
|
1217 |
-
@seqio.map_over_dataset
|
1218 |
-
def fix_count_qa(ex):
|
1219 |
-
ex["label"] = ex["label"][::2]
|
1220 |
-
tf.debugging.assert_equal(tf.shape(ex["answer_points"]["x"])[0], tf.shape(ex["label"])[0])
|
1221 |
-
return ex
|
1222 |
-
|
1223 |
-
|
1224 |
-
def filter_points(ds, max_number=40):
|
1225 |
-
|
1226 |
-
def _add_valid(ex):
|
1227 |
-
valid = (
|
1228 |
-
tf.reduce_all(ex["answer_points"]["x"] >= 0.0, axis=-1) &
|
1229 |
-
tf.reduce_all(ex["answer_points"]["x"] <= 100.0, axis=-1) &
|
1230 |
-
tf.reduce_all(ex["answer_points"]["y"] >= 0.0, axis=-1) &
|
1231 |
-
tf.reduce_all(ex["answer_points"]["y"] <= 100.0, axis=-1) &
|
1232 |
-
(ex["answer_points"]["y"].row_lengths() <= max_number)
|
1233 |
-
)
|
1234 |
-
ex["valid"] = valid
|
1235 |
-
return ex
|
1236 |
-
ds = ds.map(_add_valid)
|
1237 |
-
ds = ds.filter(lambda ex: tf.reduce_any(ex["valid"]))
|
1238 |
-
return ds
|
1239 |
-
|
1240 |
-
|
1241 |
-
# def filter_points(ds, max_number=30):
|
1242 |
-
# n_points = ds["answer_points"]["x"].row_lengths()
|
1243 |
-
# parts = tf.TensorArray(tf.int32, size=tf.shape(n_points[0]), element_shape=tf.TensorShape([None]))
|
1244 |
-
# total = 0
|
1245 |
-
# on_row = 0
|
1246 |
-
# for i in range(n_points):
|
1247 |
-
# n = n_points[i]
|
1248 |
-
# if n > max_number:
|
1249 |
-
# continue
|
1250 |
-
# if n + total > max_number:
|
1251 |
-
#
|
1252 |
-
# return ds
|
1253 |
-
|
1254 |
-
|
1255 |
-
@seqio.map_over_dataset(num_seeds=2)
|
1256 |
-
def pointing_preprocessor(ex, sequence_length, seeds, with_count=False):
|
1257 |
-
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1258 |
-
img_h = tf.shape(image)[0]
|
1259 |
-
img_w = tf.shape(image)[1]
|
1260 |
-
|
1261 |
-
ix = tf.where(ex["valid"])[:, 0]
|
1262 |
-
ix = stateless_shuffle(ix, seeds[0])
|
1263 |
-
if "label" in ex:
|
1264 |
-
question = tf.strings.lower(ex["label"])
|
1265 |
-
else:
|
1266 |
-
question = ex["question"]
|
1267 |
-
question = tf.gather(question, ix) # [n_question]
|
1268 |
-
points_x = tf.gather(ex["answer_points"]["x"], ix) # [n_question, n_points[ragged]]]
|
1269 |
-
points_y = tf.gather(ex["answer_points"]["y"], ix)
|
1270 |
-
not_in_image = tf.gather(ex["notInImage"], ix) # [n_question]
|
1271 |
-
|
1272 |
-
n = tf.shape(points_x)[0]
|
1273 |
-
point_text = tf.TensorArray(dtype=tf.string, size=n, element_shape=()) # [n_question]
|
1274 |
-
point_seeds = tf.random.split(seeds[1], n)
|
1275 |
-
for i in range(n):
|
1276 |
-
answer = points_to_answer(points_x[i], points_y[i], 100, 100, point_seeds[i], question[i], with_count)
|
1277 |
-
point_text = point_text.write(i, answer)
|
1278 |
-
return {
|
1279 |
-
"image": image,
|
1280 |
-
"metadata/image_size": [img_w, img_h],
|
1281 |
-
"entity": question,
|
1282 |
-
"question": question,
|
1283 |
-
"text": point_text.stack(),
|
1284 |
-
}
|
1285 |
-
|
1286 |
-
|
1287 |
-
@seqio.map_over_dataset
|
1288 |
-
def pointing_inf_preprocessor(ex):
|
1289 |
-
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1290 |
-
img_h = tf.shape(ex["image"])[0]
|
1291 |
-
img_w = tf.shape(ex["image"])[1]
|
1292 |
-
|
1293 |
-
question = ex["question"]
|
1294 |
-
not_in_image = tf.shape(ex["answer_points"]["x"])[0] == 0
|
1295 |
-
|
1296 |
-
# points are stored in normalized format, de-normalize here
|
1297 |
-
points_x = ex["answer_points"]["x"] * tf.cast(img_w, tf.float32) / 100.0
|
1298 |
-
points_y = ex["answer_points"]["y"] * tf.cast(img_h, tf.float32) / 100.0
|
1299 |
-
|
1300 |
-
out = dict(
|
1301 |
-
image=ex["image"],
|
1302 |
-
question=question,
|
1303 |
-
entity=question,
|
1304 |
-
)
|
1305 |
-
out.update(_add_metadata(ex))
|
1306 |
-
out["metadata/not_in_image"] = not_in_image
|
1307 |
-
# We can't use `mask` directly since it is variable size, and thus it
|
1308 |
-
# will break batching. Here we serialize it instead
|
1309 |
-
serialized_masks = tf.map_fn(tf.io.serialize_tensor, ex["masks"], fn_output_signature=tf.string)
|
1310 |
-
serialized_masks = tf.strings.reduce_join(serialized_masks, separator="|||")
|
1311 |
-
out["metadata/mask"] = serialized_masks
|
1312 |
-
out["metadata/question"] = question
|
1313 |
-
out["metadata/answer_points"] = tf.io.serialize_tensor(tf.stack([points_x, points_y], 1))
|
1314 |
-
out["metadata/image_size"] = [img_w, img_h]
|
1315 |
-
|
1316 |
-
return out
|
1317 |
-
|
1318 |
-
|
1319 |
-
@seqio.map_over_dataset(num_seeds=1)
|
1320 |
-
def count_qa_preprocessor_inf(ex, sequence_length, seed):
|
1321 |
-
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1322 |
-
img_h = tf.shape(image)[0]
|
1323 |
-
img_w = tf.shape(image)[1]
|
1324 |
-
|
1325 |
-
entity = tf.strings.substr(
|
1326 |
-
ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
|
1327 |
-
entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
|
1328 |
-
entity = tf.strings.lower(entity)
|
1329 |
-
tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
|
1330 |
-
|
1331 |
-
return {
|
1332 |
-
"image": image,
|
1333 |
-
"metadata/image_size": [img_w, img_h],
|
1334 |
-
"metadata/count": tf.strings.to_number(ex["answer"]),
|
1335 |
-
"question": ex["question"],
|
1336 |
-
"entity": entity,
|
1337 |
-
}
|
1338 |
-
|
1339 |
-
|
1340 |
-
@seqio.map_over_dataset(num_seeds=1)
|
1341 |
-
def count_qa_preprocessor(ex, sequence_length, seed, with_count=False,
|
1342 |
-
for_inference=False):
|
1343 |
-
point_answer = ex["point_answer"]
|
1344 |
-
numbers_str = tf.strings.regex_replace(point_answer, r'\.$', '')
|
1345 |
-
numbers_str = tf.strings.regex_replace(numbers_str, r'[^\d\.\s]+', '')
|
1346 |
-
numbers_str = tf.strings.strip(numbers_str)
|
1347 |
-
numbers = tf.strings.split(numbers_str)
|
1348 |
-
float_numbers = tf.strings.to_number(numbers, out_type=tf.float32)
|
1349 |
-
coordinates = tf.reshape(float_numbers, (-1, 3))
|
1350 |
-
points_x = coordinates[:, 1]
|
1351 |
-
points_y = coordinates[:, 2]
|
1352 |
-
|
1353 |
-
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1354 |
-
img_h = tf.shape(image)[0]
|
1355 |
-
img_w = tf.shape(image)[1]
|
1356 |
-
entity = tf.strings.substr(
|
1357 |
-
ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
|
1358 |
-
entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
|
1359 |
-
entity = tf.strings.lower(entity)
|
1360 |
-
tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
|
1361 |
-
count = tf.strings.to_number(ex["answer"], out_type=tf.int32)
|
1362 |
-
if for_inference:
|
1363 |
-
return {
|
1364 |
-
"image": image,
|
1365 |
-
"metadata/image_size": [img_w, img_h],
|
1366 |
-
"metadata/count": count,
|
1367 |
-
"question": ex["question"],
|
1368 |
-
"entity": entity,
|
1369 |
-
}
|
1370 |
-
else:
|
1371 |
-
tf.debugging.assert_equal(count, tf.shape(points_x)[0])
|
1372 |
-
# points are already normalized so use w=1, h=1
|
1373 |
-
answer = points_to_answer(points_x, points_y, 1, 1, seed, entity, with_count)
|
1374 |
-
return {
|
1375 |
-
"image": image,
|
1376 |
-
"metadata/image_size": [img_w, img_h],
|
1377 |
-
"metadata/count": count,
|
1378 |
-
"question": ex["question"],
|
1379 |
-
"entity": entity,
|
1380 |
-
"text": answer,
|
1381 |
-
}
|
1382 |
-
|
1383 |
-
|
1384 |
-
@gin.configurable()
|
1385 |
-
@seqio.map_over_dataset
|
1386 |
-
def cleanup_preprocessor(ex, preprocess=False):
|
1387 |
-
if preprocess:
|
1388 |
-
ex["prompt"] = tf.strings.join(
|
1389 |
-
[
|
1390 |
-
"[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
|
1391 |
-
ex["prompt"],
|
1392 |
-
"\n[[Assistant]]: {after}"
|
1393 |
-
]
|
1394 |
-
)
|
1395 |
-
return ex
|
1396 |
-
else:
|
1397 |
-
return ex
|
1398 |
-
|
1399 |
-
|
1400 |
-
@gin.configurable()
|
1401 |
-
@seqio.map_over_dataset
|
1402 |
-
def random_text_preprocessor(ex, preprocess=False):
|
1403 |
-
ex["prompt"] = "What does the text say in this image?"
|
1404 |
-
if preprocess:
|
1405 |
-
ex["prompt"] = tf.strings.join(["[[User]]: ", ex["prompt"], "\n[[Assistant]]:"])
|
1406 |
-
return ex
|
1407 |
-
else:
|
1408 |
-
return ex
|
1409 |
-
|
1410 |
-
|
1411 |
-
@seqio.map_over_dataset(num_seeds=25)
|
1412 |
-
def clock_augmentation(ex, seeds):
|
1413 |
-
seeds = list(seeds)
|
1414 |
-
image = ex["image"]
|
1415 |
-
|
1416 |
-
# Apply shear, rotation, and scale through one affine matrix
|
1417 |
-
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1418 |
-
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1419 |
-
|
1420 |
-
_call_id = [0]
|
1421 |
-
|
1422 |
-
def _rng(_minval=0, _maxval=1, shape=(), dtype=tf.float32):
|
1423 |
-
return tf.random.stateless_uniform(shape, seeds.pop(), _minval, _maxval, dtype=dtype)
|
1424 |
-
|
1425 |
-
sel = _rng(0, 1)
|
1426 |
-
if sel < 0.1:
|
1427 |
-
# Straight on
|
1428 |
-
shear_x = 0.
|
1429 |
-
shear_y = 0.
|
1430 |
-
rotation = 0.
|
1431 |
-
elif sel < 0.5:
|
1432 |
-
# Normal looking
|
1433 |
-
shear_x = _rng(-10, 10)
|
1434 |
-
shear_y = _rng(-10, 10)
|
1435 |
-
rotation = _rng(-25, 25)
|
1436 |
-
else:
|
1437 |
-
# Allowed to be very wonky
|
1438 |
-
# if tf.random.stateless_uniform((), seeds.pop(), 0, 1) > 0.8:
|
1439 |
-
# image = image[:, ::-1]
|
1440 |
-
|
1441 |
-
if _rng() > 0.5:
|
1442 |
-
shear_x = _rng( -30, 30)
|
1443 |
-
shear_y = _rng( -30, 30)
|
1444 |
-
else:
|
1445 |
-
shear_x = _rng( -10, 10)
|
1446 |
-
shear_y = _rng( -10, 10)
|
1447 |
-
rng = _rng( 0, 1)
|
1448 |
-
if rng < 0.2:
|
1449 |
-
rotation = _rng( -25, 25)
|
1450 |
-
elif rng < 0.6:
|
1451 |
-
rotation = _rng( -80, 80)
|
1452 |
-
else:
|
1453 |
-
rotation = _rng( -180, 180)
|
1454 |
-
|
1455 |
-
if _rng() > 0.5:
|
1456 |
-
scale = _rng( 0.3, 2)
|
1457 |
-
else:
|
1458 |
-
scale = _rng( 0.3, 1)
|
1459 |
-
# Pad so upscaling/rotation will not move the image out of bounds
|
1460 |
-
pad = tf.cast(tf.maximum(height, width)*0.5, tf.int32)
|
1461 |
-
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
|
1462 |
-
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1463 |
-
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1464 |
-
|
1465 |
-
image = tf.keras.ops.image.affine_transform(
|
1466 |
-
image,
|
1467 |
-
tf.stack(get_affine_matrix(
|
1468 |
-
[height/2, width/2],
|
1469 |
-
rotation,
|
1470 |
-
[0, 0],
|
1471 |
-
1/scale,
|
1472 |
-
[shear_x, shear_y]
|
1473 |
-
) + [0., 0.]),
|
1474 |
-
interpolation='bilinear',
|
1475 |
-
fill_mode='constant',
|
1476 |
-
fill_value=1.,
|
1477 |
-
data_format='channels_last'
|
1478 |
-
)
|
1479 |
-
|
1480 |
-
# Crop, otherwise it would be impossible to put the image at the corner of the image
|
1481 |
-
not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
|
1482 |
-
no_white_ix = tf.where(not_white)
|
1483 |
-
top_left = tf.reduce_min(no_white_ix, axis=0)
|
1484 |
-
bottom_right = tf.reduce_max(no_white_ix, axis=0)
|
1485 |
-
image = tf.image.crop_to_bounding_box(
|
1486 |
-
image,
|
1487 |
-
offset_height=tf.cast(top_left[0], tf.int32),
|
1488 |
-
offset_width=tf.cast(top_left[1], tf.int32),
|
1489 |
-
target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
|
1490 |
-
target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
|
1491 |
-
)
|
1492 |
-
|
1493 |
-
# Translate
|
1494 |
-
height, width = tf.shape(image)[0], tf.shape(image)[1]
|
1495 |
-
translation_seed = _rng(0, 1)
|
1496 |
-
if translation_seed < 0.2:
|
1497 |
-
h_pad = _rng(0, height//2, (2,), dtype=tf.int32)
|
1498 |
-
w_pad = _rng(0, width//2, (2,), dtype=tf.int32)
|
1499 |
-
else:
|
1500 |
-
h_pad = _rng(0, height*2, (2,), dtype=tf.int32)
|
1501 |
-
w_pad = _rng(0, width*2, (2,), dtype=tf.int32)
|
1502 |
-
image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
|
1503 |
-
constant_values=1)
|
1504 |
-
|
1505 |
-
# Random background color
|
1506 |
-
# color_rng = tf.random.stateless_uniform((4,), seeds.pop(), 0, 1)
|
1507 |
-
# random_color = color_rng[:3]
|
1508 |
-
# valid = tf.reduce_all(tf.reduce_sum(tf.abs(random_color[None, None, :] - image), -1) > 0.03)
|
1509 |
-
# if color_rng[0] < 0.2 and valid:
|
1510 |
-
# image = tf.where(tf.reduce_all(image < 0.99, axis=-1, keepdims=True),
|
1511 |
-
# image, image * 0 + random_color[None, None, :])
|
1512 |
-
|
1513 |
-
# Mild color hitter
|
1514 |
-
image = tf.image.stateless_random_hue(image, max_delta=0.05, seed=seeds.pop())
|
1515 |
-
image = tf.image.stateless_random_brightness(image, max_delta=0.15, seed=seeds.pop())
|
1516 |
-
image = tf.image.stateless_random_saturation(image, 0.8, 1.2, seed=seeds.pop())
|
1517 |
-
image = tf.image.stateless_random_contrast(image, 0.8, 1.2, seed=seeds.pop())
|
1518 |
-
|
1519 |
-
# ex["metadata/unaugmented_image"] = ex["image"]
|
1520 |
-
ex["image"] = image
|
1521 |
-
return ex
|
1522 |
-
|
1523 |
-
|
1524 |
-
@seqio.map_over_dataset
|
1525 |
-
def clocks_preprocessor(ex):
|
1526 |
-
time_format = ex["time_format"]
|
1527 |
-
shows_seconds = ex["shows_seconds"]
|
1528 |
-
hour, minute, second = [tf.cast(ex[k], tf.int32) for k in ["hour", "minute", "second"]]
|
1529 |
-
if hour == 0: # Midnight of the previous day
|
1530 |
-
am_pm = "PM"
|
1531 |
-
hour_str = 12
|
1532 |
-
hour = 24
|
1533 |
-
elif hour > 12:
|
1534 |
-
am_pm = "PM"
|
1535 |
-
hour_str = hour - 12
|
1536 |
-
else:
|
1537 |
-
hour_str = hour
|
1538 |
-
am_pm = "AM"
|
1539 |
-
hour_str = tf.strings.as_string(hour_str)
|
1540 |
-
minute_str = tf.strings.as_string(minute)
|
1541 |
-
if tf.strings.length(minute_str) == 1:
|
1542 |
-
minute_str = tf.strings.join(["0", minute_str])
|
1543 |
-
|
1544 |
-
second_str = tf.strings.as_string(second)
|
1545 |
-
if tf.strings.length(second_str) == 1:
|
1546 |
-
second_str = tf.strings.join(["0", second_str])
|
1547 |
-
|
1548 |
-
prefix = "The time shown is "
|
1549 |
-
|
1550 |
-
if time_format == "The time is not shown":
|
1551 |
-
text = "The time is not shown in the image."
|
1552 |
-
hour, minute, second = -1, -1, -1
|
1553 |
-
else:
|
1554 |
-
if not shows_seconds:
|
1555 |
-
second = -1
|
1556 |
-
if time_format == "12 hour clock (without AM/PM)" and shows_seconds:
|
1557 |
-
if hour > 12:
|
1558 |
-
hour = hour - 12
|
1559 |
-
time = tf.strings.join([hour_str, ":", minute_str, ":", second_str])
|
1560 |
-
elif time_format == "12 hour clock (with AM/PM)" and shows_seconds:
|
1561 |
-
time = tf.strings.join([hour_str, ":", minute_str, ":", second_str, " ", am_pm])
|
1562 |
-
elif time_format == "12 hour clock (with AM/PM)" and not shows_seconds:
|
1563 |
-
time = tf.strings.join([hour_str, ":", minute_str, " ", am_pm])
|
1564 |
-
elif time_format == "12 hour clock (without AM/PM)" and not shows_seconds:
|
1565 |
-
if hour > 12:
|
1566 |
-
hour = hour - 12
|
1567 |
-
time = tf.strings.join([hour_str, ":", minute_str])
|
1568 |
-
else:
|
1569 |
-
time = "" # Should never occur, but needed for tf analysis
|
1570 |
-
tf.debugging.assert_equal(tf.strings.length(time) > 0, True)
|
1571 |
-
text = tf.strings.join(["The time shown is ", time])
|
1572 |
-
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1573 |
-
image = tf.image.convert_image_dtype(image, tf.float32)[:-120] # remove the black shadow at the bottom
|
1574 |
-
return {
|
1575 |
-
"image": image,
|
1576 |
-
"prompt": "What time is being shown?",
|
1577 |
-
"text": text,
|
1578 |
-
"metadata/time_format": time_format,
|
1579 |
-
"metadata/hour": hour,
|
1580 |
-
"metadata/minute": minute,
|
1581 |
-
"metadata/text": text,
|
1582 |
-
"metadata/second": second,
|
1583 |
-
}
|
1584 |
-
|
1585 |
-
|
1586 |
-
@seqio.map_over_dataset()
|
1587 |
-
def atlas_obscura_preprocessor(ex):
|
1588 |
-
out = dict(
|
1589 |
-
image=ex["image"],
|
1590 |
-
prompt="Where was this picture taken?",
|
1591 |
-
text=tf.strings.join([
|
1592 |
-
ex["place"],
|
1593 |
-
" in ",
|
1594 |
-
ex["city"]
|
1595 |
-
])
|
1596 |
-
)
|
1597 |
-
out["metadata/image_url"] = ex["image_url"]
|
1598 |
-
out["metadata/references"] = out["text"]
|
1599 |
-
return out
|
1600 |
-
|
1601 |
-
|
1602 |
-
@seqio.map_over_dataset()
|
1603 |
-
def famous_birthdays_preprocessor(ex):
|
1604 |
-
out = dict(
|
1605 |
-
image=ex["image"],
|
1606 |
-
image_url=ex["image_url"],
|
1607 |
-
prompt="Who is this?",
|
1608 |
-
text=ex["name"]
|
1609 |
-
)
|
1610 |
-
out["metadata/references"] = out["text"]
|
1611 |
-
return out
|
1612 |
-
|
1613 |
-
|
1614 |
-
@seqio.map_over_dataset()
|
1615 |
-
def mild_color_aug_preprocessor(ex):
|
1616 |
-
if "image_url" in ex: # URL won't show the augmentations
|
1617 |
-
del ex["image_url"]
|
1618 |
-
# ex["metadata/unaugmented_image"] = ex["image"]
|
1619 |
-
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1620 |
-
ex["image"] = mild_color_aug(ex["image"])
|
1621 |
-
return ex
|
1622 |
-
|
1623 |
-
|
1624 |
-
def build_text_with_points(text, points, img_h, img_w):
|
1625 |
-
points = points_to_text(img_h, img_w, points[:, 0], points[:, 1])
|
1626 |
-
parts = tf.strings.split(text, sep="<ANS>")
|
1627 |
-
with_points = tf.strings.reduce_join(tf.reshape(tf.stack([
|
1628 |
-
parts,
|
1629 |
-
tf.pad(points, [[0, 1]], constant_values=""),
|
1630 |
-
], 1), [-1]), separator="")
|
1631 |
-
return tf.strings.split(with_points, "\n\n")
|
1632 |
-
|
1633 |
-
|
1634 |
-
@seqio.map_over_dataset()
|
1635 |
-
def synth_count_preprocessor(example):
|
1636 |
-
image_shape = tf.shape(example["image"])
|
1637 |
-
h, w = image_shape[0], image_shape[1]
|
1638 |
-
questions = build_text_with_points(example["questions"], example["question_points"], h, w)
|
1639 |
-
answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
|
1640 |
-
keep_q = tf.strings.regex_full_match(questions, "How many.*")
|
1641 |
-
keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
|
1642 |
-
keep = tf.logical_and(keep_q, keep_ans)
|
1643 |
-
questions = tf.boolean_mask(questions, keep)
|
1644 |
-
answers = tf.boolean_mask(answers, keep)
|
1645 |
-
ix = tf.range(0, tf.shape(answers)[0], dtype=tf.int32)
|
1646 |
-
ix = tf.random.shuffle(ix)
|
1647 |
-
return dict(
|
1648 |
-
image=example["image"],
|
1649 |
-
prompt=tf.gather(questions, ix),
|
1650 |
-
text=tf.gather(answers, ix),
|
1651 |
-
)
|
1652 |
-
|
1653 |
-
|
1654 |
-
def synth_count_inf_preprocessor(ds):
|
1655 |
-
|
1656 |
-
@seqio.map_over_dataset(num_seeds=1)
|
1657 |
-
def get_two(example, seed):
|
1658 |
-
image_shape = tf.shape(example["image"])
|
1659 |
-
h, w = image_shape[0], image_shape[1]
|
1660 |
-
questions = build_text_with_points(example["questions"], example["question_points"], h, w)
|
1661 |
-
answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
|
1662 |
-
keep_q = tf.strings.regex_full_match(questions, "How many.*")
|
1663 |
-
keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
|
1664 |
-
keep = tf.logical_and(keep_q, keep_ans)
|
1665 |
-
questions = tf.boolean_mask(questions, keep)
|
1666 |
-
answers = tf.boolean_mask(answers, keep)
|
1667 |
-
|
1668 |
-
ix = stateless_permutation(tf.shape(answers)[0], seed)[:2]
|
1669 |
-
return {
|
1670 |
-
"image": example["image"],
|
1671 |
-
"prompt": tf.gather(questions, ix),
|
1672 |
-
"metadata/references": tf.gather(answers, ix),
|
1673 |
-
}
|
1674 |
-
|
1675 |
-
ds = get_two(ds)
|
1676 |
-
return flatten_parts(ds, ["prompt", "metadata/references"])
|
1677 |
-
|
1678 |
-
|
1679 |
-
def mild_color_aug(image):
|
1680 |
-
image = tf.image.random_hue(image, max_delta=0.05)
|
1681 |
-
image = tf.image.random_brightness(image, max_delta=0.15)
|
1682 |
-
image = tf.image.random_saturation(image, 0.7, 1.3)
|
1683 |
-
image = tf.image.random_contrast(image, 0.8, 1.2)
|
1684 |
-
return image
|
1685 |
-
|
1686 |
-
|
1687 |
-
@seqio.map_over_dataset()
|
1688 |
-
def name_entity_augmentation(ex, p_high_color=0.7):
|
1689 |
-
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1690 |
-
image = ex["image"]
|
1691 |
-
image = tf.image.convert_image_dtype(image, tf.float32)
|
1692 |
-
|
1693 |
-
# Horizontal flip
|
1694 |
-
if tf.random.uniform((), 0, 1) > 0.85:
|
1695 |
-
image = image[:, ::-1]
|
1696 |
-
|
1697 |
-
# Random crop
|
1698 |
-
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1699 |
-
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1700 |
-
crop_rng = tf.random.uniform((), 0, 1)
|
1701 |
-
if crop_rng < 0.2:
|
1702 |
-
pass
|
1703 |
-
else:
|
1704 |
-
if crop_rng < 0.4:
|
1705 |
-
h_crop = height * 0.15
|
1706 |
-
w_crop = width * 0.15
|
1707 |
-
else:
|
1708 |
-
h_crop = height * 0.4
|
1709 |
-
w_crop = width * 0.4
|
1710 |
-
crop_h = tf.cast(tf.random.uniform((2,), 0, h_crop/2), tf.int32)
|
1711 |
-
crop_w = tf.cast(tf.random.uniform((2,), 0, w_crop/2), tf.int32)
|
1712 |
-
image = image[crop_h[0]:-crop_h[1]-1, crop_w[0]:-crop_w[1]-1]
|
1713 |
-
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1714 |
-
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1715 |
-
|
1716 |
-
if tf.random.uniform(()) > p_high_color:
|
1717 |
-
image = tf.image.random_hue(image, max_delta=0.05)
|
1718 |
-
image = tf.image.random_brightness(image, max_delta=0.15)
|
1719 |
-
image = tf.image.random_saturation(image, 0.7, 1.3)
|
1720 |
-
image = tf.image.random_contrast(image, 0.8, 1.2)
|
1721 |
-
else:
|
1722 |
-
image = tf.image.random_hue(image, max_delta=0.1)
|
1723 |
-
image = tf.image.random_brightness(image, max_delta=0.3)
|
1724 |
-
image = tf.image.random_saturation(image, 0.0, 2.0)
|
1725 |
-
image = tf.image.random_contrast(image, 0.2, 1.5)
|
1726 |
-
|
1727 |
-
# Apply shear, rotation, and scale through one affine matrix
|
1728 |
-
sel = tf.random.uniform((), 0, 1)
|
1729 |
-
if sel < 0.1:
|
1730 |
-
pass
|
1731 |
-
else:
|
1732 |
-
if sel < 0.15: # Scale only
|
1733 |
-
shear_x = 0
|
1734 |
-
shear_y = 0
|
1735 |
-
rotation = 0
|
1736 |
-
if sel < 0.7: # Mild
|
1737 |
-
shear_x = tf.random.uniform((), -2, 2)
|
1738 |
-
shear_y = tf.random.uniform((), -2, 2)
|
1739 |
-
rotation = tf.random.uniform((), -5, 5)
|
1740 |
-
else: # Severe
|
1741 |
-
shear_x = tf.random.uniform((), -10, 10)
|
1742 |
-
shear_y = tf.random.uniform((), -10, 10)
|
1743 |
-
rotation = tf.random.uniform((), -20, 20)
|
1744 |
-
|
1745 |
-
max_scale = 1.2
|
1746 |
-
scale = tf.random.uniform((), 0.4, max_scale)
|
1747 |
-
|
1748 |
-
# Pad so upscaling/rotation will not move the image out of bounds
|
1749 |
-
pad = tf.cast(tf.maximum(height, width)*0.2, tf.int32)
|
1750 |
-
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
|
1751 |
-
|
1752 |
-
image = tf.keras.ops.image.affine_transform(
|
1753 |
-
image,
|
1754 |
-
tf.stack(get_affine_matrix(
|
1755 |
-
[height/2, width/2],
|
1756 |
-
rotation,
|
1757 |
-
[0, 0],
|
1758 |
-
1/scale,
|
1759 |
-
[shear_x, shear_y]
|
1760 |
-
) + [0., 0.]),
|
1761 |
-
interpolation='bilinear',
|
1762 |
-
fill_mode='constant',
|
1763 |
-
fill_value=1.,
|
1764 |
-
data_format='channels_last'
|
1765 |
-
)
|
1766 |
-
|
1767 |
-
# Crop, otherwise it would be impossible to put the image at the corner of the image
|
1768 |
-
not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
|
1769 |
-
no_white_ix = tf.where(not_white)
|
1770 |
-
top_left = tf.reduce_min(no_white_ix, axis=0)
|
1771 |
-
bottom_right = tf.reduce_max(no_white_ix, axis=0)
|
1772 |
-
|
1773 |
-
# Very low chance center crop will get nothing but white space, we just skip
|
1774 |
-
if (
|
1775 |
-
(bottom_right[0] - top_left[0]) > 1 and (bottom_right[1] - top_left[1]) > 1
|
1776 |
-
):
|
1777 |
-
image = tf.image.crop_to_bounding_box(
|
1778 |
-
image,
|
1779 |
-
offset_height=tf.cast(top_left[0], tf.int32),
|
1780 |
-
offset_width=tf.cast(top_left[1], tf.int32),
|
1781 |
-
target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
|
1782 |
-
target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
|
1783 |
-
)
|
1784 |
-
|
1785 |
-
# Translate
|
1786 |
-
height, width = tf.shape(image)[0], tf.shape(image)[1]
|
1787 |
-
if tf.random.uniform((), 0, 1) < 0.1:
|
1788 |
-
h_pad = tf.zeros((2,), dtype=tf.int32)
|
1789 |
-
w_pad = tf.zeros((2,), dtype=tf.int32)
|
1790 |
-
elif tf.random.uniform((), 0, 1) < 0.8:
|
1791 |
-
h_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
|
1792 |
-
w_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
|
1793 |
-
else:
|
1794 |
-
pad = tf.cast(tf.maximum(height, width), tf.int32)
|
1795 |
-
h_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
|
1796 |
-
w_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
|
1797 |
-
image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
|
1798 |
-
constant_values=1)
|
1799 |
-
|
1800 |
-
if "image_url" in ex: # URL won't show the augmentations
|
1801 |
-
del ex["image_url"]
|
1802 |
-
# ex["metadata/unaugmented_image"] = ex["image"]
|
1803 |
-
ex["image"] = image
|
1804 |
-
return ex
|
1805 |
-
|
1806 |
-
|
1807 |
-
@seqio.map_over_dataset()
|
1808 |
-
def wiki_art_preprocessor(ex):
|
1809 |
-
out = dict(
|
1810 |
-
image=ex["image"],
|
1811 |
-
prompt="What is this?",
|
1812 |
-
text=ex["question"]
|
1813 |
-
)
|
1814 |
-
out["metadata/title"] = ex["title"]
|
1815 |
-
out["metadata/gt"] = ex["question"]
|
1816 |
-
out["metadata/artist"] = ex["artist"]
|
1817 |
-
out["metadata/painting_url"] = ex["painting_url"]
|
1818 |
-
# if "metadata/unaugmented_image" in ex:
|
1819 |
-
# out["metadata/unaugmented_image"] = ex["metadata/unaugmented_image"]
|
1820 |
-
return out
|
1821 |
-
|
1822 |
-
@seqio.map_over_dataset()
|
1823 |
-
def oscar_preprocessor(ex):
|
1824 |
-
out = dict(
|
1825 |
-
image=ex["image"],
|
1826 |
-
prompt=ex["question"]
|
1827 |
-
)
|
1828 |
-
out.update(_add_metadata(ex))
|
1829 |
-
out["metadata/question"] = ex["question"]
|
1830 |
-
out["metadata/answer"] = ex["answer"]
|
1831 |
-
out["metadata/category"] = ex["category"]
|
1832 |
-
return out
|
1833 |
-
|
1834 |
-
|
1835 |
-
@seqio.map_over_dataset()
|
1836 |
-
def tulu_preprocessor(ex):
|
1837 |
-
return {
|
1838 |
-
"messages": ex["messages"]["content"],
|
1839 |
-
}
|
1840 |
-
# logging.info("Debugging tulue")
|
1841 |
-
# return {"messages": ex["messages"]["content"], "text_weights": 1e-6}
|
1842 |
-
|
1843 |
-
|
1844 |
-
WIKI_DATA_QUESTION = "What is this? Respond with just a proper name."
|
1845 |
-
|
1846 |
-
|
1847 |
-
@seqio.map_over_dataset()
|
1848 |
-
def extract_wiki_data(ex):
|
1849 |
-
return dict(
|
1850 |
-
image=ex["image"],
|
1851 |
-
image_url=ex["image_url"],
|
1852 |
-
prompt=[
|
1853 |
-
WIKI_DATA_QUESTION,
|
1854 |
-
"What is this? Respond with the proper name of the main focus of the image and a few details about it."
|
1855 |
-
],
|
1856 |
-
text=[
|
1857 |
-
tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", "")),
|
1858 |
-
ex["gptResponse"],
|
1859 |
-
]
|
1860 |
-
)
|
1861 |
-
|
1862 |
-
|
1863 |
-
@seqio.map_over_dataset()
|
1864 |
-
def extract_wiki_data_name(ex):
|
1865 |
-
target = tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", ""))
|
1866 |
-
out = dict(
|
1867 |
-
image=ex["image"],
|
1868 |
-
image_url=ex["image_url"],
|
1869 |
-
prompt=WIKI_DATA_QUESTION,
|
1870 |
-
text=target,
|
1871 |
-
)
|
1872 |
-
out["metadata/references"] = target
|
1873 |
-
return out
|
1874 |
-
|
1875 |
-
|
1876 |
-
@seqio.map_over_dataset()
|
1877 |
-
def extract_wiki_data_describe(ex):
|
1878 |
-
out = dict(
|
1879 |
-
image=ex["image"],
|
1880 |
-
image_url=ex["image_url"],
|
1881 |
-
prompt="What is this? Respond with the proper name of the main focus of the image and a few details about it.",
|
1882 |
-
)
|
1883 |
-
out["metadata/references"] = ex["gptResponse"]
|
1884 |
-
return out
|
1885 |
-
|
1886 |
-
|
1887 |
-
@gin.configurable()
|
1888 |
-
def format_multiple_style_qa(ds, types=['multiple_choice', 'short_answer'], styles=['ai2_diagram', 'vqa2'], default_style='vqa2',
|
1889 |
-
strip_instruction=False):
|
1890 |
-
def _extract(ex):
|
1891 |
-
prompt = ex["question"]
|
1892 |
-
out = dict(image=ex["image"])
|
1893 |
-
out.update(_add_metadata(ex))
|
1894 |
-
|
1895 |
-
out["text"] = ex["answer"]
|
1896 |
-
out["metadata/references"] = ex["answer"]
|
1897 |
-
|
1898 |
-
if ex["metadata/question_type"] == 'multiple_choice':
|
1899 |
-
style = styles[0]
|
1900 |
-
else:
|
1901 |
-
style = styles[1]
|
1902 |
-
if strip_instruction:
|
1903 |
-
if ex["metadata/question_type"] == "multiple_choice":
|
1904 |
-
# parts = tf.strings.split(prompt, "\n")
|
1905 |
-
# parts 1 is blank and part -1 is the instruction
|
1906 |
-
# prompt = tf.strings.reduce_join(tf.concat([parts[:1], parts[2:-1]], 0), separator="\n")
|
1907 |
-
prompt = prompt
|
1908 |
-
else:
|
1909 |
-
prompt = tf.strings.split(prompt, "\n")[0]
|
1910 |
-
|
1911 |
-
out["style"] = style
|
1912 |
-
out["prompt"] = prompt
|
1913 |
-
return out
|
1914 |
-
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
1915 |
-
return ds
|
1916 |
-
|
1917 |
-
|
1918 |
-
@gin.configurable()
|
1919 |
-
def extract_mmmu(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
|
1920 |
-
assert option_format == "abc"
|
1921 |
-
keys_tensor = tf.constant(types, dtype=tf.string)
|
1922 |
-
values_tensor = tf.constant(styles, dtype=tf.string)
|
1923 |
-
table = tf.lookup.StaticHashTable(
|
1924 |
-
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
|
1925 |
-
default_value=tf.constant(default_style, dtype=tf.string),
|
1926 |
-
)
|
1927 |
-
def _extract(ex):
|
1928 |
-
out = dict(image=tf.expand_dims(ex["image_1"], 0))
|
1929 |
-
out.update(_add_metadata(ex))
|
1930 |
-
style = table.lookup(ex["metadata/question_type"])
|
1931 |
-
out["style"] = style
|
1932 |
-
out["text"] = ex["answer"]
|
1933 |
-
out["metadata/references"] = ex["answer"]
|
1934 |
-
|
1935 |
-
if style == styles[0]:
|
1936 |
-
abc = tf.constant(list("abcdefghi".upper()))
|
1937 |
-
options = ex["options"]
|
1938 |
-
num_options = tf.shape(options)[0]
|
1939 |
-
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
|
1940 |
-
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
|
1941 |
-
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
|
1942 |
-
|
1943 |
-
short_options = abc[:num_options]
|
1944 |
-
options = tf.stack([short_options, options,], 1)
|
1945 |
-
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
1946 |
-
options = tf.strings.reduce_join(options, separator="\n")
|
1947 |
-
out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
|
1948 |
-
if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
|
1949 |
-
# Following LLaVa, don't use any images if there are multiple images paths
|
1950 |
-
# I think the rationale is that this means the image are answer-options
|
1951 |
-
out["image"] = out["image"][:0]
|
1952 |
-
else:
|
1953 |
-
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
|
1954 |
-
out["prompt"] = ex["question"]
|
1955 |
-
out["image"] = out["image"][:0]
|
1956 |
-
return out
|
1957 |
-
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
1958 |
-
return ds
|
1959 |
-
|
1960 |
-
@gin.configurable()
|
1961 |
-
def extract_mmmu_cot(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
|
1962 |
-
assert option_format == "abc"
|
1963 |
-
keys_tensor = tf.constant(types, dtype=tf.string)
|
1964 |
-
values_tensor = tf.constant(styles, dtype=tf.string)
|
1965 |
-
table = tf.lookup.StaticHashTable(
|
1966 |
-
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
|
1967 |
-
default_value=tf.constant(default_style, dtype=tf.string),
|
1968 |
-
)
|
1969 |
-
def _extract(ex):
|
1970 |
-
# out = dict(image=tf.expand_dims(ex["image_with_question"], 0))
|
1971 |
-
out = dict(image=tf.expand_dims(ex["image_1"], 0))
|
1972 |
-
out.update(_add_metadata(ex))
|
1973 |
-
style = table.lookup(ex["metadata/question_type"])
|
1974 |
-
# out["style"] = style
|
1975 |
-
out["text"] = ex["answer"]
|
1976 |
-
out["metadata/question"] = ex["question"]
|
1977 |
-
out["metadata/references"] = ex["answer"]
|
1978 |
-
|
1979 |
-
if style == styles[0]:
|
1980 |
-
abc = tf.constant(list("abcdefghi".upper()))
|
1981 |
-
options = ex["options"]
|
1982 |
-
num_options = tf.shape(options)[0]
|
1983 |
-
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
|
1984 |
-
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
|
1985 |
-
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
|
1986 |
-
|
1987 |
-
short_options = abc[:num_options]
|
1988 |
-
options = tf.stack([short_options, options,], 1)
|
1989 |
-
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
1990 |
-
options = tf.strings.reduce_join(options, separator="\n")
|
1991 |
-
out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
|
1992 |
-
# out["prompt"] = ex["question"]
|
1993 |
-
if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
|
1994 |
-
# Following LLaVa, don't use any images if there are multiple images paths
|
1995 |
-
# I think the rationale is that this means the image are answer-options
|
1996 |
-
out["image"] = out["image"][:0]
|
1997 |
-
else:
|
1998 |
-
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
|
1999 |
-
out["prompt"] = ex["question"]
|
2000 |
-
# out["image"] = out["image"][:0]
|
2001 |
-
return out
|
2002 |
-
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
2003 |
-
return ds
|
2004 |
-
|
2005 |
-
|
2006 |
-
@seqio.map_over_dataset
|
2007 |
-
def reformat_math_vista(ex):
|
2008 |
-
query = ex["query"]
|
2009 |
-
query = tf.strings.split(query, sep="Question:")[-1]
|
2010 |
-
query = tf.strings.strip(tf.strings.split(query, sep="Hint:")[0])
|
2011 |
-
ex["query"] = query
|
2012 |
-
return ex
|
2013 |
-
|
2014 |
-
|
2015 |
-
@seqio.map_over_dataset
|
2016 |
-
def extract_math_vista(ex, styles=['ai2_diagram', 'vqa2']):
|
2017 |
-
out = dict(image=ex["image"])
|
2018 |
-
out.update(_add_metadata(ex))
|
2019 |
-
|
2020 |
-
is_mc = ex["metadata/question_type"] == 'multi_choice'
|
2021 |
-
if is_mc:
|
2022 |
-
style = styles[0]
|
2023 |
-
abc = tf.constant(list("abcdefghi".upper()))
|
2024 |
-
options = ex["choices"]
|
2025 |
-
num_options = tf.shape(options)[0]
|
2026 |
-
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
|
2027 |
-
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
|
2028 |
-
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
|
2029 |
-
|
2030 |
-
if ex["metadata/split"] != "test":
|
2031 |
-
short_options = abc[:num_options]
|
2032 |
-
answer_short_option = tf.boolean_mask(short_options, options == ex["answer"])[0]
|
2033 |
-
out["text"] = answer_short_option
|
2034 |
-
else:
|
2035 |
-
out["text"] = ex["answer"]
|
2036 |
-
else:
|
2037 |
-
style = styles[1]
|
2038 |
-
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
|
2039 |
-
out["text"] = ex["answer"]
|
2040 |
-
out["style"] = style
|
2041 |
-
out["prompt"] = ex["query"]
|
2042 |
-
out["metadata/query"] = ex["query"]
|
2043 |
-
out["metadata/references"] = ex["answer"]
|
2044 |
-
return out
|
2045 |
-
|
2046 |
-
|
2047 |
-
NO_POINT_PREFIX = [
|
2048 |
-
"No pointing: ",
|
2049 |
-
"No pointing: ",
|
2050 |
-
"no pointing:\n",
|
2051 |
-
"No pointing:\n",
|
2052 |
-
"Not pointing:\n",
|
2053 |
-
"No Points: ",
|
2054 |
-
"No Points: ",
|
2055 |
-
"NO POINTING\n",
|
2056 |
-
"No pontiing\n",
|
2057 |
-
"No Points:\n ",
|
2058 |
-
"No pointing\n",
|
2059 |
-
"Do not point. ",
|
2060 |
-
"Refrain from pointing. ",
|
2061 |
-
"Avoid generating points . ",
|
2062 |
-
"For this question, do not use points. ",
|
2063 |
-
"Refrain from using points:\n",
|
2064 |
-
"Don't include points in your response. ",
|
2065 |
-
"Don't point. ",
|
2066 |
-
"Don't use points. ",
|
2067 |
-
"Please don't use points.\n\n",
|
2068 |
-
"Please don't use points.\n\n",
|
2069 |
-
"Respond without using points. ",
|
2070 |
-
"Respond without pointing:\n",
|
2071 |
-
"Do not generate ponits: ",
|
2072 |
-
"Do not point. ",
|
2073 |
-
"Do not point\n",
|
2074 |
-
"no pointing\n\n",
|
2075 |
-
"Answer without points: ",
|
2076 |
-
"Answer this question without pointing: ",
|
2077 |
-
"Answer without poiints. ",
|
2078 |
-
"answer without points: ",
|
2079 |
-
"answer with text only, do not points\n"
|
2080 |
-
]
|
2081 |
-
assert all(x[-1].isspace() for x in NO_POINT_PREFIX)
|
2082 |
-
NO_POINT_PREFIX_TF = tf.constant(NO_POINT_PREFIX)
|
2083 |
-
|
2084 |
-
|
2085 |
-
def prefix_how_many(messages, seed):
|
2086 |
-
question = messages[0]
|
2087 |
-
if tf.strings.regex_full_match(tf.strings.lower(question), "how many.*"):
|
2088 |
-
ix = tf.random.stateless_uniform((), seed, 0, len(NO_POINT_PREFIX), tf.int32)
|
2089 |
-
question = tf.strings.join([NO_POINT_PREFIX_TF[ix], question])
|
2090 |
-
return tf.concat([tf.expand_dims(question, 0), messages[1:]], axis=0)
|
2091 |
-
else:
|
2092 |
-
return messages
|
2093 |
-
|
2094 |
-
|
2095 |
-
@seqio.map_over_dataset(num_seeds=1)
|
2096 |
-
def prefix_how_many_messages(ex, seed):
|
2097 |
-
messages = ex["messages"]
|
2098 |
-
n = tf.shape(messages)[0]
|
2099 |
-
seeds = tf.random.split(seed, n)
|
2100 |
-
message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
|
2101 |
-
for i in range(n):
|
2102 |
-
message_arr = message_arr.write(i, prefix_how_many(messages[i], seeds[i]))
|
2103 |
-
ex["messages"] = tf.RaggedTensor.from_row_splits(
|
2104 |
-
values=message_arr.concat(), row_splits=messages.row_splits)
|
2105 |
-
return ex
|
2106 |
-
|
2107 |
-
|
2108 |
-
def filter_single_turn(ds):
|
2109 |
-
@seqio.map_over_dataset
|
2110 |
-
def _filter(ex):
|
2111 |
-
multi_turn = ex["messages"].row_lengths() > 2
|
2112 |
-
ex["messages"] = tf.ragged.boolean_mask(ex["messages"], multi_turn)
|
2113 |
-
return ex
|
2114 |
-
|
2115 |
-
ds = _filter(ds)
|
2116 |
-
ds = ds.filter(lambda x: tf.shape(x["messages"])[0] > 0)
|
2117 |
-
return ds
|
2118 |
-
|
2119 |
-
|
2120 |
-
@seqio.map_over_dataset(num_seeds=1)
|
2121 |
-
def extract_cockatoo_qa_v2(ex, seed):
|
2122 |
-
messages = tf.RaggedTensor.from_value_rowids(ex["messages"], ex["conversation_ids"])
|
2123 |
-
ix = stateless_permutation(tf.shape(messages)[0], seed)
|
2124 |
-
messages = tf.gather(messages, ix)
|
2125 |
-
out = dict(
|
2126 |
-
image=ex["image"],
|
2127 |
-
messages=messages
|
2128 |
-
)
|
2129 |
-
out.update(_add_metadata(ex))
|
2130 |
-
return out
|
2131 |
-
|
2132 |
-
|
2133 |
-
def format_mmbench(ds):
|
2134 |
-
|
2135 |
-
def _trim(ex):
|
2136 |
-
num_passes = tf.shape(ex["id"])[0]
|
2137 |
-
ex["choices"] = ex["choices"][:num_passes, :num_passes]
|
2138 |
-
ex["answer"] = ex["answer"][:num_passes]
|
2139 |
-
return ex
|
2140 |
-
|
2141 |
-
ds = ds.map(_trim)
|
2142 |
-
ds = flatten_parts(ds, ["id", "query", "choices", "answer"])
|
2143 |
-
|
2144 |
-
def _extract(ex):
|
2145 |
-
out = dict(image=ex["image"])
|
2146 |
-
out.update(_add_metadata(ex))
|
2147 |
-
out["prompt"] = ex["query"]
|
2148 |
-
out["text"] = ex["answer"]
|
2149 |
-
options = ex["choices"]
|
2150 |
-
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
2151 |
-
out["metadata/options"] = tf.strings.reduce_join(options, separator="|||")
|
2152 |
-
out["metadata/question"] = ex["question"]
|
2153 |
-
out["metadata/references"] = ex["answer"]
|
2154 |
-
return out
|
2155 |
-
|
2156 |
-
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
2157 |
-
return ds
|
2158 |
-
|
2159 |
-
|
2160 |
-
@seqio.map_over_dataset
|
2161 |
-
def extract_lvis(ex, class_name_file="gs://oe-training-chrisc/cockatoo/data/lvis_class_names.json"):
|
2162 |
-
with tf.io.gfile.GFile(class_name_file) as f:
|
2163 |
-
class_names = json.load(f)
|
2164 |
-
class_names_arr = [None]*len(class_names)
|
2165 |
-
for k, v in class_names.items():
|
2166 |
-
class_names_arr[int(k)] = v
|
2167 |
-
assert all(x is not None for x in class_names_arr)
|
2168 |
-
class_names_arr = tf.constant(class_names_arr)
|
2169 |
-
|
2170 |
-
return dict(
|
2171 |
-
image=ex["image"],
|
2172 |
-
bbox=ex["objects"]["bbox"],
|
2173 |
-
label=tf.gather(class_names_arr, ex["objects"]["label"]),
|
2174 |
-
)
|
2175 |
-
|
2176 |
-
|
2177 |
-
def extract_open_images_boxes(ds):
|
2178 |
-
# ds = ds.filter(lambda ex: tf.logical_or(
|
2179 |
-
# tf.shape(ex["cap/cap_caption"])[0] > 0,
|
2180 |
-
# tf.shape(ex["detection/bbox"])[0] > 0
|
2181 |
-
# ))
|
2182 |
-
ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
|
2183 |
-
|
2184 |
-
@seqio.map_over_dataset
|
2185 |
-
def _map(ex):
|
2186 |
-
bbox = tf.reshape(ex["detection/bbox"], (-1, 4))
|
2187 |
-
bbox = tf.stack([
|
2188 |
-
bbox[:, 2],
|
2189 |
-
bbox[:, 0],
|
2190 |
-
bbox[:, 3],
|
2191 |
-
bbox[:, 1]
|
2192 |
-
], 1)
|
2193 |
-
return dict(
|
2194 |
-
image=tf.image.decode_jpeg(ex["image"]),
|
2195 |
-
bbox=bbox,
|
2196 |
-
label=ex["detection/label"],
|
2197 |
-
caption=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
|
2198 |
-
)
|
2199 |
-
|
2200 |
-
return _map(ds)
|
2201 |
-
|
2202 |
-
|
2203 |
-
@seqio.map_over_dataset
|
2204 |
-
def region_captions_to_dense(ex):
|
2205 |
-
if "captions" in ex:
|
2206 |
-
captions = ex["captions"]["text"]
|
2207 |
-
boxes = ex["captions"]["bbox"]
|
2208 |
-
else:
|
2209 |
-
captions = ex["label"]
|
2210 |
-
boxes = ex["bbox"]
|
2211 |
-
|
2212 |
-
|
2213 |
-
sh = tf.cast(tf.shape(ex["image"])[:2], tf.float32)
|
2214 |
-
# image_h, image_w = sh[0], sh[1]
|
2215 |
-
w = boxes[:, 2] - boxes[:, 0]
|
2216 |
-
h = boxes[:, 3] - boxes[:, 1]
|
2217 |
-
|
2218 |
-
cx = tf.cast(boxes[:, 0] + w/2, tf.float32)
|
2219 |
-
cy = tf.cast(boxes[:, 1] + h/2, tf.float32)
|
2220 |
-
# w = w / image_w
|
2221 |
-
# h = h / image_h
|
2222 |
-
coor = tf.strings.reduce_join(
|
2223 |
-
float_to_text(tf.stack([cx, cy, w, h], 1)), separator=",", axis=1)
|
2224 |
-
|
2225 |
-
area = w*h
|
2226 |
-
if tf.random.uniform(()) < 0.5:
|
2227 |
-
coor_text = "before"
|
2228 |
-
captions = tf.strings.join([coor, captions], separator=": ")
|
2229 |
-
else:
|
2230 |
-
coor_text = "after"
|
2231 |
-
captions = tf.strings.join([captions, coor], separator=": ")
|
2232 |
-
|
2233 |
-
ix = tf.random.uniform((), 0, 6, tf.int32)
|
2234 |
-
center = boxes
|
2235 |
-
if ix == 0:
|
2236 |
-
order_text = "left"
|
2237 |
-
sort_by = boxes[:, 0]
|
2238 |
-
elif ix == 1:
|
2239 |
-
order_text = "right"
|
2240 |
-
sort_by = -boxes[:, 2]
|
2241 |
-
elif ix == 2:
|
2242 |
-
order_text = "top"
|
2243 |
-
sort_by = boxes[:, 1]
|
2244 |
-
elif ix == 3:
|
2245 |
-
order_text = "bottom"
|
2246 |
-
sort_by = -boxes[:, 3]
|
2247 |
-
elif ix == 4:
|
2248 |
-
order_text = "largest"
|
2249 |
-
sort_by = area
|
2250 |
-
else:
|
2251 |
-
order_text = "smallest"
|
2252 |
-
sort_by = -area
|
2253 |
-
ixs = tf.argsort(sort_by)
|
2254 |
-
captions = tf.gather(captions, ixs)
|
2255 |
-
text = tf.strings.join([
|
2256 |
-
order_text,
|
2257 |
-
coor_text,
|
2258 |
-
tf.strings.reduce_join(captions, separator="\n")
|
2259 |
-
], separator="; ")
|
2260 |
-
|
2261 |
-
if "caption" in ex:
|
2262 |
-
if tf.random.uniform(()) > 0.5:
|
2263 |
-
text = tf.strings.join([text, "\ncaption: ", ex["caption"]])
|
2264 |
-
else:
|
2265 |
-
text = tf.strings.join(["caption: ", ex["caption"], "\n", text])
|
2266 |
-
|
2267 |
-
return dict(
|
2268 |
-
image=ex["image"],
|
2269 |
-
text=text
|
2270 |
-
)
|
2271 |
-
|
2272 |
-
|
2273 |
-
@seqio.map_over_dataset()
|
2274 |
-
def join_captions(ex):
|
2275 |
-
text = tf.random.shuffle(ex['text'])
|
2276 |
-
ex["text"] = tf.strings.reduce_join(text, separator="\n")
|
2277 |
-
return ex
|
2278 |
-
|
2279 |
-
|
2280 |
-
@seqio.map_over_dataset(num_seeds=1)
|
2281 |
-
def extract_figureqa(ex, seed):
|
2282 |
-
questions = ex["questions"]
|
2283 |
-
n = stateless_permutation(tf.shape(questions["question"])[0], seed)
|
2284 |
-
return dict(
|
2285 |
-
image=ex["image"],
|
2286 |
-
questions=tf.gather(questions["question"], n),
|
2287 |
-
question_id=tf.gather(questions["question_id"], n),
|
2288 |
-
answer=tf.gather(tf.strings.as_string(questions["answer"]), n)
|
2289 |
-
)
|
2290 |
-
|
2291 |
-
|
2292 |
-
@seqio.map_over_dataset
|
2293 |
-
def convert_figureqa_answer(ex):
|
2294 |
-
keys_tensor = tf.constant(["0", "1"])
|
2295 |
-
values_tensor = tf.constant(["no", "yes"])
|
2296 |
-
table = tf.lookup.StaticHashTable(
|
2297 |
-
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
|
2298 |
-
default_value=tf.constant("nan", dtype=tf.string),
|
2299 |
-
)
|
2300 |
-
answer = table.lookup(ex["answer"])
|
2301 |
-
ex["answer"] = answer
|
2302 |
-
return ex
|
2303 |
-
|
2304 |
-
|
2305 |
-
@seqio.map_over_dataset()
|
2306 |
-
def build_question_with_hint(ex):
|
2307 |
-
hint = ex["hint"]
|
2308 |
-
if tf.strings.length(hint) > 0:
|
2309 |
-
ex["question"] = tf.strings.join([hint, ex["question"]], separator="\n")
|
2310 |
-
return ex
|
2311 |
-
|
2312 |
-
@seqio.map_over_dataset()
|
2313 |
-
def build_question_with_context(ex):
|
2314 |
-
context = ex["context"]
|
2315 |
-
if tf.strings.length(context) > 0:
|
2316 |
-
ex["question"] = tf.strings.join([context, ex["question"]], separator="\n")
|
2317 |
-
return ex
|
2318 |
-
|
2319 |
-
|
2320 |
-
def max_words(ds, max_words):
|
2321 |
-
return ds.filter(lambda x: x["n_words"] <= max_words)
|
2322 |
-
|
2323 |
-
|
2324 |
-
@seqio.map_over_dataset
|
2325 |
-
def format_pdfa_eng_wds(example):
|
2326 |
-
return dict(
|
2327 |
-
image=example["image"],
|
2328 |
-
text=tf.strings.reduce_join(example["lines"]["text"], separator="\n"),
|
2329 |
-
)
|
2330 |
-
|
2331 |
-
|
2332 |
-
@gin.configurable()
|
2333 |
-
def accuracy_conditioned_joint(ds, sequence_length, is_eval=False, eval_quality=17,
|
2334 |
-
transcript_quality=None):
|
2335 |
-
# v2: Transcripts no longer get a quality score
|
2336 |
-
is_training = sequence_length.get('is_training', True)
|
2337 |
-
if not is_training:
|
2338 |
-
if is_eval:
|
2339 |
-
prompt = f"quality {eval_quality}:"
|
2340 |
-
else:
|
2341 |
-
prompt = f"quality 17:"
|
2342 |
-
|
2343 |
-
@seqio.map_over_dataset
|
2344 |
-
def _with_prompt(ex):
|
2345 |
-
out = dict(
|
2346 |
-
image=ex["image"],
|
2347 |
-
url=ex["url"],
|
2348 |
-
prompt=prompt,
|
2349 |
-
)
|
2350 |
-
if "text" in ex:
|
2351 |
-
out["text"] = ex["text"]
|
2352 |
-
elif "caption" in ex:
|
2353 |
-
out["text"] = ex["caption"]
|
2354 |
-
return out
|
2355 |
-
return _with_prompt(ds)
|
2356 |
-
|
2357 |
-
elif is_eval:
|
2358 |
-
raise ValueError("is_eval=True and is_training=False")
|
2359 |
-
|
2360 |
-
# each transcript
|
2361 |
-
@seqio.map_over_dataset
|
2362 |
-
def _with_transcript(ex):
|
2363 |
-
if tf.shape(ex["edited_captions"]["caption"])[0] > 0:
|
2364 |
-
edited_caption = ex["edited_captions"]["caption"][0]
|
2365 |
-
n = ex["edited_captions"]["n_edits"][0]
|
2366 |
-
else:
|
2367 |
-
edited_caption = ""
|
2368 |
-
n = 0
|
2369 |
-
text = [
|
2370 |
-
ex["caption"],
|
2371 |
-
ex["transcripts"][tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)],
|
2372 |
-
edited_caption
|
2373 |
-
]
|
2374 |
-
edit_quality = 17 - n
|
2375 |
-
prompt = [
|
2376 |
-
"quality 17:",
|
2377 |
-
"" if transcript_quality is None else f"quality: {edit_quality}:",
|
2378 |
-
tf.strings.join(["quality ", tf.strings.as_string(edit_quality), ":"])
|
2379 |
-
]
|
2380 |
-
return dict(
|
2381 |
-
image=ex["image"],
|
2382 |
-
text=tf.stack(text, 0),
|
2383 |
-
url=ex["url"],
|
2384 |
-
prompt=tf.stack(prompt, 0),
|
2385 |
-
style=["long_caption", "transcript", "long_caption"]
|
2386 |
-
)
|
2387 |
-
return _with_transcript(ds)
|
2388 |
-
|
2389 |
-
|
2390 |
-
def select_dense_caption_sample(ds, samples=200):
|
2391 |
-
def compute_hash(string: str) -> str:
|
2392 |
-
return hashlib.sha256(string.encode("utf-8")).hexdigest()
|
2393 |
-
|
2394 |
-
with tf.io.gfile.GFile("gs://oe-training-chrisc/cockatoo/data/dense-caption-eval-v0-final-data.json") as f:
|
2395 |
-
data = json.load(f)
|
2396 |
-
for ex in data:
|
2397 |
-
ex["image_id"] = compute_hash(ex["image"])
|
2398 |
-
data.sort(key=lambda x: x["image_id"])
|
2399 |
-
np.random.RandomState(12312).shuffle(data)
|
2400 |
-
keep = tf.constant([x["image"] for x in data[:samples]])
|
2401 |
-
|
2402 |
-
def _keep(ex):
|
2403 |
-
return tf.reduce_any(ex["url"] == keep)
|
2404 |
-
ds = ds.filter(_keep)
|
2405 |
-
ds = tf.data.experimental.assert_cardinality(samples)(ds)
|
2406 |
-
return ds
|
2407 |
-
|
2408 |
-
@seqio.map_over_dataset()
|
2409 |
-
def charxiv_preprocessor(ex):
|
2410 |
-
question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4", "reasoning_q"]
|
2411 |
-
answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4", "reasoning_a"]
|
2412 |
-
|
2413 |
-
questions = [ex[name] for name in question_names]
|
2414 |
-
answers = [ex[name] for name in answer_names]
|
2415 |
-
|
2416 |
-
return dict(
|
2417 |
-
image=ex["image"],
|
2418 |
-
question=tf.stack(questions, 0),
|
2419 |
-
answer=tf.stack(answers, 0)
|
2420 |
-
)
|
2421 |
-
|
2422 |
-
@seqio.map_over_dataset()
|
2423 |
-
def charxiv_descriptive_preprocessor(ex):
|
2424 |
-
question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4"]
|
2425 |
-
answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4"]
|
2426 |
-
|
2427 |
-
questions = [ex[name] for name in question_names]
|
2428 |
-
answers = [ex[name] for name in answer_names]
|
2429 |
-
|
2430 |
-
return dict(
|
2431 |
-
image=ex["image"],
|
2432 |
-
question=tf.stack(questions, 0),
|
2433 |
-
answer=tf.stack(answers, 0)
|
2434 |
-
)
|
2435 |
-
|
2436 |
-
@seqio.map_over_dataset()
|
2437 |
-
def charxiv_reasoning_preprocessor(ex):
|
2438 |
-
return dict(
|
2439 |
-
image=ex["image"],
|
2440 |
-
question=ex["reasoning_q"],
|
2441 |
-
answer=ex["reasoning_a"]
|
2442 |
-
)
|
2443 |
-
|
2444 |
-
@seqio.map_over_dataset()
|
2445 |
-
def tablevqa_preprocessor(ex):
|
2446 |
-
return dict(
|
2447 |
-
image=ex["image"],
|
2448 |
-
question=ex["question"],
|
2449 |
-
answer=ex["gt"]
|
2450 |
-
)
|
2451 |
-
|
2452 |
-
@seqio.map_over_dataset()
|
2453 |
-
def vtabfact_preprocessor(ex):
|
2454 |
-
return dict(
|
2455 |
-
image=ex["image"],
|
2456 |
-
question=tf.strings.join([ex["question"], "Answer with yes or no."], separator="\n"),
|
2457 |
-
answer=ex["gt"]
|
2458 |
-
)
|
2459 |
-
|
2460 |
-
@seqio.map_over_dataset()
|
2461 |
-
def nutrition_fact_preprocessor(ex):
|
2462 |
-
question_names = ["descriptive_q", "reasoning_q"]
|
2463 |
-
answer_names = ["descriptive_a", "reasoning_a"]
|
2464 |
-
|
2465 |
-
questions = [ex[name] for name in question_names]
|
2466 |
-
answers = [ex[name] for name in answer_names]
|
2467 |
-
|
2468 |
-
return dict(
|
2469 |
-
image=ex["image"],
|
2470 |
-
question=tf.stack(questions, 0),
|
2471 |
-
answer=tf.stack(answers, 0)
|
2472 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|