liuganghuggingface commited on
Commit
05ae67e
1 Parent(s): 26a9b58

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +20 -55
graph_decoder/diffusion_model.py CHANGED
@@ -43,19 +43,20 @@ class GraphDiT(nn.Module):
43
  self.hidden_size = dm_cfg.hidden_size
44
  self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
 
46
- self.denoiser = Transformer(
47
- max_n_nodes=self.max_n_nodes,
48
- hidden_size=dm_cfg.hidden_size,
49
- depth=dm_cfg.depth,
50
- num_heads=dm_cfg.num_heads,
51
- mlp_ratio=dm_cfg.mlp_ratio,
52
- drop_condition=dm_cfg.drop_condition,
53
- Xdim=self.Xdim,
54
- Edim=self.Edim,
55
- ydim=self.ydim,
56
- )
 
 
57
  self.model_dtype = model_dtype
58
-
59
  self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
60
  dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
61
  )
@@ -86,53 +87,17 @@ class GraphDiT(nn.Module):
86
  )
87
  self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
88
 
89
- # def to(self, *args, **kwargs):
90
- # self = super().to(*args, **kwargs)
91
- # self.model_dtype = next(self.denoiser.parameters()).dtype
92
- # return self
93
-
94
- def init_model(self, model_dir, verbose=False):
95
  model_file = os.path.join(model_dir, 'model.pt')
96
  if os.path.exists(model_file):
97
- self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
 
98
  else:
99
  raise FileNotFoundError(f"Model file not found: {model_file}")
100
-
101
- if verbose:
102
- print('GraphDiT Denoiser Model initialized.')
103
- print('Denoiser model:\n', self.denoiser)
104
-
105
- def save_pretrained(self, output_dir):
106
- if not os.path.exists(output_dir):
107
- os.makedirs(output_dir)
108
-
109
- # Save model
110
- model_path = os.path.join(output_dir, 'model.pt')
111
- torch.save(self.denoiser.state_dict(), model_path)
112
-
113
- # Save model config
114
- config_path = os.path.join(output_dir, 'model_config.yaml')
115
- with open(config_path, 'w') as f:
116
- yaml.dump(vars(self.model_config), f)
117
-
118
- # Save data info
119
- data_info_path = os.path.join(output_dir, 'data.meta.json')
120
- data_info_dict = {
121
- "active_atoms": self.data_info.active_atoms,
122
- "max_node": self.data_info.max_n_nodes,
123
- "n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(),
124
- "bond_type_dist": self.data_info.edge_types.tolist(),
125
- "transition_E": self.data_info.transition_E.tolist(),
126
- "atom_type_dist": self.data_info.node_types.tolist(),
127
- "valencies": self.data_info.valency_distribution.tolist()
128
- }
129
- with open(data_info_path, 'w') as f:
130
- json.dump(data_info_dict, f, indent=2)
131
-
132
- print('GraphDiT Model and configurations saved to:', output_dir)
133
 
134
  def disable_grads(self):
135
- self.denoiser.disable_grads()
 
136
 
137
  def forward(
138
  self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
@@ -215,7 +180,7 @@ class GraphDiT(nn.Module):
215
  }
216
  return noisy_data
217
 
218
- # @torch.no_grad()
219
  def generate(
220
  self,
221
  properties,
@@ -307,7 +272,7 @@ class GraphDiT(nn.Module):
307
 
308
  def check_valid(self, smiles):
309
  return check_valid(smiles)
310
-
311
  def sample_p_zs_given_zt(
312
  self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
313
  ):
 
43
  self.hidden_size = dm_cfg.hidden_size
44
  self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
 
46
+ # self.denoiser = Transformer(
47
+ # max_n_nodes=self.max_n_nodes,
48
+ # hidden_size=dm_cfg.hidden_size,
49
+ # depth=dm_cfg.depth,
50
+ # num_heads=dm_cfg.num_heads,
51
+ # mlp_ratio=dm_cfg.mlp_ratio,
52
+ # drop_condition=dm_cfg.drop_condition,
53
+ # Xdim=self.Xdim,
54
+ # Edim=self.Edim,
55
+ # ydim=self.ydim,
56
+ # )
57
+ self.denoiser = None
58
+
59
  self.model_dtype = model_dtype
 
60
  self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
61
  dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
62
  )
 
87
  )
88
  self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
89
 
90
+ def init_model(self, model_dir):
 
 
 
 
 
91
  model_file = os.path.join(model_dir, 'model.pt')
92
  if os.path.exists(model_file):
93
+ # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
94
+ pass
95
  else:
96
  raise FileNotFoundError(f"Model file not found: {model_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def disable_grads(self):
99
+ pass
100
+ # self.denoiser.disable_grads()
101
 
102
  def forward(
103
  self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
 
180
  }
181
  return noisy_data
182
 
183
+ @torch.no_grad()
184
  def generate(
185
  self,
186
  properties,
 
272
 
273
  def check_valid(self, smiles):
274
  return check_valid(smiles)
275
+
276
  def sample_p_zs_given_zt(
277
  self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
278
  ):