Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from rdkit import Chem, RDLogger | |
RDLogger.DisableLog("rdApp.*") | |
import re | |
import random | |
import logging | |
from rdkit import Chem | |
from typing import List, Tuple, Optional | |
random.seed(0) | |
import torch | |
bond_dict = [ | |
None, | |
Chem.rdchem.BondType.SINGLE, | |
Chem.rdchem.BondType.DOUBLE, | |
Chem.rdchem.BondType.TRIPLE, | |
Chem.rdchem.BondType.AROMATIC, | |
] | |
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} | |
logger = logging.getLogger(__name__) | |
def check_polymer(smiles): | |
if "*" in smiles: | |
monomer = smiles.replace("*", "[H]") | |
if mol2smiles(get_mol(monomer)) is None: | |
logger.warning(f"Invalid polymerization point") | |
return False | |
else: | |
return True | |
return True | |
def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]: | |
smiles_list = [] | |
for index, graph in enumerate(molecule_list): | |
try: | |
atom_types, edge_types = graph | |
mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) | |
# Try to correct the molecule with connection=True, then False if needed | |
for connection in (True, False): | |
mol_conn, _ = correct_mol(mol_init, connection=connection) | |
if mol_conn is not None: | |
break | |
else: | |
logger.warning(f"Failed to correct molecule {index}") | |
mol_conn = mol_init # Fallback to initial molecule | |
# Convert to SMILES | |
smiles = mol2smiles(mol_conn) | |
if not smiles: | |
logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles") | |
smiles = Chem.MolToSmiles(mol_conn) | |
if smiles: | |
mol = get_mol(smiles) | |
if mol is not None: | |
# Get the largest fragment | |
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) | |
largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms()) | |
largest_smiles = mol2smiles(largest_mol) | |
if largest_smiles and len(largest_smiles) > 1: | |
if check_polymer(largest_smiles): | |
smiles_list.append(largest_smiles) | |
else: | |
smiles_list.append(None) | |
elif check_polymer(smiles): | |
smiles_list.append(smiles) | |
else: | |
smiles_list.append(None) | |
else: | |
logger.warning(f"Failed to convert SMILES back to molecule for index {index}") | |
smiles_list.append(None) | |
else: | |
logger.warning(f"Failed to generate SMILES for molecule {index}, appending None") | |
smiles_list.append(None) | |
except Exception as e: | |
logger.error(f"Error processing molecule {index}: {str(e)}") | |
try: | |
# Fallback to RDKit's MolToSmiles if everything else fails | |
fallback_smiles = Chem.MolToSmiles(mol_init) | |
if fallback_smiles: | |
smiles_list.append(fallback_smiles) | |
logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}") | |
else: | |
smiles_list.append(None) | |
logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None") | |
except Exception as e2: | |
logger.error(f"All attempts failed for molecule {index}: {str(e2)}") | |
smiles_list.append(None) | |
return smiles_list | |
def build_molecule_with_partial_charges( | |
atom_types, edge_types, atom_decoder, verbose=False | |
): | |
if verbose: | |
print("\nbuilding new molecule") | |
mol = Chem.RWMol() | |
for atom in atom_types: | |
a = Chem.Atom(atom_decoder[atom.item()]) | |
mol.AddAtom(a) | |
if verbose: | |
print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | |
edge_types = torch.triu(edge_types) | |
all_bonds = torch.nonzero(edge_types) | |
for i, bond in enumerate(all_bonds): | |
if bond[0].item() != bond[1].item(): | |
mol.AddBond( | |
bond[0].item(), | |
bond[1].item(), | |
bond_dict[edge_types[bond[0], bond[1]].item()], | |
) | |
if verbose: | |
print( | |
"bond added:", | |
bond[0].item(), | |
bond[1].item(), | |
edge_types[bond[0], bond[1]].item(), | |
bond_dict[edge_types[bond[0], bond[1]].item()], | |
) | |
# add formal charge to atom: e.g. [O+], [N+], [S+] | |
# not support [O-], [N-], [S-], [NH+] etc. | |
flag, atomid_valence = check_valency(mol) | |
if verbose: | |
print("flag, valence", flag, atomid_valence) | |
if flag: | |
continue | |
else: | |
if len(atomid_valence) == 2: | |
idx = atomid_valence[0] | |
v = atomid_valence[1] | |
an = mol.GetAtomWithIdx(idx).GetAtomicNum() | |
if verbose: | |
print("atomic num of atom with a large valence", an) | |
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: | |
mol.GetAtomWithIdx(idx).SetFormalCharge(1) | |
# print("Formal charge added") | |
else: | |
continue | |
return mol | |
def correct_mol(mol, connection=False): | |
##### | |
no_correct = False | |
flag, _ = check_valency(mol) | |
if flag: | |
no_correct = True | |
while True: | |
if connection: | |
mol_conn = connect_fragments(mol) | |
mol = mol_conn | |
if mol is None: | |
return None, no_correct | |
flag, atomid_valence = check_valency(mol) | |
if flag: | |
break | |
else: | |
try: | |
assert len(atomid_valence) == 2 | |
idx = atomid_valence[0] | |
v = atomid_valence[1] | |
queue = [] | |
check_idx = 0 | |
for b in mol.GetAtomWithIdx(idx).GetBonds(): | |
type = int(b.GetBondType()) | |
queue.append( | |
(b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()) | |
) | |
if type == 12: | |
check_idx += 1 | |
queue.sort(key=lambda tup: tup[1], reverse=True) | |
if queue[-1][1] == 12: | |
return None, no_correct | |
elif len(queue) > 0: | |
start = queue[check_idx][2] | |
end = queue[check_idx][3] | |
t = queue[check_idx][1] - 1 | |
mol.RemoveBond(start, end) | |
if t >= 1: | |
mol.AddBond(start, end, bond_dict[t]) | |
except Exception as e: | |
# print(f"An error occurred in correction: {e}") | |
return None, no_correct | |
return mol, no_correct | |
def check_valid(smiles): | |
mol = get_mol(smiles) | |
if mol is None: | |
return False | |
smiles = mol2smiles(mol) | |
if smiles is None: | |
return False | |
return True | |
def get_mol(smiles_or_mol): | |
""" | |
Loads SMILES/molecule into RDKit's object | |
""" | |
if isinstance(smiles_or_mol, str): | |
if len(smiles_or_mol) == 0: | |
return None | |
mol = Chem.MolFromSmiles(smiles_or_mol) | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return mol | |
return smiles_or_mol | |
def mol2smiles(mol): | |
if mol is None: | |
return None | |
try: | |
Chem.SanitizeMol(mol) | |
except ValueError: | |
return None | |
return Chem.MolToSmiles(mol) | |
def check_valency(mol): | |
try: | |
# First attempt to sanitize with specific properties | |
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) | |
return True, None | |
except ValueError as e: | |
e = str(e) | |
p = e.find("#") | |
e_sub = e[p:] | |
atomid_valence = list(map(int, re.findall(r"\d+", e_sub))) | |
return False, atomid_valence | |
except Exception as e: | |
# print(f"An unexpected error occurred: {e}") | |
return False, [] | |
##### connect fragements | |
def select_atom_with_available_valency(frag): | |
atoms = list(frag.GetAtoms()) | |
random.shuffle(atoms) | |
for atom in atoms: | |
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: | |
return atom | |
return None | |
def select_atoms_with_available_valency(frag): | |
return [ | |
atom | |
for atom in frag.GetAtoms() | |
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0 | |
] | |
def try_to_connect_fragments(combined_mol, frag, atom1, atom2): | |
# Make copies of the molecules to try the connection | |
trial_combined_mol = Chem.RWMol(combined_mol) | |
trial_frag = Chem.RWMol(frag) | |
# Add the new fragment to the combined molecule with new indices | |
new_indices = { | |
atom.GetIdx(): trial_combined_mol.AddAtom(atom) | |
for atom in trial_frag.GetAtoms() | |
} | |
# Add the bond between the suitable atoms from each fragment | |
trial_combined_mol.AddBond( | |
atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE | |
) | |
# Adjust the hydrogen count of the connected atoms | |
for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: | |
atom = trial_combined_mol.GetAtomWithIdx(atom_idx) | |
num_h = atom.GetTotalNumHs() | |
atom.SetNumExplicitHs(max(0, num_h - 1)) | |
# Add bonds for the new fragment | |
for bond in trial_frag.GetBonds(): | |
trial_combined_mol.AddBond( | |
new_indices[bond.GetBeginAtomIdx()], | |
new_indices[bond.GetEndAtomIdx()], | |
bond.GetBondType(), | |
) | |
# Convert to a Mol object and try to sanitize it | |
new_mol = Chem.Mol(trial_combined_mol) | |
try: | |
Chem.SanitizeMol(new_mol) | |
return new_mol # Return the new valid molecule | |
except Chem.MolSanitizeException: | |
return None # If the molecule is not valid, return None | |
def connect_fragments(mol): | |
# Get the separate fragments | |
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) | |
if len(frags) < 2: | |
return mol | |
combined_mol = Chem.RWMol(frags[0]) | |
for frag in frags[1:]: | |
# Select all atoms with available valency from both molecules | |
atoms1 = select_atoms_with_available_valency(combined_mol) | |
atoms2 = select_atoms_with_available_valency(frag) | |
# Try to connect using all combinations of available valency atoms | |
for atom1 in atoms1: | |
for atom2 in atoms2: | |
new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) | |
if new_mol is not None: | |
# If a valid connection is made, update the combined molecule and break | |
combined_mol = new_mol | |
break | |
else: | |
# Continue if the inner loop didn't break (no valid connection found for atom1) | |
continue | |
# Break if the inner loop did break (valid connection found) | |
break | |
else: | |
# If no valid connections could be made with any of the atoms, return None | |
return None | |
return combined_mol | |
#### connect fragements | |