Spaces:
Runtime error
Runtime error
import os | |
import random | |
from os.path import join | |
import numpy as np | |
import torch.multiprocessing | |
from PIL import Image | |
from scipy.io import loadmat | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
from torchvision.datasets.cityscapes import Cityscapes | |
from torchvision.transforms.functional import to_pil_image | |
from tqdm import tqdm | |
def bit_get(val, idx): | |
"""Gets the bit value. | |
Args: | |
val: Input value, int or numpy int array. | |
idx: Which bit of the input val. | |
Returns: | |
The "idx"-th bit of input val. | |
""" | |
return (val >> idx) & 1 | |
def create_pascal_label_colormap(): | |
"""Creates a label colormap used in PASCAL VOC segmentation benchmark. | |
Returns: | |
A colormap for visualizing segmentation results. | |
""" | |
colormap = np.zeros((512, 3), dtype=int) | |
ind = np.arange(512, dtype=int) | |
for shift in reversed(list(range(8))): | |
for channel in range(3): | |
colormap[:, channel] |= bit_get(ind, channel) << shift | |
ind >>= 3 | |
return colormap | |
def create_cityscapes_colormap(): | |
colors = [(128, 64, 128), | |
(244, 35, 232), | |
(250, 170, 160), | |
(230, 150, 140), | |
(70, 70, 70), | |
(102, 102, 156), | |
(190, 153, 153), | |
(180, 165, 180), | |
(150, 100, 100), | |
(150, 120, 90), | |
(153, 153, 153), | |
(153, 153, 153), | |
(250, 170, 30), | |
(220, 220, 0), | |
(107, 142, 35), | |
(152, 251, 152), | |
(70, 130, 180), | |
(220, 20, 60), | |
(255, 0, 0), | |
(0, 0, 142), | |
(0, 0, 70), | |
(0, 60, 100), | |
(0, 0, 90), | |
(0, 0, 110), | |
(0, 80, 100), | |
(0, 0, 230), | |
(119, 11, 32), | |
(0, 0, 0)] | |
return np.array(colors) | |
class DirectoryDataset(Dataset): | |
def __init__(self, root, path, image_set, transform, target_transform): | |
super(DirectoryDataset, self).__init__() | |
self.split = image_set | |
self.dir = join(root, path) | |
self.img_dir = join(self.dir, "imgs", self.split) | |
self.label_dir = join(self.dir, "labels", self.split) | |
self.transform = transform | |
self.target_transform = target_transform | |
self.img_files = np.array(sorted(os.listdir(self.img_dir))) | |
assert len(self.img_files) > 0 | |
if os.path.exists(join(self.dir, "labels")): | |
self.label_files = np.array(sorted(os.listdir(self.label_dir))) | |
assert len(self.img_files) == len(self.label_files) | |
else: | |
self.label_files = None | |
self.fine_to_coarse = {0: 0, | |
1: 1, | |
2: 2, | |
3: 3, | |
4: 4, | |
5: 5, | |
6: 6, | |
7: -1, | |
} | |
def __getitem__(self, index): | |
image_fn = self.img_files[index] | |
img = Image.open(join(self.img_dir, image_fn)) | |
if self.label_files is not None: | |
label_fn = self.label_files[index] | |
label = Image.open(join(self.label_dir, label_fn)) | |
seed = np.random.randint(2147483647) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
img = self.transform(img) | |
if self.label_files is not None: | |
random.seed(seed) | |
torch.manual_seed(seed) | |
label = self.target_transform(label) | |
new_label_map = torch.zeros_like(label) | |
for fine, coarse in self.fine_to_coarse.items(): | |
new_label_map[label == fine] = coarse | |
label = new_label_map | |
else: | |
label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1 | |
mask = (label > 0).to(torch.float32) | |
return img, label, mask | |
def __len__(self): | |
return len(self.img_files) | |
class Potsdam(Dataset): | |
def __init__(self, root, image_set, transform, target_transform, coarse_labels): | |
super(Potsdam, self).__init__() | |
self.split = image_set | |
self.root = os.path.join(root, "potsdam") | |
self.transform = transform | |
self.target_transform = target_transform | |
split_files = { | |
"train": ["labelled_train.txt"], | |
"unlabelled_train": ["unlabelled_train.txt"], | |
# "train": ["unlabelled_train.txt"], | |
"val": ["labelled_test.txt"], | |
"train+val": ["labelled_train.txt", "labelled_test.txt"], | |
"all": ["all.txt"] | |
} | |
assert self.split in split_files.keys() | |
self.files = [] | |
for split_file in split_files[self.split]: | |
with open(join(self.root, split_file), "r") as f: | |
self.files.extend(fn.rstrip() for fn in f.readlines()) | |
self.coarse_labels = coarse_labels | |
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars | |
1: 1, 5: 1, # buildings and clutter | |
2: 2, 3: 2, # vegetation and trees | |
255: -1 | |
} | |
def __getitem__(self, index): | |
image_id = self.files[index] | |
img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"] | |
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back | |
try: | |
label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"] | |
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1)) | |
except FileNotFoundError: | |
label = to_pil_image(torch.ones(1, img.height, img.width)) | |
seed = np.random.randint(2147483647) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
img = self.transform(img) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
label = self.target_transform(label).squeeze(0) | |
if self.coarse_labels: | |
new_label_map = torch.zeros_like(label) | |
for fine, coarse in self.fine_to_coarse.items(): | |
new_label_map[label == fine] = coarse | |
label = new_label_map | |
mask = (label > 0).to(torch.float32) | |
return img, label, mask | |
def __len__(self): | |
return len(self.files) | |
class PotsdamRaw(Dataset): | |
def __init__(self, root, image_set, transform, target_transform, coarse_labels): | |
super(PotsdamRaw, self).__init__() | |
self.split = image_set | |
self.root = os.path.join(root, "potsdamraw", "processed") | |
self.transform = transform | |
self.target_transform = target_transform | |
self.files = [] | |
for im_num in range(38): | |
for i_h in range(15): | |
for i_w in range(15): | |
self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w)) | |
self.coarse_labels = coarse_labels | |
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars | |
1: 1, 5: 1, # buildings and clutter | |
2: 2, 3: 2, # vegetation and trees | |
255: -1 | |
} | |
def __getitem__(self, index): | |
image_id = self.files[index] | |
img = loadmat(join(self.root, "imgs", image_id))["img"] | |
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back | |
try: | |
label = loadmat(join(self.root, "gt", image_id))["gt"] | |
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1)) | |
except FileNotFoundError: | |
label = to_pil_image(torch.ones(1, img.height, img.width)) | |
seed = np.random.randint(2147483647) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
img = self.transform(img) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
label = self.target_transform(label).squeeze(0) | |
if self.coarse_labels: | |
new_label_map = torch.zeros_like(label) | |
for fine, coarse in self.fine_to_coarse.items(): | |
new_label_map[label == fine] = coarse | |
label = new_label_map | |
mask = (label > 0).to(torch.float32) | |
return img, label, mask | |
def __len__(self): | |
return len(self.files) | |
class Coco(Dataset): | |
def __init__(self, root, image_set, transform, target_transform, | |
coarse_labels, exclude_things, subset=None): | |
super(Coco, self).__init__() | |
self.split = image_set | |
self.root = join(root, "cocostuff") | |
self.coarse_labels = coarse_labels | |
self.transform = transform | |
self.label_transform = target_transform | |
self.subset = subset | |
self.exclude_things = exclude_things | |
if self.subset is None: | |
self.image_list = "Coco164kFull_Stuff_Coarse.txt" | |
elif self.subset == 6: # IIC Coarse | |
self.image_list = "Coco164kFew_Stuff_6.txt" | |
elif self.subset == 7: # IIC Fine | |
self.image_list = "Coco164kFull_Stuff_Coarse_7.txt" | |
assert self.split in ["train", "val", "train+val"] | |
split_dirs = { | |
"train": ["train2017"], | |
"val": ["val2017"], | |
"train+val": ["train2017", "val2017"] | |
} | |
self.image_files = [] | |
self.label_files = [] | |
for split_dir in split_dirs[self.split]: | |
with open(join(self.root, "curated", split_dir, self.image_list), "r") as f: | |
img_ids = [fn.rstrip() for fn in f.readlines()] | |
for img_id in img_ids: | |
self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg")) | |
self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png")) | |
self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8, | |
13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7, | |
25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10, | |
37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5, | |
49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2, | |
61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0, | |
73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4, | |
85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22, | |
97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15, | |
107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13, | |
117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24, | |
127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17, | |
137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21, | |
147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23, | |
157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17, | |
167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18, | |
177: 26, 178: 26, 179: 19, 180: 19, 181: 24} | |
self._label_names = [ | |
"ground-stuff", | |
"plant-stuff", | |
"sky-stuff", | |
] | |
self.cocostuff3_coarse_classes = [23, 22, 21] | |
self.first_stuff_index = 12 | |
def __getitem__(self, index): | |
image_path = self.image_files[index] | |
label_path = self.label_files[index] | |
seed = np.random.randint(2147483647) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
img = self.transform(Image.open(image_path).convert("RGB")) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
label = self.label_transform(Image.open(label_path)).squeeze(0) | |
label[label == 255] = -1 # to be consistent with 10k | |
coarse_label = torch.zeros_like(label) | |
for fine, coarse in self.fine_to_coarse.items(): | |
coarse_label[label == fine] = coarse | |
coarse_label[label == -1] = -1 | |
if self.coarse_labels: | |
coarser_labels = -torch.ones_like(label) | |
for i, c in enumerate(self.cocostuff3_coarse_classes): | |
coarser_labels[coarse_label == c] = i | |
return img, coarser_labels, coarser_labels >= 0 | |
else: | |
if self.exclude_things: | |
return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index) | |
else: | |
return img, coarse_label, coarse_label >= 0 | |
def __len__(self): | |
return len(self.image_files) | |
class CityscapesSeg(Dataset): | |
def __init__(self, root, image_set, transform, target_transform): | |
super(CityscapesSeg, self).__init__() | |
self.split = image_set | |
self.root = join(root, "cityscapes") | |
if image_set == "train": | |
# our_image_set = "train_extra" | |
# mode = "coarse" | |
our_image_set = "train" | |
mode = "fine" | |
else: | |
our_image_set = image_set | |
mode = "fine" | |
self.inner_loader = Cityscapes(self.root, our_image_set, | |
mode=mode, | |
target_type="semantic", | |
transform=None, | |
target_transform=None) | |
self.transform = transform | |
self.target_transform = target_transform | |
self.first_nonvoid = 7 | |
def __getitem__(self, index): | |
if self.transform is not None: | |
image, target = self.inner_loader[index] | |
seed = np.random.randint(2147483647) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
image = self.transform(image) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
target = self.target_transform(target) | |
target = target - self.first_nonvoid | |
target[target < 0] = -1 | |
mask = target == -1 | |
return image, target.squeeze(0), mask | |
else: | |
return self.inner_loader[index] | |
def __len__(self): | |
return len(self.inner_loader) | |
class CroppedDataset(Dataset): | |
def __init__(self, root, dataset_name, crop_type, crop_ratio, image_set, transform, target_transform): | |
super(CroppedDataset, self).__init__() | |
self.dataset_name = dataset_name | |
self.split = image_set | |
self.root = join(root, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio)) | |
self.transform = transform | |
self.target_transform = target_transform | |
self.img_dir = join(self.root, "img", self.split) | |
self.label_dir = join(self.root, "label", self.split) | |
self.num_images = len(os.listdir(self.img_dir)) | |
assert self.num_images == len(os.listdir(self.label_dir)) | |
def __getitem__(self, index): | |
image = Image.open(join(self.img_dir, "{}.jpg".format(index))).convert('RGB') | |
target = Image.open(join(self.label_dir, "{}.png".format(index))) | |
seed = np.random.randint(2147483647) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
image = self.transform(image) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
target = self.target_transform(target) | |
target = target - 1 | |
mask = target == -1 | |
return image, target.squeeze(0), mask | |
def __len__(self): | |
return self.num_images | |
class MaterializedDataset(Dataset): | |
def __init__(self, ds): | |
self.ds = ds | |
self.materialized = [] | |
loader = DataLoader(ds, num_workers=12, collate_fn=lambda l: l[0]) | |
for batch in tqdm(loader): | |
self.materialized.append(batch) | |
def __len__(self): | |
return len(self.ds) | |
def __getitem__(self, ind): | |
return self.materialized[ind] | |
class ContrastiveSegDataset(Dataset): | |
def __init__(self, | |
pytorch_data_dir, | |
dataset_name, | |
crop_type, | |
image_set, | |
transform, | |
target_transform, | |
cfg, | |
aug_geometric_transform=None, | |
aug_photometric_transform=None, | |
num_neighbors=5, | |
compute_knns=False, | |
mask=False, | |
pos_labels=False, | |
pos_images=False, | |
extra_transform=None, | |
model_type_override=None | |
): | |
super(ContrastiveSegDataset).__init__() | |
self.num_neighbors = num_neighbors | |
self.image_set = image_set | |
self.dataset_name = dataset_name | |
self.mask = mask | |
self.pos_labels = pos_labels | |
self.pos_images = pos_images | |
self.extra_transform = extra_transform | |
if dataset_name == "potsdam": | |
self.n_classes = 3 | |
dataset_class = Potsdam | |
extra_args = dict(coarse_labels=True) | |
elif dataset_name == "potsdamraw": | |
self.n_classes = 3 | |
dataset_class = PotsdamRaw | |
extra_args = dict(coarse_labels=True) | |
elif dataset_name == "directory": | |
self.n_classes = cfg.dir_dataset_n_classes | |
dataset_class = DirectoryDataset | |
extra_args = dict(path=cfg.dir_dataset_name) | |
elif dataset_name == "cityscapes" and crop_type is None: | |
self.n_classes = 27 | |
dataset_class = CityscapesSeg | |
extra_args = dict() | |
elif dataset_name == "cityscapes" and crop_type is not None: | |
self.n_classes = 27 | |
dataset_class = CroppedDataset | |
extra_args = dict(dataset_name="cityscapes", crop_type=crop_type, crop_ratio=cfg.crop_ratio) | |
elif dataset_name == "cocostuff3": | |
self.n_classes = 3 | |
dataset_class = Coco | |
extra_args = dict(coarse_labels=True, subset=6, exclude_things=True) | |
elif dataset_name == "cocostuff15": | |
self.n_classes = 15 | |
dataset_class = Coco | |
extra_args = dict(coarse_labels=False, subset=7, exclude_things=True) | |
elif dataset_name == "cocostuff27" and crop_type is not None: | |
self.n_classes = 27 | |
dataset_class = CroppedDataset | |
extra_args = dict(dataset_name="cocostuff27", crop_type=cfg.crop_type, crop_ratio=cfg.crop_ratio) | |
elif dataset_name == "cocostuff27" and crop_type is None: | |
self.n_classes = 27 | |
dataset_class = Coco | |
extra_args = dict(coarse_labels=False, subset=None, exclude_things=False) | |
if image_set == "val": | |
extra_args["subset"] = 7 | |
else: | |
raise ValueError("Unknown dataset: {}".format(dataset_name)) | |
self.aug_geometric_transform = aug_geometric_transform | |
self.aug_photometric_transform = aug_photometric_transform | |
self.dataset = dataset_class( | |
root=pytorch_data_dir, | |
image_set=self.image_set, | |
transform=transform, | |
target_transform=target_transform, **extra_args) | |
if model_type_override is not None: | |
model_type = model_type_override | |
else: | |
model_type = cfg.model_type | |
nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name | |
feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format( | |
model_type, nice_dataset_name, image_set, crop_type, cfg.res)) | |
if pos_labels or pos_images: | |
if not os.path.exists(feature_cache_file) or compute_knns: | |
raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file)) | |
else: | |
loaded = np.load(feature_cache_file) | |
self.nns = loaded["nns"] | |
assert len(self.dataset) == self.nns.shape[0] | |
def __len__(self): | |
return len(self.dataset) | |
def _set_seed(self, seed): | |
random.seed(seed) # apply this seed to img tranfsorms | |
torch.manual_seed(seed) # needed for torchvision 0.7 | |
def __getitem__(self, ind): | |
pack = self.dataset[ind] | |
if self.pos_images or self.pos_labels: | |
ind_pos = self.nns[ind][torch.randint(low=1, high=self.num_neighbors + 1, size=[]).item()] | |
pack_pos = self.dataset[ind_pos] | |
seed = np.random.randint(2147483647) # make a seed with numpy generator | |
self._set_seed(seed) | |
coord_entries = torch.meshgrid([torch.linspace(-1, 1, pack[0].shape[1]), | |
torch.linspace(-1, 1, pack[0].shape[2])]) | |
coord = torch.cat([t.unsqueeze(0) for t in coord_entries], 0) | |
if self.extra_transform is not None: | |
extra_trans = self.extra_transform | |
else: | |
extra_trans = lambda i, x: x | |
def squeeze_tuple(label_raw): | |
if type(label_raw) == tuple: | |
return tuple(x.squeeze() for x in label_raw) | |
else: | |
return label_raw.squeeze() | |
ret = { | |
"ind": ind, | |
"img": extra_trans(ind, pack[0]), | |
"label": squeeze_tuple(extra_trans(ind, pack[1])) | |
} | |
if self.pos_images: | |
ret["img_pos"] = extra_trans(ind, pack_pos[0]) | |
ret["ind_pos"] = ind_pos | |
if self.mask: | |
ret["mask"] = pack[2] | |
if self.pos_labels: | |
ret["label_pos"] = squeeze_tuple(extra_trans(ind, pack_pos[1])) | |
ret["mask_pos"] = pack_pos[2] | |
if self.aug_photometric_transform is not None: | |
img_aug = self.aug_photometric_transform(self.aug_geometric_transform(pack[0])) | |
self._set_seed(seed) | |
coord_aug = self.aug_geometric_transform(coord) | |
ret["img_aug"] = img_aug | |
ret["coord_aug"] = coord_aug.permute(1, 2, 0) | |
return ret | |