Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
def read_json(file_name, suppress_console_info=False): | |
with open(file_name, 'r') as f: | |
data = json.load(f) | |
if not suppress_console_info: | |
print("Read from:", file_name) | |
return data | |
def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False): | |
image_file_names = {} | |
feature_pathes = {} | |
captions = {} | |
labels = {} | |
lats = {} | |
lons = {} | |
for img in data['images']: | |
image_name = img["image_name"] | |
sample_id = img["sample_id"] | |
image_id = f'{sample_id}_{image_name}' | |
path_data = imgs_folder + f'{sample_id}/{image_name}' | |
feature_data = feature_folder + f'{sample_id}/{image_name}.npy' | |
# image_file_name.append(path_data) | |
# caption.append(img["description"]) | |
# label.append(img["labels"]) | |
# lat.append(img["lat"]) | |
# lon.append(img["lon"]) | |
image_file_names[image_id] = path_data | |
feature_pathes[image_id] = feature_data | |
captions[image_id] = img["description"] | |
labels[image_id] = img["labels"] | |
lats[image_id] = img["lat"] | |
lons[image_id] = img["lon"] | |
return image_file_names, feature_pathes, captions, labels, lats, lons | |
def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id): | |
image_file_name = image_file_names[image_id] | |
feature_path = feature_pathes[image_id] | |
caption = captions[image_id] | |
label = labels[image_id] | |
lat = lats[image_id] | |
lon = lons[image_id] | |
return image_file_name, feature_path, caption, label, lat, lon | |
def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None): | |
''' | |
return: | |
img | |
img_ -> transform(img) | |
caption | |
image_feature -> tensor | |
label | |
label_en -> text of labels | |
lat | |
lon | |
''' | |
data_info = read_json(data_dir) | |
image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder) | |
image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id) | |
label_en = [] | |
label131 = data_info['labels'] | |
for lable_name in label131.keys(): | |
label_id = label131[lable_name] | |
for label_singel in label: | |
if label_singel == label_id: | |
label_en.append(lable_name) | |
image_feature = np.load(image_feature_path) | |
img = Image.open(image_file_name).convert('RGB') | |
transform = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
]) | |
if transform is not None: | |
img_ = np.array(transform(img)) | |
else: | |
img_ = np.array(img) | |
img_ = torch.from_numpy(img_.astype('float32')) | |
return img, img_, caption, image_feature, label, label_en, lat, lon | |
# test | |
data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json' | |
imgs_folder = '/data02/xy/Clip-hash//datasets/image/' | |
feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/' | |
image_id = 'sample44_889.jpg' | |
# img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id) | |
# print(img, img_, caption, image_feature, label, label_en, lat, lon) |