liuganghuggingface commited on
Commit
7c67898
1 Parent(s): 71df9ee

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +354 -354
graph_decoder/diffusion_model.py CHANGED
@@ -11,25 +11,6 @@ from .transformer import Transformer
11
  from .molecule_utils import graph_to_smiles, check_valid
12
  from .visualize_utils import MolecularVisualization
13
 
14
- class GraphDiT(nn.Module):
15
- def __init__(
16
- self,
17
- model_config_path,
18
- data_info_path,
19
- model_dtype,
20
- ):
21
- super().__init__()
22
-
23
- def init_model(self, model_dir):
24
- pass
25
-
26
- def disable_grads(self):
27
- pass
28
-
29
- def generate(self, properties, guide_scale, num_nodes, number_chain_steps):
30
- return 0, 0
31
-
32
-
33
  # class GraphDiT(nn.Module):
34
  # def __init__(
35
  # self,
@@ -38,346 +19,365 @@ class GraphDiT(nn.Module):
38
  # model_dtype,
39
  # ):
40
  # super().__init__()
41
- # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
42
-
43
- # input_dims = data_info.input_dims
44
- # output_dims = data_info.output_dims
45
- # nodes_dist = data_info.nodes_dist
46
- # active_index = data_info.active_index
47
-
48
- # self.model_config = dm_cfg
49
- # self.data_info = data_info
50
- # self.T = dm_cfg.diffusion_steps
51
- # self.Xdim = input_dims["X"]
52
- # self.Edim = input_dims["E"]
53
- # self.ydim = input_dims["y"]
54
- # self.Xdim_output = output_dims["X"]
55
- # self.Edim_output = output_dims["E"]
56
- # self.ydim_output = output_dims["y"]
57
- # self.node_dist = nodes_dist
58
- # self.active_index = active_index
59
- # self.max_n_nodes = data_info.max_n_nodes
60
- # self.atom_decoder = data_info.atom_decoder
61
- # self.hidden_size = dm_cfg.hidden_size
62
- # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
63
-
64
- # self.denoiser = Transformer(
65
- # max_n_nodes=self.max_n_nodes,
66
- # hidden_size=dm_cfg.hidden_size,
67
- # depth=dm_cfg.depth,
68
- # num_heads=dm_cfg.num_heads,
69
- # mlp_ratio=dm_cfg.mlp_ratio,
70
- # drop_condition=dm_cfg.drop_condition,
71
- # Xdim=self.Xdim,
72
- # Edim=self.Edim,
73
- # ydim=self.ydim,
74
- # )
75
-
76
- # self.model_dtype = model_dtype
77
- # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
78
- # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
79
- # )
80
- # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
81
- # data_info.node_types.to(self.model_dtype)
82
- # )
83
- # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
84
- # data_info.edge_types.to(self.model_dtype)
85
- # )
86
- # x_marginals = x_marginals / x_marginals.sum()
87
- # e_marginals = e_marginals / e_marginals.sum()
88
-
89
- # xe_conditions = data_info.transition_E.to(self.model_dtype)
90
- # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
91
-
92
- # xe_conditions = xe_conditions.sum(dim=1)
93
- # ex_conditions = xe_conditions.t()
94
- # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
95
- # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
96
-
97
- # self.transition_model = utils.MarginalTransition(
98
- # x_marginals=x_marginals,
99
- # e_marginals=e_marginals,
100
- # xe_conditions=xe_conditions,
101
- # ex_conditions=ex_conditions,
102
- # y_classes=self.ydim_output,
103
- # n_nodes=self.max_n_nodes,
104
- # )
105
- # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
106
 
107
  # def init_model(self, model_dir):
108
- # model_file = os.path.join(model_dir, 'model.pt')
109
- # if os.path.exists(model_file):
110
- # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
111
- # else:
112
- # raise FileNotFoundError(f"Model file not found: {model_file}")
113
-
114
  # def disable_grads(self):
115
- # self.denoiser.disable_grads()
 
 
 
116
 
