import numpy as np
import copy
import pandas as pd
import json
from apischema import serialize
from optunaz.config import ModelMode
from rdkit import Chem
from rdkit.Chem import AllChem
from joblib import Parallel, delayed, effective_n_jobs
[docs]def get_ecfp_fpinfo(m, descriptor):
"""Return the ecfp info for a compound mol"""
info = {}
fp = AllChem.GetMorganFingerprintAsBitVect(
m,
radius=descriptor.parameters.radius,
nBits=descriptor.parameters.nBits,
bitInfo=info,
)
return info
[docs]def get_ecfpcount_fpinfo(m, descriptor):
"""Return the ecfp_count info for a compound mol"""
info = {}
fp = AllChem.GetHashedMorganFingerprint(
m,
radius=descriptor.parameters.radius,
nBits=descriptor.parameters.nBits,
useFeatures=descriptor.parameters.useFeatures,
bitInfo=info,
)
return info
[docs]def explain_ECFP(len_feats, estimator, descriptor):
"""Explain ECFPs using train atom environments"""
ret = np.empty(len_feats, dtype="<U50")
# enumerate through each important feature
for feat_idx in range(len_feats):
this_feat_explained = False
# enumerate through training set searching a compound with the feature
for mol_idx, mol in enumerate(estimator.X_):
if not this_feat_explained:
# the feature is present in the mol if bit is > 0
if mol[feat_idx] > 0:
m = Chem.MolFromSmiles(estimator.train_smiles_[mol_idx])
if descriptor.name == "ECFP":
info = get_ecfp_fpinfo(m, descriptor)[feat_idx]
elif descriptor.name == "ECFP_counts":
info = get_ecfpcount_fpinfo(m, descriptor)[feat_idx]
# enumerate atom matches, breaking when valid smiles produced
for atom, radius in info:
env = Chem.FindAtomEnvironmentOfRadiusN(m, radius, atom)
amap = {}
submol = Chem.PathToSubmol(m, env, atomMap=amap)
try:
feat_smi = Chem.MolToSmiles(
submol, rootedAtAtom=amap[atom], canonical=False
)
ret[feat_idx] = feat_smi
this_feat_explained = True
# continue trying other matches if blank feature
if feat_smi != "":
break
# sometime MolToSmiles fails
except KeyError:
pass
# feature is explained, so break
else:
break
return ret
[docs]def get_fp_info(exp_df, estimator, descript, fp_idx, strt_idx=None):
"""Get ECFP SMILES environments or Physchem names when available"""
info = []
if "ECFP" in descript.name:
info = explain_ECFP(fp_idx, estimator, descript)
elif "Physchem" in descript.name:
try:
info = descript.parameters.rdkit_names
except AttributeError:
info = descript.parameters.descriptor.parameters.rdkit_names
elif "Jazzy" in descript.name:
try:
info = descript.parameters.jazzy_names
except AttributeError:
info = descript.parameters.descriptor.parameters.jazzy_names
if len(info) > 0:
if strt_idx is not None:
exp_df.loc[strt_idx : strt_idx + fp_idx - 1, "info"] = info
else:
exp_df["info"] = info
return
[docs]def runShap(estimator, X_pred, mode):
"""Explain model prediction using auto explainer or SHAP KernelExplainer"""
import shap
# see if shap can auto explain
try:
try:
explainer = shap.Explainer(estimator, estimator.X_)
# deal with methods that require the inference method
except TypeError:
if mode == ModelMode.REGRESSION:
explainer = shap.Explainer(estimator.predict, estimator.X_)
if mode == ModelMode.CLASSIFICATION:
explainer = shap.Explainer(estimator.predict_proba, estimator.X_)
shap_values = np.abs(np.array(explainer.shap_values(np.array(X_pred))))
# use kernel for other models
except AttributeError:
if mode == ModelMode.REGRESSION:
explainer = shap.KernelExplainer(estimator.predict, estimator.X_)
if mode == ModelMode.CLASSIFICATION:
explainer = shap.KernelExplainer(estimator.predict_proba, estimator.X_)
shap_values = np.abs(
np.array(explainer.shap_values(np.array(X_pred[:1]), nsamples="auto"))
)
return shap_values
[docs]def ShapExplainer(estimator, X_pred, mode, descriptor):
"""
Run SHAP and populate the explainability dataframe
"""
shap_values = runShap(estimator, X_pred, mode)
descriptor_ = copy.deepcopy(descriptor)
# process the shap_values shapes
if len(shap_values.shape) == 3:
if shap_values.shape[0] == 2:
# if explainer explains both classes, take the active [1] class
shap_values = shap_values[1]
elif shap_values.shape[0] == 1:
# sometimes values are wrapped and require [0]
shap_values = shap_values[0]
# if multiple inputs provided then average the importance across predictions
if len(shap_values.shape) > 1:
shap_values = np.mean(shap_values, axis=0)
exp_df = pd.DataFrame(
data={
"shap_value": shap_values,
"descriptor": np.nan,
"bit": np.nan,
"info": np.nan,
}
)
# process single descriptors
if descriptor_.name != "CompositeDescriptor":
if descriptor_.name == "ScaledDescriptor":
descriptor_.name += f"_{descriptor_.parameters.descriptor.name}"
exp_df["descriptor"] = descriptor_.name
exp_df["bit"] = range(len(shap_values))
get_fp_info(exp_df, estimator, descriptor_, len(shap_values))
# process CompositeDescriptor
else:
fp_info = descriptor_.fp_info()
strt_idx = 0
info = None
for descript in descriptor_.parameters.descriptors:
fp_idx = fp_info[json.dumps(serialize(descript))]
if descript.name == "ScaledDescriptor":
descript.name += f"_{descript.parameters.descriptor.name}"
for col, col_value in [
("descriptor", descript.name),
("bit", range(1, fp_idx + 1, 1)),
]:
exp_df.loc[strt_idx : strt_idx + fp_idx - 1, col] = col_value
get_fp_info(exp_df, estimator, descript, fp_idx, strt_idx=strt_idx)
strt_idx += fp_idx
return exp_df.sort_values("shap_value", ascending=False)
[docs]def ExplainPreds(estimator, X_pred, mode, descriptor):
"""Explain predictions using either SHAP (shallow models) or ChemProp interpret"""
if hasattr(estimator, "interpret"):
n_cores = effective_n_jobs(-1)
if n_cores == 1:
return estimator.interpret(X_pred)
else:
return pd.concat(
Parallel(n_jobs=n_cores)(
delayed(estimator.interpret)([X]) for X in X_pred
)
)
else:
return ShapExplainer(estimator, X_pred, mode, descriptor)