Spaces:
Runtime error
Runtime error
import sys | |
import torch.utils.data as data | |
from os import listdir | |
from utils.tools import default_loader, is_image_file, normalize | |
import os | |
import torchvision.transforms as transforms | |
class Dataset(data.Dataset): | |
def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False): | |
super(Dataset, self).__init__() | |
if with_subfolder: | |
self.samples = self._find_samples_in_subfolders(data_path) | |
else: | |
self.samples = [x for x in listdir(data_path) if is_image_file(x)] | |
self.data_path = data_path | |
self.image_shape = image_shape[:-1] | |
self.random_crop = random_crop | |
self.return_name = return_name | |
def __getitem__(self, index): | |
path = os.path.join(self.data_path, self.samples[index]) | |
img = default_loader(path) | |
if self.random_crop: | |
imgw, imgh = img.size | |
if imgh < self.image_shape[0] or imgw < self.image_shape[1]: | |
img = transforms.Resize(min(self.image_shape))(img) | |
img = transforms.RandomCrop(self.image_shape)(img) | |
else: | |
img = transforms.Resize(self.image_shape)(img) | |
img = transforms.RandomCrop(self.image_shape)(img) | |
img = transforms.ToTensor()(img) # turn the image to a tensor | |
img = normalize(img) | |
if self.return_name: | |
return self.samples[index], img | |
else: | |
return img | |
def _find_samples_in_subfolders(self, dir): | |
""" | |
Finds the class folders in a dataset. | |
Args: | |
dir (string): Root directory path. | |
Returns: | |
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. | |
Ensures: | |
No class is a subdirectory of another. | |
""" | |
if sys.version_info >= (3, 5): | |
# Faster and available in Python 3.5 and above | |
classes = [d.name for d in os.scandir(dir) if d.is_dir()] | |
else: | |
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] | |
classes.sort() | |
class_to_idx = {classes[i]: i for i in range(len(classes))} | |
samples = [] | |
for target in sorted(class_to_idx.keys()): | |
d = os.path.join(dir, target) | |
if not os.path.isdir(d): | |
continue | |
for root, _, fnames in sorted(os.walk(d)): | |
for fname in sorted(fnames): | |
if is_image_file(fname): | |
path = os.path.join(root, fname) | |
# item = (path, class_to_idx[target]) | |
# samples.append(item) | |
samples.append(path) | |
return samples | |
def __len__(self): | |
return len(self.samples) | |