Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,111 Bytes
19a149b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import json
import cv2
import numpy as np
import os
from torch.utils.data import Dataset
from PIL import Image
import cv2
from .data_utils import *
from .base import BaseDataset
class SaliencyDataset(BaseDataset):
def __init__(self, MSRA_root, TR_root, TE_root, HFlickr_root):
image_mask_dict = {}
# ====== MSRA-10k ======
file_lst = os.listdir(MSRA_root)
image_lst = [MSRA_root+i for i in file_lst if '.jpg' in i]
for i in image_lst:
mask_path = i.replace('.jpg','.png')
image_mask_dict[i] = mask_path
# ===== DUT-TR ========
file_lst = os.listdir(TR_root)
image_lst = [TR_root+i for i in file_lst if '.jpg' in i]
for i in image_lst:
mask_path = i.replace('.jpg','.png').replace('DUTS-TR-Image','DUTS-TR-Mask')
image_mask_dict[i] = mask_path
# ===== DUT-TE ========
file_lst = os.listdir(TE_root)
image_lst = [TE_root+i for i in file_lst if '.jpg' in i]
for i in image_lst:
mask_path = i.replace('.jpg','.png').replace('DUTS-TE-Image','DUTS-TE-Mask')
image_mask_dict[i] = mask_path
# ===== HFlickr =======
file_lst = os.listdir(HFlickr_root)
mask_list = [HFlickr_root+i for i in file_lst if '.png' in i]
for i in file_lst:
image_name = i.split('_')[0] +'.jpg'
image_path = HFlickr_root.replace('masks', 'real_images') + image_name
mask_path = HFlickr_root + i
image_mask_dict[image_path] = mask_path
self.image_mask_dict = image_mask_dict
self.data = list(self.image_mask_dict.keys() )
self.size = (512,512)
self.clip_size = (224,224)
self.dynamic = 0
def __len__(self):
return 20000
def check_region_size(self, image, yyxx, ratio, mode = 'max'):
pass_flag = True
H,W = image.shape[0], image.shape[1]
H,W = H * ratio, W * ratio
y1,y2,x1,x2 = yyxx
h,w = y2-y1,x2-x1
if mode == 'max':
if h > H or w > W:
pass_flag = False
elif mode == 'min':
if h < H or w < W:
pass_flag = False
return pass_flag
def get_sample(self, idx):
# ==== get pairs =====
image_path = self.data[idx]
mask_path = self.image_mask_dict[image_path]
instances_mask = cv2.imread(mask_path)
if len(instances_mask.shape) == 3:
instances_mask = instances_mask[:,:,0]
instances_mask = (instances_mask > 128).astype(np.uint8)
# ======================
ref_image = cv2.imread(image_path)
ref_image = cv2.cvtColor(ref_image.copy(), cv2.COLOR_BGR2RGB)
tar_image = ref_image
ref_mask = instances_mask
tar_mask = instances_mask
item_with_collage = self.process_pairs(ref_image, ref_mask, tar_image, tar_mask)
sampled_time_steps = self.sample_timestep()
item_with_collage['time_steps'] = sampled_time_steps
return item_with_collage
|