|
import os |
|
import matplotlib as mpl |
|
mpl.use('Agg') |
|
import matplotlib.pyplot as plt |
|
from functools import partial |
|
|
|
from clize import run |
|
import numpy as np |
|
from skimage.io import imsave |
|
|
|
from viz import grid_of_images_default |
|
|
|
import torch.nn as nn |
|
import torch |
|
|
|
from model import DenseAE |
|
from model import ConvAE |
|
from model import DeepConvAE |
|
from model import SimpleConvAE |
|
from model import ZAE |
|
from model import KAE |
|
from data import load_dataset |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def plot_dataset(code_2d, categories): |
|
colors = [ |
|
'r', |
|
'b', |
|
'g', |
|
'crimson', |
|
'gold', |
|
'yellow', |
|
'maroon', |
|
'm', |
|
'c', |
|
'orange' |
|
] |
|
for cat in range(0, 10): |
|
g = (categories == cat) |
|
plt.scatter( |
|
code_2d[g, 0], |
|
code_2d[g, 1], |
|
marker='+', |
|
c=colors[cat], |
|
s=40, |
|
alpha=0.7, |
|
label="digit {}".format(cat) |
|
) |
|
|
|
|
|
def plot_generated(code_2d, categories): |
|
g = (categories < 0) |
|
plt.scatter( |
|
code_2d[g, 0], |
|
code_2d[g, 1], |
|
marker='+', |
|
c='gray', |
|
s=30 |
|
) |
|
|
|
|
|
def grid_embedding(h): |
|
from lapjv import lapjv |
|
from scipy.spatial.distance import cdist |
|
assert int(np.sqrt(h.shape[0])) ** 2 == h.shape[0], 'Nb of examples must be a square number' |
|
size = int(np.sqrt(h.shape[0])) |
|
grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2) |
|
cost_matrix = cdist(grid, h, "sqeuclidean").astype('float32') |
|
cost_matrix = cost_matrix * (100000 / cost_matrix.max()) |
|
_, rows, cols = lapjv(cost_matrix) |
|
return rows |
|
|
|
|
|
def save_weights(m, folder='.'): |
|
if isinstance(m, nn.Linear): |
|
w = m.weight.data |
|
if w.size(1) == 28*28 or w.size(0) == 28*28: |
|
w0, w1 = w.size(0), w.size(1) |
|
if w0 == 28*28: |
|
w = w.transpose(0, 1) |
|
w = w.contiguous() |
|
w = w.view(w.size(0), 1, 28, 28) |
|
gr = grid_of_images_default(np.array(w.tolist()), normalize=True) |
|
imsave('{}/feat_{}.png'.format(folder, w0), gr) |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
w = m.weight.data |
|
if w.size(0) in (32, 64, 128, 256, 512) and w.size(1) in (1, 3): |
|
gr = grid_of_images_default(np.array(w.tolist()), normalize=True) |
|
imsave('{}/feat.png'.format(folder), gr) |
|
|
|
@torch.no_grad() |
|
def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None, binarize_threshold=None): |
|
if batch_size is None: |
|
batch_size = nb_examples |
|
x = torch.rand(nb_iter, nb_examples, c, w, h) |
|
for i in range(1, nb_iter): |
|
for j in range(0, nb_examples, batch_size): |
|
oldv = x[i-1][j:j + batch_size].to(device) |
|
newv = ae(oldv) |
|
if binarize_threshold is not None: |
|
newv = (newv>binarize_threshold).float() |
|
newv = newv.data.cpu() |
|
x[i][j:j + batch_size] = newv |
|
return x |
|
|
|
|
|
def build_model(name, w, h, c): |
|
if name == 'convae': |
|
ae = ConvAE( |
|
w=w, h=h, c=c, |
|
nb_filters=128, |
|
spatial=True, |
|
channel=True, |
|
channel_stride=4, |
|
) |
|
elif name == 'zae': |
|
ae = ZAE( |
|
w=w, h=h, c=c, |
|
theta=3, |
|
nb_hidden=1000, |
|
) |
|
elif name == 'kae': |
|
ae = KAE( |
|
w=w, h=h, c=c, |
|
nb_active=1000, |
|
nb_hidden=1000, |
|
) |
|
elif name == 'denseae': |
|
ae = DenseAE( |
|
w=w, h=h, c=c, |
|
encode_hidden=[1000], |
|
decode_hidden=[], |
|
ksparse=True, |
|
nb_active=50, |
|
) |
|
elif name == 'simple_convae': |
|
ae = SimpleConvAE( |
|
w=w, h=h, c=c, |
|
nb_filters=128, |
|
) |
|
elif name == 'deep_convae': |
|
ae = DeepConvAE( |
|
w=w, h=h, c=c, |
|
nb_filters=128, |
|
spatial=True, |
|
channel=True, |
|
channel_stride=4, |
|
nb_layers=3, |
|
) |
|
else: |
|
raise ValueError('Unknown model') |
|
|
|
return ae |
|
|
|
|
|
def salt_and_pepper(X, proba=0.5): |
|
a = (torch.rand(X.size()).to(device) <= (1 - proba)).float() |
|
b = (torch.rand(X.size()).to(device) <= 0.5).float() |
|
c = ((a == 0).float() * b) |
|
return X * a + c |
|
|
|
|
|
def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walkback=False, denoise=False, epochs=100, batch_size=64, log_interval=100): |
|
gamma = 0.99 |
|
dataset = load_dataset(dataset, split='train') |
|
x0, _ = dataset[0] |
|
c, h, w = x0.size() |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=4 |
|
) |
|
if resume: |
|
ae = torch.load('{}/model.th'.format(folder)) |
|
ae = ae.to(device) |
|
else: |
|
ae = build_model(model, w=w, h=h, c=c) |
|
ae = ae.to(device) |
|
optim = torch.optim.Adadelta(ae.parameters(), lr=0.1, eps=1e-7, rho=0.95, weight_decay=0) |
|
avg_loss = 0. |
|
nb_updates = 0 |
|
_save_weights = partial(save_weights, folder=folder) |
|
|
|
for epoch in range(epochs): |
|
for X, y in dataloader: |
|
ae.zero_grad() |
|
X = X.to(device) |
|
if hasattr(ae, 'nb_active'): |
|
ae.nb_active = max(ae.nb_active - 1, 32) |
|
|
|
if walkback: |
|
loss = 0. |
|
x = X.data |
|
nb = 5 |
|
for _ in range(nb): |
|
x = salt_and_pepper(x, proba=0.3) |
|
x = x.to(device) |
|
x = ae(x) |
|
Xr = x |
|
loss += (((x - X) ** 2).view(X.size(0), -1).sum(1).mean()) / nb |
|
x = (torch.rand(x.size()).to(device) <= x.data).float() |
|
|
|
elif denoise: |
|
Xc = salt_and_pepper(X.data, proba=0.3) |
|
Xr = ae(Xc) |
|
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean() |
|
|
|
else: |
|
Xr = ae(X) |
|
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean() |
|
loss.backward() |
|
optim.step() |
|
avg_loss = avg_loss * gamma + loss.item() * (1 - gamma) |
|
if nb_updates % log_interval == 0: |
|
print('Epoch : {:05d} AvgTrainLoss: {:.6f}, Batch Loss : {:.6f}'.format(epoch, avg_loss, loss.item() )) |
|
gr = grid_of_images_default(np.array(Xr.data.tolist())) |
|
imsave('{}/rec.png'.format(folder), gr) |
|
ae.apply(_save_weights) |
|
torch.save(ae, '{}/model.th'.format(folder)) |
|
nb_updates += 1 |
|
|
|
|
|
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=25, nb_generate=100, nb_active=160, tsne=False): |
|
if not os.path.exists(folder): |
|
os.makedirs(folder, exist_ok=True) |
|
dataset = load_dataset(dataset, split='train') |
|
x0, _ = dataset[0] |
|
c, h, w = x0.size() |
|
nb = nb_generate |
|
print('Load model...') |
|
if model_path is None: |
|
model_path = os.path.join(folder, "model.th") |
|
ae = torch.load(model_path, map_location="cpu") |
|
ae = ae.to(device) |
|
ae.nb_active = nb_active |
|
def enc(X): |
|
batch_size = 64 |
|
h_list = [] |
|
for i in range(0, X.size(0), batch_size): |
|
x = X[i:i + batch_size] |
|
x = x.to(device) |
|
name = ae.__class__.__name__ |
|
if name in ('ConvAE',): |
|
h = ae.encode(x) |
|
h, _ = h.max(2) |
|
h = h.view((h.size(0), -1)) |
|
elif name in ('DenseAE',): |
|
x = x.view(x.size(0), -1) |
|
h = x |
|
|
|
else: |
|
h = x.view(x.size(0), -1) |
|
h = h.data.cpu() |
|
h_list.append(h) |
|
return torch.cat(h_list, 0) |
|
|
|
print('iterative refinement...') |
|
g = iterative_refinement( |
|
ae, |
|
nb_iter=nb_iter, |
|
nb_examples=nb, |
|
w=w, h=h, c=c, |
|
batch_size=64 |
|
) |
|
np.savez('{}/generated.npz'.format(folder), X=g.numpy()) |
|
g_subset = g[:, 0:100] |
|
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1])) |
|
imsave('{}/gen_full_iters.png'.format(folder), (gr*255).astype("uint8") ) |
|
|
|
g = g[-1] |
|
print(g.shape) |
|
gr = grid_of_images_default(g.numpy()) |
|
imsave('{}/gen_full.png'.format(folder), (gr*255).astype("uint8") ) |
|
|
|
if tsne: |
|
from sklearn.manifold import TSNE |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=nb, |
|
shuffle=True, |
|
num_workers=1 |
|
) |
|
print('Load data...') |
|
X, y = next(iter(dataloader)) |
|
print('Encode data...') |
|
xh = enc(X) |
|
print('Encode generated...') |
|
gh = enc(g) |
|
X = X.numpy() |
|
g = g.numpy() |
|
xh = xh.numpy() |
|
gh = gh.numpy() |
|
|
|
a = np.concatenate((X, g), axis=0) |
|
ah = np.concatenate((xh, gh), axis=0) |
|
labels = np.array(y.tolist() + [-1] * len(g)) |
|
sne = TSNE() |
|
print('fit tsne...') |
|
ah = sne.fit_transform(ah) |
|
print('grid embedding...') |
|
assert nb_generate >= 450 |
|
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0) |
|
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0) |
|
rows = grid_embedding(ahsmall) |
|
asmall = asmall[rows] |
|
gr = grid_of_images_default(asmall) |
|
imsave('{}/sne_grid.png'.format(folder), (gr*255).astype("uint8") ) |
|
|
|
fig = plt.figure(figsize=(10, 10)) |
|
plot_dataset(ah, labels) |
|
plot_generated(ah, labels) |
|
plt.legend(loc='best') |
|
plt.axis('off') |
|
plt.savefig('{}/sne.png'.format(folder)) |
|
plt.close(fig) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
run([train, test]) |
|
|