Source code for bonafide.utils.base_featurizer

"""Base class for all feature factory classes."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from bonafide.utils.base_mixin import _BaseMixin
from bonafide.utils.helper_functions import get_function_or_method_name

if TYPE_CHECKING:
    from rdkit import Chem


[docs] class BaseFeaturizer(_BaseMixin): """Base class for all feature factory classes. All feature factory classes must inherit from this class. It provides the basic structure and workflow for generating and storing features through its ``__call__()`` method as well as additional helper methods for caching feature values. Attributes ---------- _err : Optional[str] The error message generated during feature calculation, if any. It is returned by the ``__call__()`` method. It is ``None`` if no error occurred. _out : Optional[Union[int, float, bool, str]] The output of the feature calculation (feature value for a given atom or bond of a given conformer) that is returned by the ``__call__()`` method. It is ``None`` if an error occurred. atom_bond_idx : int The index of the atom or bond for which the feature is requested. conformer_idx : int The index of the conformer in the molecule vault. conformer_name : str The name of the conformer for which the feature is requested. extraction_mode : str Indicator if the ``calculate()`` method of a respective feature factory calculates the features for all atoms or bonds of the molecule when called once ("multi") or only for a single atom or bond ("single"). It must be set in the child class. feature_cache : List[Dict[str, Dict[int, Optional[Union[str, bool, int, float]]]]] The cache of atom or bond features for each conformer. The individual list entries are dictionaries with the feature names as keys and dictionaries mapping atom indices to feature values as values. feature_name : str The name of the feature that is requested. feature_type : str The type of the feature that is requested, either "atom" or "bond". mol : rdkit.Chem.rdchem.Mol The RDKit molecule object of the conformer for which the feature is requested. results : Dict[int, Dict[str, Optional[Union[int, float, bool, str]]]] Dictionary for storing the results of the feature calculation. Its keys are the atom or bond indices, and the values are dictionaries with the feature name(s) as key(s) and their values. It is populated by the ``calculate()`` method implemented in the child classes (feature factory). """ atom_bond_idx: int conformer_idx: int conformer_name: str extraction_mode: str feature_cache: List[Dict[str, Dict[int, Optional[Union[str, bool, int, float]]]]] feature_name: str feature_type: str mol: Chem.rdchem.Mol results: Dict[int, Dict[str, Optional[Union[int, float, bool, str]]]] _out: Optional[Union[int, float, bool, str]] _err: Optional[str] def __init__(self) -> None: self.results = {} self._out = None self._err = None # Check if feature factory is correctly implemented self._check_requirements() def __call__( self, **kwargs: Any ) -> Tuple[Optional[Union[int, float, bool, str]], Optional[str]]: """Calculate the feature value for an atom or bond of a conformer. Initially, it is attempted to pull and return the requested data from the feature cache. If the data is not available, it is calculated. After that, all the data contained in ``results`` is written to the cache, and the requested data is pulled from there and returned. Finally, the output files generated during the calculation are saved (if requested by the user), and the working directory is deleted. If an unexpected error occurs during the feature calculation which is not captured by the ``_err`` attribute, it is logged and raised as ``RuntimeError``. Parameters ---------- **kwargs: Any Optional arguments that are set as attributes of the class instance. This allows passing different data to the child classes through the ``__call__()`` method. Returns ------- Tuple[Optional[Union[int, float, bool, str]], Optional[str]] A tuple containing the feature value (``None`` if an error occurred) and an error message (``None`` if no error occurred). """ # Set all attributes required for the feature calculation for attr_name, value in kwargs.items(): setattr(self, attr_name, value) _loc = f"{self.__class__.__name__}.calculate" _namespace = self.conformer_name[::-1].split("__", 1)[-1][::-1] # Try to get the data from the cache (in case it was already calculated) self._from_cache() if self._out is not None: return self._out, self._err # Set up a working directory self._setup_work_dir() # Try to calculate the feature. This will populate self.results and potentially self._err try: # self._check_requirements() ensures that the child class implements the calculate() # method; mypy does not recognize this, so we ignore the type error here self.calculate() # type: ignore[attr-defined] except Exception as e: _errmsg = ( f"An unexpected error occurred during the calculation of the " f"'{self.feature_name}' feature for the '{self.feature_type}' with index " f"{self.atom_bond_idx}: {e.__class__.__name__}: {e}." ) if _errmsg.endswith(".."): _errmsg = _errmsg[:-1] if _errmsg.endswith(".") is False: _errmsg += "." logging.error(f"'{_namespace}' | {_loc}()\n{_errmsg}") raise RuntimeError(f"{_loc}(): {_errmsg}") else: # Write the results to the cache and then get the data from it if self._err is None: self._to_cache() self._from_cache() # Save the potentially generated output files and return the data self._save_output_files() return self._out, self._err
[docs] def _check_requirements(self) -> None: """Check if the respective feature factory (child class) implements the required ``calculate()`` method and ``extraction_mode`` attribute. Returns ------- None """ _loc = f"{self.__class__.__name__}.{get_function_or_method_name()}" # Check if child class implements calculate method method_names = [ attr for attr in dir(self) if callable(getattr(self, attr)) is True and not attr.startswith("__") ] if "calculate" not in method_names: _errmsg = ( f"calculate() method must be implemented in child " f"class '{self.__class__.__name__}'." ) logging.error(f"'None' | {_loc}()\n{_errmsg}") raise NotImplementedError(f"{_loc}(): {_errmsg}") # Check if child class sets extraction_mode attribute if "extraction_mode" not in vars(self): _errmsg = ( "Attribute 'extraction_mode' must be set in child class " f"'{self.__class__.__name__}', either to 'single' or 'multi'." ) logging.error(f"'None' | {_loc}()\n{_errmsg}") raise AttributeError(f"{_loc}(): {_errmsg}") # Check if extraction_mode is set to either 'single' or 'multi' extraction_mode = str(getattr(self, "extraction_mode")).lower() if extraction_mode not in ["single", "multi"]: _errmsg = ( f"'extraction_mode' must be set either to 'single' or 'multi', but got " f"'{extraction_mode}' in class '{self.__class__.__name__}'." ) logging.error(f"'None' | {_loc}()\n{_errmsg}") raise ValueError(f"{_loc}(): {_errmsg}")
[docs] def _from_cache(self) -> None: """Attempt to retrieve the requested data from the feature cache. If the data is found in the cache, it is stored in the ``_out`` attribute. ``feature_cache`` is a list of cache dictionaries for the individual conformers. The keys of each dictionary are the feature names, and the values are dictionaries mapping atom or bond indices to feature values. Returns ------- None """ if self.feature_name in self.feature_cache[self.conformer_idx]: if self.atom_bond_idx in self.feature_cache[self.conformer_idx][self.feature_name]: self._out = self.feature_cache[self.conformer_idx][self.feature_name][ self.atom_bond_idx ]
[docs] def _to_cache(self) -> None: """Write the data contained in ``results`` to the feature cache. If the child class sets the ``extraction_mode`` attribute to "multi", this method expects all atom or bond indices to be present in ``results``. If indices are missing, the feature value is set to "_inaccessible" for all features found within ``results``. If certain features could not be calculated for specific atoms or bonds, those features are also set to "_inaccessible" for the respective indices. Returns ------- None """ # Skip if results dictionary is empty if self.results == {}: return # Add missing atom or bond indices to results dictionary in case no feature was calculated # for that specific atom or bond if self.extraction_mode == "multi": # Get all feature names present in results all_feature_names = set() for idx, data in self.results.items(): for feature_name in data: all_feature_names.add(feature_name) # Get all atom or bond indices of the molecule if self.feature_type == "atom": idx_list = [a.GetIdx() for a in self.mol.GetAtoms()] if self.feature_type == "bond": idx_list = [b.GetIdx() for b in self.mol.GetBonds()] # Add missing indices to results with value "_inaccessible" for all features found in # results for idx in idx_list: if idx not in self.results: self.results[idx] = {} for feature_name in all_feature_names: if feature_name not in self.results[idx]: self.results[idx][feature_name] = "_inaccessible" # Save the results to the feature cache for idx, data in self.results.items(): for feature_name, value in data.items(): if feature_name not in self.feature_cache[self.conformer_idx]: self.feature_cache[self.conformer_idx][feature_name] = {} self.feature_cache[self.conformer_idx][feature_name][idx] = value