117
- # def forward(
118
- # self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
119
- # ):
120
- # raise ValueError('Not Implement')
121
-
122
- # def _forward(self, noisy_data, unconditioned=False):
123
- # noisy_x, noisy_e, properties = (
124
- # noisy_data["X_t"].to(self.model_dtype),
125
- # noisy_data["E_t"].to(self.model_dtype),
126
- # noisy_data["y_t"].to(self.model_dtype).clone(),
127
- # )
128
- # node_mask, timestep = (
129
- # noisy_data["node_mask"],
130
- # noisy_data["t"],
131
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # pred = self.denoiser(
134
- # noisy_x,
135
- # noisy_e,
136
- # node_mask,
137
- # properties,
138
- # timestep,
139
- # unconditioned=unconditioned,
140
- # )
141
- # return pred
142
-
143
- # def apply_noise(self, X, E, y, node_mask):
144
- # """Sample noise and apply it to the data."""
145
-
146
- # # Sample a timestep t.
147
- # # When evaluating, the loss for t=0 is computed separately
148
- # lowest_t = 0 if self.training else 1
149
- # t_int = torch.randint(
150
- # lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
151
- # ).to(
152
- # self.model_dtype
153
- # ) # (bs, 1)
154
- # s_int = t_int - 1
155
-
156
- # t_float = t_int / self.T
157
- # s_float = s_int / self.T
158
-
159
- # # beta_t and alpha_s_bar are used for denoising/loss computation
160
- # beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
161
- # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
162
- # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
163
-
164
- # Qtb = self.transition_model.get_Qt_bar(
165
- # alpha_t_bar, X.device
166
- # ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
167
-
168
- # bs, n, d = X.shape
169
- # X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
170
- # prob_all = X_all @ Qtb.X
171
- # probX = prob_all[:, :, : self.Xdim_output]
172
- # probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
173
-
174
- # sampled_t = utils.sample_discrete_features(
175
- # probX=probX, probE=probE, node_mask=node_mask
176
- # )
177
-
178
- # X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
179
- # E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
180
- # assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
181
-
182
- # y_t = y
183
- # z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
184
-
185
- # noisy_data = {
186
- # "t_int": t_int,
187
- # "t": t_float,
188
- # "beta_t": beta_t,
189
- # "alpha_s_bar": alpha_s_bar,
190
- # "alpha_t_bar": alpha_t_bar,
191
- # "X_t": z_t.X,
192
- # "E_t": z_t.E,
193
- # "y_t": z_t.y,
194
- # "node_mask": node_mask,
195
- # }
196
- # return noisy_data
197
-
198
- # @torch.no_grad()
199
- # def generate(
200
- # self,
201
- # properties,
202
- # guide_scale=1.,
203
- # num_nodes=None,
204
- # number_chain_steps=50,
205
- # ):
206
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
- # properties = [float('nan') if x is None else x for x in properties]
208
- # properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
209
- # batch_size = properties.size(0)
210
- # assert batch_size == 1
211
- # if num_nodes is None:
212
- # num_nodes = self.node_dist.sample_n(batch_size, device)
213
- # else:
214
- # num_nodes = torch.LongTensor([num_nodes]).to(device)
215
-
216
- # arange = (
217
- # torch.arange(self.max_n_nodes, device=device)
218
- # .unsqueeze(0)
219
- # .expand(batch_size, -1)
220
- # )
221
- # node_mask = arange < num_nodes.unsqueeze(1)
222
-
223
- # z_T = utils.sample_discrete_feature_noise(
224
- # limit_dist=self.limit_dist, node_mask=node_mask
225
- # )
226
- # X, E = z_T.X, z_T.E
227
-
228
- # assert (E == torch.transpose(E, 1, 2)).all()
229
-
230
- # if number_chain_steps > 0:
231
- # chain_X_size = torch.Size((number_chain_steps, X.size(1)))
232
- # chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
233
- # chain_X = torch.zeros(chain_X_size)
234
- # chain_E = torch.zeros(chain_E_size)
235
-
236
- # # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
237
- # y = properties
238
- # for s_int in reversed(range(0, self.T)):
239
- # s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
240
- # t_array = s_array + 1
241
- # s_norm = s_array / self.T
242
- # t_norm = t_array / self.T
243
-
244
- # # Sample z_s
245
- # sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
246
- # s_norm, t_norm, X, E, y, node_mask, guide_scale, device
247
- # )
248
- # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
249
 
