Spaces:
Runtime error
Runtime error
from os.path import expanduser | |
import torch | |
import json | |
import torchvision | |
from general_utils import get_from_repository | |
from general_utils import log | |
from torchvision import transforms | |
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], | |
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], | |
['chair.n.01', 'pot_plant.n.01']] | |
class PascalZeroShot(object): | |
def __init__(self, split, n_unseen, image_size=224) -> None: | |
super().__init__() | |
import sys | |
sys.path.append('third_party/JoEm') | |
from third_party.JoEm.data_loader.dataset import VOCSegmentation | |
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC | |
self.pascal_classes = VOC | |
self.image_size = image_size | |
self.transform = transforms.Compose([ | |
transforms.Resize((image_size, image_size)), | |
]) | |
if split == 'train': | |
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), | |
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), | |
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) | |
elif split == 'val': | |
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), | |
split=split, transform=False, | |
ignore_bg=False, ignore_unseen=False) | |
self.unseen_idx = get_unseen_idx(n_unseen) | |
def __len__(self): | |
return len(self.voc) | |
def __getitem__(self, i): | |
sample = self.voc[i] | |
label = sample['label'].long() | |
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] | |
class_indices = [l for l in all_labels] | |
class_names = [self.pascal_classes[l] for l in all_labels] | |
image = self.transform(sample['image']) | |
label = transforms.Resize((self.image_size, self.image_size), | |
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] | |
return (image,), (label, ) | |