liuganghuggingface commited on
Commit
8957853
1 Parent(s): 6f3f870

Upload graph_decoder/molecule_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graph_decoder/molecule_utils.py +355 -0
graph_decoder/molecule_utils.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the Llamole team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from rdkit import Chem, RDLogger
16
+
17
+ RDLogger.DisableLog("rdApp.*")
18
+
19
+ import re
20
+ import random
21
+ import logging
22
+ from rdkit import Chem
23
+ from typing import List, Tuple, Optional
24
+ random.seed(0)
25
+ import torch
26
+
27
+ bond_dict = [
28
+ None,
29
+ Chem.rdchem.BondType.SINGLE,
30
+ Chem.rdchem.BondType.DOUBLE,
31
+ Chem.rdchem.BondType.TRIPLE,
32
+ Chem.rdchem.BondType.AROMATIC,
33
+ ]
34
+
35
+ ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ def check_polymer(smiles):
40
+ if "*" in smiles:
41
+ monomer = smiles.replace("*", "[H]")
42
+ if mol2smiles(get_mol(monomer)) is None:
43
+ logger.warning(f"Invalid polymerization point")
44
+ return False
45
+ else:
46
+ return True
47
+ return True
48
+
49
+ def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]:
50
+
51
+ smiles_list = []
52
+ for index, graph in enumerate(molecule_list):
53
+ try:
54
+ atom_types, edge_types = graph
55
+ mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder)
56
+
57
+ # Try to correct the molecule with connection=True, then False if needed
58
+ for connection in (True, False):
59
+ mol_conn, _ = correct_mol(mol_init, connection=connection)
60
+ if mol_conn is not None:
61
+ break
62
+ else:
63
+ logger.warning(f"Failed to correct molecule {index}")
64
+ mol_conn = mol_init # Fallback to initial molecule
65
+
66
+ # Convert to SMILES
67
+ smiles = mol2smiles(mol_conn)
68
+ if not smiles:
69
+ logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles")
70
+ smiles = Chem.MolToSmiles(mol_conn)
71
+
72
+ if smiles:
73
+ mol = get_mol(smiles)
74
+ if mol is not None:
75
+ # Get the largest fragment
76
+ mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
77
+ largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms())
78
+
79
+ largest_smiles = mol2smiles(largest_mol)
80
+ if largest_smiles and len(largest_smiles) > 1:
81
+ if check_polymer(largest_smiles):
82
+ smiles_list.append(largest_smiles)
83
+ else:
84
+ smiles_list.append(None)
85
+ elif check_polymer(smiles):
86
+ smiles_list.append(smiles)
87
+ else:
88
+ smiles_list.append(None)
89
+ else:
90
+ logger.warning(f"Failed to convert SMILES back to molecule for index {index}")
91
+ smiles_list.append(None)
92
+ else:
93
+ logger.warning(f"Failed to generate SMILES for molecule {index}, appending None")
94
+ smiles_list.append(None)
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error processing molecule {index}: {str(e)}")
98
+ try:
99
+ # Fallback to RDKit's MolToSmiles if everything else fails
100
+ fallback_smiles = Chem.MolToSmiles(mol_init)
101
+ if fallback_smiles:
102
+ smiles_list.append(fallback_smiles)
103
+ logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}")
104
+ else:
105
+ smiles_list.append(None)
106
+ logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None")
107
+ except Exception as e2:
108
+ logger.error(f"All attempts failed for molecule {index}: {str(e2)}")
109
+ smiles_list.append(None)
110
+
111
+ return smiles_list
112
+
113
+ def build_molecule_with_partial_charges(
114
+ atom_types, edge_types, atom_decoder, verbose=False
115
+ ):
116
+ if verbose:
117
+ print("\nbuilding new molecule")
118
+
119
+ mol = Chem.RWMol()
120
+ for atom in atom_types:
121
+ a = Chem.Atom(atom_decoder[atom.item()])
122
+ mol.AddAtom(a)
123
+ if verbose:
124
+ print("Atom added: ", atom.item(), atom_decoder[atom.item()])
125
+
126
+ edge_types = torch.triu(edge_types)
127
+ all_bonds = torch.nonzero(edge_types)
128
+
129
+ for i, bond in enumerate(all_bonds):
130
+ if bond[0].item() != bond[1].item():
131
+ mol.AddBond(
132
+ bond[0].item(),
133
+ bond[1].item(),
134
+ bond_dict[edge_types[bond[0], bond[1]].item()],
135
+ )
136
+ if verbose:
137
+ print(
138
+ "bond added:",
139
+ bond[0].item(),
140
+ bond[1].item(),
141
+ edge_types[bond[0], bond[1]].item(),
142
+ bond_dict[edge_types[bond[0], bond[1]].item()],
143
+ )
144
+ # add formal charge to atom: e.g. [O+], [N+], [S+]
145
+ # not support [O-], [N-], [S-], [NH+] etc.
146
+ flag, atomid_valence = check_valency(mol)
147
+ if verbose:
148
+ print("flag, valence", flag, atomid_valence)
149
+ if flag:
150
+ continue
151
+ else:
152
+ if len(atomid_valence) == 2:
153
+ idx = atomid_valence[0]
154
+ v = atomid_valence[1]
155
+ an = mol.GetAtomWithIdx(idx).GetAtomicNum()
156
+ if verbose:
157
+ print("atomic num of atom with a large valence", an)
158
+ if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
159
+ mol.GetAtomWithIdx(idx).SetFormalCharge(1)
160
+ # print("Formal charge added")
161
+ else:
162
+ continue
163
+ return mol
164
+
165
+
166
+ def correct_mol(mol, connection=False):
167
+ #####
168
+ no_correct = False
169
+ flag, _ = check_valency(mol)
170
+ if flag:
171
+ no_correct = True
172
+
173
+ while True:
174
+ if connection:
175
+ mol_conn = connect_fragments(mol)
176
+ mol = mol_conn
177
+ if mol is None:
178
+ return None, no_correct
179
+ flag, atomid_valence = check_valency(mol)
180
+ if flag:
181
+ break
182
+ else:
183
+ try:
184
+ assert len(atomid_valence) == 2
185
+ idx = atomid_valence[0]
186
+ v = atomid_valence[1]
187
+ queue = []
188
+ check_idx = 0
189
+ for b in mol.GetAtomWithIdx(idx).GetBonds():
190
+ type = int(b.GetBondType())
191
+ queue.append(
192
+ (b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())
193
+ )
194
+ if type == 12:
195
+ check_idx += 1
196
+ queue.sort(key=lambda tup: tup[1], reverse=True)
197
+
198
+ if queue[-1][1] == 12:
199
+ return None, no_correct
200
+ elif len(queue) > 0:
201
+ start = queue[check_idx][2]
202
+ end = queue[check_idx][3]
203
+ t = queue[check_idx][1] - 1
204
+ mol.RemoveBond(start, end)
205
+ if t >= 1:
206
+ mol.AddBond(start, end, bond_dict[t])
207
+ except Exception as e:
208
+ # print(f"An error occurred in correction: {e}")
209
+ return None, no_correct
210
+ return mol, no_correct
211
+
212
+ def check_valid(smiles):
213
+ mol = get_mol(smiles)
214
+ if mol is None:
215
+ return False
216
+ smiles = mol2smiles(mol)
217
+ if smiles is None:
218
+ return False
219
+ return True
220
+
221
+ def get_mol(smiles_or_mol):
222
+ """
223
+ Loads SMILES/molecule into RDKit's object
224
+ """
225
+ if isinstance(smiles_or_mol, str):
226
+ if len(smiles_or_mol) == 0:
227
+ return None
228
+ mol = Chem.MolFromSmiles(smiles_or_mol)
229
+ if mol is None:
230
+ return None
231
+ try:
232
+ Chem.SanitizeMol(mol)
233
+ except ValueError:
234
+ return None
235
+ return mol
236
+ return smiles_or_mol
237
+
238
+
239
+ def mol2smiles(mol):
240
+ if mol is None:
241
+ return None
242
+ try:
243
+ Chem.SanitizeMol(mol)
244
+ except ValueError:
245
+ return None
246
+ return Chem.MolToSmiles(mol)
247
+
248
+
249
+ def check_valency(mol):
250
+ try:
251
+ # First attempt to sanitize with specific properties
252
+ Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
253
+ return True, None
254
+ except ValueError as e:
255
+ e = str(e)
256
+ p = e.find("#")
257
+ e_sub = e[p:]
258
+ atomid_valence = list(map(int, re.findall(r"\d+", e_sub)))
259
+ return False, atomid_valence
260
+ except Exception as e:
261
+ # print(f"An unexpected error occurred: {e}")
262
+ return False, []
263
+
264
+
265
+ ##### connect fragements
266
+ def select_atom_with_available_valency(frag):
267
+ atoms = list(frag.GetAtoms())
268
+ random.shuffle(atoms)
269
+ for atom in atoms:
270
+ if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0:
271
+ return atom
272
+ return None
273
+
274
+
275
+ def select_atoms_with_available_valency(frag):
276
+ return [
277
+ atom
278
+ for atom in frag.GetAtoms()
279
+ if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0
280
+ ]
281
+
282
+
283
+ def try_to_connect_fragments(combined_mol, frag, atom1, atom2):
284
+ # Make copies of the molecules to try the connection
285
+ trial_combined_mol = Chem.RWMol(combined_mol)
286
+ trial_frag = Chem.RWMol(frag)
287
+
288
+ # Add the new fragment to the combined molecule with new indices
289
+ new_indices = {
290
+ atom.GetIdx(): trial_combined_mol.AddAtom(atom)
291
+ for atom in trial_frag.GetAtoms()
292
+ }
293
+
294
+ # Add the bond between the suitable atoms from each fragment
295
+ trial_combined_mol.AddBond(
296
+ atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE
297
+ )
298
+
299
+ # Adjust the hydrogen count of the connected atoms
300
+ for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]:
301
+ atom = trial_combined_mol.GetAtomWithIdx(atom_idx)
302
+ num_h = atom.GetTotalNumHs()
303
+ atom.SetNumExplicitHs(max(0, num_h - 1))
304
+
305
+ # Add bonds for the new fragment
306
+ for bond in trial_frag.GetBonds():
307
+ trial_combined_mol.AddBond(
308
+ new_indices[bond.GetBeginAtomIdx()],
309
+ new_indices[bond.GetEndAtomIdx()],
310
+ bond.GetBondType(),
311
+ )
312
+
313
+ # Convert to a Mol object and try to sanitize it
314
+ new_mol = Chem.Mol(trial_combined_mol)
315
+ try:
316
+ Chem.SanitizeMol(new_mol)
317
+ return new_mol # Return the new valid molecule
318
+ except Chem.MolSanitizeException:
319
+ return None # If the molecule is not valid, return None
320
+
321
+
322
+ def connect_fragments(mol):
323
+ # Get the separate fragments
324
+ frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
325
+ if len(frags) < 2:
326
+ return mol
327
+
328
+ combined_mol = Chem.RWMol(frags[0])
329
+
330
+ for frag in frags[1:]:
331
+ # Select all atoms with available valency from both molecules
332
+ atoms1 = select_atoms_with_available_valency(combined_mol)
333
+ atoms2 = select_atoms_with_available_valency(frag)
334
+
335
+ # Try to connect using all combinations of available valency atoms
336
+ for atom1 in atoms1:
337
+ for atom2 in atoms2:
338
+ new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2)
339
+ if new_mol is not None:
340
+ # If a valid connection is made, update the combined molecule and break
341
+ combined_mol = new_mol
342
+ break
343
+ else:
344
+ # Continue if the inner loop didn't break (no valid connection found for atom1)
345
+ continue
346
+ # Break if the inner loop did break (valid connection found)
347
+ break
348
+ else:
349
+ # If no valid connections could be made with any of the atoms, return None
350
+ return None
351
+
352
+ return combined_mol
353
+
354
+
355
+ #### connect fragements