250
- # if number_chain_steps > 0:
251
- # # Save the first keep_chain graphs
252
- # write_index = (s_int * number_chain_steps) // self.T
253
- # chain_X[write_index] = discrete_sampled_s.X[:1]
254
- # chain_E[write_index] = discrete_sampled_s.E[:1]
255
-
256
- # # Sample
257
- # sampled_s = sampled_s.mask(node_mask, collapse=True)
258
- # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
259
-
260
- # molecule_list = []
261
- # n = num_nodes[0]
262
- # atom_types = X[0, :n].cpu()
263
- # edge_types = E[0, :n, :n].cpu()
264
- # molecule_list.append([atom_types, edge_types])
265
- # smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
266
-
267
- # # Visualize Chains
268
- # if number_chain_steps > 0:
269
- # final_X_chain = X[:1]
270
- # final_E_chain = E[:1]
271
-
272
- # chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
273
- # chain_E[0] = final_E_chain
274
-
275
- # chain_X = utils.reverse_tensor(chain_X)
276
- # chain_E = utils.reverse_tensor(chain_E)
277
-
278
- # # Repeat last frame to see final sample better
279
- # chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
280
- # chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
281
- # mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
282
- # else:
283
- # mol_img_list = []
284
-
285
- # return smiles, mol_img_list
286
-
287
- # def check_valid(self, smiles):
288
- # return check_valid(smiles)
289
 
290
- # def sample_p_zs_given_zt(
291
- # self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
292
- # ):
293
- # """Samples from zs ~ p(zs | zt). Only used during sampling.
294
- # if last_step, return the graph prediction as well"""
295
- # bs, n, _ = X_t.shape
296
- # beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
297
- # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
298
- # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
299
-
300
- # # Neural net predictions
301
- # noisy_data = {
302
- # "X_t": X_t,
303
- # "E_t": E_t,
304
- # "y_t": properties,
305
- # "t": t,
306
- # "node_mask": node_mask,
307
- # }
308
-
309
- # def get_prob(noisy_data, unconditioned=False):
310
- # pred = self._forward(noisy_data, unconditioned=unconditioned)
311
-
312
- # # Normalize predictions
313
- # pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
314
- # pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
315
-
316
- # # Retrieve transitions matrix
317
- # Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
318
- # Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
319
- # Qt = self.transition_model.get_Qt(beta_t, device)
320
-
321
- # Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
322
- # predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
323
-
324
- # unnormalized_probX_all = utils.reverse_diffusion(
325
- # predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
326
- # )
327
-
328
- # unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
329
- # unnormalized_prob_E = unnormalized_probX_all[
330
- # :, :, self.Xdim_output :
331
- # ].reshape(bs, n * n, -1)
332
-
333
- # unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
334
- # unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
335
-
336
- # prob_X = unnormalized_prob_X / torch.sum(
337
- # unnormalized_prob_X, dim=-1, keepdim=True
338
- # ) # bs, n, d_t-1
339
- # prob_E = unnormalized_prob_E / torch.sum(
340
- # unnormalized_prob_E, dim=-1, keepdim=True
341
- # ) # bs, n, d_t-1
342
- # prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
343
-
344
- # return prob_X, prob_E
345
-
346
- # prob_X, prob_E = get_prob(noisy_data)
347
-
348
- # ### Guidance
349
- # if guide_scale != 1:
350
- # uncon_prob_X, uncon_prob_E = get_prob(
351
- # noisy_data, unconditioned=True
352
- # )
353
- # prob_X = (
354
- # uncon_prob_X
355
- # * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
356
- # )
357
- # prob_E = (
358
- # uncon_prob_E
359
- # * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
360
- # )
361
- # prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
362
- # prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
363
-
364
- # # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
365
- # # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
366
-
367
- # sampled_s = utils.sample_discrete_features(
368
- # prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
369
- # )
370
-
371
- # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
372
- # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
373
-
374
- # assert (E_s == torch.transpose(E_s, 1, 2)).all()
375
- # assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
376
-
377
- # out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
378
- # out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
379
-
380
- # return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
381
- # node_mask, collapse=True
382
- # ).type_as(properties)
383
 
 
11
  from .molecule_utils import graph_to_smiles, check_valid
12
  from .visualize_utils import MolecularVisualization
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # class GraphDiT(nn.Module):
15
  # def __init__(
16
  # self,
 
19
  # model_dtype,
20
  # ):
21
  # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # def init_model(self, model_dir):
24
+ # pass
25
+
 
 
 
 
26
  # def disable_grads(self):
27
+ # pass
28
+
29
+ # def generate(self, properties, guide_scale, num_nodes, number_chain_steps):
30
+ # return 0, 0
31
 
32
+
33
+ class GraphDiT(nn.Module):
34
+ def __init__(
35
+ self,
36
+ model_config_path,
37
+ data_info_path,
38
+ model_dtype,
39
+ ):
40
+ super().__init__()
41
+ dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
42
+
43
+ input_dims = data_info.input_dims
44
+ output_dims = data_info.output_dims
45
+ nodes_dist = data_info.nodes_dist
46
+ active_index = data_info.active_index
47
+
48
+ self.model_config = dm_cfg
49
+ self.data_info = data_info
50
+ self.T = dm_cfg.diffusion_steps
51
+ self.Xdim = input_dims["X"]
52
+ self.Edim = input_dims["E"]
53
+ self.ydim = input_dims["y"]
54
+ self.Xdim_output = output_dims["X"]
55
+ self.Edim_output = output_dims["E"]
56
+ self.ydim_output = output_dims["y"]
57
+ self.node_dist = nodes_dist
58
+ self.active_index = active_index
59
+ self.max_n_nodes = data_info.max_n_nodes
60
+ self.atom_decoder = data_info.atom_decoder
61
+ self.hidden_size = dm_cfg.hidden_size
62
+ self.mol_visualizer = MolecularVisualization(self.atom_decoder)
63
+
64
+ self.denoiser = Transformer(
65
+ max_n_nodes=self.max_n_nodes,
66
+ hidden_size=dm_cfg.hidden_size,
67
+ depth=dm_cfg.depth,
68
+ num_heads=dm_cfg.num_heads,
69
+ mlp_ratio=dm_cfg.mlp_ratio,
70
+ drop_condition=dm_cfg.drop_condition,
71
+ Xdim=self.Xdim,
72
+ Edim=self.Edim,
73
+ ydim=self.ydim,
74
+ )
75
+
76
+ self.model_dtype = model_dtype
77
+ self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
78
+ dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
79
+ )
80
+ x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
81
+ data_info.node_types.to(self.model_dtype)
82
+ )
83
+ e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
84
+ data_info.edge_types.to(self.model_dtype)
85
+ )
86
+ x_marginals = x_marginals / x_marginals.sum()
87
+ e_marginals = e_marginals / e_marginals.sum()
88
+
89
+ xe_conditions = data_info.transition_E.to(self.model_dtype)
90
+ xe_conditions = xe_conditions[self.active_index][:, self.active_index]
91
+
92
+ xe_conditions = xe_conditions.sum(dim=1)
93
+ ex_conditions = xe_conditions.t()
94
+ xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
95
+ ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
96
+
97
+ self.transition_model = utils.MarginalTransition(
98
+ x_marginals=x_marginals,
99
+ e_marginals=e_marginals,
100
+ xe_conditions=xe_conditions,
101
+ ex_conditions=ex_conditions,
102
+ y_classes=self.ydim_output,
103
+ n_nodes=self.max_n_nodes,
104
+ )
105
+ self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
106
+
107
+ def init_model(self, model_dir):
108
+ model_file = os.path.join(model_dir, 'model.pt')
109
+ if os.path.exists(model_file):
110
+ self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
111
+ else:
112
+ raise FileNotFoundError(f"Model file not found: {model_file}")
113
+
114
+ def disable_grads(self):
115
+ self.denoiser.disable_grads()
116
+
117
+ def forward(
118
+ self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
119
+ ):
120
+ raise ValueError('Not Implement')
121
+
122
+ def _forward(self, noisy_data, unconditioned=False):
123
+ noisy_x, noisy_e, properties = (
124
+ noisy_data["X_t"].to(self.model_dtype),
125
+ noisy_data["E_t"].to(self.model_dtype),
126
+ noisy_data["y_t"].to(self.model_dtype).clone(),
127
+ )
128
+ node_mask, timestep = (
129
+ noisy_data["node_mask"],
130
+ noisy_data["t"],
131
+ )
132
 
