# Copyright 2024 the Llamole team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from rdkit import Chem, RDLogger RDLogger.DisableLog("rdApp.*") import re import random import logging from rdkit import Chem from typing import List, Tuple, Optional random.seed(0) import torch bond_dict = [ None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, ] ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} logger = logging.getLogger(__name__) def check_polymer(smiles): if "*" in smiles: monomer = smiles.replace("*", "[H]") if mol2smiles(get_mol(monomer)) is None: logger.warning(f"Invalid polymerization point") return False else: return True return True def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]: smiles_list = [] for index, graph in enumerate(molecule_list): try: atom_types, edge_types = graph mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) # Try to correct the molecule with connection=True, then False if needed for connection in (True, False): mol_conn, _ = correct_mol(mol_init, connection=connection) if mol_conn is not None: break else: logger.warning(f"Failed to correct molecule {index}") mol_conn = mol_init # Fallback to initial molecule # Convert to SMILES smiles = mol2smiles(mol_conn) if not smiles: logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles") smiles = Chem.MolToSmiles(mol_conn) if smiles: mol = get_mol(smiles) if mol is not None: # Get the largest fragment mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms()) largest_smiles = mol2smiles(largest_mol) if largest_smiles and len(largest_smiles) > 1: if check_polymer(largest_smiles): smiles_list.append(largest_smiles) else: smiles_list.append(None) elif check_polymer(smiles): smiles_list.append(smiles) else: smiles_list.append(None) else: logger.warning(f"Failed to convert SMILES back to molecule for index {index}") smiles_list.append(None) else: logger.warning(f"Failed to generate SMILES for molecule {index}, appending None") smiles_list.append(None) except Exception as e: logger.error(f"Error processing molecule {index}: {str(e)}") try: # Fallback to RDKit's MolToSmiles if everything else fails fallback_smiles = Chem.MolToSmiles(mol_init) if fallback_smiles: smiles_list.append(fallback_smiles) logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}") else: smiles_list.append(None) logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None") except Exception as e2: logger.error(f"All attempts failed for molecule {index}: {str(e2)}") smiles_list.append(None) return smiles_list def build_molecule_with_partial_charges( atom_types, edge_types, atom_decoder, verbose=False ): if verbose: print("\nbuilding new molecule") mol = Chem.RWMol() for atom in atom_types: a = Chem.Atom(atom_decoder[atom.item()]) mol.AddAtom(a) if verbose: print("Atom added: ", atom.item(), atom_decoder[atom.item()]) edge_types = torch.triu(edge_types) all_bonds = torch.nonzero(edge_types) for i, bond in enumerate(all_bonds): if bond[0].item() != bond[1].item(): mol.AddBond( bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()], ) if verbose: print( "bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), bond_dict[edge_types[bond[0], bond[1]].item()], ) # add formal charge to atom: e.g. [O+], [N+], [S+] # not support [O-], [N-], [S-], [NH+] etc. flag, atomid_valence = check_valency(mol) if verbose: print("flag, valence", flag, atomid_valence) if flag: continue else: if len(atomid_valence) == 2: idx = atomid_valence[0] v = atomid_valence[1] an = mol.GetAtomWithIdx(idx).GetAtomicNum() if verbose: print("atomic num of atom with a large valence", an) if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: mol.GetAtomWithIdx(idx).SetFormalCharge(1) # print("Formal charge added") else: continue return mol def correct_mol(mol, connection=False): ##### no_correct = False flag, _ = check_valency(mol) if flag: no_correct = True while True: if connection: mol_conn = connect_fragments(mol) mol = mol_conn if mol is None: return None, no_correct flag, atomid_valence = check_valency(mol) if flag: break else: try: assert len(atomid_valence) == 2 idx = atomid_valence[0] v = atomid_valence[1] queue = [] check_idx = 0 for b in mol.GetAtomWithIdx(idx).GetBonds(): type = int(b.GetBondType()) queue.append( (b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()) ) if type == 12: check_idx += 1 queue.sort(key=lambda tup: tup[1], reverse=True) if queue[-1][1] == 12: return None, no_correct elif len(queue) > 0: start = queue[check_idx][2] end = queue[check_idx][3] t = queue[check_idx][1] - 1 mol.RemoveBond(start, end) if t >= 1: mol.AddBond(start, end, bond_dict[t]) except Exception as e: # print(f"An error occurred in correction: {e}") return None, no_correct return mol, no_correct def check_valid(smiles): mol = get_mol(smiles) if mol is None: return False smiles = mol2smiles(mol) if smiles is None: return False return True def get_mol(smiles_or_mol): """ Loads SMILES/molecule into RDKit's object """ if isinstance(smiles_or_mol, str): if len(smiles_or_mol) == 0: return None mol = Chem.MolFromSmiles(smiles_or_mol) if mol is None: return None try: Chem.SanitizeMol(mol) except ValueError: return None return mol return smiles_or_mol def mol2smiles(mol): if mol is None: return None try: Chem.SanitizeMol(mol) except ValueError: return None return Chem.MolToSmiles(mol) def check_valency(mol): try: # First attempt to sanitize with specific properties Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) return True, None except ValueError as e: e = str(e) p = e.find("#") e_sub = e[p:] atomid_valence = list(map(int, re.findall(r"\d+", e_sub))) return False, atomid_valence except Exception as e: # print(f"An unexpected error occurred: {e}") return False, [] ##### connect fragements def select_atom_with_available_valency(frag): atoms = list(frag.GetAtoms()) random.shuffle(atoms) for atom in atoms: if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: return atom return None def select_atoms_with_available_valency(frag): return [ atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0 ] def try_to_connect_fragments(combined_mol, frag, atom1, atom2): # Make copies of the molecules to try the connection trial_combined_mol = Chem.RWMol(combined_mol) trial_frag = Chem.RWMol(frag) # Add the new fragment to the combined molecule with new indices new_indices = { atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms() } # Add the bond between the suitable atoms from each fragment trial_combined_mol.AddBond( atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE ) # Adjust the hydrogen count of the connected atoms for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: atom = trial_combined_mol.GetAtomWithIdx(atom_idx) num_h = atom.GetTotalNumHs() atom.SetNumExplicitHs(max(0, num_h - 1)) # Add bonds for the new fragment for bond in trial_frag.GetBonds(): trial_combined_mol.AddBond( new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType(), ) # Convert to a Mol object and try to sanitize it new_mol = Chem.Mol(trial_combined_mol) try: Chem.SanitizeMol(new_mol) return new_mol # Return the new valid molecule except Chem.MolSanitizeException: return None # If the molecule is not valid, return None def connect_fragments(mol): # Get the separate fragments frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) if len(frags) < 2: return mol combined_mol = Chem.RWMol(frags[0]) for frag in frags[1:]: # Select all atoms with available valency from both molecules atoms1 = select_atoms_with_available_valency(combined_mol) atoms2 = select_atoms_with_available_valency(frag) # Try to connect using all combinations of available valency atoms for atom1 in atoms1: for atom2 in atoms2: new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) if new_mol is not None: # If a valid connection is made, update the combined molecule and break combined_mol = new_mol break else: # Continue if the inner loop didn't break (no valid connection found for atom1) continue # Break if the inner loop did break (valid connection found) break else: # If no valid connections could be made with any of the atoms, return None return None return combined_mol #### connect fragements