coledie commited on
Commit
f24146d
1 Parent(s): 530f894
Files changed (2) hide show
  1. app.py +31 -11
  2. vae.py +129 -0
app.py CHANGED
@@ -1,15 +1,35 @@
1
  import numpy as np
 
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def sepia(input_img):
5
- sepia_filter = np.array([
6
- [0.393, 0.769, 0.189],
7
- [0.349, 0.686, 0.168],
8
- [0.272, 0.534, 0.131]
9
- ])
10
- sepia_img = input_img.dot(sepia_filter.T)
11
- sepia_img /= sepia_img.max()
12
- return sepia_img
13
-
14
- demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
15
  demo.launch()
 
1
  import numpy as np
2
+ import torch
3
  import gradio as gr
4
+ from vae import *
5
+ import matplotlib.image as mpimg
6
+
7
+
8
+ with open("vae.pt", "rb") as file:
9
+ vae = torch.load(file)
10
+ vae.eval()
11
+
12
+
13
+ def generate_image(filename):
14
+ image = mpimg.imread(filename)[:, :, 0] / 255
15
+
16
+ grayscale = vae(torch.Tensor(image))[0].reshape((28, 28))
17
+
18
+ return grayscale.detach().numpy()
19
+
20
+
21
+ examples = [f"examples/{i}.jpg" for i in range(10)]
22
+
23
+ demo = gr.Interface(generate_image,
24
+ gr.Image(type="filepath"),
25
+ "image",
26
+ examples,
27
+ title="VAE running on Fashion MNIST",
28
+ description=".",
29
+ article="...",
30
+ allow_flagging=False,
31
+ )
32
+
33
+
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  demo.launch()
vae.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MNIST digit classificatin."""
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.datasets
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(self, image_dim, latent_dim):
13
+ super().__init__()
14
+ self.image_dim = image_dim
15
+ self.latent_dim = latent_dim
16
+ self.cnn = nn.Sequential(
17
+ nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
18
+ nn.MaxPool2d(kernel_size=2),
19
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
20
+ nn.MaxPool2d(kernel_size=2),
21
+ nn.Flatten(1, -1),
22
+ )
23
+ self.l_mu = nn.Linear(1568, np.product(self.latent_dim))
24
+ self.l_sigma = nn.Linear(1568, np.product(self.latent_dim))
25
+
26
+ def forward(self, x):
27
+ x = x.reshape((-1, 1, *self.image_dim))
28
+ x = self.cnn(x)
29
+ mu = self.l_mu(x)
30
+ sigma = self.l_sigma(x)
31
+ return mu, sigma
32
+
33
+
34
+ class Decoder(nn.Module):
35
+ def __init__(self, image_dim, latent_dim):
36
+ super().__init__()
37
+ self.image_dim = image_dim
38
+ self.latent_dim = latent_dim
39
+ self.cnn = nn.Sequential(
40
+ nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
41
+ nn.MaxPool2d(kernel_size=2),
42
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
43
+ nn.MaxPool2d(kernel_size=2),
44
+ nn.Flatten(1, -1),
45
+ nn.Linear(288, np.product(self.image_dim)),
46
+ nn.Sigmoid(),
47
+ )
48
+
49
+ def forward(self, c):
50
+ c = c.reshape((-1, 1, *self.latent_dim))
51
+ x = self.cnn(c)
52
+ return x
53
+
54
+
55
+ class VAE(nn.Module):
56
+ def __init__(self, image_dim=(28, 28), latent_dim=(14, 14)):
57
+ super().__init__()
58
+ self.image_dim = image_dim
59
+ self.encoder = Encoder(image_dim, latent_dim)
60
+ self.decoder = Decoder(image_dim, latent_dim)
61
+
62
+ def forward(self, x):
63
+ x = x.reshape((-1, 1, *self.image_dim))
64
+ mu, sigma = self.encoder(x)
65
+ c = mu + sigma * torch.randn_like(sigma)
66
+ xhat = self.decoder(c)
67
+ return xhat, mu, sigma
68
+
69
+
70
+ if __name__ == '__main__':
71
+ N_EPOCHS = 50
72
+ LEARNING_RATE = .001
73
+
74
+ model = VAE().cuda()
75
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
76
+ loss_fn = torch.nn.MSELoss()
77
+
78
+ dataset_base = torchvision.datasets.FashionMNIST("MNIST", download=True, transform=transforms.ToTensor())
79
+ dataset_train, dataset_test = torch.utils.data.random_split(
80
+ dataset_base, (int(.8 * len(dataset_base)), int(.2 * len(dataset_base)))
81
+ )
82
+
83
+ model.train()
84
+ dataloader = torch.utils.data.DataLoader(dataset_train,
85
+ batch_size=512,
86
+ shuffle=True,
87
+ num_workers=0)
88
+ i = 0
89
+ for epoch in range(N_EPOCHS):
90
+ total_loss = 0
91
+ for x, label in dataloader:
92
+ x = x.cuda()
93
+ label = label.cuda()
94
+ optimizer.zero_grad()
95
+ xhat, mu, logvar = model(x)
96
+
97
+ BCE = F.binary_cross_entropy(xhat, x.reshape(xhat.shape), reduction='mean')
98
+ # https://arxiv.org/abs/1312.6114
99
+ # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
100
+ KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
101
+ loss = BCE + KLD
102
+ loss.backward()
103
+ optimizer.step()
104
+ total_loss += loss.item()
105
+ print(f"{epoch}: {total_loss:.4f}")
106
+
107
+ model.cpu()
108
+ with open("vae.pt", "wb") as file:
109
+ torch.save(model, file)
110
+ model.eval()
111
+ dataloader = torch.utils.data.DataLoader(dataset_test,
112
+ batch_size=512,
113
+ shuffle=True,
114
+ num_workers=0)
115
+ n_correct = 0
116
+
117
+ COLS = 4
118
+ ROWS = 4
119
+ fig, axes = plt.subplots(ncols=COLS, nrows=ROWS, figsize=(5.5, 3.5),
120
+ constrained_layout=True)
121
+
122
+ dataloader_gen = iter(dataloader)
123
+ x, label = next(dataloader_gen)
124
+ xhat, mu, logvar = model(x)
125
+ xhat = xhat.reshape((-1, 28, 28))
126
+ for row in range(ROWS):
127
+ for col in range(COLS):
128
+ axes[row, col].imshow(xhat[row * COLS + col].detach().numpy(), cmap="gray")
129
+ plt.show()