Fashion_VAE / model.py
coledie
Add model.
ea698d3
"""MNIST digit classificatin."""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets
import torch.nn.functional as F
from torchvision import transforms
class Encoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
)
self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
x = self.cnn(x)
mu = self.l_mu(x)
sigma = self.l_sigma(x)
return mu, sigma
class Decoder(nn.Module):
def __init__(self, image_dim, latent_dim):
super().__init__()
self.image_dim = image_dim
self.latent_dim = latent_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(1, -1),
nn.Linear(288, np.product(self.image_dim)),
nn.Sigmoid(),
)
def forward(self, c):
c = c.reshape((-1, 1, *self.latent_dim))
x = self.cnn(c)
return x
class VAE(nn.Module):
def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
super().__init__()
self.image_dim = image_dim
self.encoder = Encoder(image_dim, latent_dim)
self.decoder = Decoder(image_dim, latent_dim)
def forward(self, x):
x = x.reshape((-1, 1, *self.image_dim))
mu, sigma = self.encoder(x)
c = mu + sigma * torch.randn_like(sigma)
xhat = self.decoder(c)
return xhat, mu, sigma
if __name__ == '__main__':
N_EPOCHS = 100
LEARNING_RATE = .001
model = VAE().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.MSELoss()
dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
dataset_base_2 = torchvision.datasets.MNIST("MNIST", download=True, transform=transforms.ToTensor())
dataset_base = torch.utils.data.ConcatDataset([dataset_base, dataset_base_2])
dataset_train, dataset_test = torch.utils.data.random_split(
dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
)
model.train()
dataloader = torch.utils.data.DataLoader(dataset_train,
batch_size=512,
shuffle=True,
num_workers=0)
i = 0
for epoch in range(N_EPOCHS):
total_loss = 0
for x, label in dataloader:
#for j in range(512):
# plt.imsave(f"{i}-{label[j]}.jpg", np.stack([x[j].reshape((28, 28)).detach().numpy()] * 3, -1), cmap='gray')
# i += 1
#exit()
x = x.cuda()
label = label.cuda()
optimizer.zero_grad()
xhat, mu, logvar = model(x)
BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
loss = BCE + KLD
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"{epoch}: {total_loss:.4f}")
model.cpu()
with open("vae.pt", "wb") as file:
torch.save(model, file)
model.eval()
dataloader = torch.utils.data.DataLoader(dataset_test,
batch_size=512,
shuffle=True,
num_workers=0)
n_correct = 0
COLS = 4
ROWS = 4
fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
constrained_layout=True)
dataloader_gen = iter(dataloader)
x, label = next(dataloader_gen)
xhat, mu, logvar = model(x)
xhat = xhat.reshape((-1, 28, 28))
for row in range(ROWS):
for col in range(COLS):
axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
plt.show()