File size: 3,739 Bytes
bb6012a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import numpy as np
from PIL import Image
import torch
from torchvision import transforms


def read_json(file_name, suppress_console_info=False):
        with open(file_name, 'r') as f:
            data = json.load(f)
            if not suppress_console_info:
                print("Read from:", file_name)
        return data

def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False):

    image_file_names = {}
    feature_pathes = {}
    captions = {}
    labels = {}
    lats = {}
    lons = {}

    for img in data['images']:
        image_name = img["image_name"]
        sample_id = img["sample_id"]
        image_id = f'{sample_id}_{image_name}'
        path_data = imgs_folder + f'{sample_id}/{image_name}'
        feature_data = feature_folder + f'{sample_id}/{image_name}.npy'
        # image_file_name.append(path_data)
        # caption.append(img["description"])
        # label.append(img["labels"])
        # lat.append(img["lat"])
        # lon.append(img["lon"])

        image_file_names[image_id] = path_data
        feature_pathes[image_id] = feature_data
        captions[image_id] = img["description"]
        labels[image_id] = img["labels"]
        lats[image_id] = img["lat"]
        lons[image_id] = img["lon"]

    return image_file_names, feature_pathes, captions, labels, lats, lons


def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id):

    image_file_name = image_file_names[image_id]
    feature_path = feature_pathes[image_id]
    caption = captions[image_id]
    label = labels[image_id]
    lat = lats[image_id]
    lon = lons[image_id]
    
    return image_file_name, feature_path, caption, label, lat, lon


def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None):
    '''

    return:

        img

        img_ -> transform(img)

        caption

        image_feature -> tensor

        label

        label_en -> text of labels 

        lat

        lon

    '''

    data_info = read_json(data_dir)
    image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder)

    image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id)

    label_en = []
    label131 = data_info['labels']

    for lable_name in label131.keys():
        label_id = label131[lable_name]
        for label_singel in label:
            if label_singel == label_id:
                label_en.append(lable_name)
    image_feature = np.load(image_feature_path)

    img = Image.open(image_file_name).convert('RGB')  

    transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

    if transform is not None:
        img_ = np.array(transform(img))
    else:
        img_ = np.array(img)
    img_ = torch.from_numpy(img_.astype('float32'))

    return img, img_, caption, image_feature, label, label_en, lat, lon


# test
data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json'
imgs_folder = '/data02/xy/Clip-hash//datasets/image/'
feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/'
image_id = 'sample44_889.jpg'

# img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id)
# print(img, img_, caption, image_feature, label, label_en, lat, lon)