"""MNIST digit classificatin.""" import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torchvision.datasets import torch.nn.functional as F from torchvision import transforms class Encoder(nn.Module): def __init__(self, image_dim, latent_dim): super().__init__() self.image_dim = image_dim self.latent_dim = latent_dim self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), nn.MaxPool2d(kernel_size=2), nn.Flatten(1, -1), ) self.l_mu = nn.Linear(1568, np.product(self.latent_dim)) self.l_sigma = nn.Linear(1568, np.product(self.latent_dim)) def forward(self, x): x = x.reshape((-1, 1, *self.image_dim)) x = self.cnn(x) mu = self.l_mu(x) sigma = self.l_sigma(x) return mu, sigma class Decoder(nn.Module): def __init__(self, image_dim, latent_dim): super().__init__() self.image_dim = image_dim self.latent_dim = latent_dim self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2), nn.MaxPool2d(kernel_size=2), nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), nn.MaxPool2d(kernel_size=2), nn.Flatten(1, -1), nn.Linear(288, np.product(self.image_dim)), nn.Sigmoid(), ) def forward(self, c): c = c.reshape((-1, 1, *self.latent_dim)) x = self.cnn(c) return x class VAE(nn.Module): def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)): super().__init__() self.image_dim = image_dim self.encoder = Encoder(image_dim, latent_dim) self.decoder = Decoder(image_dim, latent_dim) def forward(self, x): x = x.reshape((-1, 1, *self.image_dim)) mu, sigma = self.encoder(x) c = mu + sigma * torch.randn_like(sigma) xhat = self.decoder(c) return xhat, mu, sigma if __name__ == '__main__': N_EPOCHS = 50 LEARNING_RATE = .001 model = VAE().cuda() optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) loss_fn = torch.nn.MSELoss() dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor()) dataset_train, dataset_test = torch.utils.data.random_split( dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base))) ) model.train() dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=512, shuffle=True, num_workers=0) i = 0 for epoch in range(N_EPOCHS): total_loss = 0 for x, label in dataloader: x = x.cuda() label = label.cuda() optimizer.zero_grad() xhat, mu, logvar = model(x) BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean') # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) loss = BCE + KLD loss.backward() optimizer.step() total_loss += loss.item() print(f"{epoch}: {total_loss:.4f}") model.cpu() with open("vae.pt", "wb") as file: torch.save(model, file) model.eval() dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=512, shuffle=True, num_workers=0) n_correct = 0 COLS = 4 ROWS = 4 fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5), constrained_layout=True) dataloader_gen = iter(dataloader) x, label = next(dataloader_gen) xhat, mu, logvar = model(x) xhat = xhat.reshape((-1, 28, 28)) for row in range(ROWS): for col in range(COLS): axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray") plt.show()