Fashion_VAE / vae.py
coledie
Add app.
f24146d
raw
history blame
No virus
4.24 kB
"""MNIST digit classificatin."""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets
import torch.nn.functional as F
from torchvision import transforms
class Encoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
)
self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
x = self.cnn(x)
mu = self.l_mu(x)
sigma = self.l_sigma(x)
return mu, sigma
class Decoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
nn.Linear(288, np.product(self.image_dim)),
nn.Sigmoid(),
)
def forward(self, c):
c = c.reshape((-1, 1, *self.latent_dim))
x = self.cnn(c)
return x
class VAE(nn.Module):
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
super().__init__()
self.image_dim = image_dim
self.encoder = Encoder(image_dim, latent_dim)
self.decoder = Decoder(image_dim, latent_dim)
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
mu, sigma = self.encoder(x)
c = mu + sigma * torch.randn_like(sigma)
xhat = self.decoder(c)
return xhat, mu, sigma
if __name__ == '__main__':
N_EPOCHS = 50
LEARNING_RATE = .001
model = VAE().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.MSELoss()
dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
dataset_train, dataset_test = torch.utils.data.random_split(
dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
)
model.train()
dataloader = torch.utils.data.DataLoader(dataset_train,
batch_size=512,
shuffle=True,
num_workers=0)
i = 0
for epoch in range(N_EPOCHS):
total_loss = 0
for x, label in dataloader:
x = x.cuda()
label = label.cuda()
optimizer.zero_grad()
xhat, mu, logvar = model(x)
BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
loss = BCE + KLD
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"{epoch}: {total_loss:.4f}")
model.cpu()
with open("vae.pt", "wb") as file:
torch.save(model, file)
model.eval()
dataloader = torch.utils.data.DataLoader(dataset_test,
batch_size=512,
shuffle=True,
num_workers=0)
n_correct = 0
COLS = 4
ROWS = 4
fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
constrained_layout=True)
dataloader_gen = iter(dataloader)
x, label = next(dataloader_gen)
xhat, mu, logvar = model(x)
xhat = xhat.reshape((-1, 28, 28))
for row in range(ROWS):
for col in range(COLS):
axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
plt.show()