Source code for bonafide.features.symmetric_sites

"""Identification of symmetry equivalent positions in 2D mol objects."""

import copy
from collections import defaultdict
from typing import Dict, List, Union, cast

from rdkit import Chem

from bonafide.utils.base_featurizer import BaseFeaturizer


[docs] class Bonafide2DAtomIsSymmetricTo(BaseFeaturizer): """Feature factory for the 2D atom feature "is_symmetric_to", implemented within this package. The index of this feature is 30 (see the ``list_atom_features()`` and ``list_bond_features()`` method). The corresponding configuration settings can be found under "bonafide.symmetry" in the _feature_config.toml file. """ includeAtomMaps: bool includeChirality: bool includeChiralPresence: bool includeIsotopes: bool reduce_to_canonical: bool def __init__(self) -> None: self.extraction_mode = "multi" super().__init__()
[docs] def calculate(self) -> None: """Calculate the ``bonafide2D-atom-is_symmetric_to`` feature.""" # Rank the atoms based on their canonical ranks (symmetry) canonical_rank_list = list( Chem.CanonicalRankAtoms( mol=self.mol, breakTies=False, includeChirality=self.includeChirality, includeIsotopes=self.includeIsotopes, includeAtomMaps=self.includeAtomMaps, includeChiralPresence=self.includeChiralPresence, ) ) # Get dictionary of symmetry equivalent sites sites_ = defaultdict(list) for atom_idx, rank_idx in enumerate(canonical_rank_list): sites_[rank_idx].append(atom_idx) sites = {atom_indices[0]: atom_indices for atom_indices in sites_.values()} # Handle missing indices in initial sites dictionary new_sites = cast(Dict[int, Union[List[int], str]], copy.deepcopy(sites)) for rank_idx, idx_list in sites.items(): if len(idx_list) == 1: continue for atom_idx in idx_list: if atom_idx not in sites: if self.reduce_to_canonical is True: new_sites[atom_idx] = "_inaccessible" else: new_sites[atom_idx] = idx_list # Write values to the results dictionary for atom_idx, idx_list2 in new_sites.items(): if isinstance(idx_list2, str) and idx_list2 == "_inaccessible": self.results[atom_idx] = {self.feature_name: "_inaccessible"} else: self.results[atom_idx] = {self.feature_name: ",".join([str(i) for i in idx_list2])}
[docs] class Bonafide2DBondIsSymmetricTo(BaseFeaturizer): """Feature factory for the 2D bond feature "is_symmetric_to", implemented within this package. The index of this feature is 52 (see the ``list_atom_features()`` and ``list_bond_features()`` method). The corresponding configuration settings can be found under "bonafide.symmetry" in the _feature_config.toml file. """ includeAtomMaps: bool includeChirality: bool includeChiralPresence: bool includeIsotopes: bool reduce_to_canonical: bool def __init__(self) -> None: self.extraction_mode = "multi" super().__init__()
[docs] def calculate(self) -> None: """Calculate the ``bonafide2D-bond-is_symmetric_to`` feature.""" # Rank the atoms based on their canonical ranks (symmetry) canonical_rank_list = list( Chem.CanonicalRankAtoms( mol=self.mol, breakTies=False, includeChirality=self.includeChirality, includeIsotopes=self.includeIsotopes, includeAtomMaps=self.includeAtomMaps, includeChiralPresence=self.includeChiralPresence, ) ) # Get dictionary of symmetry equivalent atom sites atom_sites_ = defaultdict(list) for atom_idx, rank_idx in enumerate(canonical_rank_list): atom_sites_[rank_idx].append(atom_idx) atom_sites = {atom_indices[0]: atom_indices for atom_indices in atom_sites_.values()} # Get dictionary of symmetry equivalent bond sites defined by the atom rank indices # of their begin and end atoms bond_sites_ = defaultdict(list) for bond in self.mol.GetBonds(): rank_begin_idx = self._get_rank_idx(rank_dict=atom_sites, idx=bond.GetBeginAtomIdx()) rank_end_idx = self._get_rank_idx(rank_dict=atom_sites, idx=bond.GetEndAtomIdx()) bond_id = [rank_begin_idx, rank_end_idx] bond_id.sort() bond_id_str = "-".join([str(x) for x in bond_id]) bond_sites_[bond_id_str].append(bond.GetIdx()) bond_sites = {indices[0]: indices for indices in bond_sites_.values()} # Handle missing indices in bond sites dictionary new_sites = cast(Dict[int, Union[List[int], str]], copy.deepcopy(bond_sites)) for rank_idx, idx_list in bond_sites.items(): if len(idx_list) == 1: continue for bond_idx in idx_list: if bond_idx not in bond_sites: if self.reduce_to_canonical is True: new_sites[bond_idx] = "_inaccessible" else: new_sites[bond_idx] = idx_list # Write values to the results dictionary for bond_idx, idx_list2 in new_sites.items(): if isinstance(idx_list2, str) and idx_list2 == "_inaccessible": self.results[bond_idx] = {self.feature_name: "_inaccessible"} else: self.results[bond_idx] = {self.feature_name: ",".join([str(i) for i in idx_list2])}
[docs] @staticmethod def _get_rank_idx(rank_dict: Dict[int, List[int]], idx: int) -> int: """Get the rank index for a given atom index from the rank dictionary. Parameters ---------- rank_dict : Dict[int, List[int]] The rank dictionary mapping rank indices to lists of atom indices. idx : int The atom index for which to find the rank index. Returns ------- int The rank index corresponding to the given atom index. """ for rank_idx, atom_indices in rank_dict.items(): if idx in atom_indices: return rank_idx return -1 # Return -1 if the atom index is not found in the rank dictionary