Source code for icolos.core.workflow_steps.prediction.predictor

import pickle
from copy import deepcopy

import numpy as np
from typing import List

from pydantic import BaseModel
from rdkit import Chem

from icolos.utils.general.icolos_exceptions import StepFailed, get_exception_message
from icolos.utils.enums.step_enums import StepPredictorEnum
from icolos.core.workflow_steps.io.base import StepIOBase
from icolos.core.workflow_steps.step import _LE

from icolos.utils.general.convenience_functions import *

_SPE = StepPredictorEnum()


[docs]class StepPredictor(StepIOBase, BaseModel): def __init__(self, **data): super().__init__(**data) @classmethod def _load_scikit_model(cls, model_path: str): with open(model_path, "rb") as f: scikit_model = pickle.load(f) return scikit_model def _get_feature_values( self, conformer: Chem.Mol, feature_names: List[str] ) -> np.ndarray: list_values = [] for feature in feature_names: try: list_values.append(float(conformer.GetProp(feature))) except KeyError as e: self._logger.log( f"Could not find feature / property, error message: {get_exception_message(e)}", _LE.ERROR, ) raise e # cast list to 2D array return np.array([list_values])
[docs] def execute(self): # get parameters parameters = deepcopy(self.settings.additional) model_path = nested_get(parameters, _SPE.MODEL_PATH, default=None) feature_names = nested_get(parameters, _SPE.FEATURES, default=None) name_predicted = nested_get( parameters, _SPE.NAME_PREDICTED, default=_SPE.NAME_PREDICTED_DEFAULT ) # check parameters; model_path and features are mandatory if model_path is None or feature_names is None: message = f"Parameters {_SPE.MODEL_PATH} (path to model) and {_SPE.FEATURES} (list with features) have to be set - abort." self._logger.log(message, _LE.ERROR) raise StepFailed(message) if name_predicted == _SPE.NAME_PREDICTED_DEFAULT: self._logger.log( f"Name of predicted property not specified, using default value {_SPE.NAME_PREDICTED_DEFAULT} instead (not recommended).", _LE.WARNING, ) # load model from file and predict endpoint model = self._load_scikit_model(model_path=model_path) predicted = 0 for compound in self.get_compounds(): for enumeration in compound.get_enumerations(): for conformer in enumeration.get_conformers(): if not self._input_object_valid(conformer): continue f_values = self._get_feature_values( conformer=conformer.get_molecule(), feature_names=feature_names ) conformer.get_molecule().SetProp( name_predicted, str(model.predict(X=f_values)[0]) ) predicted += 1 self._logger.log( f"Predicted {name_predicted} for {predicted} conformers in {len(self.get_compounds())} compounds.", _LE.INFO, )