133
+ pred = self.denoiser(
134
+ noisy_x,
135
+ noisy_e,
136
+ node_mask,
137
+ properties,
138
+ timestep,
139
+ unconditioned=unconditioned,
140
+ )
141
+ return pred
142
+
143
+ def apply_noise(self, X, E, y, node_mask):
144
+ """Sample noise and apply it to the data."""
145
+
146
+ # Sample a timestep t.
147
+ # When evaluating, the loss for t=0 is computed separately
148
+ lowest_t = 0 if self.training else 1
149
+ t_int = torch.randint(
150
+ lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
151
+ ).to(
152
+ self.model_dtype
153
+ ) # (bs, 1)
154
+ s_int = t_int - 1
155
+
156
+ t_float = t_int / self.T
157
+ s_float = s_int / self.T
158
+
159
+ # beta_t and alpha_s_bar are used for denoising/loss computation
160
+ beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
161
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
162
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
163
+
164
+ Qtb = self.transition_model.get_Qt_bar(
165
+ alpha_t_bar, X.device
166
+ ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
167
+
168
+ bs, n, d = X.shape
169
+ X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
170
+ prob_all = X_all @ Qtb.X
171
+ probX = prob_all[:, :, : self.Xdim_output]
172
+ probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
173
+
174
+ sampled_t = utils.sample_discrete_features(
175
+ probX=probX, probE=probE, node_mask=node_mask
176
+ )
177
+
178
+ X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
179
+ E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
180
+ assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
181
+
182
+ y_t = y
183
+ z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
184
+
185
+ noisy_data = {
186
+ "t_int": t_int,
187
+ "t": t_float,
188
+ "beta_t": beta_t,
189
+ "alpha_s_bar": alpha_s_bar,
190
+ "alpha_t_bar": alpha_t_bar,
191
+ "X_t": z_t.X,
192
+ "E_t": z_t.E,
193
+ "y_t": z_t.y,
194
+ "node_mask": node_mask,
195
+ }
196
+ return noisy_data
197
+
198
+ @torch.no_grad()
199
+ def generate(
200
+ self,
201
+ properties,
202
+ guide_scale=1.,
203
+ num_nodes=None,
204
+ number_chain_steps=50,
205
+ ):
206
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
+ properties = [float('nan') if x is None else x for x in properties]
208
+ properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
209
+ batch_size = properties.size(0)
210
+ assert batch_size == 1
211
+ if num_nodes is None:
212
+ num_nodes = self.node_dist.sample_n(batch_size, device)
213
+ else:
214
+ num_nodes = torch.LongTensor([num_nodes]).to(device)
215
+
216
+ arange = (
217
+ torch.arange(self.max_n_nodes, device=device)
218
+ .unsqueeze(0)
219
+ .expand(batch_size, -1)
220
+ )
221
+ node_mask = arange < num_nodes.unsqueeze(1)
222
+
223
+ z_T = utils.sample_discrete_feature_noise(
224
+ limit_dist=self.limit_dist, node_mask=node_mask
225
+ )
226
+ X, E = z_T.X, z_T.E
227
+
228
+ assert (E == torch.transpose(E, 1, 2)).all()
229
+
230
+ if number_chain_steps > 0:
231
+ chain_X_size = torch.Size((number_chain_steps, X.size(1)))
232
+ chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
233
+ chain_X = torch.zeros(chain_X_size)
234
+ chain_E = torch.zeros(chain_E_size)
235
+
236
+ # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
237
+ y = properties
238
+ for s_int in reversed(range(0, self.T)):
239
+ s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
240
+ t_array = s_array + 1
241
+ s_norm = s_array / self.T
242
+ t_norm = t_array / self.T
243
+
244
+ # Sample z_s
245
+ sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
246
+ s_norm, t_norm, X, E, y, node_mask, guide_scale, device
247
+ )
248
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
249
 
