Spaces:
Runtime error
Runtime error
"""MNIST digit classificatin.""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision.datasets | |
import torch.nn.functional as F | |
from torchvision import transforms | |
class Encoder(nn.Module): | |
def __init__(self, image_dim, latent_dim): | |
super().__init__() | |
self.image_dim = image_dim | |
self.latent_dim = latent_dim | |
self.cnn = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Flatten(1, -1), | |
) | |
self.l_mu = nn.Linear(1568, np.product(self.latent_dim)) | |
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim)) | |
def forward(self, x): | |
x = x.reshape((-1, 1, *self.image_dim)) | |
x = self.cnn(x) | |
mu = self.l_mu(x) | |
sigma = self.l_sigma(x) | |
return mu, sigma | |
class Decoder(nn.Module): | |
def __init__(self, image_dim, latent_dim): | |
super().__init__() | |
self.image_dim = image_dim | |
self.latent_dim = latent_dim | |
self.cnn = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Flatten(1, -1), | |
nn.Linear(288, np.product(self.image_dim)), | |
nn.Sigmoid(), | |
) | |
def forward(self, c): | |
c = c.reshape((-1, 1, *self.latent_dim)) | |
x = self.cnn(c) | |
return x | |
class VAE(nn.Module): | |
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)): | |
super().__init__() | |
self.image_dim = image_dim | |
self.encoder = Encoder(image_dim, latent_dim) | |
self.decoder = Decoder(image_dim, latent_dim) | |
def forward(self, x): | |
x = x.reshape((-1, 1, *self.image_dim)) | |
mu, sigma = self.encoder(x) | |
c = mu + sigma * torch.randn_like(sigma) | |
xhat = self.decoder(c) | |
return xhat, mu, sigma | |
if __name__ == '__main__': | |
N_EPOCHS = 100 | |
LEARNING_RATE = .001 | |
model = VAE().cuda() | |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
loss_fn = torch.nn.MSELoss() | |
dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor()) | |
dataset_base_2 = torchvision.datasets.MNIST("MNIST", download=True, transform=transforms.ToTensor()) | |
dataset_base = torch.utils.data.ConcatDataset([dataset_base, dataset_base_2]) | |
dataset_train, dataset_test = torch.utils.data.random_split( | |
dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base))) | |
) | |
model.train() | |
dataloader = torch.utils.data.DataLoader(dataset_train, | |
batch_size=512, | |
shuffle=True, | |
num_workers=0) | |
i = 0 | |
for epoch in range(N_EPOCHS): | |
total_loss = 0 | |
for x, label in dataloader: | |
#for j in range(512): | |
# plt.imsave(f"{i}-{label[j]}.jpg", np.stack([x[j].reshape((28, 28)).detach().numpy()] * 3, -1), cmap='gray') | |
# i += 1 | |
#exit() | |
x = x.cuda() | |
label = label.cuda() | |
optimizer.zero_grad() | |
xhat, mu, logvar = model(x) | |
BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean') | |
# https://arxiv.org/abs/1312.6114 | |
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | |
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) | |
loss = BCE + KLD | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"{epoch}: {total_loss:.4f}") | |
model.cpu() | |
with open("vae.pt", "wb") as file: | |
torch.save(model, file) | |
model.eval() | |
dataloader = torch.utils.data.DataLoader(dataset_test, | |
batch_size=512, | |
shuffle=True, | |
num_workers=0) | |
n_correct = 0 | |
COLS = 4 | |
ROWS = 4 | |
fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5), | |
constrained_layout=True) | |
dataloader_gen = iter(dataloader) | |
x, label = next(dataloader_gen) | |
xhat, mu, logvar = model(x) | |
xhat = xhat.reshape((-1, 28, 28)) | |
for row in range(ROWS): | |
for col in range(COLS): | |
axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray") | |
plt.show() | |