from pydantic import BaseModel
from icolos.core.containers.compound import Compound
from icolos.core.workflow_steps.schrodinger.base import StepSchrodingerBase
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import shortest_path
from icolos.utils.enums.step_enums import StepFepPlusEnum
from typing import List
import time
import os
from icolos.core.workflow_steps.step import _LE
_SFE = StepFepPlusEnum()
[docs]class StepFEPBase(StepSchrodingerBase, BaseModel):
    """
    Base class containing common functionality for Schrodinger FEP+ workflows
    """
    def __init__(self, **data):
        super().__init__(**data)
    def _parse_output(self, tmp_dir):
        # pick up the final annotated map construction
        self.data.generic.clear_file_dict()
        self._logger.log(f"Reading output map.", _LE.INFO)
        data = None
        counts = 0
        # hold whilst the job data gets written to local fs
        while data is None and counts < 50000:
            try:
                path = [
                    file
                    for file in os.listdir(tmp_dir)
                    if file.endswith(_SFE.FMP_OUTPUT_FILE)
                ]
                assert len(path) == 1
                path = path[0]
                with open(os.path.join(tmp_dir, path), "rb") as f:
                    data = f.read()
            except AssertionError:
                self._logger.log(
                    "Output file has not yet appeared in the file system, sleeping and retrying...",
                    _LE.INFO,
                )
                time.sleep(15)
                counts += 1
        self._add_data_to_generic(path, data)
    def _extract_log_file_data(self, tmp_dir):
        """
        Parses FEP log file to extract edge and node properties
        """
        lines = None
        counts = 0
        # wait whilst job sits in the queue
        while lines is None and counts < 50000:
            try:
                log_file = [
                    file for file in os.listdir(tmp_dir) if file.endswith(_SFE.LOGFILE)
                ]
                assert len(log_file) == 1
                log_file = log_file[0]
                with open(os.path.join(tmp_dir, log_file), "r") as f:
                    lines = f.readlines()
                edge_header_index = [
                    idx for idx, s in enumerate(lines) if _SFE.EDGE_HEADER_LINE in s
                ][-1]
                node_header_index = [
                    idx for idx, s in enumerate(lines) if _SFE.NODE_HEADER_LINE in s
                ][-1]
                end_of_data_index = [
                    idx for idx, s in enumerate(lines) if _SFE.DATA_TERMINUS in s
                ][0]
                edge_data_lines = [
                    line
                    for line in lines[edge_header_index + 3 : node_header_index - 1]
                ]
                node_data_lines = [
                    line
                    for line in lines[node_header_index + 3 : end_of_data_index - 1]
                ]
                self._process_edge_lines(edge_data_lines)
                self._process_node_lines(node_data_lines)
            except AssertionError:
                self._logger.log(
                    "Log file has not yet appeared in the file system, sleeping and retrying...",
                    _LE.INFO,
                )
                time.sleep(15)
                counts += 1
    def _process_node_lines(self, data: List[str]) -> None:
        for entry in data:
            fields = entry.split()
            idx = fields[1]
            dG = fields[2]
            # attach dG tags to compound objects if present
            if self.data.compounds:
                # account for running this step compoundless
                self.data.compounds[int(idx[0])].get_enumerations()[0].get_conformers()[
                    0
                ].get_molecule().SetProp("dG", str(dG))
            self._logger.log(
                f"dG directly from the output file for compound {idx} is {dG} ",
                _LE.INFO,
            )
    def _process_edge_lines(self, edge_data: List[str]) -> None:
        """
        Calibrate dG values using a reference compound and edge ddG from log file output, return dG for each compound
        """
        # caluclate the max ligand index, accounting for ligands that may have been skipped in previous steps, so can't rely on self.get_compounds()
        len_nodes = 0
        for line in edge_data:
            parts = line.split()
            lig_from = int(parts[1].split(":")[0])
            lig_to = int(parts[3].split(":")[0])
            for idx in [lig_from, lig_to]:
                if idx > len_nodes:
                    len_nodes = idx
        len_nodes += 1  # account for zero indexed ligands
        error_matrix = np.zeros((len_nodes, len_nodes))
        ddG_matrix = np.zeros((len_nodes, len_nodes))
        for line in edge_data:
            parts = line.split()
            try:
                # parse the compound info from the log file
                lig_from = int(parts[1].split(":")[0])
                lig_to = int(parts[3].split(":")[0])
                ddG = float(parts[4].split("+-")[0])
                err = float(parts[4].split("+-")[1])
            except ValueError:
                self._logger.log(
                    f"Line: {line} from the logfile contained an unexpected datatype - cannot process this edge - skipping",
                    _LE.WARNING,
                )
                continue
            error_matrix[lig_from, lig_to] = err
            error_matrix[lig_to, lig_from] = err
            ddG_matrix[lig_from, lig_to] = ddG
            ddG_matrix[lig_to, lig_from] = -ddG
        error_matrix = csr_matrix(error_matrix)
        # compute shortest path from one ligand to the anchor
        _, predecessors = shortest_path(
            error_matrix, directed=False, return_predecessors=True, indices=0
        )
        self._construct_dg_per_compound(
            ddG_matrix, predecessors, error_matrix, self.get_compounds()
        )
    def _construct_dg_per_compound(
        self,
        ddG: np.ndarray,
        predecessors: List,
        error_matrix: np.ndarray,
        compounds: List[Compound],
    ) -> None:
        """
        Calculate the calibrated binding free energy per compound using a reference value
        Attach calcualted dG to compounds
        """
        try:
            ref_dG = self.settings.additional[_SFE.REFERENCE_DG]
        except KeyError:
            self._logger.log(
                "Expected to find a reference dG value for the lead compound, but none was found."
                "Defaulting to 0.00, you will need to apply a manual correction afterwards",
                _LE.WARNING,
            )
            ref_dG = 0.00
        def _calculate_dg(comp_num: int, dG=ref_dG, err=0):
            prev_index = predecessors[comp_num]
            dG += ddG[prev_index, comp_num]
            err += error_matrix[prev_index, comp_num]
            if prev_index != 0:
                _calculate_dg(prev_index, dG=dG, err=err)
            else:
                data = str(round(dG, 2)) + "+-" + str(round(err, 2))
                compounds[idx].get_enumerations()[0].get_conformers()[
                    0
                ].get_molecule().SetProp("map_dG", data)
                self._logger.log(
                    f"Calculated dG from spanning tree for compound {idx} is {data}",
                    _LE.INFO,
                )
        for comp in compounds:
            idx = comp.get_compound_number()
            # check whether the compound appeared in the final map
            try:
                if idx == 0:
                    comp.get_enumerations()[0].get_conformers()[
                        0
                    ].get_molecule().SetProp(
                        "map_dG", str(self.settings.additional[_SFE.REFERENCE_DG])
                    )
                if idx != 0:  # skip the reference compound
                    _calculate_dg(idx)
            except IndexError:
                self._logger.log(
                    f"Compound {idx} was not found in the output map, it was likely dropped during the workflow",
                    _LE.WARNING,
                )
                continue