250
+ if number_chain_steps > 0:
251
+ # Save the first keep_chain graphs
252
+ write_index = (s_int * number_chain_steps) // self.T
253
+ chain_X[write_index] = discrete_sampled_s.X[:1]
254
+ chain_E[write_index] = discrete_sampled_s.E[:1]
255
+
256
+ # Sample
257
+ sampled_s = sampled_s.mask(node_mask, collapse=True)
258
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
259
+
260
+ molecule_list = []
261
+ n = num_nodes[0]
262
+ atom_types = X[0, :n].cpu()
263
+ edge_types = E[0, :n, :n].cpu()
264
+ molecule_list.append([atom_types, edge_types])
265
+ smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
266
+
267
+ # Visualize Chains
268
+ if number_chain_steps > 0:
269
+ final_X_chain = X[:1]
270
+ final_E_chain = E[:1]
271
+
272
+ chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
273
+ chain_E[0] = final_E_chain
274
+
275
+ chain_X = utils.reverse_tensor(chain_X)
276
+ chain_E = utils.reverse_tensor(chain_E)
277
+
278
+ # Repeat last frame to see final sample better
279
+ chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
280
+ chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
281
+ mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
282
+ else:
283
+ mol_img_list = []
284
+
285
+ return smiles, mol_img_list
286
+
287
+ def check_valid(self, smiles):
288
+ return check_valid(smiles)
289
 
290
+ def sample_p_zs_given_zt(
291
+ self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
292
+ ):
293
+ """Samples from zs ~ p(zs | zt). Only used during sampling.
294
+ if last_step, return the graph prediction as well"""
295
+ bs, n, _ = X_t.shape
296
+ beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
297
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
298
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
299
+
300
+ # Neural net predictions
301
+ noisy_data = {
302
+ "X_t": X_t,
303
+ "E_t": E_t,
304
+ "y_t": properties,
305
+ "t": t,
306
+ "node_mask": node_mask,
307
+ }
308
+
309
+ def get_prob(noisy_data, unconditioned=False):
310
+ pred = self._forward(noisy_data, unconditioned=unconditioned)
311
+
312
+ # Normalize predictions
313
+ pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
314
+ pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
315
+
316
+ # Retrieve transitions matrix
317
+ Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
318
+ Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
319
+ Qt = self.transition_model.get_Qt(beta_t, device)
320
+
321
+ Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
322
+ predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
323
+
324
+ unnormalized_probX_all = utils.reverse_diffusion(
325
+ predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
326
+ )
327
+
328
+ unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
329
+ unnormalized_prob_E = unnormalized_probX_all[
330
+ :, :, self.Xdim_output :
331
+ ].reshape(bs, n * n, -1)
332
+
333
+ unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
334
+ unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
335
+
336
+ prob_X = unnormalized_prob_X / torch.sum(
337
+ unnormalized_prob_X, dim=-1, keepdim=True
338
+ ) # bs, n, d_t-1
339
+ prob_E = unnormalized_prob_E / torch.sum(
340
+ unnormalized_prob_E, dim=-1, keepdim=True
341
+ ) # bs, n, d_t-1
342
+ prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
343
+
344
+ return prob_X, prob_E
345
+
346
+ prob_X, prob_E = get_prob(noisy_data)
347
+
348
+ ### Guidance
349
+ if guide_scale != 1:
350
+ uncon_prob_X, uncon_prob_E = get_prob(
351
+ noisy_data, unconditioned=True
352
+ )
353
+ prob_X = (
354
+ uncon_prob_X
355
+ * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
356
+ )
357
+ prob_E = (
358
+ uncon_prob_E
359
+ * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
360
+ )
361
+ prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
362
+ prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
363
+
364
+ # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
365
+ # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
366
+
367
+ sampled_s = utils.sample_discrete_features(
368
+ prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
369
+ )
370
+
371
+ X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
372
+ E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
373
+
374
+ assert (E_s == torch.transpose(E_s, 1, 2)).all()
375
+ assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
376
+
377
+ out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
378
+ out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
379
+
380
+ return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
381
+ node_mask, collapse=True
382
+ ).type_as(properties)
383