Spaces:
Runtime error
Runtime error
File size: 2,817 Bytes
89c278d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import sys
import torch.utils.data as data
from os import listdir
from utils.tools import default_loader, is_image_file, normalize
import os
import torchvision.transforms as transforms
class Dataset(data.Dataset):
def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False):
super(Dataset, self).__init__()
if with_subfolder:
self.samples = self._find_samples_in_subfolders(data_path)
else:
self.samples = [x for x in listdir(data_path) if is_image_file(x)]
self.data_path = data_path
self.image_shape = image_shape[:-1]
self.random_crop = random_crop
self.return_name = return_name
def __getitem__(self, index):
path = os.path.join(self.data_path, self.samples[index])
img = default_loader(path)
if self.random_crop:
imgw, imgh = img.size
if imgh < self.image_shape[0] or imgw < self.image_shape[1]:
img = transforms.Resize(min(self.image_shape))(img)
img = transforms.RandomCrop(self.image_shape)(img)
else:
img = transforms.Resize(self.image_shape)(img)
img = transforms.RandomCrop(self.image_shape)(img)
img = transforms.ToTensor()(img) # turn the image to a tensor
img = normalize(img)
if self.return_name:
return self.samples[index], img
else:
return img
def _find_samples_in_subfolders(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
samples = []
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if is_image_file(fname):
path = os.path.join(root, fname)
# item = (path, class_to_idx[target])
# samples.append(item)
samples.append(path)
return samples
def __len__(self):
return len(self.samples)
|