Source code for icolos.core.workflow_steps.schrodinger.fep_base

from pydantic import BaseModel
from icolos.core.containers.compound import Compound
from icolos.core.workflow_steps.schrodinger.base import StepSchrodingerBase
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import shortest_path
from icolos.utils.enums.step_enums import StepFepPlusEnum
from typing import List
import time
import os
from icolos.core.workflow_steps.step import _LE

_SFE = StepFepPlusEnum()


[docs]class StepFEPBase(StepSchrodingerBase, BaseModel): """ Base class containing common functionality for Schrodinger FEP+ workflows """ def __init__(self, **data): super().__init__(**data) def _parse_output(self, tmp_dir): # pick up the final annotated map construction self.data.generic.clear_file_dict() self._logger.log(f"Reading output map.", _LE.INFO) data = None counts = 0 # hold whilst the job data gets written to local fs while data is None and counts < 50000: try: path = [ file for file in os.listdir(tmp_dir) if file.endswith(_SFE.FMP_OUTPUT_FILE) ] assert len(path) == 1 path = path[0] with open(os.path.join(tmp_dir, path), "rb") as f: data = f.read() except AssertionError: self._logger.log( "Output file has not yet appeared in the file system, sleeping and retrying...", _LE.INFO, ) time.sleep(15) counts += 1 self._add_data_to_generic(path, data) def _extract_log_file_data(self, tmp_dir): """ Parses FEP log file to extract edge and node properties """ lines = None counts = 0 # wait whilst job sits in the queue while lines is None and counts < 50000: try: log_file = [ file for file in os.listdir(tmp_dir) if file.endswith(_SFE.LOGFILE) ] assert len(log_file) == 1 log_file = log_file[0] with open(os.path.join(tmp_dir, log_file), "r") as f: lines = f.readlines() edge_header_index = [ idx for idx, s in enumerate(lines) if _SFE.EDGE_HEADER_LINE in s ][-1] node_header_index = [ idx for idx, s in enumerate(lines) if _SFE.NODE_HEADER_LINE in s ][-1] end_of_data_index = [ idx for idx, s in enumerate(lines) if _SFE.DATA_TERMINUS in s ][0] edge_data_lines = [ line for line in lines[edge_header_index + 3 : node_header_index - 1] ] node_data_lines = [ line for line in lines[node_header_index + 3 : end_of_data_index - 1] ] self._process_edge_lines(edge_data_lines) self._process_node_lines(node_data_lines) except AssertionError: self._logger.log( "Log file has not yet appeared in the file system, sleeping and retrying...", _LE.INFO, ) time.sleep(15) counts += 1 def _process_node_lines(self, data: List[str]) -> None: for entry in data: fields = entry.split() idx = fields[1] dG = fields[2] # attach dG tags to compound objects if present if self.data.compounds: # account for running this step compoundless self.data.compounds[int(idx[0])].get_enumerations()[0].get_conformers()[ 0 ].get_molecule().SetProp("dG", str(dG)) self._logger.log( f"dG directly from the output file for compound {idx} is {dG} ", _LE.INFO, ) def _process_edge_lines(self, edge_data: List[str]) -> None: """ Calibrate dG values using a reference compound and edge ddG from log file output, return dG for each compound """ # caluclate the max ligand index, accounting for ligands that may have been skipped in previous steps, so can't rely on self.get_compounds() len_nodes = 0 for line in edge_data: parts = line.split() lig_from = int(parts[1].split(":")[0]) lig_to = int(parts[3].split(":")[0]) for idx in [lig_from, lig_to]: if idx > len_nodes: len_nodes = idx len_nodes += 1 # account for zero indexed ligands error_matrix = np.zeros((len_nodes, len_nodes)) ddG_matrix = np.zeros((len_nodes, len_nodes)) for line in edge_data: parts = line.split() try: # parse the compound info from the log file lig_from = int(parts[1].split(":")[0]) lig_to = int(parts[3].split(":")[0]) ddG = float(parts[4].split("+-")[0]) err = float(parts[4].split("+-")[1]) except ValueError: self._logger.log( f"Line: {line} from the logfile contained an unexpected datatype - cannot process this edge - skipping", _LE.WARNING, ) continue error_matrix[lig_from, lig_to] = err error_matrix[lig_to, lig_from] = err ddG_matrix[lig_from, lig_to] = ddG ddG_matrix[lig_to, lig_from] = -ddG error_matrix = csr_matrix(error_matrix) # compute shortest path from one ligand to the anchor _, predecessors = shortest_path( error_matrix, directed=False, return_predecessors=True, indices=0 ) self._construct_dg_per_compound( ddG_matrix, predecessors, error_matrix, self.get_compounds() ) def _construct_dg_per_compound( self, ddG: np.ndarray, predecessors: List, error_matrix: np.ndarray, compounds: List[Compound], ) -> None: """ Calculate the calibrated binding free energy per compound using a reference value Attach calcualted dG to compounds """ try: ref_dG = self.settings.additional[_SFE.REFERENCE_DG] except KeyError: self._logger.log( "Expected to find a reference dG value for the lead compound, but none was found." "Defaulting to 0.00, you will need to apply a manual correction afterwards", _LE.WARNING, ) ref_dG = 0.00 def _calculate_dg(comp_num: int, dG=ref_dG, err=0): prev_index = predecessors[comp_num] dG += ddG[prev_index, comp_num] err += error_matrix[prev_index, comp_num] if prev_index != 0: _calculate_dg(prev_index, dG=dG, err=err) else: data = str(round(dG, 2)) + "+-" + str(round(err, 2)) compounds[idx].get_enumerations()[0].get_conformers()[ 0 ].get_molecule().SetProp("map_dG", data) self._logger.log( f"Calculated dG from spanning tree for compound {idx} is {data}", _LE.INFO, ) for comp in compounds: idx = comp.get_compound_number() # check whether the compound appeared in the final map try: if idx == 0: comp.get_enumerations()[0].get_conformers()[ 0 ].get_molecule().SetProp( "map_dG", str(self.settings.additional[_SFE.REFERENCE_DG]) ) if idx != 0: # skip the reference compound _calculate_dg(idx) except IndexError: self._logger.log( f"Compound {idx} was not found in the output map, it was likely dropped during the workflow", _LE.WARNING, ) continue