Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 the Llamole team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import yaml | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from . import diffusion_utils as utils | |
from .molecule_utils import graph_to_smiles, check_valid | |
from .transformer import Transformer | |
from .visualize_utils import MolecularVisualization | |
class GraphDiT(nn.Module): | |
def __init__( | |
self, | |
model_config_path, | |
data_info_path, | |
model_dtype, | |
): | |
super().__init__() | |
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) | |
input_dims = data_info.input_dims | |
output_dims = data_info.output_dims | |
nodes_dist = data_info.nodes_dist | |
active_index = data_info.active_index | |
self.model_config = dm_cfg | |
self.data_info = data_info | |
self.T = dm_cfg.diffusion_steps | |
self.Xdim = input_dims["X"] | |
self.Edim = input_dims["E"] | |
self.ydim = input_dims["y"] | |
self.Xdim_output = output_dims["X"] | |
self.Edim_output = output_dims["E"] | |
self.ydim_output = output_dims["y"] | |
self.node_dist = nodes_dist | |
self.active_index = active_index | |
self.max_n_nodes = data_info.max_n_nodes | |
self.atom_decoder = data_info.atom_decoder | |
self.hidden_size = dm_cfg.hidden_size | |
self.mol_visualizer = MolecularVisualization(self.atom_decoder) | |
self.denoiser = Transformer( | |
max_n_nodes=self.max_n_nodes, | |
hidden_size=dm_cfg.hidden_size, | |
depth=dm_cfg.depth, | |
num_heads=dm_cfg.num_heads, | |
mlp_ratio=dm_cfg.mlp_ratio, | |
drop_condition=dm_cfg.drop_condition, | |
Xdim=self.Xdim, | |
Edim=self.Edim, | |
ydim=self.ydim, | |
) | |
self.model_dtype = model_dtype | |
# self.device = next(self.denoiser.parameters()).device | |
# model_params = torch.load(model_params_path, map_location='cpu') | |
# self.denoiser.load_state_dict(model_params) | |
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( | |
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps | |
) | |
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( | |
data_info.node_types.to(self.model_dtype) | |
) | |
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( | |
data_info.edge_types.to(self.model_dtype) | |
) | |
x_marginals = x_marginals / x_marginals.sum() | |
e_marginals = e_marginals / e_marginals.sum() | |
xe_conditions = data_info.transition_E.to(self.model_dtype) | |
xe_conditions = xe_conditions[self.active_index][:, self.active_index] | |
xe_conditions = xe_conditions.sum(dim=1) | |
ex_conditions = xe_conditions.t() | |
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) | |
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) | |
self.transition_model = utils.MarginalTransition( | |
x_marginals=x_marginals, | |
e_marginals=e_marginals, | |
xe_conditions=xe_conditions, | |
ex_conditions=ex_conditions, | |
y_classes=self.ydim_output, | |
n_nodes=self.max_n_nodes, | |
) | |
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) | |
# def to(self, *args, **kwargs): | |
# self = super().to(*args, **kwargs) | |
# self.model_dtype = next(self.denoiser.parameters()).dtype | |
# return self | |
def init_model(self, model_dir, verbose=False): | |
model_file = os.path.join(model_dir, 'model.pt') | |
if os.path.exists(model_file): | |
self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) | |
else: | |
raise FileNotFoundError(f"Model file not found: {model_file}") | |
if verbose: | |
print('GraphDiT Denoiser Model initialized.') | |
print('Denoiser model:\n', self.denoiser) | |
def save_pretrained(self, output_dir): | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Save model | |
model_path = os.path.join(output_dir, 'model.pt') | |
torch.save(self.denoiser.state_dict(), model_path) | |
# Save model config | |
config_path = os.path.join(output_dir, 'model_config.yaml') | |
with open(config_path, 'w') as f: | |
yaml.dump(vars(self.model_config), f) | |
# Save data info | |
data_info_path = os.path.join(output_dir, 'data.meta.json') | |
data_info_dict = { | |
"active_atoms": self.data_info.active_atoms, | |
"max_node": self.data_info.max_n_nodes, | |
"n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(), | |
"bond_type_dist": self.data_info.edge_types.tolist(), | |
"transition_E": self.data_info.transition_E.tolist(), | |
"atom_type_dist": self.data_info.node_types.tolist(), | |
"valencies": self.data_info.valency_distribution.tolist() | |
} | |
with open(data_info_path, 'w') as f: | |
json.dump(data_info_dict, f, indent=2) | |
print('GraphDiT Model and configurations saved to:', output_dir) | |
def disable_grads(self): | |
self.denoiser.disable_grads() | |
def forward( | |
self, x, edge_index, edge_attr, graph_batch, properties, no_label_index | |
): | |
raise ValueError('Not Implement') | |
def _forward(self, noisy_data, unconditioned=False): | |
noisy_x, noisy_e, properties = ( | |
noisy_data["X_t"].to(self.model_dtype), | |
noisy_data["E_t"].to(self.model_dtype), | |
noisy_data["y_t"].to(self.model_dtype).clone(), | |
) | |
node_mask, timestep = ( | |
noisy_data["node_mask"], | |
noisy_data["t"], | |
) | |
pred = self.denoiser( | |
noisy_x, | |
noisy_e, | |
node_mask, | |
properties, | |
timestep, | |
unconditioned=unconditioned, | |
) | |
return pred | |
def apply_noise(self, X, E, y, node_mask): | |
"""Sample noise and apply it to the data.""" | |
# Sample a timestep t. | |
# When evaluating, the loss for t=0 is computed separately | |
lowest_t = 0 if self.training else 1 | |
t_int = torch.randint( | |
lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device | |
).to( | |
self.model_dtype | |
) # (bs, 1) | |
s_int = t_int - 1 | |
t_float = t_int / self.T | |
s_float = s_int / self.T | |
# beta_t and alpha_s_bar are used for denoising/loss computation | |
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) | |
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) | |
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) | |
Qtb = self.transition_model.get_Qt_bar( | |
alpha_t_bar, X.device | |
) # (bs, dx_in, dx_out), (bs, de_in, de_out) | |
bs, n, d = X.shape | |
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | |
prob_all = X_all @ Qtb.X | |
probX = prob_all[:, :, : self.Xdim_output] | |
probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) | |
sampled_t = utils.sample_discrete_features( | |
probX=probX, probE=probE, node_mask=node_mask | |
) | |
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) | |
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) | |
assert (X.shape == X_t.shape) and (E.shape == E_t.shape) | |
y_t = y | |
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) | |
noisy_data = { | |
"t_int": t_int, | |
"t": t_float, | |
"beta_t": beta_t, | |
"alpha_s_bar": alpha_s_bar, | |
"alpha_t_bar": alpha_t_bar, | |
"X_t": z_t.X, | |
"E_t": z_t.E, | |
"y_t": z_t.y, | |
"node_mask": node_mask, | |
} | |
return noisy_data | |
def generate( | |
self, | |
properties, | |
device, | |
guide_scale=1., | |
num_nodes=None, | |
number_chain_steps=50, | |
): | |
properties = [float('nan') if x is None else x for x in properties] | |
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device) | |
batch_size = properties.size(0) | |
assert batch_size == 1 | |
# print('self.denoiser.dtype', self.model_dtype) | |
if num_nodes is None: | |
num_nodes = self.node_dist.sample_n(batch_size, device) | |
else: | |
num_nodes = torch.LongTensor([num_nodes]).to(device) | |
arange = ( | |
torch.arange(self.max_n_nodes, device=device) | |
.unsqueeze(0) | |
.expand(batch_size, -1) | |
) | |
node_mask = arange < num_nodes.unsqueeze(1) | |
z_T = utils.sample_discrete_feature_noise( | |
limit_dist=self.limit_dist, node_mask=node_mask | |
) | |
X, E = z_T.X, z_T.E | |
assert (E == torch.transpose(E, 1, 2)).all() | |
if number_chain_steps > 0: | |
chain_X_size = torch.Size((number_chain_steps, X.size(1))) | |
chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2))) | |
chain_X = torch.zeros(chain_X_size) | |
chain_E = torch.zeros(chain_E_size) | |
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | |
y = properties | |
for s_int in reversed(range(0, self.T)): | |
s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | |
t_array = s_array + 1 | |
s_norm = s_array / self.T | |
t_norm = t_array / self.T | |
# Sample z_s | |
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( | |
s_norm, t_norm, X, E, y, node_mask, guide_scale, device | |
) | |
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
if number_chain_steps > 0: | |
# Save the first keep_chain graphs | |
write_index = (s_int * number_chain_steps) // self.T | |
chain_X[write_index] = discrete_sampled_s.X[:1] | |
chain_E[write_index] = discrete_sampled_s.E[:1] | |
# Sample | |
sampled_s = sampled_s.mask(node_mask, collapse=True) | |
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
molecule_list = [] | |
n = num_nodes[0] | |
atom_types = X[0, :n].cpu() | |
edge_types = E[0, :n, :n].cpu() | |
molecule_list.append([atom_types, edge_types]) | |
smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0] | |
# Visualize Chains | |
if number_chain_steps > 0: | |
final_X_chain = X[:1] | |
final_E_chain = E[:1] | |
chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E | |
chain_E[0] = final_E_chain | |
chain_X = utils.reverse_tensor(chain_X) | |
chain_E = utils.reverse_tensor(chain_E) | |
# Repeat last frame to see final sample better | |
chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0) | |
chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0) | |
mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy()) | |
else: | |
mol_img_list = [] | |
return smiles, mol_img_list | |
def check_valid(self, smiles): | |
return check_valid(smiles) | |
def sample_p_zs_given_zt( | |
self, s, t, X_t, E_t, properties, node_mask, guide_scale, device | |
): | |
"""Samples from zs ~ p(zs | zt). Only used during sampling. | |
if last_step, return the graph prediction as well""" | |
bs, n, _ = X_t.shape | |
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) | |
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) | |
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) | |
# Neural net predictions | |
noisy_data = { | |
"X_t": X_t, | |
"E_t": E_t, | |
"y_t": properties, | |
"t": t, | |
"node_mask": node_mask, | |
} | |
def get_prob(noisy_data, unconditioned=False): | |
pred = self._forward(noisy_data, unconditioned=unconditioned) | |
# Normalize predictions | |
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 | |
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 | |
# Retrieve transitions matrix | |
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) | |
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) | |
Qt = self.transition_model.get_Qt(beta_t, device) | |
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) | |
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) | |
unnormalized_probX_all = utils.reverse_diffusion( | |
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X | |
) | |
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] | |
unnormalized_prob_E = unnormalized_probX_all[ | |
:, :, self.Xdim_output : | |
].reshape(bs, n * n, -1) | |
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 | |
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 | |
prob_X = unnormalized_prob_X / torch.sum( | |
unnormalized_prob_X, dim=-1, keepdim=True | |
) # bs, n, d_t-1 | |
prob_E = unnormalized_prob_E / torch.sum( | |
unnormalized_prob_E, dim=-1, keepdim=True | |
) # bs, n, d_t-1 | |
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | |
return prob_X, prob_E | |
prob_X, prob_E = get_prob(noisy_data) | |
### Guidance | |
if guide_scale != 1: | |
uncon_prob_X, uncon_prob_E = get_prob( | |
noisy_data, unconditioned=True | |
) | |
prob_X = ( | |
uncon_prob_X | |
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale | |
) | |
prob_E = ( | |
uncon_prob_E | |
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale | |
) | |
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) | |
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) | |
# assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all() | |
# assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all() | |
sampled_s = utils.sample_discrete_features( | |
prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() | |
) | |
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) | |
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) | |
assert (E_s == torch.transpose(E_s, 1, 2)).all() | |
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) | |
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) | |
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) | |
return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( | |
node_mask, collapse=True | |
).type_as(properties) | |