File size: 16,215 Bytes
bb6012a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
import os
from typing import List, Union
import random

import json
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from loguru import logger

from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
_tokenizer = _Tokenizer()

# text_tokenize = AutoTokenizer.from_pretrained("./Taiyi-CLIP-s", model_max_length=512)
def tokenize(texts: Union[str, List[str]],

             context_length: int = 77,

             truncate: bool = False) -> torch.LongTensor:
    """

    Returns the tokenized representation of given input string(s)



    Parameters

    ----------

    texts : Union[str, List[str]]

        An input string or a list of input strings to tokenize



    context_length : int

        The context length to use; all CLIP models use 77 as the context length



    truncate: bool

        Whether to truncate the text in case its encoding is longer than the context length



    Returns

    -------

    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]

    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
                  for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(
                    f"Input {texts[i]} is too long for context length {context_length}"
                )
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

def select_idxs(seq_length, n_to_select, n_from_select, seed=42):
    """

    Select n_to_select indexes from each consequent n_from_select indexes from range with length seq_length, split

    selected indexes to separate arrays



    Example:



    seq_length = 20

    n_from_select = 5

    n_to_select = 2



    input, range of length seq_length:

    range = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]



    sequences of length n_from_select:

    sequences = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]



    selected n_to_select elements from each sequence

    selected = [[0, 4], [7, 9], [13, 14], [16, 18]]



    output, n_to_select lists of length seq_length / n_from_select:

    output = [[0, 7, 13, 16], [4, 9, 14, 18]]



    :param seq_length: length of sequence, say 10

    :param n_to_select: number of elements to select

    :param n_from_select: number of consequent elements

    :return:

    """
    random.seed(seed)
    idxs = [[] for _ in range(n_to_select)]
    for i in range(seq_length // n_from_select):
        ints = random.sample(range(n_from_select), n_to_select)
        for j in range(n_to_select):
            idxs[j].append(i * n_from_select + ints[j])
    return idxs

def read_json(file_name, suppress_console_info=False):
    """

    Read JSON



    :param file_name: input JSON path

    :param suppress_console_info: toggle console printing

    :return: dictionary from JSON

    """
    with open(file_name, 'r') as f:
        data = json.load(f)
        if not suppress_console_info:
            print("Read from:", file_name)
    return data

def get_image_file_names(data, suppress_console_info=False):# ok
    """

    Get list of image file names



    :param data: original data from JSON

    :param suppress_console_info: toggle console printing

    :return: list of strings (file names)

    """

    file_names = []
    for img in data['images']:
        image_name = img["image_name"]
        sample_id = img["sample_id"]
        path_data = f'{sample_id}/{image_name}'
        file_names.append(path_data)
    if not suppress_console_info:
        print("Total number of files:", len(file_names))
    return file_names

def get_images(file_names, args):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])
    imgs = []
    for i in range(len(file_names)):

        img = np.array(transform(Image.open(os.path.join(args.imgs_folder, file_names[i]))))
        imgs.append(img)

    return np.array(imgs)

def get_captions(data, suppress_console_info=False):
    """

    Get list of formatted captions

    :param data: original data from JSON

    :return: list of strings (captions)

    """
    def format_caption(string):
        return string.replace('.', '').replace(',', '').replace('!', '').replace('?', '').lower()

    captions = []
    augmented_captions_rb = []
    augmented_captions_bt_prob = []
    augmented_captions_bt_chain = []
    for img in data['images']:
        for sent in img['sentences']:
            captions.append(format_caption(sent['raw']))
            try:
                augmented_captions_rb.append(format_caption(sent['aug_rb']))
            except:
                pass
            try:
                augmented_captions_bt_prob.append(format_caption(sent['aug_bt_prob']))
            except:
                pass
            try:
                augmented_captions_bt_chain.append(format_caption(sent['aug_bt_chain']))
            except:
                pass
    if not suppress_console_info:
        logger.info("Total number of captions:{}", len(captions))
        logger.info("Total number of augmented captions RB:{}", len(augmented_captions_rb))
        logger.info("Total number of augmented captions BT (prob):{}", len(augmented_captions_bt_prob))
        logger.info("Total number of augmented captions BT (chain):{}", len(augmented_captions_bt_chain))
    return captions, augmented_captions_rb, augmented_captions_bt_prob, augmented_captions_bt_chain

def get_labels(data, suppress_console_info=False):
    """

    Get list of labels



    :param data: original data from JSON

    :param suppress_console_info: toggle console printing

    :return: list ints (labels)

    """

    labels = []
    for img in data['images']:
        labels.append(img["classcode"])
    if not suppress_console_info:
        print("Total number of labels:", len(labels))
    return labels

def remove_tokens(data):
    """

    Removes 'tokens' key from caption record, if exists; halves the size of the file



    :param data: original data

    :return: data without tokens

    """
    for img in data['images']:
        for sent in img['sentences']:
            try:
                sent.pop("tokens")
            except:
                pass
    return data

def write_json(file_name, data):
    """

    Write dictionary to JSON file



    :param file_name: output path

    :param data: dictionary

    :return: None

    """
    bn = os.path.basename(file_name)
    dn = os.path.dirname(file_name)
    name, ext = os.path.splitext(bn)
    file_name = os.path.join(dn, name + '.json')
    with open(file_name, 'w') as f:
        f.write(json.dumps(data, indent='\t'))
    print("Written to:", file_name)

def get_split_idxs(arr_len, args):
    """

    Get indexes for training, query and db subsets



    :param: arr_len: array length



    :return: indexes for training, query and db subsets

    """
    idx_all = list(range(arr_len))
    idx_train, idx_eval = split_indexes(idx_all, args.dataset_train_split)
    idx_query, idx_db = split_indexes(idx_eval, args.dataset_query_split)

    return idx_train, idx_eval, idx_query, idx_db

def split_indexes(idx_all, split):
    """

    Splits list in two parts.



    :param idx_all: array to split

    :param split: portion to split

    :return: splitted lists

    """
    idx_length = len(idx_all)
    selection_length = int(idx_length * split)

    idx_selection = sorted(random.sample(idx_all, selection_length))

    idx_rest = sorted(list(set(idx_all).difference(set(idx_selection))))

    return idx_selection, idx_rest

def get_caption_idxs(idx_train, idx_query, idx_db):
    """

    Get caption indexes.



    :param: idx_train: train image (and label) indexes

    :param: idx_query: query image (and label) indexes

    :param: idx_db: db image (and label) indexes



    :return: caption indexes for corresponding index sets

    """
    idx_train_cap = get_caption_idxs_from_img_idxs(idx_train, num=5)
    idx_query_cap = get_caption_idxs_from_img_idxs(idx_query, num=5)
    idx_db_cap = get_caption_idxs_from_img_idxs(idx_db)
    return idx_train_cap, idx_query_cap, idx_db_cap

def get_caption_idxs_from_img_idxs(img_idxs, num=5):
    """

    Get caption indexes. There are 5 captions for each image (and label).

    Say, img indexes - [0, 10, 100]

    Then, caption indexes - [0, 1, 2, 3, 4, 50, 51, 52, 53, 54, 100, 501, 502, 503, 504]



    :param: img_idxs: image (and label) indexes



    :return: caption indexes

    """
    caption_idxs = []
    for idx in img_idxs:
        for i in range(num):  # each image has 5 captions
            caption_idxs.append(idx * num + i)
    return caption_idxs

def split_data(images, captions, labels, captions_aug, images_aug, args):
    """

    Split dataset to get training, query and db subsets



    :param: images: image embeddings array

    :param: captions: caption embeddings array

    :param: labels: labels array

    :param: captions_aug: augmented caption embeddings

    :param: images_aug: augmented image embeddings



    :return: tuples of (images, captions, labels), each element is array

    """
    idx_tr, idx_q, idx_db = get_split_idxs(len(images), args)
    idx_tr_cap, idx_q_cap, idx_db_cap = get_caption_idxs(idx_tr, idx_q, idx_db)

    train = images[idx_tr], captions[idx_tr_cap], labels[idx_tr], (idx_tr, idx_tr_cap), captions_aug[idx_tr_cap], \
                images_aug[idx_tr]
    query = images[idx_q], captions[idx_q_cap], labels[idx_q], (idx_q, idx_q_cap), captions_aug[idx_q_cap], \
                images_aug[idx_q]
    db = images[idx_db], captions[idx_db_cap], labels[idx_db], (idx_db, idx_db_cap), captions_aug[idx_db_cap], \
             images_aug[idx_db]

    return train, query, db

def select_idxs(seq_length, n_to_select, n_from_select, seed=42):
    """

    Select n_to_select indexes from each consequent n_from_select indexes from range with length seq_length, split

    selected indexes to separate arrays



    Example:



    seq_length = 20

    n_from_select = 5

    n_to_select = 2



    input, range of length seq_length:

    range = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]



    sequences of length n_from_select:

    sequences = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]



    selected n_to_select elements from each sequence

    selected = [[0, 4], [7, 9], [13, 14], [16, 18]]



    output, n_to_select lists of length seq_length / n_from_select:

    output = [[0, 7, 13, 16], [4, 9, 14, 18]]



    :param seq_length: length of sequence, say 10

    :param n_to_select: number of elements to select

    :param n_from_select: number of consequent elements

    :return:

    """
    random.seed(seed)
    idxs = [[] for _ in range(n_to_select)]
    for i in range(seq_length // n_from_select):
        ints = random.sample(range(n_from_select), n_to_select)
        for j in range(n_to_select):
            idxs[j].append(i * n_from_select + ints[j])
    return idxs

class AbstractDataset(torch.utils.data.Dataset):

    def __init__(self, images, captions, labels, targets, idxs):

        self.image_replication_factor = 1  # default value, how many times we need to replicate image

        self.images = images
        self.captions = captions
        self.labels = labels
        self.targets = targets

        self.idxs = np.array(idxs[0])


    def __getitem__(self, index):
        return

    def __len__(self):
        return

class CISENDataset(torch.utils.data.Dataset):
    """

    Class for dataset representation.

    Each image has 5 corresponding captions

    Duplet dataset sample - img-txt (image and corresponding caption)

    """
    def __init__(self, images, captions, args):
        """

        Initialization.

        :param images: image embeddings vector

        :param captions: captions embeddings vector

        :param labels: labels vector

        """
        super().__init__()

        self.images = images
        self.captions = captions
        # self.targets = targets
        # self.labels = labels

        self.word_len = args.word_len

    def __getitem__(self, index):
        """

        Returns a tuple (img, txt, label) - image and corresponding caption

        :param index: index of sample

        :return: tuple (img, txt, label)

        """
        return (
            torch.from_numpy(self.images[index].astype('float32')),
            torch.from_numpy(np.array(tokenize(self.captions[index], self.word_len).squeeze(0)).astype('int64'))
            # ,torch.from_numpy(self.targets[index])
        )

    def __len__(self):
        return len(self.images)


class DatasetDuplet(AbstractDataset):
    """

    Class for dataset representation.

    Each image has 5 corresponding captions

    Duplet dataset sample - img-txt (image and corresponding caption)

    """
    def __init__(self, images, captions, labels, targets, idxs, args):
        """

        Initialization.

        :param images: image embeddings vector

        :param captions: captions embeddings vector

        :param labels: labels vector

        """
        super().__init__(images, captions, labels, targets, idxs)

        self.word_len = args.word_len

    def __getitem__(self, index):
        """

        Returns a tuple (img, txt, label) - image and corresponding caption

        :param index: index of sample

        :return: tuple (img, txt, label)

        """
        return (
            index,
            torch.from_numpy(self.images[index].astype('float32')),
            torch.from_numpy(np.array(tokenize(self.captions[index] + self.captions[index], self.word_len).squeeze(0)).astype('int64')),
            self.labels[index],
            self.targets[index]
        )

    def __len__(self):
        return len(self.images)

class ModifiedDatasetDuplet(AbstractDataset):
    """

    Class for dataset representation.



    Each image has 5 corresponding captions



    Duplet dataset sample - img-txt (image and corresponding caption)

    """

    def __init__(self, images, captions, labels, targets, idxs, args):
        """

        Initialization.



        :param images: image embeddings vector

        :param captions: captions embeddings vector

        :param labels: labels vector

        """
        super().__init__(images, captions, labels, targets, idxs)


    def __getitem__(self, index):
        """

        Returns a tuple (img, txt, label) - image and corresponding caption



        :param index: index of sample

        :return: tuple (img, txt, label)

        """
        text = text_tokenize(self.captions[index], return_tensors='pt', padding='max_length', truncation='longest_first')['input_ids']
        return (
            index,
            torch.from_numpy(self.images[index].astype('float32')),
            torch.from_numpy(np.array(text_tokenize(self.captions[index], return_tensors='pt', padding='max_length', truncation='longest_first')['input_ids']).astype('int64')),
            self.labels[index],
            self.targets[index]
        )

    def __len__(self):
        return len(self.images)