Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Union | |
import random | |
import json | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from loguru import logger | |
from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
_tokenizer = _Tokenizer() | |
# text_tokenize = AutoTokenizer.from_pretrained("./Taiyi-CLIP-s", model_max_length=512) | |
def tokenize(texts: Union[str, List[str]], | |
context_length: int = 77, | |
truncate: bool = False) -> torch.LongTensor: | |
""" | |
Returns the tokenized representation of given input string(s) | |
Parameters | |
---------- | |
texts : Union[str, List[str]] | |
An input string or a list of input strings to tokenize | |
context_length : int | |
The context length to use; all CLIP models use 77 as the context length | |
truncate: bool | |
Whether to truncate the text in case its encoding is longer than the context length | |
Returns | |
------- | |
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |
""" | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = _tokenizer.encoder["<|startoftext|>"] | |
eot_token = _tokenizer.encoder["<|endoftext|>"] | |
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] | |
for text in texts] | |
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
for i, tokens in enumerate(all_tokens): | |
if len(tokens) > context_length: | |
if truncate: | |
tokens = tokens[:context_length] | |
tokens[-1] = eot_token | |
else: | |
raise RuntimeError( | |
f"Input {texts[i]} is too long for context length {context_length}" | |
) | |
result[i, :len(tokens)] = torch.tensor(tokens) | |
return result | |
def select_idxs(seq_length, n_to_select, n_from_select, seed=42): | |
""" | |
Select n_to_select indexes from each consequent n_from_select indexes from range with length seq_length, split | |
selected indexes to separate arrays | |
Example: | |
seq_length = 20 | |
n_from_select = 5 | |
n_to_select = 2 | |
input, range of length seq_length: | |
range = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] | |
sequences of length n_from_select: | |
sequences = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]] | |
selected n_to_select elements from each sequence | |
selected = [[0, 4], [7, 9], [13, 14], [16, 18]] | |
output, n_to_select lists of length seq_length / n_from_select: | |
output = [[0, 7, 13, 16], [4, 9, 14, 18]] | |
:param seq_length: length of sequence, say 10 | |
:param n_to_select: number of elements to select | |
:param n_from_select: number of consequent elements | |
:return: | |
""" | |
random.seed(seed) | |
idxs = [[] for _ in range(n_to_select)] | |
for i in range(seq_length // n_from_select): | |
ints = random.sample(range(n_from_select), n_to_select) | |
for j in range(n_to_select): | |
idxs[j].append(i * n_from_select + ints[j]) | |
return idxs | |
def read_json(file_name, suppress_console_info=False): | |
""" | |
Read JSON | |
:param file_name: input JSON path | |
:param suppress_console_info: toggle console printing | |
:return: dictionary from JSON | |
""" | |
with open(file_name, 'r') as f: | |
data = json.load(f) | |
if not suppress_console_info: | |
print("Read from:", file_name) | |
return data | |
def get_image_file_names(data, suppress_console_info=False):# ok | |
""" | |
Get list of image file names | |
:param data: original data from JSON | |
:param suppress_console_info: toggle console printing | |
:return: list of strings (file names) | |
""" | |
file_names = [] | |
for img in data['images']: | |
image_name = img["image_name"] | |
sample_id = img["sample_id"] | |
path_data = f'{sample_id}/{image_name}' | |
file_names.append(path_data) | |
if not suppress_console_info: | |
print("Total number of files:", len(file_names)) | |
return file_names | |
def get_images(file_names, args): | |
transform = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
]) | |
imgs = [] | |
for i in range(len(file_names)): | |
img = np.array(transform(Image.open(os.path.join(args.imgs_folder, file_names[i])))) | |
imgs.append(img) | |
return np.array(imgs) | |
def get_captions(data, suppress_console_info=False): | |
""" | |
Get list of formatted captions | |
:param data: original data from JSON | |
:return: list of strings (captions) | |
""" | |
def format_caption(string): | |
return string.replace('.', '').replace(',', '').replace('!', '').replace('?', '').lower() | |
captions = [] | |
augmented_captions_rb = [] | |
augmented_captions_bt_prob = [] | |
augmented_captions_bt_chain = [] | |
for img in data['images']: | |
for sent in img['sentences']: | |
captions.append(format_caption(sent['raw'])) | |
try: | |
augmented_captions_rb.append(format_caption(sent['aug_rb'])) | |
except: | |
pass | |
try: | |
augmented_captions_bt_prob.append(format_caption(sent['aug_bt_prob'])) | |
except: | |
pass | |
try: | |
augmented_captions_bt_chain.append(format_caption(sent['aug_bt_chain'])) | |
except: | |
pass | |
if not suppress_console_info: | |
logger.info("Total number of captions:{}", len(captions)) | |
logger.info("Total number of augmented captions RB:{}", len(augmented_captions_rb)) | |
logger.info("Total number of augmented captions BT (prob):{}", len(augmented_captions_bt_prob)) | |
logger.info("Total number of augmented captions BT (chain):{}", len(augmented_captions_bt_chain)) | |
return captions, augmented_captions_rb, augmented_captions_bt_prob, augmented_captions_bt_chain | |
def get_labels(data, suppress_console_info=False): | |
""" | |
Get list of labels | |
:param data: original data from JSON | |
:param suppress_console_info: toggle console printing | |
:return: list ints (labels) | |
""" | |
labels = [] | |
for img in data['images']: | |
labels.append(img["classcode"]) | |
if not suppress_console_info: | |
print("Total number of labels:", len(labels)) | |
return labels | |
def remove_tokens(data): | |
""" | |
Removes 'tokens' key from caption record, if exists; halves the size of the file | |
:param data: original data | |
:return: data without tokens | |
""" | |
for img in data['images']: | |
for sent in img['sentences']: | |
try: | |
sent.pop("tokens") | |
except: | |
pass | |
return data | |
def write_json(file_name, data): | |
""" | |
Write dictionary to JSON file | |
:param file_name: output path | |
:param data: dictionary | |
:return: None | |
""" | |
bn = os.path.basename(file_name) | |
dn = os.path.dirname(file_name) | |
name, ext = os.path.splitext(bn) | |
file_name = os.path.join(dn, name + '.json') | |
with open(file_name, 'w') as f: | |
f.write(json.dumps(data, indent='\t')) | |
print("Written to:", file_name) | |
def get_split_idxs(arr_len, args): | |
""" | |
Get indexes for training, query and db subsets | |
:param: arr_len: array length | |
:return: indexes for training, query and db subsets | |
""" | |
idx_all = list(range(arr_len)) | |
idx_train, idx_eval = split_indexes(idx_all, args.dataset_train_split) | |
idx_query, idx_db = split_indexes(idx_eval, args.dataset_query_split) | |
return idx_train, idx_eval, idx_query, idx_db | |
def split_indexes(idx_all, split): | |
""" | |
Splits list in two parts. | |
:param idx_all: array to split | |
:param split: portion to split | |
:return: splitted lists | |
""" | |
idx_length = len(idx_all) | |
selection_length = int(idx_length * split) | |
idx_selection = sorted(random.sample(idx_all, selection_length)) | |
idx_rest = sorted(list(set(idx_all).difference(set(idx_selection)))) | |
return idx_selection, idx_rest | |
def get_caption_idxs(idx_train, idx_query, idx_db): | |
""" | |
Get caption indexes. | |
:param: idx_train: train image (and label) indexes | |
:param: idx_query: query image (and label) indexes | |
:param: idx_db: db image (and label) indexes | |
:return: caption indexes for corresponding index sets | |
""" | |
idx_train_cap = get_caption_idxs_from_img_idxs(idx_train, num=5) | |
idx_query_cap = get_caption_idxs_from_img_idxs(idx_query, num=5) | |
idx_db_cap = get_caption_idxs_from_img_idxs(idx_db) | |
return idx_train_cap, idx_query_cap, idx_db_cap | |
def get_caption_idxs_from_img_idxs(img_idxs, num=5): | |
""" | |
Get caption indexes. There are 5 captions for each image (and label). | |
Say, img indexes - [0, 10, 100] | |
Then, caption indexes - [0, 1, 2, 3, 4, 50, 51, 52, 53, 54, 100, 501, 502, 503, 504] | |
:param: img_idxs: image (and label) indexes | |
:return: caption indexes | |
""" | |
caption_idxs = [] | |
for idx in img_idxs: | |
for i in range(num): # each image has 5 captions | |
caption_idxs.append(idx * num + i) | |
return caption_idxs | |
def split_data(images, captions, labels, captions_aug, images_aug, args): | |
""" | |
Split dataset to get training, query and db subsets | |
:param: images: image embeddings array | |
:param: captions: caption embeddings array | |
:param: labels: labels array | |
:param: captions_aug: augmented caption embeddings | |
:param: images_aug: augmented image embeddings | |
:return: tuples of (images, captions, labels), each element is array | |
""" | |
idx_tr, idx_q, idx_db = get_split_idxs(len(images), args) | |
idx_tr_cap, idx_q_cap, idx_db_cap = get_caption_idxs(idx_tr, idx_q, idx_db) | |
train = images[idx_tr], captions[idx_tr_cap], labels[idx_tr], (idx_tr, idx_tr_cap), captions_aug[idx_tr_cap], \ | |
images_aug[idx_tr] | |
query = images[idx_q], captions[idx_q_cap], labels[idx_q], (idx_q, idx_q_cap), captions_aug[idx_q_cap], \ | |
images_aug[idx_q] | |
db = images[idx_db], captions[idx_db_cap], labels[idx_db], (idx_db, idx_db_cap), captions_aug[idx_db_cap], \ | |
images_aug[idx_db] | |
return train, query, db | |
def select_idxs(seq_length, n_to_select, n_from_select, seed=42): | |
""" | |
Select n_to_select indexes from each consequent n_from_select indexes from range with length seq_length, split | |
selected indexes to separate arrays | |
Example: | |
seq_length = 20 | |
n_from_select = 5 | |
n_to_select = 2 | |
input, range of length seq_length: | |
range = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] | |
sequences of length n_from_select: | |
sequences = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]] | |
selected n_to_select elements from each sequence | |
selected = [[0, 4], [7, 9], [13, 14], [16, 18]] | |
output, n_to_select lists of length seq_length / n_from_select: | |
output = [[0, 7, 13, 16], [4, 9, 14, 18]] | |
:param seq_length: length of sequence, say 10 | |
:param n_to_select: number of elements to select | |
:param n_from_select: number of consequent elements | |
:return: | |
""" | |
random.seed(seed) | |
idxs = [[] for _ in range(n_to_select)] | |
for i in range(seq_length // n_from_select): | |
ints = random.sample(range(n_from_select), n_to_select) | |
for j in range(n_to_select): | |
idxs[j].append(i * n_from_select + ints[j]) | |
return idxs | |
class AbstractDataset(torch.utils.data.Dataset): | |
def __init__(self, images, captions, labels, targets, idxs): | |
self.image_replication_factor = 1 # default value, how many times we need to replicate image | |
self.images = images | |
self.captions = captions | |
self.labels = labels | |
self.targets = targets | |
self.idxs = np.array(idxs[0]) | |
def __getitem__(self, index): | |
return | |
def __len__(self): | |
return | |
class CISENDataset(torch.utils.data.Dataset): | |
""" | |
Class for dataset representation. | |
Each image has 5 corresponding captions | |
Duplet dataset sample - img-txt (image and corresponding caption) | |
""" | |
def __init__(self, images, captions, args): | |
""" | |
Initialization. | |
:param images: image embeddings vector | |
:param captions: captions embeddings vector | |
:param labels: labels vector | |
""" | |
super().__init__() | |
self.images = images | |
self.captions = captions | |
# self.targets = targets | |
# self.labels = labels | |
self.word_len = args.word_len | |
def __getitem__(self, index): | |
""" | |
Returns a tuple (img, txt, label) - image and corresponding caption | |
:param index: index of sample | |
:return: tuple (img, txt, label) | |
""" | |
return ( | |
torch.from_numpy(self.images[index].astype('float32')), | |
torch.from_numpy(np.array(tokenize(self.captions[index], self.word_len).squeeze(0)).astype('int64')) | |
# ,torch.from_numpy(self.targets[index]) | |
) | |
def __len__(self): | |
return len(self.images) | |
class DatasetDuplet(AbstractDataset): | |
""" | |
Class for dataset representation. | |
Each image has 5 corresponding captions | |
Duplet dataset sample - img-txt (image and corresponding caption) | |
""" | |
def __init__(self, images, captions, labels, targets, idxs, args): | |
""" | |
Initialization. | |
:param images: image embeddings vector | |
:param captions: captions embeddings vector | |
:param labels: labels vector | |
""" | |
super().__init__(images, captions, labels, targets, idxs) | |
self.word_len = args.word_len | |
def __getitem__(self, index): | |
""" | |
Returns a tuple (img, txt, label) - image and corresponding caption | |
:param index: index of sample | |
:return: tuple (img, txt, label) | |
""" | |
return ( | |
index, | |
torch.from_numpy(self.images[index].astype('float32')), | |
torch.from_numpy(np.array(tokenize(self.captions[index] + self.captions[index], self.word_len).squeeze(0)).astype('int64')), | |
self.labels[index], | |
self.targets[index] | |
) | |
def __len__(self): | |
return len(self.images) | |
class ModifiedDatasetDuplet(AbstractDataset): | |
""" | |
Class for dataset representation. | |
Each image has 5 corresponding captions | |
Duplet dataset sample - img-txt (image and corresponding caption) | |
""" | |
def __init__(self, images, captions, labels, targets, idxs, args): | |
""" | |
Initialization. | |
:param images: image embeddings vector | |
:param captions: captions embeddings vector | |
:param labels: labels vector | |
""" | |
super().__init__(images, captions, labels, targets, idxs) | |
def __getitem__(self, index): | |
""" | |
Returns a tuple (img, txt, label) - image and corresponding caption | |
:param index: index of sample | |
:return: tuple (img, txt, label) | |
""" | |
text = text_tokenize(self.captions[index], return_tensors='pt', padding='max_length', truncation='longest_first')['input_ids'] | |
return ( | |
index, | |
torch.from_numpy(self.images[index].astype('float32')), | |
torch.from_numpy(np.array(text_tokenize(self.captions[index], return_tensors='pt', padding='max_length', truncation='longest_first')['input_ids']).astype('int64')), | |
self.labels[index], | |
self.targets[index] | |
) | |
def __len__(self): | |
return len(self.images) |