liuganghuggingface commited on
Commit
96c15b8
1 Parent(s): 341f250

Update graph_decoder/diffusion_utils.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_utils.py +49 -49
graph_decoder/diffusion_utils.py CHANGED
@@ -1,52 +1,52 @@
1
- # import os
2
- # import json
3
- # import yaml
4
-
5
- # import torch
6
- # import numpy as np
7
- # from torch.nn import functional as F
8
- # from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops
9
- # from types import SimpleNamespace
10
-
11
- # def dict_to_namespace(d):
12
- # return SimpleNamespace(
13
- # **{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in d.items()}
14
- # )
15
-
16
- # class DataInfos:
17
- # def __init__(self, meta_filename="data.meta.json"):
18
- # self.all_targets = ['CH4', 'CO2', 'H2', 'N2', 'O2']
19
- # self.task_type = "gas_permeability"
20
- # if os.path.exists(meta_filename):
21
- # with open(meta_filename, "r") as f:
22
- # meta_dict = json.load(f)
23
- # else:
24
- # raise FileNotFoundError(f"Meta file {meta_filename} not found.")
25
-
26
- # self.active_atoms = meta_dict["active_atoms"]
27
- # self.max_n_nodes = meta_dict["max_node"]
28
- # self.original_max_n_nodes = meta_dict["max_node"]
29
- # self.n_nodes = torch.Tensor(meta_dict["n_atoms_per_mol_dist"])
30
- # self.edge_types = torch.Tensor(meta_dict["bond_type_dist"])
31
- # self.transition_E = torch.Tensor(meta_dict["transition_E"])
32
-
33
- # self.atom_decoder = meta_dict["active_atoms"]
34
- # node_types = torch.Tensor(meta_dict["atom_type_dist"])
35
- # active_index = (node_types > 0).nonzero().squeeze()
36
- # self.node_types = torch.Tensor(meta_dict["atom_type_dist"])[active_index]
37
- # self.nodes_dist = DistributionNodes(self.n_nodes)
38
- # self.active_index = active_index
39
-
40
- # val_len = 3 * self.original_max_n_nodes - 2
41
- # meta_val = torch.Tensor(meta_dict["valencies"])
42
- # self.valency_distribution = torch.zeros(val_len)
43
- # val_len = min(val_len, len(meta_val))
44
- # self.valency_distribution[:val_len] = meta_val[:val_len]
45
- # ## for all
46
- # self.input_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
47
- # self.output_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
48
- # # self.input_dims = {"X": 11, "E": 5, "y": 5}
49
- # # self.output_dims = {"X": 11, "E": 5, "y": 5}
50
 
51
  # def load_config(config_path, data_meta_info_path):
52
  # if not os.path.exists(config_path):
 
1
+ import os
2
+ import json
3
+ import yaml
4
+
5
+ import torch
6
+ import numpy as np
7
+ from torch.nn import functional as F
8
+ from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops
9
+ from types import SimpleNamespace
10
+
11
+ def dict_to_namespace(d):
12
+ return SimpleNamespace(
13
+ **{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in d.items()}
14
+ )
15
+
16
+ class DataInfos:
17
+ def __init__(self, meta_filename="data.meta.json"):
18
+ self.all_targets = ['CH4', 'CO2', 'H2', 'N2', 'O2']
19
+ self.task_type = "gas_permeability"
20
+ if os.path.exists(meta_filename):
21
+ with open(meta_filename, "r") as f:
22
+ meta_dict = json.load(f)
23
+ else:
24
+ raise FileNotFoundError(f"Meta file {meta_filename} not found.")
25
+
26
+ self.active_atoms = meta_dict["active_atoms"]
27
+ self.max_n_nodes = meta_dict["max_node"]
28
+ self.original_max_n_nodes = meta_dict["max_node"]
29
+ self.n_nodes = torch.Tensor(meta_dict["n_atoms_per_mol_dist"])
30
+ self.edge_types = torch.Tensor(meta_dict["bond_type_dist"])
31
+ self.transition_E = torch.Tensor(meta_dict["transition_E"])
32
+
33
+ self.atom_decoder = meta_dict["active_atoms"]
34
+ node_types = torch.Tensor(meta_dict["atom_type_dist"])
35
+ active_index = (node_types > 0).nonzero().squeeze()
36
+ self.node_types = torch.Tensor(meta_dict["atom_type_dist"])[active_index]
37
+ self.nodes_dist = DistributionNodes(self.n_nodes)
38
+ self.active_index = active_index
39
+
40
+ val_len = 3 * self.original_max_n_nodes - 2
41
+ meta_val = torch.Tensor(meta_dict["valencies"])
42
+ self.valency_distribution = torch.zeros(val_len)
43
+ val_len = min(val_len, len(meta_val))
44
+ self.valency_distribution[:val_len] = meta_val[:val_len]
45
+ ## for all
46
+ self.input_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
47
+ self.output_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
48
+ # self.input_dims = {"X": 11, "E": 5, "y": 5}
49
+ # self.output_dims = {"X": 11, "E": 5, "y": 5}
50
 
51
  # def load_config(config_path, data_meta_info_path):
52
  # if not os.path.exists(config_path):