LuojiaHOG / get_data_by_image_id.py
aleo1's picture
Upload 41 files
bb6012a verified
raw
history blame
3.74 kB
import json
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
def read_json(file_name, suppress_console_info=False):
with open(file_name, 'r') as f:
data = json.load(f)
if not suppress_console_info:
print("Read from:", file_name)
return data
def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False):
image_file_names = {}
feature_pathes = {}
captions = {}
labels = {}
lats = {}
lons = {}
for img in data['images']:
image_name = img["image_name"]
sample_id = img["sample_id"]
image_id = f'{sample_id}_{image_name}'
path_data = imgs_folder + f'{sample_id}/{image_name}'
feature_data = feature_folder + f'{sample_id}/{image_name}.npy'
# image_file_name.append(path_data)
# caption.append(img["description"])
# label.append(img["labels"])
# lat.append(img["lat"])
# lon.append(img["lon"])
image_file_names[image_id] = path_data
feature_pathes[image_id] = feature_data
captions[image_id] = img["description"]
labels[image_id] = img["labels"]
lats[image_id] = img["lat"]
lons[image_id] = img["lon"]
return image_file_names, feature_pathes, captions, labels, lats, lons
def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id):
image_file_name = image_file_names[image_id]
feature_path = feature_pathes[image_id]
caption = captions[image_id]
label = labels[image_id]
lat = lats[image_id]
lon = lons[image_id]
return image_file_name, feature_path, caption, label, lat, lon
def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None):
'''
return:
img
img_ -> transform(img)
caption
image_feature -> tensor
label
label_en -> text of labels
lat
lon
'''
data_info = read_json(data_dir)
image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder)
image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id)
label_en = []
label131 = data_info['labels']
for lable_name in label131.keys():
label_id = label131[lable_name]
for label_singel in label:
if label_singel == label_id:
label_en.append(lable_name)
image_feature = np.load(image_feature_path)
img = Image.open(image_file_name).convert('RGB')
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
if transform is not None:
img_ = np.array(transform(img))
else:
img_ = np.array(img)
img_ = torch.from_numpy(img_.astype('float32'))
return img, img_, caption, image_feature, label, label_en, lat, lon
# test
data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json'
imgs_folder = '/data02/xy/Clip-hash//datasets/image/'
feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/'
image_id = 'sample44_889.jpg'
# img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id)
# print(img, img_, caption, image_feature, label, label_en, lat, lon)