liuganghuggingface commited on
Commit
7044315
1 Parent(s): 05ae67e

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +13 -17
graph_decoder/diffusion_model.py CHANGED
@@ -43,18 +43,17 @@ class GraphDiT(nn.Module):
43
  self.hidden_size = dm_cfg.hidden_size
44
  self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
 
46
- # self.denoiser = Transformer(
47
- # max_n_nodes=self.max_n_nodes,
48
- # hidden_size=dm_cfg.hidden_size,
49
- # depth=dm_cfg.depth,
50
- # num_heads=dm_cfg.num_heads,
51
- # mlp_ratio=dm_cfg.mlp_ratio,
52
- # drop_condition=dm_cfg.drop_condition,
53
- # Xdim=self.Xdim,
54
- # Edim=self.Edim,
55
- # ydim=self.ydim,
56
- # )
57
- self.denoiser = None
58
 
59
  self.model_dtype = model_dtype
60
  self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
@@ -90,14 +89,12 @@ class GraphDiT(nn.Module):
90
  def init_model(self, model_dir):
91
  model_file = os.path.join(model_dir, 'model.pt')
92
  if os.path.exists(model_file):
93
- # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
94
- pass
95
  else:
96
  raise FileNotFoundError(f"Model file not found: {model_file}")
97
 
98
  def disable_grads(self):
99
- pass
100
- # self.denoiser.disable_grads()
101
 
102
  def forward(
103
  self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
@@ -193,7 +190,6 @@ class GraphDiT(nn.Module):
193
  properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
194
  batch_size = properties.size(0)
195
  assert batch_size == 1
196
- # print('self.denoiser.dtype', self.model_dtype)
197
  if num_nodes is None:
198
  num_nodes = self.node_dist.sample_n(batch_size, device)
199
  else:
 
43
  self.hidden_size = dm_cfg.hidden_size
44
  self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
 
46
+ self.denoiser = Transformer(
47
+ max_n_nodes=self.max_n_nodes,
48
+ hidden_size=dm_cfg.hidden_size,
49
+ depth=dm_cfg.depth,
50
+ num_heads=dm_cfg.num_heads,
51
+ mlp_ratio=dm_cfg.mlp_ratio,
52
+ drop_condition=dm_cfg.drop_condition,
53
+ Xdim=self.Xdim,
54
+ Edim=self.Edim,
55
+ ydim=self.ydim,
56
+ )
 
57
 
58
  self.model_dtype = model_dtype
59
  self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
 
89
  def init_model(self, model_dir):
90
  model_file = os.path.join(model_dir, 'model.pt')
91
  if os.path.exists(model_file):
92
+ self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
 
93
  else:
94
  raise FileNotFoundError(f"Model file not found: {model_file}")
95
 
96
  def disable_grads(self):
97
+ self.denoiser.disable_grads()
 
98
 
99
  def forward(
100
  self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
 
190
  properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
191
  batch_size = properties.size(0)
192
  assert batch_size == 1
 
193
  if num_nodes is None:
194
  num_nodes = self.node_dist.sample_n(batch_size, device)
195
  else: