File size: 2,846 Bytes
fa128ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dset
class Invert:
def __call__(self, x):
return 1 - x
class Gray:
def __call__(self, x):
return x[0:1]
def load_dataset(dataset_name, split='full'):
if dataset_name == 'mnist':
dataset = dset.MNIST(
root='data/mnist',
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
])
)
return dataset
elif dataset_name == 'coco':
dataset = dset.ImageFolder(root='data/coco',
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'quickdraw':
X = (np.load('data/quickdraw/teapot.npy'))
X = X.reshape((X.shape[0], 28, 28))
X = X / 255.
X = X.astype(np.float32)
X = torch.from_numpy(X)
dataset = TensorDataset(X, X)
return dataset
elif dataset_name == 'shoes':
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes',
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'footwear':
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images',
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'celeba':
dataset = dset.ImageFolder(root='data/celeba',
transform=transforms.Compose([
transforms.Scale(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'birds':
dataset = dset.ImageFolder(root='data/birds/'+split,
transform=transforms.Compose([
transforms.Scale(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'sketchy':
dataset = dset.ImageFolder(root='data/sketchy/'+split,
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
Gray()
]))
return dataset
elif dataset_name == 'fonts':
dataset = dset.ImageFolder(root='data/fonts/'+split,
transform=transforms.Compose([
transforms.ToTensor(),
Invert(),
Gray(),
]))
return dataset
else:
raise ValueError('Error : unknown dataset')
|