import math, os import pickle import os.path as op import numpy as np import pandas as pd from joblib import dump, load, Parallel, delayed from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.metrics import mean_absolute_error, roc_auc_score from sklearn.base import BaseEstimator from tqdm import tqdm from rdkit import Chem from rdkit import rdBase from rdkit.Chem import AllChem from rdkit import DataStructs from rdkit.Chem import rdMolDescriptors rdBase.DisableLog('rdApp.error') def process_smiles(smiles): mol = Chem.MolFromSmiles(smiles) if mol is not None: return Evaluator.fingerprints_from_mol(mol), 1 return np.zeros((1, 2048)), 0 class Evaluator(): """Scores based on an ECFP classifier.""" def __init__(self, model_path, task_name, n_jobs=2): self.n_jobs = n_jobs task_type = 'regression' self.task_name = task_name self.task_type = task_type self.model_path = model_path self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error self.model = load(model_path) def __call__(self, smiles_list): fps = [] mask = [] for i,smiles in enumerate(smiles_list): mol = Chem.MolFromSmiles(smiles) mask.append( int(mol is not None) ) fp = Evaluator.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) fps.append(fp) fps = np.concatenate(fps, axis=0) if 'classification' in self.task_type: scores = self.model.predict_proba(fps)[:, 1] else: scores = self.model.predict(fps) scores = scores * np.array(mask) return np.float32(scores) @classmethod def fingerprints_from_mol(cls, mol): # use ECFP4 features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) features = np.zeros((1,)) DataStructs.ConvertToNumpyArray(features_vec, features) return features.reshape(1, -1) ###### SAS Score ###### _fscores = None def readFragmentScores(name='fpscores'): import gzip global _fscores # generate the full path filename: if name == "fpscores": name = op.join(op.dirname(__file__), name) data = pickle.load(gzip.open('%s.pkl.gz' % name)) outDict = {} for i in data: for j in range(1, len(i)): outDict[i[j]] = float(i[0]) _fscores = outDict def numBridgeheadsAndSpiro(mol, ri=None): nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) return nBridgehead, nSpiro def calculateSAS(smiles_list): scores = [] for i, smiles in enumerate(smiles_list): mol = Chem.MolFromSmiles(smiles) score = calculateScore(mol) scores.append(score) return np.float32(scores) def calculateScore(m): if _fscores is None: readFragmentScores() # fragment score fp = rdMolDescriptors.GetMorganFingerprint(m, 2) # <- 2 is the *radius* of the circular fingerprint fps = fp.GetNonzeroElements() score1 = 0. nf = 0 for bitId, v in fps.items(): nf += v sfp = bitId score1 += _fscores.get(sfp, -4) * v score1 /= nf # features score nAtoms = m.GetNumAtoms() nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) ri = m.GetRingInfo() nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) nMacrocycles = 0 for x in ri.AtomRings(): if len(x) > 8: nMacrocycles += 1 sizePenalty = nAtoms**1.005 - nAtoms stereoPenalty = math.log10(nChiralCenters + 1) spiroPenalty = math.log10(nSpiro + 1) bridgePenalty = math.log10(nBridgeheads + 1) macrocyclePenalty = 0. # --------------------------------------- # This differs from the paper, which defines: # macrocyclePenalty = math.log10(nMacrocycles+1) # This form generates better results when 2 or more macrocycles are present if nMacrocycles > 0: macrocyclePenalty = math.log10(2) score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty # correction for the fingerprint density # not in the original publication, added in version 1.1 # to make highly symmetrical molecules easier to synthetise score3 = 0. if nAtoms > len(fps): score3 = math.log(float(nAtoms) / len(fps)) * .5 sascore = score1 + score2 + score3 # need to transform "raw" value into scale between 1 and 10 min = -4.0 max = 2.5 sascore = 11. - (sascore - min + 1) / (max - min) * 9. # smooth the 10-end if sascore > 8.: sascore = 8. + math.log(sascore + 1. - 9.) if sascore > 10.: sascore = 10.0 elif sascore < 1.: sascore = 1.0 return sascore