chrisc36 commited on
Commit
59a83f1
1 Parent(s): 8d085c7

Delete preprocesssors.py

Browse files
Files changed (1) hide show
  1. preprocesssors.py +0 -2472
preprocesssors.py DELETED
@@ -1,2472 +0,0 @@
1
- import hashlib
2
- import json
3
- import math
4
- from functools import reduce
5
- from typing import Mapping, Optional, Sequence
6
-
7
- import numpy as np
8
- import tensorflow as tf
9
- import seqio
10
- import gin
11
-
12
- from .data_utils import flatten_parts, stateless_permutation, stateless_shuffle
13
- from .. import config
14
-
15
-
16
- def get_from_dict(data, keys):
17
- """Iterate nested dictionary"""
18
- return reduce(dict.get, keys, data)
19
-
20
- def get_blank_image():
21
- image = tf.zeros([224, 224, 3], dtype=tf.uint8)
22
- image = tf.expand_dims(image, 0)[:1]
23
- return image
24
-
25
-
26
- @seqio.utils.map_over_dataset
27
- def rekey(x, key_map=None):
28
- """Replace the feature keys according to the mapping in `key_map`.
29
- For example, if the dataset returns examples of the format:
30
- {'foo': 'something', 'bar': 'something else'}
31
- and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return
32
- examples with the format
33
- {'boo': 'something', 'spar': 'something else'}
34
- If a mapping is to an empty key or None, set the new key to an empty string.
35
- Args:
36
- x: an example to process.
37
- key_map: dictionary mapping new keys to original keys
38
- Returns:
39
- A preprocessed example with the format listed above.
40
- """
41
- if key_map:
42
- out = {}
43
- for new_key, old_key in key_map.items():
44
- if isinstance(old_key, list):
45
- out[new_key] = get_from_dict(x, old_key)
46
- else:
47
- out[new_key] = x[old_key]
48
- return out
49
- return x
50
-
51
-
52
- def rename(**kwargs):
53
- @seqio.map_over_dataset
54
- def _fn(x):
55
- updates = {}
56
- for new_key, old_key in kwargs.items():
57
- if isinstance(old_key, list):
58
- val = x[old_key[0]]
59
- for k in old_key[1:-1]:
60
- val = val[k]
61
- updates[new_key] = val.pop(old_key[-1])
62
- else:
63
- updates[new_key] = x.pop(old_key)
64
- x.update(updates)
65
- return x
66
- return _fn
67
-
68
-
69
- def extract_transcripts(ds):
70
- ds = flatten_parts(ds, ["transcripts"])
71
- def _map(ex):
72
- return dict(
73
- image=ex["image"],
74
- text=ex["transcripts"],
75
- url=ex["url"]
76
- )
77
- return ds.map(_map)
78
-
79
-
80
- @seqio.map_over_dataset
81
- def extract_caption_and_all_transcripts(ex):
82
- transcripts = tf.random.shuffle(ex["transcripts"])[:3]
83
- weight = 1.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
84
- return dict(
85
- image=ex["image"],
86
- text=tf.concat([tf.expand_dims(ex["caption"], 0), transcripts], 0),
87
- url=ex["url"],
88
- text_weights=tf.pad(
89
- tf.ones((1,), dtype=tf.float32), [[0, tf.shape(transcripts)[0]]],
90
- constant_values=weight),
91
- )
92
-
93
-
94
- @seqio.map_over_dataset
95
- def extract_all_transcripts(ex):
96
- transcripts = tf.random.shuffle(ex["transcripts"])[:3]
97
- weight = 3.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
98
- return dict(
99
- image=ex["image"],
100
- text=transcripts,
101
- url=ex["url"],
102
- text_weights=tf.fill((tf.shape(transcripts)[0],), weight),
103
- )
104
-
105
-
106
- @seqio.map_over_dataset
107
- def extract_transcript(ex):
108
- transcripts = tf.random.shuffle(ex["transcripts"])
109
- return dict(
110
- image=ex["image"],
111
- text=transcripts[0],
112
- url=ex["url"],
113
- )
114
-
115
-
116
- @seqio.map_over_dataset
117
- def extract_caption(ex):
118
- caption = ex["caption"]
119
- if len(caption.shape) > 0:
120
- ex["text"] = caption[0]
121
- else:
122
- ex["text"] = caption
123
- return ex
124
-
125
-
126
- @seqio.map_over_dataset
127
- def extract_joint_captions(ex):
128
- caption = ex["caption"]
129
- if len(caption.shape) > 0:
130
- caption = caption[0]
131
- _ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
132
- _ix = _ix % tf.shape(ex["transcripts"])[0]
133
- return dict(
134
- image=ex["image"],
135
- text=tf.stack([caption, ex["mistral_caption"], ex["transcripts"][_ix]], 0),
136
- url=ex["url"]
137
- )
138
-
139
-
140
- @seqio.map_over_dataset(num_seeds=1)
141
- def extract_caption_and_transcript(ex, seed):
142
- caption = ex["caption"]
143
- if len(caption.shape) > 0:
144
- caption = caption[0]
145
- _ix = tf.random.stateless_uniform((), seed, 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
146
- return dict(
147
- image=ex["image"],
148
- text=tf.stack([caption, ex["transcripts"][_ix]], 0),
149
- url=ex["url"]
150
- )
151
-
152
-
153
- @seqio.map_over_dataset
154
- def caption_transcript_augmented(ex, sequence_length):
155
- caption = ex["caption"]
156
- if len(caption.shape) > 0:
157
- caption = caption[0]
158
- image = ex["image"]
159
- properties = []
160
-
161
- do_augmentation = sequence_length["is_training"]
162
- # do_augmentation = False
163
-
164
- # Keep this off, it screws up OCR
165
- # do_hflip = (tf.random.uniform(()) > 0.2 and do_augmentation)
166
- do_hflip = False
167
- if do_hflip:
168
- image = image[:, ::-1]
169
-
170
- # Mild color jitter
171
- do_color = (tf.random.uniform(()) > 0.5 and do_augmentation)
172
- if do_color:
173
- image = tf.image.random_hue(image, max_delta=0.05)
174
- image = tf.image.random_brightness(image, max_delta=0.2)
175
- image = tf.image.random_saturation(image, 0.7, 1.3)
176
- image = tf.image.random_contrast(image, 0.7, 1.3)
177
-
178
- # Mild affine transformation
179
- do_affine = (tf.random.uniform(()) > 0.5 and do_augmentation)
180
- if do_affine and do_augmentation:
181
- shift_x = tf.random.uniform((), -10, 10) * 0
182
- shift_y = tf.random.uniform((), -10, 10) * 0
183
- shear_x = tf.random.uniform((), -2, 2)
184
- shear_y = tf.random.uniform((), -2, 2)
185
- rotation = tf.random.uniform((), -6, 6)
186
- max_scale = 1.1
187
- scale = tf.random.uniform((), 0.8, max_scale)
188
- center = tf.cast(tf.shape(image), tf.float32)/2
189
-
190
- image = tf.keras.ops.image.affine_transform(
191
- image,
192
- tf.stack(get_affine_matrix(
193
- [center[0], center[1]],
194
- rotation,
195
- [shift_x, shift_y],
196
- 1/scale,
197
- [shear_x, shear_y]
198
- ) + [0., 0.]),
199
- interpolation='bilinear',
200
- fill_mode='constant',
201
- fill_value=1.,
202
- data_format='channels_last'
203
- )
204
-
205
- properties = tf.stack([
206
- ("[hflip]" if do_hflip else ""),
207
- ("[color]" if do_color else ""),
208
- ("[affine]" if do_affine else "")
209
- ])
210
- properties = tf.boolean_mask(properties, tf.strings.length(properties) > 0)
211
- prompt = tf.strings.reduce_join(properties, separator=" ")
212
- ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
213
- out = dict(
214
- image=image,
215
- text=tf.stack([caption, ex["transcripts"][ix]], 0),
216
- url=ex["url"],
217
- prompt=prompt,
218
- )
219
- # out["metadata/unaugmented_image"] = image
220
- return out
221
-
222
-
223
- def extract_caption_and_transcript_hflip(ds):
224
-
225
- # Just in case they are ordered somehow in Matt's data
226
- @seqio.map_over_dataset
227
- def _shuffle_transcripts(_ex):
228
- _ex["transcripts"] = tf.random.shuffle(_ex["transcripts"])
229
- _ex["hflip"] = tf.random.uniform((), 0, 3, dtype=tf.int32)
230
- return _ex
231
-
232
- ds = _shuffle_transcripts(ds)
233
-
234
- # Build a 3x long dataset with each individual transcript so we iterate through
235
- # each transcript
236
- @seqio.map_over_dataset
237
- def _with_transcript(ex, _ix):
238
- caption = ex["caption"]
239
- if len(caption.shape) > 0:
240
- caption = caption[0]
241
- hflip = ex["hflip"] == _ix
242
- if hflip:
243
- ex["image"] = ex["image"][:, ::-1]
244
- style = ["long_caption_flipped", "transcript_flipped"]
245
- else:
246
- style = ["long_caption", "transcript"]
247
- return dict(
248
- image=ex["image"],
249
- text=tf.stack([caption, ex["transcripts"][_ix]], 0),
250
- url=ex["url"],
251
- style=style
252
- )
253
-
254
- joint_ds = _with_transcript(ds, 0)
255
- for i in range(1, 3):
256
- joint_ds = joint_ds.concatenate(_with_transcript(ds, i))
257
-
258
- return joint_ds
259
-
260
-
261
- @seqio.map_over_dataset
262
- def extract_llava(ex, sequence_length, output_features):
263
- tf.assert_equal(tf.shape(ex['conversations']['value'])[0], 2)
264
- prompt = ex['conversations']['value'][0]
265
- text = ex['conversations']['value'][1]
266
- ex.pop('conversations')
267
- ex["text"] = text
268
- ex["prompt"] = prompt
269
- return ex
270
-
271
-
272
- def extract_localized_narrative(ds):
273
- ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
274
- def _map(ex):
275
- return dict(
276
- image=ex["image"],
277
- text=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
278
- )
279
- return ds.map(_map)
280
-
281
-
282
- def float_to_text(val):
283
- return tf.strings.as_string(tf.cast(val * 100, tf.int32))
284
-
285
-
286
- @seqio.map_over_dataset
287
- def extract_vqa(ex):
288
- questions = ex["vqa"]["questions"]
289
- answers = ex["vqa"]["answers"]
290
- answers = tf.strings.reduce_join(answers, 1, separator="; ")
291
- qas = tf.strings.reduce_join(tf.stack([questions, answers], 1), separator=" ")
292
- return dict(
293
- image=ex["image"],
294
- text=tf.strings.reduce_join(qas, separator="\n")
295
- )
296
-
297
-
298
- @seqio.map_over_dataset
299
- def coco_image_id_from_path(ex):
300
- image_id = tf.strings.substr(ex["image/filename"], 0, tf.strings.length(ex["image/filename"])-4)
301
- ex["image_id"] = tf.strings.to_number(image_id)
302
- return ex
303
-
304
-
305
- @seqio.map_over_dataset
306
- def add_coco_url(ex):
307
- """Turns a COCO path into a URL, which can then be used in visualizations"""
308
- path = ex["image/filename"]
309
- if not tf.strings.regex_full_match(path, ".*/.*"):
310
- prefix = tf.strings.regex_replace(path, "COCO_", "")
311
- prefix = tf.strings.regex_replace(prefix, "_[0-9]+.jpg", "")
312
- path = tf.strings.join([prefix, path], separator="/")
313
-
314
- # images are hosted by the COCO website here
315
- url = tf.strings.join(["https://s3.us-east-1.amazonaws.com/images.cocodataset.org/", path])
316
- ex["metadata/image_url"] = url
317
- return ex
318
-
319
-
320
- def flatten_vqa(ds):
321
- parts = ["questions", "answers"]
322
- for k in ["id", "question_id"]:
323
- if k in ds.element_spec:
324
- parts.append(k)
325
- return flatten_parts(ds, parts)
326
-
327
-
328
- def format_gqa(ds, is_balanced=True, flatten=True):
329
- if is_balanced:
330
- ds = ds.filter(lambda x: tf.reduce_any(x["questions"]["is_balanced"]))
331
- def _filter_qs(ex):
332
- qs = ex["questions"]
333
- mask = qs["is_balanced"]
334
- qs = {k: tf.boolean_mask(v, mask) for k, v in qs.items()}
335
- ex["questions"] = qs
336
- return ex
337
- ds = ds.map(_filter_qs)
338
-
339
- if flatten:
340
- ds = flatten_parts(ds, ["questions"])
341
-
342
- def _rename(ex):
343
- out = ex["questions"]
344
- out["image"] = ex["image"]
345
- out["image_id"] = ex["image_id"]
346
- return out
347
- return ds.map(_rename)
348
-
349
-
350
- @seqio.map_over_dataset
351
- def fix_doqa_url(x):
352
- x["image_url"] = tf.strings.regex_replace(x["image_url"], "gs://", "")
353
- return x
354
-
355
-
356
- def _add_metadata(ex):
357
- out = {}
358
- if "id" in ex:
359
- out["metadata/example_id"] = ex["id"]
360
- elif "example_id" in ex:
361
- out["metadata/example_id"] = ex["example_id"]
362
- elif "question_id" in ex:
363
- out["metadata/example_id"] = ex["question_id"]
364
- if "image_url" in ex:
365
- out["metadata/image_url"] = ex["image_url"]
366
- for k, v in ex.items():
367
- if k.startswith("metadata/"):
368
- out[k] = v
369
- return out
370
-
371
-
372
- def image_only(ds):
373
- return ds.filter(lambda x: x["has_image"])
374
-
375
-
376
- def filter_difficult_direct_answer(ds):
377
- return ds.filter(lambda x: not x["difficult_direct_answer"])
378
-
379
-
380
- @seqio.map_over_dataset()
381
- def format_ai2d(ex, variable_style=True):
382
- abc = tf.constant(list("abcdefg".upper()))
383
- out = dict(image=ex["image"])
384
- out.update(_add_metadata(ex))
385
-
386
- options = ex["choices"]
387
- # >= 3 in case of none of the above like answers
388
- n_options = tf.shape(ex["option_is_abc"])[0]
389
- if ex["abc_label"] and tf.reduce_sum(tf.cast(ex["option_is_abc"], tf.int32)) >= (n_options - 1):
390
- # The image labels are always upper, so use upper in the answer ptions
391
- options = tf.where(
392
- ex["option_is_abc"],
393
- tf.strings.upper(options),
394
- options
395
- )
396
- short_options = options
397
- style = "ai2_diagram_no_letter"
398
- else:
399
- short_options = abc[:tf.shape(options)[0]]
400
- options = tf.stack([short_options, options,], 1)
401
- options = tf.strings.reduce_join(options, axis=-1, separator=": ")
402
- style = "ai2_diagram"
403
-
404
- options = tf.strings.reduce_join(options, separator="\n")
405
- out["question"] = ex["question"]
406
- out["options"] = options
407
- if variable_style:
408
- out["style"] = style
409
- if ex["answer_idx"] < 0:
410
- out["text"] = "?"
411
- else:
412
- out["text"] = short_options[ex["answer_idx"]]
413
- out["metadata/answer_idx"] = ex["answer_idx"]
414
- tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
415
- out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
416
- out["metadata/has_transparent_box"] = ex.get("has_transparent_box", tf.constant(False))
417
- out["metadata/abc_label"] = ex["abc_label"]
418
- return out
419
-
420
-
421
- @gin.configurable()
422
- @seqio.map_over_dataset()
423
- def format_multiple_choice_qa(ex, option_format="abc"):
424
- assert option_format == "abc"
425
- abc = tf.constant(list("abcdefg".upper()))
426
- out = dict(image=ex["image"])
427
- out.update(_add_metadata(ex))
428
- options = ex["choices"]
429
- short_options = abc[:tf.shape(options)[0]]
430
- options = tf.stack([short_options, options,], 1)
431
- options = tf.strings.reduce_join(options, axis=-1, separator=": ")
432
- options = tf.strings.reduce_join(options, separator="\n")
433
- out["question"] = ex["question"]
434
- out["options"] = options
435
- if ex["answer_idx"] < 0:
436
- out["text"] = "?"
437
- else:
438
- out["text"] = short_options[ex["answer_idx"]]
439
- out["metadata/answer_idx"] = ex["answer_idx"]
440
- tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
441
- out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
442
- # out["metadata/option_names"] = tf.RaggedTensor.from_row_lengths(short_options, tf.shape(short_options))
443
- # out["metadata/option_names"] = short_options
444
- return out
445
-
446
-
447
- @seqio.map_over_dataset()
448
- def output_options(ex):
449
- ex["metadata/options"] = ex["options"]
450
- return ex
451
-
452
-
453
- @seqio.map_over_dataset()
454
- def extract_tally_qa(ex):
455
- questions = ex.pop("questions")
456
- ex["questions"] = questions["question"]
457
- ex["answers"] = tf.strings.as_string(questions["answer"])
458
- ex["question_id"] = questions["question_id"]
459
- return ex
460
-
461
-
462
- @seqio.map_over_dataset()
463
- def count_bench_preprocessor(ex):
464
- return {
465
- "image": ex["image"],
466
- "text": tf.strings.as_string(ex["number"]),
467
- "object": ex["noun"],
468
- "question": tf.strings.join([
469
- "How many ", ex["noun"], " are there?"
470
- ]),
471
- "metadata/count": ex["number"],
472
- }
473
-
474
-
475
- def filter_human(ds):
476
- return ds.filter(lambda x: x["is_human"])
477
-
478
-
479
- def filter_aug(ds):
480
- return ds.filter(lambda x: not x["is_human"])
481
-
482
-
483
- @seqio.map_over_dataset()
484
- def reweight_chartqa(ex, human, aug):
485
- is_human = ex["metadata/is_human"]
486
- ex["text_weights"] = human if is_human else aug
487
- return ex
488
-
489
-
490
- @seqio.map_over_dataset()
491
- def chartqa_prompting(ex):
492
- question = tf.strings.join([ex["question"], " Answer:"])
493
- return dict(
494
- image=ex["image"],
495
- question=question,
496
- answer=ex["answer"]
497
- )
498
-
499
-
500
- @seqio.map_over_dataset()
501
- def chartqa_explanation(ex):
502
- question = tf.strings.join([ex["question"], " Explanation:"])
503
- out = {
504
- "image": ex["image"],
505
- "question": question,
506
- "answer": ex["answer"],
507
- }
508
- out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
509
- return out
510
-
511
-
512
- @seqio.map_over_dataset(num_seeds=1)
513
- def _preprocess_scifi(ex, seed):
514
- if "qa_pairs" in ex:
515
- q = ex["qa_pairs"]
516
- else:
517
- q = ex["qa"]
518
- ix = stateless_permutation(tf.shape(q["question"])[0], seed)
519
- return dict(
520
- image=ex["image"],
521
- question=tf.gather(q["question"], ix),
522
- explanation=tf.gather(q["explanation"], ix),
523
- answer=tf.gather(q["answer"], ix),
524
- )
525
-
526
- @seqio.map_over_dataset
527
- def scifi_explanation_only(ex):
528
- return dict(
529
- image=ex["image"],
530
- question=ex["question"],
531
- answer=ex["explanation"],
532
- )
533
-
534
-
535
- def filter_named_entity(ds):
536
- @seqio.map_over_dataset
537
- def _load_image(ex):
538
- ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
539
- return ex
540
-
541
- ds = _load_image(ds)
542
- return ds.filter(lambda x: tf.reduce_min(tf.shape(x["image"])[:2]) >= 32)
543
-
544
-
545
- @seqio.map_over_dataset()
546
- def extract_named_entity(ex):
547
- qs = ex["questions"]
548
- return {
549
- "image": ex["image"],
550
- "metadata/image_url": ex["url"],
551
- "metadata/entity": ex["entity"],
552
- "questions": qs["question"],
553
- "answers": qs["answer"],
554
- }
555
-
556
- @gin.configurable()
557
- def extract_individual_vqa(ds, test=False, answer_mode="best"):
558
-
559
- @seqio.map_over_dataset(num_seeds=1)
560
- def _extract(ex, seed):
561
- if "questions" in ex:
562
- question = ex["questions"]
563
- else:
564
- question = ex["question"]
565
- out = dict(
566
- image=ex["image"],
567
- question=question,
568
- )
569
- out.update(_add_metadata(ex))
570
- out["metadata/question"] = question
571
- if ex.get("answers") is not None:
572
- out["metadata/references"] = tf.strings.reduce_join(ex["answers"], separator="\n")
573
- elif ex.get("answer") is not None:
574
- out["metadata/references"] = ex["answer"]
575
-
576
- if not test:
577
- if "answer" in ex:
578
- answer = ex["answer"]
579
- else:
580
- answer = ex["answers"]
581
- if answer.dtype in [tf.int32, tf.int64]:
582
- answer = tf.strings.as_string(answer)
583
- if len(answer.shape) == 1 and tf.shape(answer)[0] == 0:
584
- answer = tf.expand_dims("", 0)
585
- if len(answer.shape) == len(question.shape):
586
- pass
587
- # Handle questions with multiple answers
588
- elif answer_mode == "random":
589
- assert len(answer.shape) == 1
590
- answer = answer[tf.random.stateless_uniform((), seed, 0, tf.shape(answer)[0], dtype=tf.int32)]
591
- elif answer_mode == "best":
592
- def _get_best(_answer):
593
- vals, _, counts = tf.unique_with_counts(_answer)
594
- count_thresh = tf.reduce_max(counts)
595
- vals = tf.boolean_mask(vals, counts >= count_thresh)
596
- return vals[tf.random.stateless_uniform((), seed, 0, tf.shape(vals)[0], dtype=tf.int32)]
597
- if len(answer.shape) == 1:
598
- answer = _get_best(answer)
599
- elif isinstance(answer, tf.RaggedTensor):
600
- n = tf.shape(answer)[0]
601
- answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
602
- for i in range(n):
603
- answer_arr = answer_arr.write(i, _get_best(answer[i]))
604
- answer = answer_arr.stack()
605
- else:
606
- answer = tf.map_fn(_get_best, answer)
607
- elif answer_mode == "all_segments":
608
- out["text"] = answer
609
- elif answer_mode == "all_segments_weighted":
610
- out["text"] = answer
611
- out["text_weights"] = 1.0 / tf.cast(tf.shape(answer)[-1], tf.float32)
612
- elif answer_mode == "all":
613
- if len(answer.shape) == 1:
614
- answer = stateless_shuffle(answer, seed)
615
- answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
616
- elif isinstance(answer, tf.RaggedTensor):
617
- n = tf.shape(answer)[0]
618
- answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
619
- for i in range(n):
620
- answer_arr = answer_arr.write(i, tf.strings.reduce_join(tf.random.shuffle(answer[i]), separator="\n", axis=-1))
621
- answer = answer_arr.stack()
622
- else:
623
- answer = tf.map_fn(tf.random.shuffle, answer)
624
- answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
625
- else:
626
- raise NotImplementedError()
627
- out["text"] = answer
628
- return out
629
- return _extract(ds)
630
-
631
-
632
- @seqio.map_over_dataset()
633
- def extract_khan_academy(ex):
634
- return dict(
635
- image=ex["image"],
636
- image_url=ex["image_url"],
637
- prompt="Answer this question",
638
- text=ex["gptResponse"]
639
- )
640
-
641
- @seqio.map_over_dataset()
642
- def extract_vaia_qa_latex_image(ex, add_short_answer=False, set_short_answer_first=False):
643
- if ex["has_image"]:
644
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
645
- image = tf.expand_dims(image, 0)[:1]
646
- else:
647
- # image = get_blank_image() # blank image
648
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
649
- image = tf.expand_dims(image, 0)[:0]
650
- img_h = tf.shape(image)[1]
651
- img_w = tf.shape(image)[2]
652
-
653
- if add_short_answer:
654
- if set_short_answer_first:
655
- answer = tf.strings.join(["Answer: ", ex["short_answer"], "\n\n", ex["answer"]])
656
- else:
657
- answer = tf.strings.join([ex["answer"], "\n\n", "Answer: ", ex["short_answer"]])
658
- else:
659
- answer = ex["answer"]
660
- out = dict(
661
- image=image, # 4-d tensor
662
- text=answer,
663
- prompt=tf.strings.join([ex["latex_question"], "\n"]),
664
- )
665
- out["metadata/images"] = image
666
- out.update(_add_metadata(ex))
667
- out["metadata/batch_id"] = ex["batch_id"]
668
- out["metadata/image_size"] = [img_w, img_h]
669
- return out
670
-
671
- @seqio.map_over_dataset()
672
- def extract_vqa_online(ex):
673
- out = dict(
674
- image=ex["image"],
675
- prompt=tf.strings.join([ex["question"], "\n"]),
676
- text=ex["answer"]
677
- )
678
- out.update(_add_metadata(ex))
679
- out["metadata/row_id"] = ex["row_id"]
680
- return out
681
-
682
-
683
- @seqio.map_over_dataset()
684
- def extract_scifi_joint(ex):
685
- if "qa_pairs" in ex:
686
- q = ex["qa_pairs"]
687
- else:
688
- q = ex["qa"]
689
- prompts = tf.concat([["Describe this image in detail."], q["question"]], 0)
690
- responses = tf.concat([ex["summary"][None], q["answer"]], 0)
691
- return dict(
692
- image=ex["image"],
693
- prompt=prompts,
694
- text=responses,
695
- )
696
-
697
-
698
- def remove_no_qa(ds):
699
- def _filter(ex):
700
- if "qa_pairs" in ex:
701
- q = ex["qa_pairs"]
702
- else:
703
- q = ex["qa"]
704
- return tf.shape(q["question"])[0] > 0
705
- return ds.filter(_filter)
706
-
707
-
708
- @seqio.map_over_dataset()
709
- def extract_scifi_qa_exp(ex):
710
- return dict(
711
- image=ex["image"],
712
- question=ex["question"], # Array of questions
713
- answer=tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]]),
714
- )
715
-
716
-
717
- @seqio.map_over_dataset(num_seeds=1)
718
- def extract_scifi_qa_demo(ex, seed):
719
- # if tf.random.stateless_uniform((), 0, 1) > 0.5:
720
- answer = tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]])
721
- # else:
722
- # answer = ex["explanation"]
723
- return dict(
724
- image=ex["image"],
725
- question=ex["question"], # Array of questions
726
- answer=answer,
727
- )
728
-
729
-
730
- @seqio.map_over_dataset()
731
- def clock_bench_preprocessor(ex):
732
- out = dict(
733
- image=ex["image"],
734
- prompt="What time is being shown?",
735
- )
736
- for k in ["hour", "minute", "second", "answerable"]:
737
- out[f"metadata/{k}"] = ex[k]
738
- return out
739
-
740
-
741
- def deg2rad(x):
742
- return x*math.pi/180.0
743
-
744
-
745
- def get_affine_matrix(center, angle, translate, scale, shear):
746
- # From https://github.com/pytorch/vision/blob/f96c42fca53230057b16941b078a0a9eee06e20f/torchvision/transforms/functional.py#L1006
747
- rot = deg2rad(angle)
748
- sx = deg2rad(shear[0])
749
- sy = deg2rad(shear[1])
750
-
751
- cx, cy = center
752
- tx, ty = translate
753
-
754
- # RSS without scaling
755
- a = tf.cos(rot - sy) / tf.cos(sy)
756
- b = -tf.cos(rot - sy) * tf.tan(sx) / tf.cos(sy) - tf.sin(rot)
757
- c = tf.sin(rot - sy) / tf.cos(sy)
758
- d = -tf.sin(rot - sy) * tf.tan(sx) / tf.cos(sy) + tf.cos(rot)
759
-
760
- matrix = [a, b, 0.0, c, d, 0.0]
761
- matrix = [x * scale for x in matrix]
762
- # Apply inverse of center translation: RSS * C^-1
763
- matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
764
- matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
765
- # Apply translation and center : T * C * RSS * C^-1
766
- matrix[2] += cx + tx
767
- matrix[5] += cy + ty
768
- return matrix
769
-
770
-
771
- def quantize_point(coor, max_dim, mode="percent-precision-1"):
772
- max_dim = tf.cast(max_dim, tf.float32)
773
- coor = tf.cast(coor, tf.float32)
774
- x = (coor / max_dim)
775
- if mode == "percent-precision-1":
776
- return tf.strings.as_string(x*100, precision=1)
777
- elif mode == "zero_to_one":
778
- return tf.strings.as_string(x, precision=3)
779
- elif mode == "1k":
780
- return tf.strings.as_string(x*1000, precision=0)
781
- else:
782
- raise NotImplementedError(mode)
783
-
784
-
785
- def construct_pointing_format(label_text, alt_text, x_str, y_str):
786
- if alt_text is None:
787
- alt_text = label_text
788
- np = tf.shape(x_str)[0]
789
- if np == 0:
790
- output = ""
791
- elif np == 1:
792
- output = tf.strings.join([
793
- '<point x="', x_str[0], '" y="', y_str[0], '" alt="',
794
- alt_text, '">', label_text, '</point>'
795
- ])
796
- else:
797
- ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
798
- xs = tf.strings.join(["x", ids, '="', x_str, '"'])
799
- ys = tf.strings.join(["y", ids, '="', y_str, '"'])
800
- points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
801
- output = tf.strings.join(
802
- ["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
803
- return output
804
-
805
-
806
- def order_points(x, y, seed, point_order):
807
- if point_order == "natural":
808
- return x, y
809
-
810
- if point_order == "random":
811
- ix = stateless_permutation(tf.shape(x)[0], seed)
812
- elif point_order == "xy":
813
- x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
814
- ix = tf.argsort(x_float*100000 + y_float)
815
- elif point_order == "yx":
816
- x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
817
- ix = tf.argsort(y_float*100000 + x_float)
818
- else:
819
- raise NotImplementedError(point_order)
820
- return tf.gather(x, ix), tf.gather(y, ix)
821
-
822
-
823
- @gin.configurable()
824
- def points_to_text(x, y, w, h, seed, label=None, alt_text=None, point_mode="percent-precision-1",
825
- point_order="xy", point_list_mode="tag"):
826
- """Returns a string encoding of a list of points"""
827
- x = quantize_point(x, w, point_mode)
828
- y = quantize_point(y, h, point_mode)
829
- # Order the quantized points to make the order matches what was generated, this can matter
830
- # when points have the same quantized value e.g, (10.001, 20) (10.002, 10) should be
831
- # represented (10, 10), (10, 20), but if we sort before quantization we get (10, 20), (10, 10)
832
- x, y = order_points(x, y, seed, point_order)
833
- if point_list_mode == "tag":
834
- return construct_pointing_format(label, alt_text, x, y)
835
- elif point_list_mode == "paren":
836
- n = tf.shape(x)[0]
837
- return tf.strings.reduce_join(tf.strings.join([
838
- "(", x, ", ", y, ")"
839
- ]), separator=", ")
840
- # if n == 0:
841
- # output = ""
842
- # else:
843
- # ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
844
- # xs = tf.strings.join(["x", ids, '="', x_str, '"'])
845
- # ys = tf.strings.join(["y", ids, '="', y_str, '"'])
846
- # points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
847
- # output = tf.strings.join(
848
- # ["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
849
- # return output
850
- else:
851
- raise NotImplementedError(point_list_mode)
852
-
853
-
854
- def points_to_answer(x, y, w, h, seed, label, is_counting, alt_text=None):
855
- count = tf.shape(x)[0]
856
- if is_counting:
857
- if count == 0:
858
- return "There are none."
859
- else:
860
- point_text = points_to_text(x, y, w, h, seed, label, alt_text)
861
- return tf.strings.join([
862
- "Counting the ", point_text,
863
- " shows a total of ",
864
- tf.strings.as_string(count),
865
- "."
866
- ])
867
- else:
868
- if count == 0:
869
- return "There are none."
870
- else:
871
- return points_to_text(x, y, w, h, seed, label, alt_text)
872
-
873
-
874
- @seqio.map_over_dataset(num_seeds=2)
875
- def extract_point_qa(ex, seeds, answer_type="y_major"):
876
- ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
877
- img_h = tf.shape(ex["image"])[0]
878
- img_w = tf.shape(ex["image"])[1]
879
-
880
- questions = ex["questions"]
881
- question = questions["question"]
882
- n = tf.shape(question)[0]
883
- answers = tf.TensorArray(tf.string, size=n, element_shape=())
884
- point_text = questions["annotations"]["point_text"]
885
- point_seeds = tf.RaggedTensor.from_row_splits(
886
- row_splits=point_text.row_splits,
887
- values=tf.random.split(seeds[0], num=tf.shape(point_text.values)[0])
888
- )
889
- for question_ix in range(n):
890
- anno = questions["annotations"]
891
- answer = questions["answer_with_placeholders"][question_ix]
892
- n_anno = tf.shape(anno["point_text"][question_ix])[0]
893
- for anno_ix in range(n_anno):
894
- points = anno["points"][question_ix, anno_ix]
895
- point_text = points_to_answer(
896
- points[:, 0], points[:, 1], 100, 100,
897
- point_seeds[question_ix, anno_ix],
898
- anno["point_text"][question_ix, anno_ix],
899
- False,
900
- alt_text=anno["alt_text"][question_ix, anno_ix],
901
- )
902
- answer_split = tf.strings.split(answer, sep="<|POINT|>", maxsplit=1)
903
- answer = tf.strings.join([answer_split[0], point_text, answer_split[1]])
904
- # Make sure all placeholders where used
905
- tf.debugging.assert_equal(tf.shape(tf.strings.split(answer, sep="<|POINT|>"))[0], 1)
906
- answers = answers.write(question_ix, answer)
907
-
908
- messages = tf.stack([question, answers.stack()], axis=1)
909
- messages = tf.reshape(messages, [-1])
910
- conversation_ids = tf.range(tf.shape(messages)[0] // 2, dtype=tf.int32)
911
- conversation_ids = tf.repeat(conversation_ids, 2)
912
- out = dict(
913
- image=ex["image"],
914
- messages=tf.RaggedTensor.from_value_rowids(messages, conversation_ids)
915
- )
916
- ix = stateless_permutation(tf.shape(messages)[0], seeds[1])
917
- messages = tf.gather(messages, ix)
918
- out.update(_add_metadata(ex))
919
- out["metadata/image_size"] = [img_w, img_h]
920
- return out
921
-
922
-
923
- def select_point(mask):
924
- bs = tf.shape(mask)[0]
925
- valid = tf.cast(mask, tf.float32)
926
- h, w = tf.shape(mask)[1], tf.shape(mask)[2]
927
- ys = tf.range(h, dtype=tf.int32)
928
- xs = tf.range(w, dtype=tf.int32)
929
-
930
- n = tf.reduce_sum(valid, [1, 2])
931
- cy = tf.reduce_sum(tf.cast(ys[None, :, None], tf.float32) * valid, [1, 2]) / n # [bs]
932
- cx = tf.reduce_sum(tf.cast(xs[None, None, :], tf.float32) * valid, [1, 2]) / n # [bs]
933
-
934
- dist_y = tf.square(tf.range(h, dtype=tf.float32)[None, :] - cy[:, None]) # [bs, h]
935
- dist_x = tf.square(tf.range(w, dtype=tf.float32)[None, :] - cx[:, None]) # [bs, w]
936
- dist = dist_y[:, :, None] + dist_x[:, None, :] # [batch, h, w]
937
- dist = dist + (1 - valid) * 1e12
938
- min_dist = tf.argmin(tf.reshape(dist, [bs, -1]), axis=-1) # [batch]
939
- w = tf.cast(w, min_dist.dtype)
940
- cy = tf.cast(min_dist // w, tf.float32)
941
- cx = tf.cast(min_dist % w, tf.float32)
942
- return cx, cy
943
-
944
-
945
- @seqio.map_over_dataset
946
- def refexp_pointing(ex):
947
- img_h = tf.shape(ex["image"])[0]
948
- img_w = tf.shape(ex["image"])[1]
949
- objects = ex["objects"]
950
-
951
- # Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
952
- refexps = objects['refexp']['raw']
953
- bbox = objects["bbox"]
954
- mask = tf.squeeze(objects["mask"], -1)
955
-
956
- ix = tf.range(0, tf.shape(refexps)[0], dtype=tf.int32)
957
- ix = tf.random.shuffle(ix)
958
- refexps = tf.gather(refexps, ix)
959
- bbox = tf.gather(bbox, ix)
960
- mask = tf.gather(mask, ix)
961
-
962
- cx, cy = select_point(mask)
963
- answers = points_to_text(img_h, img_w, cx, cy)
964
-
965
- out = {
966
- "image": ex["image"],
967
- "refexp": refexps.values,
968
- "metadata/image_size": tf.stack([img_w, img_h,]),
969
- "text": tf.repeat(answers, refexps.row_lengths()),
970
- }
971
- if "image_url" in ex:
972
- out["metadata/image_url"] = ex["image_url"]
973
- return out
974
-
975
-
976
- @seqio.map_over_dataset
977
- def refexp_pointing_inf(ex):
978
- img_h = tf.shape(ex["image"])[0]
979
- img_w = tf.shape(ex["image"])[1]
980
-
981
- objects = ex["objects"]
982
- mask = tf.squeeze(objects["mask"], -1)
983
- cx, cy = select_point(mask)
984
- answers = points_to_text(img_h, img_w, cx, cy)
985
-
986
- refexps = objects["refexp"]["raw"]
987
-
988
- # We can't use `mask` directly since it is variable size, and thus it
989
- # will break batching. Here we serialize it instead
990
- serialized_masks = tf.map_fn(tf.io.serialize_tensor, mask, fn_output_signature=tf.string)
991
- out = {
992
- "image": ex["image"],
993
- "refexp": refexps,
994
- "metadata/bbox": objects["bbox"],
995
- "metadata/answer": answers,
996
- "metadata/mask": serialized_masks,
997
- "metadata/image_size": tf.stack([img_w, img_h]),
998
- }
999
- out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
1000
- return out
1001
-
1002
- @seqio.map_over_dataset
1003
- def extract_andriod_control_inf(ex, mode):
1004
- if mode == "ll":
1005
- prompt = tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]])
1006
- elif mode == "hl_ll":
1007
- prompt = tf.strings.join([
1008
- "high_level: ", ex["metadata/hl_instruction"],
1009
- " low_level: ", ex["metadata/ll_instruction"]
1010
- ])
1011
- elif mode == "hl":
1012
- prompt = tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]])
1013
- elif mode == "hl_cot":
1014
- prompt = tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]])
1015
- else:
1016
- raise NotImplementedError()
1017
-
1018
- out = dict(
1019
- image=ex["image"],
1020
- prompt=prompt,
1021
- text=ex["metadata/target_action"]
1022
- )
1023
- out.update(_add_metadata(ex))
1024
- return out
1025
-
1026
- @seqio.map_over_dataset
1027
- def extract_android_control(ex):
1028
- # Each image has three tasks:
1029
- # low level -> action
1030
- # high+low level -> action
1031
- # high level -> action
1032
- # high level -> low level + action (CoT)
1033
- out = dict(
1034
- image=ex["image"],
1035
- prompt=tf.stack([
1036
- tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]]),
1037
- tf.strings.join([
1038
- "high_level: ", ex["metadata/hl_instruction"],
1039
- " low_level: ", ex["metadata/ll_instruction"]
1040
- ]),
1041
- tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]]),
1042
- tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]]),
1043
- ]),
1044
- text=tf.stack([
1045
- ex["metadata/target_action"],
1046
- ex["metadata/target_action"],
1047
- ex["metadata/target_action"],
1048
- tf.strings.join(["Plan: ", ex["metadata/ll_instruction"], " Action: ", ex["metadata/target_action"]]),
1049
- ])
1050
- )
1051
- # Only needed if visualizing
1052
- # ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1053
- # img_h = tf.shape(ex["image"])[0]
1054
- # img_w = tf.shape(ex["image"])[1]
1055
- # out["metadata/image_size"] = tf.stack([img_w, img_h,])
1056
- out.update(_add_metadata(ex))
1057
- return out
1058
-
1059
-
1060
- @seqio.map_over_dataset(num_seeds=1)
1061
- def refexp(ex, seed):
1062
- img_h = tf.shape(ex["image"])[0]
1063
- img_w = tf.shape(ex["image"])[1]
1064
- objects = ex["objects"]
1065
-
1066
- # Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
1067
- refexps = objects['refexp']['raw']
1068
- bbox = objects["bbox"]
1069
- ix = stateless_permutation(tf.shape(refexps)[0], seed)
1070
- refexps = tf.gather(refexps, ix)
1071
- bbox = tf.gather(bbox, ix)
1072
-
1073
- x2 = bbox[:, 0] + bbox[:, 2]
1074
- y2 = bbox[:, 1] + bbox[:, 3]
1075
- with tf.control_dependencies([
1076
- tf.debugging.assert_equal(tf.reduce_any(x2 <= tf.cast(img_w, tf.float32)), True),
1077
- tf.debugging.assert_equal(tf.reduce_any(y2 <= tf.cast(img_h, tf.float32)), True)
1078
- ]):
1079
- answers = points_to_text(
1080
- img_h, img_w,
1081
- tf.reshape(tf.stack([bbox[:, 0], x2], 1), [-1]),
1082
- tf.reshape(tf.stack([bbox[:, 1], y2], 1), [-1]))
1083
- answers = tf.strings.reduce_join(tf.reshape(answers, [-1, 2]), separator=" ", axis=1)
1084
-
1085
- out = {
1086
- "image": ex["image"],
1087
- "refexp": refexps.values,
1088
- "metadata/bbox": bbox,
1089
- "metadata/image_size": tf.stack([img_w, img_h,]),
1090
- "text": tf.repeat(answers, refexps.row_lengths()),
1091
- }
1092
-
1093
- if "image_url" in ex:
1094
- out["image_url"] = ex["image_url"]
1095
- return out
1096
-
1097
-
1098
- @seqio.map_over_dataset
1099
- def refexp_inf(ex):
1100
- img_h = tf.shape(ex["image"])[0]
1101
- img_w = tf.shape(ex["image"])[1]
1102
- out = {
1103
- "image": ex["image"],
1104
- "refexp": ex["objects"]["refexp"]["raw"],
1105
- "metadata/bbox": ex["objects"]["bbox"],
1106
- "metadata/image_size": tf.stack([img_w, img_h,]),
1107
- }
1108
- out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
1109
- return out
1110
-
1111
-
1112
- def point_text_interleaved(*args):
1113
- raise NotImplementedError()
1114
-
1115
-
1116
- @seqio.map_over_dataset
1117
- def web_pointing_preprocessor(ex):
1118
- img_h = tf.shape(ex["image"])[0]
1119
- img_w = tf.shape(ex["image"])[1]
1120
-
1121
- question = point_text_interleaved(
1122
- img_h, img_w, ex["question"], ex["question_points"]["x"], ex["question_points"]["y"])
1123
- answer = point_text_interleaved(
1124
- img_h, img_w, ex["answer"], ex["answer_points"]["x"], ex["answer_points"]["y"])
1125
- answer_points = tf.stack([ex["answer_points"]["x"], ex["answer_points"]["y"]], axis=1)
1126
- return {
1127
- "question": question,
1128
- "answer": answer,
1129
- "image": ex["image"],
1130
- "metadata/image_size": [img_w, img_h],
1131
- "metadata/question_type": ex["question_type"],
1132
- "metadata/answer_points": tf.io.serialize_tensor(answer_points),
1133
- "metadata/answer": answer,
1134
- }
1135
-
1136
-
1137
- def filter_pointing(ds):
1138
- return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] >= 1)
1139
-
1140
-
1141
- def filter_qa(ds):
1142
- return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] == 0)
1143
-
1144
- # vaia filtering
1145
- def filter_image_only(ds):
1146
- return ds.filter(lambda ex: ex["has_image"])
1147
-
1148
- def filter_mc(ds):
1149
- return ds.filter(lambda ex: ex["is_mc"])
1150
-
1151
- def remove_is_long(ds):
1152
- return ds.filter(lambda ex: not ex["is_long"])
1153
-
1154
- def remove_has_multiple_parts(ds):
1155
- return ds.filter(lambda ex: not ex["has_multiple_parts"])
1156
-
1157
-
1158
- def _split(ds: tf.data.Dataset, keys, n_splits=2):
1159
- def _map(ex):
1160
- n = tf.shape(ex[keys[0]])[0]
1161
- if n < n_splits:
1162
- return tf.data.Dataset.from_tensors(ex)
1163
- else:
1164
- # import pdb; pdb.set_trace()
1165
- bs = n // n_splits
1166
- remainder = n - bs*n_splits
1167
- lens = tf.concat([
1168
- tf.ones([remainder], dtype=tf.int32),
1169
- tf.zeros([n_splits-remainder], dtype=tf.int32),
1170
- ], axis=0) + bs
1171
- tf.debugging.assert_equal(tf.reduce_sum(lens), n)
1172
- ends = tf.cumsum(lens)
1173
-
1174
- parts = []
1175
- for split_ix in range(n_splits):
1176
- part_ex = dict(ex)
1177
- e = ends[split_ix]
1178
- s = e - lens[split_ix]
1179
- for k in keys:
1180
- if isinstance(k, tuple):
1181
- assert len(k) == 2
1182
- part_ex[k[0]][k[1]] = ex[k[0]][k[1]][s:e]
1183
- else:
1184
- part_ex[k] = ex[k][s:e]
1185
- parts.append(part_ex)
1186
-
1187
- ds = tf.data.Dataset.from_tensors(parts[0])
1188
- for sub_ds in parts[1:]:
1189
- sub_ds = tf.data.Dataset.from_tensors(sub_ds)
1190
- ds = ds.concatenate(sub_ds)
1191
- return ds
1192
-
1193
- return ds.flat_map(_map)
1194
-
1195
-
1196
-
1197
- def split(ds, n=2):
1198
- # return ds
1199
- return _split(ds, [k for k in [
1200
- "question",
1201
- "label",
1202
- "text",
1203
- "entity",
1204
- "messages"
1205
- ] if k in ds.element_spec], n_splits=n)
1206
-
1207
-
1208
- def split_points(ds, max_points=50):
1209
- label = "question" if "question" in ds.element_spec else "label"
1210
- return _split(ds, [
1211
- "question", label, "notInImage",
1212
- ("answer_points", "x"),
1213
- ("answer_points", "y"),
1214
- ])
1215
-
1216
-
1217
- @seqio.map_over_dataset
1218
- def fix_count_qa(ex):
1219
- ex["label"] = ex["label"][::2]
1220
- tf.debugging.assert_equal(tf.shape(ex["answer_points"]["x"])[0], tf.shape(ex["label"])[0])
1221
- return ex
1222
-
1223
-
1224
- def filter_points(ds, max_number=40):
1225
-
1226
- def _add_valid(ex):
1227
- valid = (
1228
- tf.reduce_all(ex["answer_points"]["x"] >= 0.0, axis=-1) &
1229
- tf.reduce_all(ex["answer_points"]["x"] <= 100.0, axis=-1) &
1230
- tf.reduce_all(ex["answer_points"]["y"] >= 0.0, axis=-1) &
1231
- tf.reduce_all(ex["answer_points"]["y"] <= 100.0, axis=-1) &
1232
- (ex["answer_points"]["y"].row_lengths() <= max_number)
1233
- )
1234
- ex["valid"] = valid
1235
- return ex
1236
- ds = ds.map(_add_valid)
1237
- ds = ds.filter(lambda ex: tf.reduce_any(ex["valid"]))
1238
- return ds
1239
-
1240
-
1241
- # def filter_points(ds, max_number=30):
1242
- # n_points = ds["answer_points"]["x"].row_lengths()
1243
- # parts = tf.TensorArray(tf.int32, size=tf.shape(n_points[0]), element_shape=tf.TensorShape([None]))
1244
- # total = 0
1245
- # on_row = 0
1246
- # for i in range(n_points):
1247
- # n = n_points[i]
1248
- # if n > max_number:
1249
- # continue
1250
- # if n + total > max_number:
1251
- #
1252
- # return ds
1253
-
1254
-
1255
- @seqio.map_over_dataset(num_seeds=2)
1256
- def pointing_preprocessor(ex, sequence_length, seeds, with_count=False):
1257
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1258
- img_h = tf.shape(image)[0]
1259
- img_w = tf.shape(image)[1]
1260
-
1261
- ix = tf.where(ex["valid"])[:, 0]
1262
- ix = stateless_shuffle(ix, seeds[0])
1263
- if "label" in ex:
1264
- question = tf.strings.lower(ex["label"])
1265
- else:
1266
- question = ex["question"]
1267
- question = tf.gather(question, ix) # [n_question]
1268
- points_x = tf.gather(ex["answer_points"]["x"], ix) # [n_question, n_points[ragged]]]
1269
- points_y = tf.gather(ex["answer_points"]["y"], ix)
1270
- not_in_image = tf.gather(ex["notInImage"], ix) # [n_question]
1271
-
1272
- n = tf.shape(points_x)[0]
1273
- point_text = tf.TensorArray(dtype=tf.string, size=n, element_shape=()) # [n_question]
1274
- point_seeds = tf.random.split(seeds[1], n)
1275
- for i in range(n):
1276
- answer = points_to_answer(points_x[i], points_y[i], 100, 100, point_seeds[i], question[i], with_count)
1277
- point_text = point_text.write(i, answer)
1278
- return {
1279
- "image": image,
1280
- "metadata/image_size": [img_w, img_h],
1281
- "entity": question,
1282
- "question": question,
1283
- "text": point_text.stack(),
1284
- }
1285
-
1286
-
1287
- @seqio.map_over_dataset
1288
- def pointing_inf_preprocessor(ex):
1289
- ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1290
- img_h = tf.shape(ex["image"])[0]
1291
- img_w = tf.shape(ex["image"])[1]
1292
-
1293
- question = ex["question"]
1294
- not_in_image = tf.shape(ex["answer_points"]["x"])[0] == 0
1295
-
1296
- # points are stored in normalized format, de-normalize here
1297
- points_x = ex["answer_points"]["x"] * tf.cast(img_w, tf.float32) / 100.0
1298
- points_y = ex["answer_points"]["y"] * tf.cast(img_h, tf.float32) / 100.0
1299
-
1300
- out = dict(
1301
- image=ex["image"],
1302
- question=question,
1303
- entity=question,
1304
- )
1305
- out.update(_add_metadata(ex))
1306
- out["metadata/not_in_image"] = not_in_image
1307
- # We can't use `mask` directly since it is variable size, and thus it
1308
- # will break batching. Here we serialize it instead
1309
- serialized_masks = tf.map_fn(tf.io.serialize_tensor, ex["masks"], fn_output_signature=tf.string)
1310
- serialized_masks = tf.strings.reduce_join(serialized_masks, separator="|||")
1311
- out["metadata/mask"] = serialized_masks
1312
- out["metadata/question"] = question
1313
- out["metadata/answer_points"] = tf.io.serialize_tensor(tf.stack([points_x, points_y], 1))
1314
- out["metadata/image_size"] = [img_w, img_h]
1315
-
1316
- return out
1317
-
1318
-
1319
- @seqio.map_over_dataset(num_seeds=1)
1320
- def count_qa_preprocessor_inf(ex, sequence_length, seed):
1321
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1322
- img_h = tf.shape(image)[0]
1323
- img_w = tf.shape(image)[1]
1324
-
1325
- entity = tf.strings.substr(
1326
- ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
1327
- entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
1328
- entity = tf.strings.lower(entity)
1329
- tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
1330
-
1331
- return {
1332
- "image": image,
1333
- "metadata/image_size": [img_w, img_h],
1334
- "metadata/count": tf.strings.to_number(ex["answer"]),
1335
- "question": ex["question"],
1336
- "entity": entity,
1337
- }
1338
-
1339
-
1340
- @seqio.map_over_dataset(num_seeds=1)
1341
- def count_qa_preprocessor(ex, sequence_length, seed, with_count=False,
1342
- for_inference=False):
1343
- point_answer = ex["point_answer"]
1344
- numbers_str = tf.strings.regex_replace(point_answer, r'\.$', '')
1345
- numbers_str = tf.strings.regex_replace(numbers_str, r'[^\d\.\s]+', '')
1346
- numbers_str = tf.strings.strip(numbers_str)
1347
- numbers = tf.strings.split(numbers_str)
1348
- float_numbers = tf.strings.to_number(numbers, out_type=tf.float32)
1349
- coordinates = tf.reshape(float_numbers, (-1, 3))
1350
- points_x = coordinates[:, 1]
1351
- points_y = coordinates[:, 2]
1352
-
1353
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1354
- img_h = tf.shape(image)[0]
1355
- img_w = tf.shape(image)[1]
1356
- entity = tf.strings.substr(
1357
- ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
1358
- entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
1359
- entity = tf.strings.lower(entity)
1360
- tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
1361
- count = tf.strings.to_number(ex["answer"], out_type=tf.int32)
1362
- if for_inference:
1363
- return {
1364
- "image": image,
1365
- "metadata/image_size": [img_w, img_h],
1366
- "metadata/count": count,
1367
- "question": ex["question"],
1368
- "entity": entity,
1369
- }
1370
- else:
1371
- tf.debugging.assert_equal(count, tf.shape(points_x)[0])
1372
- # points are already normalized so use w=1, h=1
1373
- answer = points_to_answer(points_x, points_y, 1, 1, seed, entity, with_count)
1374
- return {
1375
- "image": image,
1376
- "metadata/image_size": [img_w, img_h],
1377
- "metadata/count": count,
1378
- "question": ex["question"],
1379
- "entity": entity,
1380
- "text": answer,
1381
- }
1382
-
1383
-
1384
- @gin.configurable()
1385
- @seqio.map_over_dataset
1386
- def cleanup_preprocessor(ex, preprocess=False):
1387
- if preprocess:
1388
- ex["prompt"] = tf.strings.join(
1389
- [
1390
- "[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
1391
- ex["prompt"],
1392
- "\n[[Assistant]]: {after}"
1393
- ]
1394
- )
1395
- return ex
1396
- else:
1397
- return ex
1398
-
1399
-
1400
- @gin.configurable()
1401
- @seqio.map_over_dataset
1402
- def random_text_preprocessor(ex, preprocess=False):
1403
- ex["prompt"] = "What does the text say in this image?"
1404
- if preprocess:
1405
- ex["prompt"] = tf.strings.join(["[[User]]: ", ex["prompt"], "\n[[Assistant]]:"])
1406
- return ex
1407
- else:
1408
- return ex
1409
-
1410
-
1411
- @seqio.map_over_dataset(num_seeds=25)
1412
- def clock_augmentation(ex, seeds):
1413
- seeds = list(seeds)
1414
- image = ex["image"]
1415
-
1416
- # Apply shear, rotation, and scale through one affine matrix
1417
- height = tf.cast(tf.shape(image)[0], tf.float32)
1418
- width = tf.cast(tf.shape(image)[1], tf.float32)
1419
-
1420
- _call_id = [0]
1421
-
1422
- def _rng(_minval=0, _maxval=1, shape=(), dtype=tf.float32):
1423
- return tf.random.stateless_uniform(shape, seeds.pop(), _minval, _maxval, dtype=dtype)
1424
-
1425
- sel = _rng(0, 1)
1426
- if sel < 0.1:
1427
- # Straight on
1428
- shear_x = 0.
1429
- shear_y = 0.
1430
- rotation = 0.
1431
- elif sel < 0.5:
1432
- # Normal looking
1433
- shear_x = _rng(-10, 10)
1434
- shear_y = _rng(-10, 10)
1435
- rotation = _rng(-25, 25)
1436
- else:
1437
- # Allowed to be very wonky
1438
- # if tf.random.stateless_uniform((), seeds.pop(), 0, 1) > 0.8:
1439
- # image = image[:, ::-1]
1440
-
1441
- if _rng() > 0.5:
1442
- shear_x = _rng( -30, 30)
1443
- shear_y = _rng( -30, 30)
1444
- else:
1445
- shear_x = _rng( -10, 10)
1446
- shear_y = _rng( -10, 10)
1447
- rng = _rng( 0, 1)
1448
- if rng < 0.2:
1449
- rotation = _rng( -25, 25)
1450
- elif rng < 0.6:
1451
- rotation = _rng( -80, 80)
1452
- else:
1453
- rotation = _rng( -180, 180)
1454
-
1455
- if _rng() > 0.5:
1456
- scale = _rng( 0.3, 2)
1457
- else:
1458
- scale = _rng( 0.3, 1)
1459
- # Pad so upscaling/rotation will not move the image out of bounds
1460
- pad = tf.cast(tf.maximum(height, width)*0.5, tf.int32)
1461
- image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
1462
- height = tf.cast(tf.shape(image)[0], tf.float32)
1463
- width = tf.cast(tf.shape(image)[1], tf.float32)
1464
-
1465
- image = tf.keras.ops.image.affine_transform(
1466
- image,
1467
- tf.stack(get_affine_matrix(
1468
- [height/2, width/2],
1469
- rotation,
1470
- [0, 0],
1471
- 1/scale,
1472
- [shear_x, shear_y]
1473
- ) + [0., 0.]),
1474
- interpolation='bilinear',
1475
- fill_mode='constant',
1476
- fill_value=1.,
1477
- data_format='channels_last'
1478
- )
1479
-
1480
- # Crop, otherwise it would be impossible to put the image at the corner of the image
1481
- not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
1482
- no_white_ix = tf.where(not_white)
1483
- top_left = tf.reduce_min(no_white_ix, axis=0)
1484
- bottom_right = tf.reduce_max(no_white_ix, axis=0)
1485
- image = tf.image.crop_to_bounding_box(
1486
- image,
1487
- offset_height=tf.cast(top_left[0], tf.int32),
1488
- offset_width=tf.cast(top_left[1], tf.int32),
1489
- target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
1490
- target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
1491
- )
1492
-
1493
- # Translate
1494
- height, width = tf.shape(image)[0], tf.shape(image)[1]
1495
- translation_seed = _rng(0, 1)
1496
- if translation_seed < 0.2:
1497
- h_pad = _rng(0, height//2, (2,), dtype=tf.int32)
1498
- w_pad = _rng(0, width//2, (2,), dtype=tf.int32)
1499
- else:
1500
- h_pad = _rng(0, height*2, (2,), dtype=tf.int32)
1501
- w_pad = _rng(0, width*2, (2,), dtype=tf.int32)
1502
- image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
1503
- constant_values=1)
1504
-
1505
- # Random background color
1506
- # color_rng = tf.random.stateless_uniform((4,), seeds.pop(), 0, 1)
1507
- # random_color = color_rng[:3]
1508
- # valid = tf.reduce_all(tf.reduce_sum(tf.abs(random_color[None, None, :] - image), -1) > 0.03)
1509
- # if color_rng[0] < 0.2 and valid:
1510
- # image = tf.where(tf.reduce_all(image < 0.99, axis=-1, keepdims=True),
1511
- # image, image * 0 + random_color[None, None, :])
1512
-
1513
- # Mild color hitter
1514
- image = tf.image.stateless_random_hue(image, max_delta=0.05, seed=seeds.pop())
1515
- image = tf.image.stateless_random_brightness(image, max_delta=0.15, seed=seeds.pop())
1516
- image = tf.image.stateless_random_saturation(image, 0.8, 1.2, seed=seeds.pop())
1517
- image = tf.image.stateless_random_contrast(image, 0.8, 1.2, seed=seeds.pop())
1518
-
1519
- # ex["metadata/unaugmented_image"] = ex["image"]
1520
- ex["image"] = image
1521
- return ex
1522
-
1523
-
1524
- @seqio.map_over_dataset
1525
- def clocks_preprocessor(ex):
1526
- time_format = ex["time_format"]
1527
- shows_seconds = ex["shows_seconds"]
1528
- hour, minute, second = [tf.cast(ex[k], tf.int32) for k in ["hour", "minute", "second"]]
1529
- if hour == 0: # Midnight of the previous day
1530
- am_pm = "PM"
1531
- hour_str = 12
1532
- hour = 24
1533
- elif hour > 12:
1534
- am_pm = "PM"
1535
- hour_str = hour - 12
1536
- else:
1537
- hour_str = hour
1538
- am_pm = "AM"
1539
- hour_str = tf.strings.as_string(hour_str)
1540
- minute_str = tf.strings.as_string(minute)
1541
- if tf.strings.length(minute_str) == 1:
1542
- minute_str = tf.strings.join(["0", minute_str])
1543
-
1544
- second_str = tf.strings.as_string(second)
1545
- if tf.strings.length(second_str) == 1:
1546
- second_str = tf.strings.join(["0", second_str])
1547
-
1548
- prefix = "The time shown is "
1549
-
1550
- if time_format == "The time is not shown":
1551
- text = "The time is not shown in the image."
1552
- hour, minute, second = -1, -1, -1
1553
- else:
1554
- if not shows_seconds:
1555
- second = -1
1556
- if time_format == "12 hour clock (without AM/PM)" and shows_seconds:
1557
- if hour > 12:
1558
- hour = hour - 12
1559
- time = tf.strings.join([hour_str, ":", minute_str, ":", second_str])
1560
- elif time_format == "12 hour clock (with AM/PM)" and shows_seconds:
1561
- time = tf.strings.join([hour_str, ":", minute_str, ":", second_str, " ", am_pm])
1562
- elif time_format == "12 hour clock (with AM/PM)" and not shows_seconds:
1563
- time = tf.strings.join([hour_str, ":", minute_str, " ", am_pm])
1564
- elif time_format == "12 hour clock (without AM/PM)" and not shows_seconds:
1565
- if hour > 12:
1566
- hour = hour - 12
1567
- time = tf.strings.join([hour_str, ":", minute_str])
1568
- else:
1569
- time = "" # Should never occur, but needed for tf analysis
1570
- tf.debugging.assert_equal(tf.strings.length(time) > 0, True)
1571
- text = tf.strings.join(["The time shown is ", time])
1572
- image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1573
- image = tf.image.convert_image_dtype(image, tf.float32)[:-120] # remove the black shadow at the bottom
1574
- return {
1575
- "image": image,
1576
- "prompt": "What time is being shown?",
1577
- "text": text,
1578
- "metadata/time_format": time_format,
1579
- "metadata/hour": hour,
1580
- "metadata/minute": minute,
1581
- "metadata/text": text,
1582
- "metadata/second": second,
1583
- }
1584
-
1585
-
1586
- @seqio.map_over_dataset()
1587
- def atlas_obscura_preprocessor(ex):
1588
- out = dict(
1589
- image=ex["image"],
1590
- prompt="Where was this picture taken?",
1591
- text=tf.strings.join([
1592
- ex["place"],
1593
- " in ",
1594
- ex["city"]
1595
- ])
1596
- )
1597
- out["metadata/image_url"] = ex["image_url"]
1598
- out["metadata/references"] = out["text"]
1599
- return out
1600
-
1601
-
1602
- @seqio.map_over_dataset()
1603
- def famous_birthdays_preprocessor(ex):
1604
- out = dict(
1605
- image=ex["image"],
1606
- image_url=ex["image_url"],
1607
- prompt="Who is this?",
1608
- text=ex["name"]
1609
- )
1610
- out["metadata/references"] = out["text"]
1611
- return out
1612
-
1613
-
1614
- @seqio.map_over_dataset()
1615
- def mild_color_aug_preprocessor(ex):
1616
- if "image_url" in ex: # URL won't show the augmentations
1617
- del ex["image_url"]
1618
- # ex["metadata/unaugmented_image"] = ex["image"]
1619
- ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1620
- ex["image"] = mild_color_aug(ex["image"])
1621
- return ex
1622
-
1623
-
1624
- def build_text_with_points(text, points, img_h, img_w):
1625
- points = points_to_text(img_h, img_w, points[:, 0], points[:, 1])
1626
- parts = tf.strings.split(text, sep="<ANS>")
1627
- with_points = tf.strings.reduce_join(tf.reshape(tf.stack([
1628
- parts,
1629
- tf.pad(points, [[0, 1]], constant_values=""),
1630
- ], 1), [-1]), separator="")
1631
- return tf.strings.split(with_points, "\n\n")
1632
-
1633
-
1634
- @seqio.map_over_dataset()
1635
- def synth_count_preprocessor(example):
1636
- image_shape = tf.shape(example["image"])
1637
- h, w = image_shape[0], image_shape[1]
1638
- questions = build_text_with_points(example["questions"], example["question_points"], h, w)
1639
- answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
1640
- keep_q = tf.strings.regex_full_match(questions, "How many.*")
1641
- keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
1642
- keep = tf.logical_and(keep_q, keep_ans)
1643
- questions = tf.boolean_mask(questions, keep)
1644
- answers = tf.boolean_mask(answers, keep)
1645
- ix = tf.range(0, tf.shape(answers)[0], dtype=tf.int32)
1646
- ix = tf.random.shuffle(ix)
1647
- return dict(
1648
- image=example["image"],
1649
- prompt=tf.gather(questions, ix),
1650
- text=tf.gather(answers, ix),
1651
- )
1652
-
1653
-
1654
- def synth_count_inf_preprocessor(ds):
1655
-
1656
- @seqio.map_over_dataset(num_seeds=1)
1657
- def get_two(example, seed):
1658
- image_shape = tf.shape(example["image"])
1659
- h, w = image_shape[0], image_shape[1]
1660
- questions = build_text_with_points(example["questions"], example["question_points"], h, w)
1661
- answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
1662
- keep_q = tf.strings.regex_full_match(questions, "How many.*")
1663
- keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
1664
- keep = tf.logical_and(keep_q, keep_ans)
1665
- questions = tf.boolean_mask(questions, keep)
1666
- answers = tf.boolean_mask(answers, keep)
1667
-
1668
- ix = stateless_permutation(tf.shape(answers)[0], seed)[:2]
1669
- return {
1670
- "image": example["image"],
1671
- "prompt": tf.gather(questions, ix),
1672
- "metadata/references": tf.gather(answers, ix),
1673
- }
1674
-
1675
- ds = get_two(ds)
1676
- return flatten_parts(ds, ["prompt", "metadata/references"])
1677
-
1678
-
1679
- def mild_color_aug(image):
1680
- image = tf.image.random_hue(image, max_delta=0.05)
1681
- image = tf.image.random_brightness(image, max_delta=0.15)
1682
- image = tf.image.random_saturation(image, 0.7, 1.3)
1683
- image = tf.image.random_contrast(image, 0.8, 1.2)
1684
- return image
1685
-
1686
-
1687
- @seqio.map_over_dataset()
1688
- def name_entity_augmentation(ex, p_high_color=0.7):
1689
- ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1690
- image = ex["image"]
1691
- image = tf.image.convert_image_dtype(image, tf.float32)
1692
-
1693
- # Horizontal flip
1694
- if tf.random.uniform((), 0, 1) > 0.85:
1695
- image = image[:, ::-1]
1696
-
1697
- # Random crop
1698
- height = tf.cast(tf.shape(image)[0], tf.float32)
1699
- width = tf.cast(tf.shape(image)[1], tf.float32)
1700
- crop_rng = tf.random.uniform((), 0, 1)
1701
- if crop_rng < 0.2:
1702
- pass
1703
- else:
1704
- if crop_rng < 0.4:
1705
- h_crop = height * 0.15
1706
- w_crop = width * 0.15
1707
- else:
1708
- h_crop = height * 0.4
1709
- w_crop = width * 0.4
1710
- crop_h = tf.cast(tf.random.uniform((2,), 0, h_crop/2), tf.int32)
1711
- crop_w = tf.cast(tf.random.uniform((2,), 0, w_crop/2), tf.int32)
1712
- image = image[crop_h[0]:-crop_h[1]-1, crop_w[0]:-crop_w[1]-1]
1713
- height = tf.cast(tf.shape(image)[0], tf.float32)
1714
- width = tf.cast(tf.shape(image)[1], tf.float32)
1715
-
1716
- if tf.random.uniform(()) > p_high_color:
1717
- image = tf.image.random_hue(image, max_delta=0.05)
1718
- image = tf.image.random_brightness(image, max_delta=0.15)
1719
- image = tf.image.random_saturation(image, 0.7, 1.3)
1720
- image = tf.image.random_contrast(image, 0.8, 1.2)
1721
- else:
1722
- image = tf.image.random_hue(image, max_delta=0.1)
1723
- image = tf.image.random_brightness(image, max_delta=0.3)
1724
- image = tf.image.random_saturation(image, 0.0, 2.0)
1725
- image = tf.image.random_contrast(image, 0.2, 1.5)
1726
-
1727
- # Apply shear, rotation, and scale through one affine matrix
1728
- sel = tf.random.uniform((), 0, 1)
1729
- if sel < 0.1:
1730
- pass
1731
- else:
1732
- if sel < 0.15: # Scale only
1733
- shear_x = 0
1734
- shear_y = 0
1735
- rotation = 0
1736
- if sel < 0.7: # Mild
1737
- shear_x = tf.random.uniform((), -2, 2)
1738
- shear_y = tf.random.uniform((), -2, 2)
1739
- rotation = tf.random.uniform((), -5, 5)
1740
- else: # Severe
1741
- shear_x = tf.random.uniform((), -10, 10)
1742
- shear_y = tf.random.uniform((), -10, 10)
1743
- rotation = tf.random.uniform((), -20, 20)
1744
-
1745
- max_scale = 1.2
1746
- scale = tf.random.uniform((), 0.4, max_scale)
1747
-
1748
- # Pad so upscaling/rotation will not move the image out of bounds
1749
- pad = tf.cast(tf.maximum(height, width)*0.2, tf.int32)
1750
- image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
1751
-
1752
- image = tf.keras.ops.image.affine_transform(
1753
- image,
1754
- tf.stack(get_affine_matrix(
1755
- [height/2, width/2],
1756
- rotation,
1757
- [0, 0],
1758
- 1/scale,
1759
- [shear_x, shear_y]
1760
- ) + [0., 0.]),
1761
- interpolation='bilinear',
1762
- fill_mode='constant',
1763
- fill_value=1.,
1764
- data_format='channels_last'
1765
- )
1766
-
1767
- # Crop, otherwise it would be impossible to put the image at the corner of the image
1768
- not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
1769
- no_white_ix = tf.where(not_white)
1770
- top_left = tf.reduce_min(no_white_ix, axis=0)
1771
- bottom_right = tf.reduce_max(no_white_ix, axis=0)
1772
-
1773
- # Very low chance center crop will get nothing but white space, we just skip
1774
- if (
1775
- (bottom_right[0] - top_left[0]) > 1 and (bottom_right[1] - top_left[1]) > 1
1776
- ):
1777
- image = tf.image.crop_to_bounding_box(
1778
- image,
1779
- offset_height=tf.cast(top_left[0], tf.int32),
1780
- offset_width=tf.cast(top_left[1], tf.int32),
1781
- target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
1782
- target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
1783
- )
1784
-
1785
- # Translate
1786
- height, width = tf.shape(image)[0], tf.shape(image)[1]
1787
- if tf.random.uniform((), 0, 1) < 0.1:
1788
- h_pad = tf.zeros((2,), dtype=tf.int32)
1789
- w_pad = tf.zeros((2,), dtype=tf.int32)
1790
- elif tf.random.uniform((), 0, 1) < 0.8:
1791
- h_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
1792
- w_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
1793
- else:
1794
- pad = tf.cast(tf.maximum(height, width), tf.int32)
1795
- h_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
1796
- w_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
1797
- image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
1798
- constant_values=1)
1799
-
1800
- if "image_url" in ex: # URL won't show the augmentations
1801
- del ex["image_url"]
1802
- # ex["metadata/unaugmented_image"] = ex["image"]
1803
- ex["image"] = image
1804
- return ex
1805
-
1806
-
1807
- @seqio.map_over_dataset()
1808
- def wiki_art_preprocessor(ex):
1809
- out = dict(
1810
- image=ex["image"],
1811
- prompt="What is this?",
1812
- text=ex["question"]
1813
- )
1814
- out["metadata/title"] = ex["title"]
1815
- out["metadata/gt"] = ex["question"]
1816
- out["metadata/artist"] = ex["artist"]
1817
- out["metadata/painting_url"] = ex["painting_url"]
1818
- # if "metadata/unaugmented_image" in ex:
1819
- # out["metadata/unaugmented_image"] = ex["metadata/unaugmented_image"]
1820
- return out
1821
-
1822
- @seqio.map_over_dataset()
1823
- def oscar_preprocessor(ex):
1824
- out = dict(
1825
- image=ex["image"],
1826
- prompt=ex["question"]
1827
- )
1828
- out.update(_add_metadata(ex))
1829
- out["metadata/question"] = ex["question"]
1830
- out["metadata/answer"] = ex["answer"]
1831
- out["metadata/category"] = ex["category"]
1832
- return out
1833
-
1834
-
1835
- @seqio.map_over_dataset()
1836
- def tulu_preprocessor(ex):
1837
- return {
1838
- "messages": ex["messages"]["content"],
1839
- }
1840
- # logging.info("Debugging tulue")
1841
- # return {"messages": ex["messages"]["content"], "text_weights": 1e-6}
1842
-
1843
-
1844
- WIKI_DATA_QUESTION = "What is this? Respond with just a proper name."
1845
-
1846
-
1847
- @seqio.map_over_dataset()
1848
- def extract_wiki_data(ex):
1849
- return dict(
1850
- image=ex["image"],
1851
- image_url=ex["image_url"],
1852
- prompt=[
1853
- WIKI_DATA_QUESTION,
1854
- "What is this? Respond with the proper name of the main focus of the image and a few details about it."
1855
- ],
1856
- text=[
1857
- tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", "")),
1858
- ex["gptResponse"],
1859
- ]
1860
- )
1861
-
1862
-
1863
- @seqio.map_over_dataset()
1864
- def extract_wiki_data_name(ex):
1865
- target = tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", ""))
1866
- out = dict(
1867
- image=ex["image"],
1868
- image_url=ex["image_url"],
1869
- prompt=WIKI_DATA_QUESTION,
1870
- text=target,
1871
- )
1872
- out["metadata/references"] = target
1873
- return out
1874
-
1875
-
1876
- @seqio.map_over_dataset()
1877
- def extract_wiki_data_describe(ex):
1878
- out = dict(
1879
- image=ex["image"],
1880
- image_url=ex["image_url"],
1881
- prompt="What is this? Respond with the proper name of the main focus of the image and a few details about it.",
1882
- )
1883
- out["metadata/references"] = ex["gptResponse"]
1884
- return out
1885
-
1886
-
1887
- @gin.configurable()
1888
- def format_multiple_style_qa(ds, types=['multiple_choice', 'short_answer'], styles=['ai2_diagram', 'vqa2'], default_style='vqa2',
1889
- strip_instruction=False):
1890
- def _extract(ex):
1891
- prompt = ex["question"]
1892
- out = dict(image=ex["image"])
1893
- out.update(_add_metadata(ex))
1894
-
1895
- out["text"] = ex["answer"]
1896
- out["metadata/references"] = ex["answer"]
1897
-
1898
- if ex["metadata/question_type"] == 'multiple_choice':
1899
- style = styles[0]
1900
- else:
1901
- style = styles[1]
1902
- if strip_instruction:
1903
- if ex["metadata/question_type"] == "multiple_choice":
1904
- # parts = tf.strings.split(prompt, "\n")
1905
- # parts 1 is blank and part -1 is the instruction
1906
- # prompt = tf.strings.reduce_join(tf.concat([parts[:1], parts[2:-1]], 0), separator="\n")
1907
- prompt = prompt
1908
- else:
1909
- prompt = tf.strings.split(prompt, "\n")[0]
1910
-
1911
- out["style"] = style
1912
- out["prompt"] = prompt
1913
- return out
1914
- ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1915
- return ds
1916
-
1917
-
1918
- @gin.configurable()
1919
- def extract_mmmu(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
1920
- assert option_format == "abc"
1921
- keys_tensor = tf.constant(types, dtype=tf.string)
1922
- values_tensor = tf.constant(styles, dtype=tf.string)
1923
- table = tf.lookup.StaticHashTable(
1924
- tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
1925
- default_value=tf.constant(default_style, dtype=tf.string),
1926
- )
1927
- def _extract(ex):
1928
- out = dict(image=tf.expand_dims(ex["image_1"], 0))
1929
- out.update(_add_metadata(ex))
1930
- style = table.lookup(ex["metadata/question_type"])
1931
- out["style"] = style
1932
- out["text"] = ex["answer"]
1933
- out["metadata/references"] = ex["answer"]
1934
-
1935
- if style == styles[0]:
1936
- abc = tf.constant(list("abcdefghi".upper()))
1937
- options = ex["options"]
1938
- num_options = tf.shape(options)[0]
1939
- dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
1940
- out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
1941
- out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
1942
-
1943
- short_options = abc[:num_options]
1944
- options = tf.stack([short_options, options,], 1)
1945
- options = tf.strings.reduce_join(options, axis=-1, separator=": ")
1946
- options = tf.strings.reduce_join(options, separator="\n")
1947
- out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
1948
- if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
1949
- # Following LLaVa, don't use any images if there are multiple images paths
1950
- # I think the rationale is that this means the image are answer-options
1951
- out["image"] = out["image"][:0]
1952
- else:
1953
- out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
1954
- out["prompt"] = ex["question"]
1955
- out["image"] = out["image"][:0]
1956
- return out
1957
- ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1958
- return ds
1959
-
1960
- @gin.configurable()
1961
- def extract_mmmu_cot(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
1962
- assert option_format == "abc"
1963
- keys_tensor = tf.constant(types, dtype=tf.string)
1964
- values_tensor = tf.constant(styles, dtype=tf.string)
1965
- table = tf.lookup.StaticHashTable(
1966
- tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
1967
- default_value=tf.constant(default_style, dtype=tf.string),
1968
- )
1969
- def _extract(ex):
1970
- # out = dict(image=tf.expand_dims(ex["image_with_question"], 0))
1971
- out = dict(image=tf.expand_dims(ex["image_1"], 0))
1972
- out.update(_add_metadata(ex))
1973
- style = table.lookup(ex["metadata/question_type"])
1974
- # out["style"] = style
1975
- out["text"] = ex["answer"]
1976
- out["metadata/question"] = ex["question"]
1977
- out["metadata/references"] = ex["answer"]
1978
-
1979
- if style == styles[0]:
1980
- abc = tf.constant(list("abcdefghi".upper()))
1981
- options = ex["options"]
1982
- num_options = tf.shape(options)[0]
1983
- dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
1984
- out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
1985
- out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
1986
-
1987
- short_options = abc[:num_options]
1988
- options = tf.stack([short_options, options,], 1)
1989
- options = tf.strings.reduce_join(options, axis=-1, separator=": ")
1990
- options = tf.strings.reduce_join(options, separator="\n")
1991
- out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
1992
- # out["prompt"] = ex["question"]
1993
- if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
1994
- # Following LLaVa, don't use any images if there are multiple images paths
1995
- # I think the rationale is that this means the image are answer-options
1996
- out["image"] = out["image"][:0]
1997
- else:
1998
- out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
1999
- out["prompt"] = ex["question"]
2000
- # out["image"] = out["image"][:0]
2001
- return out
2002
- ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
2003
- return ds
2004
-
2005
-
2006
- @seqio.map_over_dataset
2007
- def reformat_math_vista(ex):
2008
- query = ex["query"]
2009
- query = tf.strings.split(query, sep="Question:")[-1]
2010
- query = tf.strings.strip(tf.strings.split(query, sep="Hint:")[0])
2011
- ex["query"] = query
2012
- return ex
2013
-
2014
-
2015
- @seqio.map_over_dataset
2016
- def extract_math_vista(ex, styles=['ai2_diagram', 'vqa2']):
2017
- out = dict(image=ex["image"])
2018
- out.update(_add_metadata(ex))
2019
-
2020
- is_mc = ex["metadata/question_type"] == 'multi_choice'
2021
- if is_mc:
2022
- style = styles[0]
2023
- abc = tf.constant(list("abcdefghi".upper()))
2024
- options = ex["choices"]
2025
- num_options = tf.shape(options)[0]
2026
- dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
2027
- out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
2028
- out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
2029
-
2030
- if ex["metadata/split"] != "test":
2031
- short_options = abc[:num_options]
2032
- answer_short_option = tf.boolean_mask(short_options, options == ex["answer"])[0]
2033
- out["text"] = answer_short_option
2034
- else:
2035
- out["text"] = ex["answer"]
2036
- else:
2037
- style = styles[1]
2038
- out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
2039
- out["text"] = ex["answer"]
2040
- out["style"] = style
2041
- out["prompt"] = ex["query"]
2042
- out["metadata/query"] = ex["query"]
2043
- out["metadata/references"] = ex["answer"]
2044
- return out
2045
-
2046
-
2047
- NO_POINT_PREFIX = [
2048
- "No pointing: ",
2049
- "No pointing: ",
2050
- "no pointing:\n",
2051
- "No pointing:\n",
2052
- "Not pointing:\n",
2053
- "No Points: ",
2054
- "No Points: ",
2055
- "NO POINTING\n",
2056
- "No pontiing\n",
2057
- "No Points:\n ",
2058
- "No pointing\n",
2059
- "Do not point. ",
2060
- "Refrain from pointing. ",
2061
- "Avoid generating points . ",
2062
- "For this question, do not use points. ",
2063
- "Refrain from using points:\n",
2064
- "Don't include points in your response. ",
2065
- "Don't point. ",
2066
- "Don't use points. ",
2067
- "Please don't use points.\n\n",
2068
- "Please don't use points.\n\n",
2069
- "Respond without using points. ",
2070
- "Respond without pointing:\n",
2071
- "Do not generate ponits: ",
2072
- "Do not point. ",
2073
- "Do not point\n",
2074
- "no pointing\n\n",
2075
- "Answer without points: ",
2076
- "Answer this question without pointing: ",
2077
- "Answer without poiints. ",
2078
- "answer without points: ",
2079
- "answer with text only, do not points\n"
2080
- ]
2081
- assert all(x[-1].isspace() for x in NO_POINT_PREFIX)
2082
- NO_POINT_PREFIX_TF = tf.constant(NO_POINT_PREFIX)
2083
-
2084
-
2085
- def prefix_how_many(messages, seed):
2086
- question = messages[0]
2087
- if tf.strings.regex_full_match(tf.strings.lower(question), "how many.*"):
2088
- ix = tf.random.stateless_uniform((), seed, 0, len(NO_POINT_PREFIX), tf.int32)
2089
- question = tf.strings.join([NO_POINT_PREFIX_TF[ix], question])
2090
- return tf.concat([tf.expand_dims(question, 0), messages[1:]], axis=0)
2091
- else:
2092
- return messages
2093
-
2094
-
2095
- @seqio.map_over_dataset(num_seeds=1)
2096
- def prefix_how_many_messages(ex, seed):
2097
- messages = ex["messages"]
2098
- n = tf.shape(messages)[0]
2099
- seeds = tf.random.split(seed, n)
2100
- message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
2101
- for i in range(n):
2102
- message_arr = message_arr.write(i, prefix_how_many(messages[i], seeds[i]))
2103
- ex["messages"] = tf.RaggedTensor.from_row_splits(
2104
- values=message_arr.concat(), row_splits=messages.row_splits)
2105
- return ex
2106
-
2107
-
2108
- def filter_single_turn(ds):
2109
- @seqio.map_over_dataset
2110
- def _filter(ex):
2111
- multi_turn = ex["messages"].row_lengths() > 2
2112
- ex["messages"] = tf.ragged.boolean_mask(ex["messages"], multi_turn)
2113
- return ex
2114
-
2115
- ds = _filter(ds)
2116
- ds = ds.filter(lambda x: tf.shape(x["messages"])[0] > 0)
2117
- return ds
2118
-
2119
-
2120
- @seqio.map_over_dataset(num_seeds=1)
2121
- def extract_cockatoo_qa_v2(ex, seed):
2122
- messages = tf.RaggedTensor.from_value_rowids(ex["messages"], ex["conversation_ids"])
2123
- ix = stateless_permutation(tf.shape(messages)[0], seed)
2124
- messages = tf.gather(messages, ix)
2125
- out = dict(
2126
- image=ex["image"],
2127
- messages=messages
2128
- )
2129
- out.update(_add_metadata(ex))
2130
- return out
2131
-
2132
-
2133
- def format_mmbench(ds):
2134
-
2135
- def _trim(ex):
2136
- num_passes = tf.shape(ex["id"])[0]
2137
- ex["choices"] = ex["choices"][:num_passes, :num_passes]
2138
- ex["answer"] = ex["answer"][:num_passes]
2139
- return ex
2140
-
2141
- ds = ds.map(_trim)
2142
- ds = flatten_parts(ds, ["id", "query", "choices", "answer"])
2143
-
2144
- def _extract(ex):
2145
- out = dict(image=ex["image"])
2146
- out.update(_add_metadata(ex))
2147
- out["prompt"] = ex["query"]
2148
- out["text"] = ex["answer"]
2149
- options = ex["choices"]
2150
- tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
2151
- out["metadata/options"] = tf.strings.reduce_join(options, separator="|||")
2152
- out["metadata/question"] = ex["question"]
2153
- out["metadata/references"] = ex["answer"]
2154
- return out
2155
-
2156
- ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
2157
- return ds
2158
-
2159
-
2160
- @seqio.map_over_dataset
2161
- def extract_lvis(ex, class_name_file="gs://oe-training-chrisc/cockatoo/data/lvis_class_names.json"):
2162
- with tf.io.gfile.GFile(class_name_file) as f:
2163
- class_names = json.load(f)
2164
- class_names_arr = [None]*len(class_names)
2165
- for k, v in class_names.items():
2166
- class_names_arr[int(k)] = v
2167
- assert all(x is not None for x in class_names_arr)
2168
- class_names_arr = tf.constant(class_names_arr)
2169
-
2170
- return dict(
2171
- image=ex["image"],
2172
- bbox=ex["objects"]["bbox"],
2173
- label=tf.gather(class_names_arr, ex["objects"]["label"]),
2174
- )
2175
-
2176
-
2177
- def extract_open_images_boxes(ds):
2178
- # ds = ds.filter(lambda ex: tf.logical_or(
2179
- # tf.shape(ex["cap/cap_caption"])[0] > 0,
2180
- # tf.shape(ex["detection/bbox"])[0] > 0
2181
- # ))
2182
- ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
2183
-
2184
- @seqio.map_over_dataset
2185
- def _map(ex):
2186
- bbox = tf.reshape(ex["detection/bbox"], (-1, 4))
2187
- bbox = tf.stack([
2188
- bbox[:, 2],
2189
- bbox[:, 0],
2190
- bbox[:, 3],
2191
- bbox[:, 1]
2192
- ], 1)
2193
- return dict(
2194
- image=tf.image.decode_jpeg(ex["image"]),
2195
- bbox=bbox,
2196
- label=ex["detection/label"],
2197
- caption=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
2198
- )
2199
-
2200
- return _map(ds)
2201
-
2202
-
2203
- @seqio.map_over_dataset
2204
- def region_captions_to_dense(ex):
2205
- if "captions" in ex:
2206
- captions = ex["captions"]["text"]
2207
- boxes = ex["captions"]["bbox"]
2208
- else:
2209
- captions = ex["label"]
2210
- boxes = ex["bbox"]
2211
-
2212
-
2213
- sh = tf.cast(tf.shape(ex["image"])[:2], tf.float32)
2214
- # image_h, image_w = sh[0], sh[1]
2215
- w = boxes[:, 2] - boxes[:, 0]
2216
- h = boxes[:, 3] - boxes[:, 1]
2217
-
2218
- cx = tf.cast(boxes[:, 0] + w/2, tf.float32)
2219
- cy = tf.cast(boxes[:, 1] + h/2, tf.float32)
2220
- # w = w / image_w
2221
- # h = h / image_h
2222
- coor = tf.strings.reduce_join(
2223
- float_to_text(tf.stack([cx, cy, w, h], 1)), separator=",", axis=1)
2224
-
2225
- area = w*h
2226
- if tf.random.uniform(()) < 0.5:
2227
- coor_text = "before"
2228
- captions = tf.strings.join([coor, captions], separator=": ")
2229
- else:
2230
- coor_text = "after"
2231
- captions = tf.strings.join([captions, coor], separator=": ")
2232
-
2233
- ix = tf.random.uniform((), 0, 6, tf.int32)
2234
- center = boxes
2235
- if ix == 0:
2236
- order_text = "left"
2237
- sort_by = boxes[:, 0]
2238
- elif ix == 1:
2239
- order_text = "right"
2240
- sort_by = -boxes[:, 2]
2241
- elif ix == 2:
2242
- order_text = "top"
2243
- sort_by = boxes[:, 1]
2244
- elif ix == 3:
2245
- order_text = "bottom"
2246
- sort_by = -boxes[:, 3]
2247
- elif ix == 4:
2248
- order_text = "largest"
2249
- sort_by = area
2250
- else:
2251
- order_text = "smallest"
2252
- sort_by = -area
2253
- ixs = tf.argsort(sort_by)
2254
- captions = tf.gather(captions, ixs)
2255
- text = tf.strings.join([
2256
- order_text,
2257
- coor_text,
2258
- tf.strings.reduce_join(captions, separator="\n")
2259
- ], separator="; ")
2260
-
2261
- if "caption" in ex:
2262
- if tf.random.uniform(()) > 0.5:
2263
- text = tf.strings.join([text, "\ncaption: ", ex["caption"]])
2264
- else:
2265
- text = tf.strings.join(["caption: ", ex["caption"], "\n", text])
2266
-
2267
- return dict(
2268
- image=ex["image"],
2269
- text=text
2270
- )
2271
-
2272
-
2273
- @seqio.map_over_dataset()
2274
- def join_captions(ex):
2275
- text = tf.random.shuffle(ex['text'])
2276
- ex["text"] = tf.strings.reduce_join(text, separator="\n")
2277
- return ex
2278
-
2279
-
2280
- @seqio.map_over_dataset(num_seeds=1)
2281
- def extract_figureqa(ex, seed):
2282
- questions = ex["questions"]
2283
- n = stateless_permutation(tf.shape(questions["question"])[0], seed)
2284
- return dict(
2285
- image=ex["image"],
2286
- questions=tf.gather(questions["question"], n),
2287
- question_id=tf.gather(questions["question_id"], n),
2288
- answer=tf.gather(tf.strings.as_string(questions["answer"]), n)
2289
- )
2290
-
2291
-
2292
- @seqio.map_over_dataset
2293
- def convert_figureqa_answer(ex):
2294
- keys_tensor = tf.constant(["0", "1"])
2295
- values_tensor = tf.constant(["no", "yes"])
2296
- table = tf.lookup.StaticHashTable(
2297
- tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
2298
- default_value=tf.constant("nan", dtype=tf.string),
2299
- )
2300
- answer = table.lookup(ex["answer"])
2301
- ex["answer"] = answer
2302
- return ex
2303
-
2304
-
2305
- @seqio.map_over_dataset()
2306
- def build_question_with_hint(ex):
2307
- hint = ex["hint"]
2308
- if tf.strings.length(hint) > 0:
2309
- ex["question"] = tf.strings.join([hint, ex["question"]], separator="\n")
2310
- return ex
2311
-
2312
- @seqio.map_over_dataset()
2313
- def build_question_with_context(ex):
2314
- context = ex["context"]
2315
- if tf.strings.length(context) > 0:
2316
- ex["question"] = tf.strings.join([context, ex["question"]], separator="\n")
2317
- return ex
2318
-
2319
-
2320
- def max_words(ds, max_words):
2321
- return ds.filter(lambda x: x["n_words"] <= max_words)
2322
-
2323
-
2324
- @seqio.map_over_dataset
2325
- def format_pdfa_eng_wds(example):
2326
- return dict(
2327
- image=example["image"],
2328
- text=tf.strings.reduce_join(example["lines"]["text"], separator="\n"),
2329
- )
2330
-
2331
-
2332
- @gin.configurable()
2333
- def accuracy_conditioned_joint(ds, sequence_length, is_eval=False, eval_quality=17,
2334
- transcript_quality=None):
2335
- # v2: Transcripts no longer get a quality score
2336
- is_training = sequence_length.get('is_training', True)
2337
- if not is_training:
2338
- if is_eval:
2339
- prompt = f"quality {eval_quality}:"
2340
- else:
2341
- prompt = f"quality 17:"
2342
-
2343
- @seqio.map_over_dataset
2344
- def _with_prompt(ex):
2345
- out = dict(
2346
- image=ex["image"],
2347
- url=ex["url"],
2348
- prompt=prompt,
2349
- )
2350
- if "text" in ex:
2351
- out["text"] = ex["text"]
2352
- elif "caption" in ex:
2353
- out["text"] = ex["caption"]
2354
- return out
2355
- return _with_prompt(ds)
2356
-
2357
- elif is_eval:
2358
- raise ValueError("is_eval=True and is_training=False")
2359
-
2360
- # each transcript
2361
- @seqio.map_over_dataset
2362
- def _with_transcript(ex):
2363
- if tf.shape(ex["edited_captions"]["caption"])[0] > 0:
2364
- edited_caption = ex["edited_captions"]["caption"][0]
2365
- n = ex["edited_captions"]["n_edits"][0]
2366
- else:
2367
- edited_caption = ""
2368
- n = 0
2369
- text = [
2370
- ex["caption"],
2371
- ex["transcripts"][tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)],
2372
- edited_caption
2373
- ]
2374
- edit_quality = 17 - n
2375
- prompt = [
2376
- "quality 17:",
2377
- "" if transcript_quality is None else f"quality: {edit_quality}:",
2378
- tf.strings.join(["quality ", tf.strings.as_string(edit_quality), ":"])
2379
- ]
2380
- return dict(
2381
- image=ex["image"],
2382
- text=tf.stack(text, 0),
2383
- url=ex["url"],
2384
- prompt=tf.stack(prompt, 0),
2385
- style=["long_caption", "transcript", "long_caption"]
2386
- )
2387
- return _with_transcript(ds)
2388
-
2389
-
2390
- def select_dense_caption_sample(ds, samples=200):
2391
- def compute_hash(string: str) -> str:
2392
- return hashlib.sha256(string.encode("utf-8")).hexdigest()
2393
-
2394
- with tf.io.gfile.GFile("gs://oe-training-chrisc/cockatoo/data/dense-caption-eval-v0-final-data.json") as f:
2395
- data = json.load(f)
2396
- for ex in data:
2397
- ex["image_id"] = compute_hash(ex["image"])
2398
- data.sort(key=lambda x: x["image_id"])
2399
- np.random.RandomState(12312).shuffle(data)
2400
- keep = tf.constant([x["image"] for x in data[:samples]])
2401
-
2402
- def _keep(ex):
2403
- return tf.reduce_any(ex["url"] == keep)
2404
- ds = ds.filter(_keep)
2405
- ds = tf.data.experimental.assert_cardinality(samples)(ds)
2406
- return ds
2407
-
2408
- @seqio.map_over_dataset()
2409
- def charxiv_preprocessor(ex):
2410
- question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4", "reasoning_q"]
2411
- answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4", "reasoning_a"]
2412
-
2413
- questions = [ex[name] for name in question_names]
2414
- answers = [ex[name] for name in answer_names]
2415
-
2416
- return dict(
2417
- image=ex["image"],
2418
- question=tf.stack(questions, 0),
2419
- answer=tf.stack(answers, 0)
2420
- )
2421
-
2422
- @seqio.map_over_dataset()
2423
- def charxiv_descriptive_preprocessor(ex):
2424
- question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4"]
2425
- answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4"]
2426
-
2427
- questions = [ex[name] for name in question_names]
2428
- answers = [ex[name] for name in answer_names]
2429
-
2430
- return dict(
2431
- image=ex["image"],
2432
- question=tf.stack(questions, 0),
2433
- answer=tf.stack(answers, 0)
2434
- )
2435
-
2436
- @seqio.map_over_dataset()
2437
- def charxiv_reasoning_preprocessor(ex):
2438
- return dict(
2439
- image=ex["image"],
2440
- question=ex["reasoning_q"],
2441
- answer=ex["reasoning_a"]
2442
- )
2443
-
2444
- @seqio.map_over_dataset()
2445
- def tablevqa_preprocessor(ex):
2446
- return dict(
2447
- image=ex["image"],
2448
- question=ex["question"],
2449
- answer=ex["gt"]
2450
- )
2451
-
2452
- @seqio.map_over_dataset()
2453
- def vtabfact_preprocessor(ex):
2454
- return dict(
2455
- image=ex["image"],
2456
- question=tf.strings.join([ex["question"], "Answer with yes or no."], separator="\n"),
2457
- answer=ex["gt"]
2458
- )
2459
-
2460
- @seqio.map_over_dataset()
2461
- def nutrition_fact_preprocessor(ex):
2462
- question_names = ["descriptive_q", "reasoning_q"]
2463
- answer_names = ["descriptive_a", "reasoning_a"]
2464
-
2465
- questions = [ex[name] for name in question_names]
2466
- answers = [ex[name] for name in answer_names]
2467
-
2468
- return dict(
2469
- image=ex["image"],
2470
- question=tf.stack(questions, 0),
2471
- answer=tf.stack(answers, 0)
2472
- )