liuganghuggingface commited on
Commit
9e636c4
1 Parent(s): 959d452

Upload graph_decoder/diffusion_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +422 -0
graph_decoder/diffusion_model.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the Llamole team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import yaml
17
+ import json
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from . import diffusion_utils as utils
24
+ from .molecule_utils import graph_to_smiles, check_valid
25
+ from .transformer import Transformer
26
+ from .visualize_utils import MolecularVisualization
27
+
28
+ class GraphDiT(nn.Module):
29
+ def __init__(
30
+ self,
31
+ model_config_path,
32
+ data_info_path,
33
+ model_dtype,
34
+ ):
35
+ super().__init__()
36
+
37
+ dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
38
+
39
+ input_dims = data_info.input_dims
40
+ output_dims = data_info.output_dims
41
+ nodes_dist = data_info.nodes_dist
42
+ active_index = data_info.active_index
43
+
44
+ self.model_config = dm_cfg
45
+ self.data_info = data_info
46
+ self.T = dm_cfg.diffusion_steps
47
+ self.Xdim = input_dims["X"]
48
+ self.Edim = input_dims["E"]
49
+ self.ydim = input_dims["y"]
50
+ self.Xdim_output = output_dims["X"]
51
+ self.Edim_output = output_dims["E"]
52
+ self.ydim_output = output_dims["y"]
53
+ self.node_dist = nodes_dist
54
+ self.active_index = active_index
55
+ self.max_n_nodes = data_info.max_n_nodes
56
+ self.atom_decoder = data_info.atom_decoder
57
+ self.hidden_size = dm_cfg.hidden_size
58
+ self.mol_visualizer = MolecularVisualization(self.atom_decoder)
59
+
60
+ self.denoiser = Transformer(
61
+ max_n_nodes=self.max_n_nodes,
62
+ hidden_size=dm_cfg.hidden_size,
63
+ depth=dm_cfg.depth,
64
+ num_heads=dm_cfg.num_heads,
65
+ mlp_ratio=dm_cfg.mlp_ratio,
66
+ drop_condition=dm_cfg.drop_condition,
67
+ Xdim=self.Xdim,
68
+ Edim=self.Edim,
69
+ ydim=self.ydim,
70
+ )
71
+ self.model_dtype = model_dtype
72
+ # self.device = next(self.denoiser.parameters()).device
73
+
74
+ # model_params = torch.load(model_params_path, map_location='cpu')
75
+ # self.denoiser.load_state_dict(model_params)
76
+
77
+ self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
78
+ dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
79
+ )
80
+ x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
81
+ data_info.node_types.to(self.model_dtype)
82
+ )
83
+ e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
84
+ data_info.edge_types.to(self.model_dtype)
85
+ )
86
+ x_marginals = x_marginals / x_marginals.sum()
87
+ e_marginals = e_marginals / e_marginals.sum()
88
+
89
+ xe_conditions = data_info.transition_E.to(self.model_dtype)
90
+ xe_conditions = xe_conditions[self.active_index][:, self.active_index]
91
+
92
+ xe_conditions = xe_conditions.sum(dim=1)
93
+ ex_conditions = xe_conditions.t()
94
+ xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
95
+ ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
96
+
97
+ self.transition_model = utils.MarginalTransition(
98
+ x_marginals=x_marginals,
99
+ e_marginals=e_marginals,
100
+ xe_conditions=xe_conditions,
101
+ ex_conditions=ex_conditions,
102
+ y_classes=self.ydim_output,
103
+ n_nodes=self.max_n_nodes,
104
+ )
105
+ self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
106
+
107
+ # def to(self, *args, **kwargs):
108
+ # self = super().to(*args, **kwargs)
109
+ # self.model_dtype = next(self.denoiser.parameters()).dtype
110
+ # return self
111
+
112
+ def init_model(self, model_dir, verbose=False):
113
+ model_file = os.path.join(model_dir, 'model.pt')
114
+ if os.path.exists(model_file):
115
+ self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
116
+ else:
117
+ raise FileNotFoundError(f"Model file not found: {model_file}")
118
+
119
+ if verbose:
120
+ print('GraphDiT Denoiser Model initialized.')
121
+ print('Denoiser model:\n', self.denoiser)
122
+
123
+ def save_pretrained(self, output_dir):
124
+ if not os.path.exists(output_dir):
125
+ os.makedirs(output_dir)
126
+
127
+ # Save model
128
+ model_path = os.path.join(output_dir, 'model.pt')
129
+ torch.save(self.denoiser.state_dict(), model_path)
130
+
131
+ # Save model config
132
+ config_path = os.path.join(output_dir, 'model_config.yaml')
133
+ with open(config_path, 'w') as f:
134
+ yaml.dump(vars(self.model_config), f)
135
+
136
+ # Save data info
137
+ data_info_path = os.path.join(output_dir, 'data.meta.json')
138
+ data_info_dict = {
139
+ "active_atoms": self.data_info.active_atoms,
140
+ "max_node": self.data_info.max_n_nodes,
141
+ "n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(),
142
+ "bond_type_dist": self.data_info.edge_types.tolist(),
143
+ "transition_E": self.data_info.transition_E.tolist(),
144
+ "atom_type_dist": self.data_info.node_types.tolist(),
145
+ "valencies": self.data_info.valency_distribution.tolist()
146
+ }
147
+ with open(data_info_path, 'w') as f:
148
+ json.dump(data_info_dict, f, indent=2)
149
+
150
+ print('GraphDiT Model and configurations saved to:', output_dir)
151
+
152
+ def disable_grads(self):
153
+ self.denoiser.disable_grads()
154
+
155
+ def forward(
156
+ self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
157
+ ):
158
+ raise ValueError('Not Implement')
159
+
160
+ def _forward(self, noisy_data, unconditioned=False):
161
+ noisy_x, noisy_e, properties = (
162
+ noisy_data["X_t"].to(self.model_dtype),
163
+ noisy_data["E_t"].to(self.model_dtype),
164
+ noisy_data["y_t"].to(self.model_dtype).clone(),
165
+ )
166
+ node_mask, timestep = (
167
+ noisy_data["node_mask"],
168
+ noisy_data["t"],
169
+ )
170
+
171
+ pred = self.denoiser(
172
+ noisy_x,
173
+ noisy_e,
174
+ node_mask,
175
+ properties,
176
+ timestep,
177
+ unconditioned=unconditioned,
178
+ )
179
+ return pred
180
+
181
+ def apply_noise(self, X, E, y, node_mask):
182
+ """Sample noise and apply it to the data."""
183
+
184
+ # Sample a timestep t.
185
+ # When evaluating, the loss for t=0 is computed separately
186
+ lowest_t = 0 if self.training else 1
187
+ t_int = torch.randint(
188
+ lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
189
+ ).to(
190
+ self.model_dtype
191
+ ) # (bs, 1)
192
+ s_int = t_int - 1
193
+
194
+ t_float = t_int / self.T
195
+ s_float = s_int / self.T
196
+
197
+ # beta_t and alpha_s_bar are used for denoising/loss computation
198
+ beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
199
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
200
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
201
+
202
+ Qtb = self.transition_model.get_Qt_bar(
203
+ alpha_t_bar, X.device
204
+ ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
205
+
206
+ bs, n, d = X.shape
207
+ X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
208
+ prob_all = X_all @ Qtb.X
209
+ probX = prob_all[:, :, : self.Xdim_output]
210
+ probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
211
+
212
+ sampled_t = utils.sample_discrete_features(
213
+ probX=probX, probE=probE, node_mask=node_mask
214
+ )
215
+
216
+ X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
217
+ E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
218
+ assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
219
+
220
+ y_t = y
221
+ z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
222
+
223
+ noisy_data = {
224
+ "t_int": t_int,
225
+ "t": t_float,
226
+ "beta_t": beta_t,
227
+ "alpha_s_bar": alpha_s_bar,
228
+ "alpha_t_bar": alpha_t_bar,
229
+ "X_t": z_t.X,
230
+ "E_t": z_t.E,
231
+ "y_t": z_t.y,
232
+ "node_mask": node_mask,
233
+ }
234
+ return noisy_data
235
+
236
+ @torch.no_grad()
237
+ def generate(
238
+ self,
239
+ properties,
240
+ device,
241
+ guide_scale=1.,
242
+ num_nodes=None,
243
+ number_chain_steps=50,
244
+ ):
245
+ properties = [float('nan') if x is None else x for x in properties]
246
+ properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
247
+ batch_size = properties.size(0)
248
+ assert batch_size == 1
249
+ # print('self.denoiser.dtype', self.model_dtype)
250
+ if num_nodes is None:
251
+ num_nodes = self.node_dist.sample_n(batch_size, device)
252
+ else:
253
+ num_nodes = torch.LongTensor([num_nodes]).to(device)
254
+
255
+ arange = (
256
+ torch.arange(self.max_n_nodes, device=device)
257
+ .unsqueeze(0)
258
+ .expand(batch_size, -1)
259
+ )
260
+ node_mask = arange < num_nodes.unsqueeze(1)
261
+
262
+ z_T = utils.sample_discrete_feature_noise(
263
+ limit_dist=self.limit_dist, node_mask=node_mask
264
+ )
265
+ X, E = z_T.X, z_T.E
266
+
267
+ assert (E == torch.transpose(E, 1, 2)).all()
268
+
269
+ if number_chain_steps > 0:
270
+ chain_X_size = torch.Size((number_chain_steps, X.size(1)))
271
+ chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
272
+ chain_X = torch.zeros(chain_X_size)
273
+ chain_E = torch.zeros(chain_E_size)
274
+
275
+ # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
276
+ y = properties
277
+ for s_int in reversed(range(0, self.T)):
278
+ s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
279
+ t_array = s_array + 1
280
+ s_norm = s_array / self.T
281
+ t_norm = t_array / self.T
282
+
283
+ # Sample z_s
284
+ sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
285
+ s_norm, t_norm, X, E, y, node_mask, guide_scale, device
286
+ )
287
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
288
+
289
+ if number_chain_steps > 0:
290
+ # Save the first keep_chain graphs
291
+ write_index = (s_int * number_chain_steps) // self.T
292
+ chain_X[write_index] = discrete_sampled_s.X[:1]
293
+ chain_E[write_index] = discrete_sampled_s.E[:1]
294
+
295
+ # Sample
296
+ sampled_s = sampled_s.mask(node_mask, collapse=True)
297
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
298
+
299
+ molecule_list = []
300
+ n = num_nodes[0]
301
+ atom_types = X[0, :n].cpu()
302
+ edge_types = E[0, :n, :n].cpu()
303
+ molecule_list.append([atom_types, edge_types])
304
+ smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
305
+
306
+ # Visualize Chains
307
+ if number_chain_steps > 0:
308
+ final_X_chain = X[:1]
309
+ final_E_chain = E[:1]
310
+
311
+ chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
312
+ chain_E[0] = final_E_chain
313
+
314
+ chain_X = utils.reverse_tensor(chain_X)
315
+ chain_E = utils.reverse_tensor(chain_E)
316
+
317
+ # Repeat last frame to see final sample better
318
+ chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
319
+ chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
320
+ mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
321
+ else:
322
+ mol_img_list = []
323
+
324
+ return smiles, mol_img_list
325
+
326
+ def check_valid(self, smiles):
327
+ return check_valid(smiles)
328
+
329
+ def sample_p_zs_given_zt(
330
+ self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
331
+ ):
332
+ """Samples from zs ~ p(zs | zt). Only used during sampling.
333
+ if last_step, return the graph prediction as well"""
334
+ bs, n, _ = X_t.shape
335
+ beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
336
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
337
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
338
+
339
+ # Neural net predictions
340
+ noisy_data = {
341
+ "X_t": X_t,
342
+ "E_t": E_t,
343
+ "y_t": properties,
344
+ "t": t,
345
+ "node_mask": node_mask,
346
+ }
347
+
348
+ def get_prob(noisy_data, unconditioned=False):
349
+ pred = self._forward(noisy_data, unconditioned=unconditioned)
350
+
351
+ # Normalize predictions
352
+ pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
353
+ pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
354
+
355
+ # Retrieve transitions matrix
356
+ Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
357
+ Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
358
+ Qt = self.transition_model.get_Qt(beta_t, device)
359
+
360
+ Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
361
+ predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
362
+
363
+ unnormalized_probX_all = utils.reverse_diffusion(
364
+ predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
365
+ )
366
+
367
+ unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
368
+ unnormalized_prob_E = unnormalized_probX_all[
369
+ :, :, self.Xdim_output :
370
+ ].reshape(bs, n * n, -1)
371
+
372
+ unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
373
+ unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
374
+
375
+ prob_X = unnormalized_prob_X / torch.sum(
376
+ unnormalized_prob_X, dim=-1, keepdim=True
377
+ ) # bs, n, d_t-1
378
+ prob_E = unnormalized_prob_E / torch.sum(
379
+ unnormalized_prob_E, dim=-1, keepdim=True
380
+ ) # bs, n, d_t-1
381
+ prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
382
+
383
+ return prob_X, prob_E
384
+
385
+ prob_X, prob_E = get_prob(noisy_data)
386
+
387
+ ### Guidance
388
+ if guide_scale != 1:
389
+ uncon_prob_X, uncon_prob_E = get_prob(
390
+ noisy_data, unconditioned=True
391
+ )
392
+ prob_X = (
393
+ uncon_prob_X
394
+ * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
395
+ )
396
+ prob_E = (
397
+ uncon_prob_E
398
+ * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
399
+ )
400
+ prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
401
+ prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
402
+
403
+ # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
404
+ # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
405
+
406
+ sampled_s = utils.sample_discrete_features(
407
+ prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
408
+ )
409
+
410
+ X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
411
+ E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
412
+
413
+ assert (E_s == torch.transpose(E_s, 1, 2)).all()
414
+ assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
415
+
416
+ out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
417
+ out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
418
+
419
+ return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
420
+ node_mask, collapse=True
421
+ ).type_as(properties)
422
+