Cédric Colas
initial commit
e775f6d
raw
history blame
10.2 kB
import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_activation(activation):
if activation == 'tanh':
activ = F.tanh
elif activation == 'relu':
activ = F.relu
elif activation == 'mish':
activ = F.mish
elif activation == 'sigmoid':
activ = F.sigmoid
elif activation == 'leakyrelu':
activ = F.leaky_relu
elif activation == 'exp':
activ = torch.exp
else:
raise ValueError
return activ
class IngredientEncoder(nn.Module):
def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
super(IngredientEncoder, self).__init__()
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
dims = [input_dim] + hidden_dims + [deepset_latent_dim]
for d_in, d_out in zip(dims[:-1], dims[1:]):
self.linears.append(nn.Linear(d_in, d_out))
self.dropouts.append(nn.Dropout(dropout))
self.activation = get_activation(activation)
self.n_layers = len(self.linears)
self.layer_range = range(self.n_layers)
def forward(self, x):
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
x = layer(x)
if i_layer != self.n_layers - 1:
x = self.activation(dropout(x))
return x # do not use dropout on last layer?
class DeepsetCocktailEncoder(nn.Module):
def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
hidden_dims_cocktail, latent_dim, aggregation, dropout):
super(DeepsetCocktailEncoder, self).__init__()
self.input_dim = input_dim # dimension of ingredient representation + quantity
self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
self.aggregation = aggregation
self.latent_dim = latent_dim
# post aggregation network
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
dims = [deepset_latent_dim] + hidden_dims_cocktail
for d_in, d_out in zip(dims[:-1], dims[1:]):
self.linears.append(nn.Linear(d_in, d_out))
self.dropouts.append(nn.Dropout(dropout))
self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
self.softplus = nn.Softplus()
self.activation = get_activation(activation)
self.n_layers = len(self.linears)
self.layer_range = range(self.n_layers)
def forward(self, nb_ingredients, x):
# reshape x in (batch size * nb ingredients, dim_ing_rep)
batch_size = x.shape[0]
all_ingredients = []
for i in range(batch_size):
for j in range(nb_ingredients[i]):
all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
x = torch.cat(all_ingredients, dim=0)
# encode ingredients in parallel
ingredients_encodings = self.ingredient_encoder(x)
assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
# aggregate
x = []
index_first = 0
for i in range(batch_size):
index_last = index_first + nb_ingredients[i]
# aggregate
if self.aggregation == 'sum':
x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
elif self.aggregation == 'mean':
x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
else:
raise ValueError
index_first = index_last
x = torch.cat(x, dim=0)
assert x.shape[0] == batch_size
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
x = self.activation(dropout(layer(x)))
mean = self.FC_mean(x)
logvar = self.FC_logvar(x)
return mean, logvar
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dims, num_ingredients, activation, dropout, filter_output=None):
super(Decoder, self).__init__()
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
dims = [latent_dim] + hidden_dims + [num_ingredients]
for d_in, d_out in zip(dims[:-1], dims[1:]):
self.linears.append(nn.Linear(d_in, d_out))
self.dropouts.append(nn.Dropout(dropout))
self.activation = get_activation(activation)
self.n_layers = len(self.linears)
self.layer_range = range(self.n_layers)
self.filter = filter_output
def forward(self, x, to_filter=False):
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
x = layer(x)
if i_layer != self.n_layers - 1:
x = self.activation(dropout(x))
if to_filter:
x = self.filter(x)
return x
class PredictorHead(nn.Module):
def __init__(self, latent_dim, dim_output, final_activ):
super(PredictorHead, self).__init__()
self.linear = nn.Linear(latent_dim, dim_output)
if final_activ != None:
self.final_activ = get_activation(final_activ)
self.use_final_activ = True
else:
self.use_final_activ = False
def forward(self, x):
x = self.linear(x)
if self.use_final_activ: x = self.final_activ(x)
return x
class VAEModel(nn.Module):
def __init__(self, encoder, decoder, auxiliaries_dict):
super(VAEModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.latent_dim = self.encoder.latent_dim
self.auxiliaries_str = []
self.auxiliaries = nn.ModuleList()
for aux_str in sorted(auxiliaries_dict.keys()):
if aux_str == 'taste_reps':
self.taste_reps_decoder = PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ'])
else:
self.auxiliaries_str.append(aux_str)
self.auxiliaries.append(PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ']))
def reparameterization(self, mean, logvar):
std = torch.exp(0.5 * logvar)
epsilon = torch.randn_like(std).to(device) # sampling epsilon
z = mean + std * epsilon # reparameterization trick
return z
def sample(self, n=1):
z = torch.randn(size=(n, self.latent_dim))
return self.decoder(z)
def get_all_auxiliaries(self, x):
return [aux(x) for aux in self.auxiliaries]
def get_auxiliary(self, z, aux_str):
if aux_str == 'taste_reps':
return self.taste_reps_decoder(z)
else:
index = self.auxiliaries_str.index(aux_str)
return self.auxiliaries[index](z)
def forward_direct(self, x, aux_str=None, to_filter=False):
mean, logvar = self.encoder(x)
z = self.reparameterization(mean, logvar) # takes exponential function (log var -> std)
x_hat = self.decoder(mean, to_filter=to_filter)
if aux_str is not None:
return x_hat, z, mean, logvar, self.get_auxiliary(z, aux_str), [aux_str]
else:
return x_hat, z, mean, logvar, self.get_all_auxiliaries(z), self.auxiliaries_str
def forward(self, nb_ingredients, x, aux_str=None, to_filter=False):
assert False
mean, std = self.encoder(nb_ingredients, x)
z = self.reparameterization(mean, std) # takes exponential function (log var -> std)
x_hat = self.decoder(mean, to_filter=to_filter)
if aux_str is not None:
return x_hat, z, mean, std, self.get_auxiliary(z, aux_str), [aux_str]
else:
return x_hat, z, mean, std, self.get_all_auxiliaries(z), self.auxiliaries_str
class SimpleEncoder(nn.Module):
def __init__(self, input_dim, hidden_dims, latent_dim, activation, dropout):
super(SimpleEncoder, self).__init__()
self.latent_dim = latent_dim
# post aggregation network
self.linears = nn.ModuleList()
self.dropouts = nn.ModuleList()
dims = [input_dim] + hidden_dims
for d_in, d_out in zip(dims[:-1], dims[1:]):
self.linears.append(nn.Linear(d_in, d_out))
self.dropouts.append(nn.Dropout(dropout))
self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim)
self.FC_logvar = nn.Linear(hidden_dims[-1], latent_dim)
# self.softplus = nn.Softplus()
self.activation = get_activation(activation)
self.n_layers = len(self.linears)
self.layer_range = range(self.n_layers)
def forward(self, x):
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
x = self.activation(dropout(layer(x)))
mean = self.FC_mean(x)
logvar = self.FC_logvar(x)
return mean, logvar
def get_vae_model(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
hidden_dims_cocktail, hidden_dims_decoder, num_ingredients, latent_dim, aggregation, dropout, auxiliaries_dict,
filter_decoder_output):
# encoder = DeepsetCocktailEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
# hidden_dims_cocktail, latent_dim, aggregation, dropout)
encoder = SimpleEncoder(num_ingredients, hidden_dims_cocktail, latent_dim, activation, dropout)
decoder = Decoder(latent_dim, hidden_dims_decoder, num_ingredients, activation, dropout, filter_output=filter_decoder_output)
vae = VAEModel(encoder, decoder, auxiliaries_dict)
return vae