liuganghuggingface's picture
Upload evaluator.py with huggingface_hub
824a00b verified
raw
history blame
4.91 kB
import math, os
import pickle
import os.path as op
import numpy as np
import pandas as pd
from joblib import dump, load, Parallel, delayed
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import mean_absolute_error, roc_auc_score
from sklearn.base import BaseEstimator
from tqdm import tqdm
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import rdMolDescriptors
rdBase.DisableLog('rdApp.error')
def process_smiles(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
return Evaluator.fingerprints_from_mol(mol), 1
return np.zeros((1, 2048)), 0
class Evaluator():
"""Scores based on an ECFP classifier."""
def __init__(self, model_path, task_name, n_jobs=2):
self.n_jobs = n_jobs
task_type = 'regression'
self.task_name = task_name
self.task_type = task_type
self.model_path = model_path
self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error
self.model = load(model_path)
def __call__(self, smiles_list):
fps = []
mask = []
for i,smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
mask.append( int(mol is not None) )
fp = Evaluator.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
fps.append(fp)
fps = np.concatenate(fps, axis=0)
if 'classification' in self.task_type:
scores = self.model.predict_proba(fps)[:, 1]
else:
scores = self.model.predict(fps)
scores = scores * np.array(mask)
return np.float32(scores)
@classmethod
def fingerprints_from_mol(cls, mol): # use ECFP4
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
features = np.zeros((1,))
DataStructs.ConvertToNumpyArray(features_vec, features)
return features.reshape(1, -1)
###### SAS Score ######
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
# generate the full path filename:
if name == "fpscores":
name = op.join(op.dirname(__file__), name)
data = pickle.load(gzip.open('%s.pkl.gz' % name))
outDict = {}
for i in data:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol, ri=None):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateSAS(smiles_list):
scores = []
for i, smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
score = calculateScore(mol)
scores.append(score)
return np.float32(scores)
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
fp = rdMolDescriptors.GetMorganFingerprint(m,
2) # <- 2 is the *radius* of the circular fingerprint
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in fps.items():
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore