liuganghuggingface commited on
Commit
89df0a5
1 Parent(s): e513fba

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +12 -366
app.py CHANGED
@@ -5,375 +5,22 @@ import torch.nn as nn
5
  import random
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
 
8
 
9
- #####
10
-
11
- import os
12
- import yaml
13
- import json
14
-
15
- import torch.nn.functional as F
16
- from graph_decoder import diffusion_utils as utils
17
- from graph_decoder.molecule_utils import graph_to_smiles, check_valid
18
- from graph_decoder.transformer import Transformer
19
- from graph_decoder.visualize_utils import MolecularVisualization
20
-
21
- class GraphDiT(nn.Module):
22
- def __init__(
23
- self,
24
- model_config_path,
25
- data_info_path,
26
- model_dtype,
27
- ):
28
  super().__init__()
29
- pass
30
-
31
- # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
32
-
33
- # input_dims = data_info.input_dims
34
- # output_dims = data_info.output_dims
35
- # nodes_dist = data_info.nodes_dist
36
- # active_index = data_info.active_index
37
-
38
- # self.model_config = dm_cfg
39
- # self.data_info = data_info
40
- # self.T = dm_cfg.diffusion_steps
41
- # self.Xdim = input_dims["X"]
42
- # self.Edim = input_dims["E"]
43
- # self.ydim = input_dims["y"]
44
- # self.Xdim_output = output_dims["X"]
45
- # self.Edim_output = output_dims["E"]
46
- # self.ydim_output = output_dims["y"]
47
- # self.node_dist = nodes_dist
48
- # self.active_index = active_index
49
- # self.max_n_nodes = data_info.max_n_nodes
50
- # self.atom_decoder = data_info.atom_decoder
51
- # self.hidden_size = dm_cfg.hidden_size
52
- # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
53
-
54
- # self.denoiser = Transformer(
55
- # max_n_nodes=self.max_n_nodes,
56
- # hidden_size=dm_cfg.hidden_size,
57
- # depth=dm_cfg.depth,
58
- # num_heads=dm_cfg.num_heads,
59
- # mlp_ratio=dm_cfg.mlp_ratio,
60
- # drop_condition=dm_cfg.drop_condition,
61
- # Xdim=self.Xdim,
62
- # Edim=self.Edim,
63
- # ydim=self.ydim,
64
- # )
65
-
66
- # self.model_dtype = model_dtype
67
- # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
68
- # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
69
- # )
70
- # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
71
- # data_info.node_types.to(self.model_dtype)
72
- # )
73
- # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
74
- # data_info.edge_types.to(self.model_dtype)
75
- # )
76
- # x_marginals = x_marginals / x_marginals.sum()
77
- # e_marginals = e_marginals / e_marginals.sum()
78
-
79
- # xe_conditions = data_info.transition_E.to(self.model_dtype)
80
- # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
81
-
82
- # xe_conditions = xe_conditions.sum(dim=1)
83
- # ex_conditions = xe_conditions.t()
84
- # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
85
- # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
86
-
87
- # self.transition_model = utils.MarginalTransition(
88
- # x_marginals=x_marginals,
89
- # e_marginals=e_marginals,
90
- # xe_conditions=xe_conditions,
91
- # ex_conditions=ex_conditions,
92
- # y_classes=self.ydim_output,
93
- # n_nodes=self.max_n_nodes,
94
- # )
95
- # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
96
-
97
- def init_model(self, model_dir):
98
- model_file = os.path.join(model_dir, 'model.pt')
99
- if os.path.exists(model_file):
100
- self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
101
- else:
102
- raise FileNotFoundError(f"Model file not found: {model_file}")
103
-
104
- def disable_grads(self):
105
- self.denoiser.disable_grads()
106
-
107
- def forward(
108
- self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
109
- ):
110
- raise ValueError('Not Implement')
111
-
112
- def _forward(self, noisy_data, unconditioned=False):
113
- noisy_x, noisy_e, properties = (
114
- noisy_data["X_t"].to(self.model_dtype),
115
- noisy_data["E_t"].to(self.model_dtype),
116
- noisy_data["y_t"].to(self.model_dtype).clone(),
117
- )
118
- node_mask, timestep = (
119
- noisy_data["node_mask"],
120
- noisy_data["t"],
121
- )
122
-
123
- pred = self.denoiser(
124
- noisy_x,
125
- noisy_e,
126
- node_mask,
127
- properties,
128
- timestep,
129
- unconditioned=unconditioned,
130
- )
131
- return pred
132
-
133
- def apply_noise(self, X, E, y, node_mask):
134
- """Sample noise and apply it to the data."""
135
-
136
- # Sample a timestep t.
137
- # When evaluating, the loss for t=0 is computed separately
138
- lowest_t = 0 if self.training else 1
139
- t_int = torch.randint(
140
- lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
141
- ).to(
142
- self.model_dtype
143
- ) # (bs, 1)
144
- s_int = t_int - 1
145
-
146
- t_float = t_int / self.T
147
- s_float = s_int / self.T
148
-
149
- # beta_t and alpha_s_bar are used for denoising/loss computation
150
- beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
151
- alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
152
- alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
153
-
154
- Qtb = self.transition_model.get_Qt_bar(
155
- alpha_t_bar, X.device
156
- ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
157
-
158
- bs, n, d = X.shape
159
- X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
160
- prob_all = X_all @ Qtb.X
161
- probX = prob_all[:, :, : self.Xdim_output]
162
- probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
163
-
164
- sampled_t = utils.sample_discrete_features(
165
- probX=probX, probE=probE, node_mask=node_mask
166
- )
167
-
168
- X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
169
- E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
170
- assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
171
-
172
- y_t = y
173
- z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
174
-
175
- noisy_data = {
176
- "t_int": t_int,
177
- "t": t_float,
178
- "beta_t": beta_t,
179
- "alpha_s_bar": alpha_s_bar,
180
- "alpha_t_bar": alpha_t_bar,
181
- "X_t": z_t.X,
182
- "E_t": z_t.E,
183
- "y_t": z_t.y,
184
- "node_mask": node_mask,
185
- }
186
- return noisy_data
187
-
188
- @torch.no_grad()
189
- def generate(
190
- self,
191
- properties,
192
- device,
193
- guide_scale=1.,
194
- num_nodes=None,
195
- number_chain_steps=50,
196
- ):
197
- properties = [float('nan') if x is None else x for x in properties]
198
- properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
199
- batch_size = properties.size(0)
200
- assert batch_size == 1
201
- if num_nodes is None:
202
- num_nodes = self.node_dist.sample_n(batch_size, device)
203
- else:
204
- num_nodes = torch.LongTensor([num_nodes]).to(device)
205
-
206
- arange = (
207
- torch.arange(self.max_n_nodes, device=device)
208
- .unsqueeze(0)
209
- .expand(batch_size, -1)
210
- )
211
- node_mask = arange < num_nodes.unsqueeze(1)
212
-
213
- z_T = utils.sample_discrete_feature_noise(
214
- limit_dist=self.limit_dist, node_mask=node_mask
215
- )
216
- X, E = z_T.X, z_T.E
217
-
218
- assert (E == torch.transpose(E, 1, 2)).all()
219
-
220
- if number_chain_steps > 0:
221
- chain_X_size = torch.Size((number_chain_steps, X.size(1)))
222
- chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
223
- chain_X = torch.zeros(chain_X_size)
224
- chain_E = torch.zeros(chain_E_size)
225
 
226
- # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
227
- y = properties
228
- for s_int in reversed(range(0, self.T)):
229
- s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
230
- t_array = s_array + 1
231
- s_norm = s_array / self.T
232
- t_norm = t_array / self.T
233
 
234
- # Sample z_s
235
- sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
236
- s_norm, t_norm, X, E, y, node_mask, guide_scale, device
237
- )
238
- X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
239
-
240
- if number_chain_steps > 0:
241
- # Save the first keep_chain graphs
242
- write_index = (s_int * number_chain_steps) // self.T
243
- chain_X[write_index] = discrete_sampled_s.X[:1]
244
- chain_E[write_index] = discrete_sampled_s.E[:1]
245
-
246
- # Sample
247
- sampled_s = sampled_s.mask(node_mask, collapse=True)
248
- X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
249
-
250
- molecule_list = []
251
- n = num_nodes[0]
252
- atom_types = X[0, :n].cpu()
253
- edge_types = E[0, :n, :n].cpu()
254
- molecule_list.append([atom_types, edge_types])
255
- smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
256
-
257
- # Visualize Chains
258
- if number_chain_steps > 0:
259
- final_X_chain = X[:1]
260
- final_E_chain = E[:1]
261
-
262
- chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
263
- chain_E[0] = final_E_chain
264
-
265
- chain_X = utils.reverse_tensor(chain_X)
266
- chain_E = utils.reverse_tensor(chain_E)
267
-
268
- # Repeat last frame to see final sample better
269
- chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
270
- chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
271
- mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
272
- else:
273
- mol_img_list = []
274
-
275
- return smiles, mol_img_list
276
-
277
- def check_valid(self, smiles):
278
- return check_valid(smiles)
279
-
280
- def sample_p_zs_given_zt(
281
- self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
282
- ):
283
- """Samples from zs ~ p(zs | zt). Only used during sampling.
284
- if last_step, return the graph prediction as well"""
285
- bs, n, _ = X_t.shape
286
- beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
287
- alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
288
- alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
289
-
290
- # Neural net predictions
291
- noisy_data = {
292
- "X_t": X_t,
293
- "E_t": E_t,
294
- "y_t": properties,
295
- "t": t,
296
- "node_mask": node_mask,
297
- }
298
-
299
- def get_prob(noisy_data, unconditioned=False):
300
- pred = self._forward(noisy_data, unconditioned=unconditioned)
301
-
302
- # Normalize predictions
303
- pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
304
- pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
305
-
306
- # Retrieve transitions matrix
307
- Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
308
- Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
309
- Qt = self.transition_model.get_Qt(beta_t, device)
310
-
311
- Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
312
- predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
313
-
314
- unnormalized_probX_all = utils.reverse_diffusion(
315
- predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
316
- )
317
-
318
- unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
319
- unnormalized_prob_E = unnormalized_probX_all[
320
- :, :, self.Xdim_output :
321
- ].reshape(bs, n * n, -1)
322
-
323
- unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
324
- unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
325
-
326
- prob_X = unnormalized_prob_X / torch.sum(
327
- unnormalized_prob_X, dim=-1, keepdim=True
328
- ) # bs, n, d_t-1
329
- prob_E = unnormalized_prob_E / torch.sum(
330
- unnormalized_prob_E, dim=-1, keepdim=True
331
- ) # bs, n, d_t-1
332
- prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
333
-
334
- return prob_X, prob_E
335
-
336
- prob_X, prob_E = get_prob(noisy_data)
337
-
338
- ### Guidance
339
- if guide_scale != 1:
340
- uncon_prob_X, uncon_prob_E = get_prob(
341
- noisy_data, unconditioned=True
342
- )
343
- prob_X = (
344
- uncon_prob_X
345
- * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
346
- )
347
- prob_E = (
348
- uncon_prob_E
349
- * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
350
- )
351
- prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
352
- prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
353
-
354
- # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
355
- # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
356
-
357
- sampled_s = utils.sample_discrete_features(
358
- prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
359
- )
360
-
361
- X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
362
- E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
363
-
364
- assert (E_s == torch.transpose(E_s, 1, 2)).all()
365
- assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
366
-
367
- out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
368
- out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
369
-
370
- return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
371
- node_mask, collapse=True
372
- ).type_as(properties)
373
-
374
-
375
- #####
376
- # from graph_decoder.diffusion_model import GraphDiT
377
  def load_graph_decoder(path='model_labeled'):
378
  model = GraphDiT(
379
  model_config_path=f"{path}/config.yaml",
@@ -384,7 +31,6 @@ def load_graph_decoder(path='model_labeled'):
384
  # model.disable_grads()
385
  return model
386
 
387
-
388
  ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
389
 
390
  def generate_random_smiles(length=10):
 
5
  import random
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
+ from graph_decoder.diffusion_model import GraphDiT
9
 
10
+ class RandomPolymerGenerator(nn.Module):
11
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  super().__init__()
13
+ self.fc1 = nn.Linear(5, 64)
14
+ self.fc2 = nn.Linear(64, 128)
15
+ self.fc3 = nn.Linear(128, 256)
16
+ self.fc4 = nn.Linear(256, 100) # Output size set to 100 for simplicity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def forward(self, x):
19
+ x = torch.relu(self.fc1(x))
20
+ x = torch.relu(self.fc2(x))
21
+ x = torch.relu(self.fc3(x))
22
+ return torch.sigmoid(self.fc4(x))
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_graph_decoder(path='model_labeled'):
25
  model = GraphDiT(
26
  model_config_path=f"{path}/config.yaml",
 
31
  # model.disable_grads()
32
  return model
33
 
 
34
  ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
35
 
36
  def generate_random_smiles(length=10):