liuganghuggingface commited on
Commit
19ae33a
1 Parent(s): 26a7403

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +67 -66
graph_decoder/diffusion_model.py CHANGED
@@ -19,72 +19,73 @@ class GraphDiT(nn.Module):
19
  model_dtype,
20
  ):
21
  super().__init__()
22
-
23
- dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
24
-
25
- input_dims = data_info.input_dims
26
- output_dims = data_info.output_dims
27
- nodes_dist = data_info.nodes_dist
28
- active_index = data_info.active_index
29
-
30
- self.model_config = dm_cfg
31
- self.data_info = data_info
32
- self.T = dm_cfg.diffusion_steps
33
- self.Xdim = input_dims["X"]
34
- self.Edim = input_dims["E"]
35
- self.ydim = input_dims["y"]
36
- self.Xdim_output = output_dims["X"]
37
- self.Edim_output = output_dims["E"]
38
- self.ydim_output = output_dims["y"]
39
- self.node_dist = nodes_dist
40
- self.active_index = active_index
41
- self.max_n_nodes = data_info.max_n_nodes
42
- self.atom_decoder = data_info.atom_decoder
43
- self.hidden_size = dm_cfg.hidden_size
44
- self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
-
46
- self.denoiser = Transformer(
47
- max_n_nodes=self.max_n_nodes,
48
- hidden_size=dm_cfg.hidden_size,
49
- depth=dm_cfg.depth,
50
- num_heads=dm_cfg.num_heads,
51
- mlp_ratio=dm_cfg.mlp_ratio,
52
- drop_condition=dm_cfg.drop_condition,
53
- Xdim=self.Xdim,
54
- Edim=self.Edim,
55
- ydim=self.ydim,
56
- )
57
-
58
- self.model_dtype = model_dtype
59
- self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
60
- dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
61
- )
62
- x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
63
- data_info.node_types.to(self.model_dtype)
64
- )
65
- e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
66
- data_info.edge_types.to(self.model_dtype)
67
- )
68
- x_marginals = x_marginals / x_marginals.sum()
69
- e_marginals = e_marginals / e_marginals.sum()
70
-
71
- xe_conditions = data_info.transition_E.to(self.model_dtype)
72
- xe_conditions = xe_conditions[self.active_index][:, self.active_index]
73
-
74
- xe_conditions = xe_conditions.sum(dim=1)
75
- ex_conditions = xe_conditions.t()
76
- xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
77
- ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
78
-
79
- self.transition_model = utils.MarginalTransition(
80
- x_marginals=x_marginals,
81
- e_marginals=e_marginals,
82
- xe_conditions=xe_conditions,
83
- ex_conditions=ex_conditions,
84
- y_classes=self.ydim_output,
85
- n_nodes=self.max_n_nodes,
86
- )
87
- self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
 
88
 
89
  def init_model(self, model_dir):
90
  model_file = os.path.join(model_dir, 'model.pt')
 
19
  model_dtype,
20
  ):
21
  super().__init__()
22
+ pass
23
+
24
+ # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
25
+
26
+ # input_dims = data_info.input_dims
27
+ # output_dims = data_info.output_dims
28
+ # nodes_dist = data_info.nodes_dist
29
+ # active_index = data_info.active_index
30
+
31
+ # self.model_config = dm_cfg
32
+ # self.data_info = data_info
33
+ # self.T = dm_cfg.diffusion_steps
34
+ # self.Xdim = input_dims["X"]
35
+ # self.Edim = input_dims["E"]
36
+ # self.ydim = input_dims["y"]
37
+ # self.Xdim_output = output_dims["X"]
38
+ # self.Edim_output = output_dims["E"]
39
+ # self.ydim_output = output_dims["y"]
40
+ # self.node_dist = nodes_dist
41
+ # self.active_index = active_index
42
+ # self.max_n_nodes = data_info.max_n_nodes
43
+ # self.atom_decoder = data_info.atom_decoder
44
+ # self.hidden_size = dm_cfg.hidden_size
45
+ # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
46
+
47
+ # self.denoiser = Transformer(
48
+ # max_n_nodes=self.max_n_nodes,
49
+ # hidden_size=dm_cfg.hidden_size,
50
+ # depth=dm_cfg.depth,
51
+ # num_heads=dm_cfg.num_heads,
52
+ # mlp_ratio=dm_cfg.mlp_ratio,
53
+ # drop_condition=dm_cfg.drop_condition,
54
+ # Xdim=self.Xdim,
55
+ # Edim=self.Edim,
56
+ # ydim=self.ydim,
57
+ # )
58
+
59
+ # self.model_dtype = model_dtype
60
+ # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
61
+ # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
62
+ # )
63
+ # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
64
+ # data_info.node_types.to(self.model_dtype)
65
+ # )
66
+ # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
67
+ # data_info.edge_types.to(self.model_dtype)
68
+ # )
69
+ # x_marginals = x_marginals / x_marginals.sum()
70
+ # e_marginals = e_marginals / e_marginals.sum()
71
+
72
+ # xe_conditions = data_info.transition_E.to(self.model_dtype)
73
+ # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
74
+
75
+ # xe_conditions = xe_conditions.sum(dim=1)
76
+ # ex_conditions = xe_conditions.t()
77
+ # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
78
+ # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
79
+
80
+ # self.transition_model = utils.MarginalTransition(
81
+ # x_marginals=x_marginals,
82
+ # e_marginals=e_marginals,
83
+ # xe_conditions=xe_conditions,
84
+ # ex_conditions=ex_conditions,
85
+ # y_classes=self.ydim_output,
86
+ # n_nodes=self.max_n_nodes,
87
+ # )
88
+ # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
89
 
90
  def init_model(self, model_dir):
91
  model_file = os.path.join(model_dir, 'model.pt')