import os
from collections import OrderedDict
from copy import deepcopy
import numpy as np
import pandas as pd
import json
from typing import List
from pydantic import BaseModel, PrivateAttr
from rdkit import Chem
from pathlib import Path
# from icolos.core.composite_agents.workflow import WorkflowData
from icolos.core.containers.compound import Compound, Conformer
from icolos.core.step_utils.input_preparator import StepData
from icolos.core.step_utils.run_variables_resolver import RunVariablesResolver
from icolos.loggers.steplogger import StepLogger
from icolos.utils.enums.logging_enums import LoggingConfigEnum
from icolos.utils.enums.step_enums import StepBaseEnum, StepGromacsEnum
from icolos.utils.enums.write_out_enums import WriteOutEnum
_WE = WriteOutEnum()
_LE = LoggingConfigEnum()
_SBE = StepBaseEnum
_SGE = StepGromacsEnum()
[docs]class StepWriteoutCompoundAggregationParameters(BaseModel):
mode: _SBE = _SBE.WRITEOUT_COMP_AGGREGATION_MODE_ALL
highest_is_best: bool = True
key: str = None
[docs]class StepWriteoutCompoundParameters(BaseModel):
category: _SBE
aggregation: StepWriteoutCompoundAggregationParameters = (
StepWriteoutCompoundAggregationParameters()
)
key: str = None
selected_tags: List[str] = None
[docs]class StepWriteoutGenericParameters(BaseModel):
key: str
[docs]class StepWriteoutGromacsParameters(BaseModel):
key: str
[docs]class StepWriteoutDestinationParameters(BaseModel):
resource: str = None
type: _SBE = _SBE.WRITEOUT_DESTINATION_TYPE_FILE
format: _SBE = _SBE.FORMAT_TXT
merge: bool = True
mode: _SBE = _SBE.WRITEOUT_DESTINATION_BASE_NAME
[docs]class StepWriteoutParameters(BaseModel):
compounds: StepWriteoutCompoundParameters = None
generic: StepWriteoutGenericParameters = None
gmx_state: StepWriteoutGromacsParameters = None
destination: StepWriteoutDestinationParameters = None
[docs]class WriteOutHandler(BaseModel):
config: StepWriteoutParameters
data: StepData = None
workflow_data: BaseModel = None
[docs] class Config:
underscore_attrs_are_private = True
_logger = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
self._logger = StepLogger()
[docs] def set_data(self, data: StepData):
self.data = deepcopy(data)
[docs] def set_workflow_data(self, data):
self.workflow_data = data
[docs] def get_data(self) -> StepData:
return self.data
def _handle_destination_type(self):
if self.config.destination.type.lower() in (
_SBE.WRITEOUT_DESTINATION_TYPE_FILE,
_SBE.WRITEOUT_DESTINATION_TYPE_REINVENT,
_SBE.WRITEOUT_DESTINATION_DIR,
):
return self.config.destination.resource
elif (
self.config.destination.type.lower() == _SBE.WRITEOUT_DESTINATION_TYPE_REST
):
raise ValueError("REST end-point destination type not supported yet.")
raise ValueError(
f"Destination type {self.config.destination.type} not supported."
)
def _write_compounds(self):
resource = self._handle_destination_type()
resolver = RunVariablesResolver()
if self.config.compounds.category == _SBE.WRITEOUT_COMP_CATEGORY_CONFORMERS:
if self.config.destination.format.upper() == _SBE.FORMAT_CSV:
if self.config.destination.format.upper() != _SBE.FORMAT_CSV:
raise NotImplementedError(
"Only supporting CSV write-out format for tabular data."
)
self._writeout_tabular()
elif self.config.destination.format.upper() == _SBE.FORMAT_JSON:
self._writeout_reinvent()
elif self.config.destination.format.upper() == _SBE.FORMAT_SDF:
def _write_compounds(compounds: List[Compound], resource: str):
# TODO: deal with resolving resources differently (also for writing enumerations below)
resource_resolved = resource
for compound in compounds:
for enum in compound.get_enumerations():
if len(enum.get_conformers()) > 0:
resource_resolved = resolver.resolve(resource, enum[0])
break
self._make_folder(resource_resolved)
writer = Chem.SDWriter(resource_resolved)
written = 0
for comp in compounds:
for enum in comp:
for conf in enum:
molecule = conf.get_molecule()
if (
comp.get_name() is not None
and comp.get_name() != ""
):
molecule.SetProp(_WE.COMPOUND_NAME, comp.get_name())
molecule.SetProp(
_WE.RDKIT_NAME, conf.get_index_string()
)
molecule.SetProp(
_WE.INDEX_STRING, conf.get_index_string()
)
writer.write(molecule)
written += 1
writer.close()
self._logger.log(
f"Wrote {written} conformers to file {resource_resolved}.",
_LE.DEBUG,
)
# TODO: At the moment, this only splits at the compound level (taking the first conformer for resolving),
if self.config.destination.merge:
_write_compounds(self.data.compounds, resource=resource)
else:
for comp in self.data.compounds:
_write_compounds([comp], resource)
elif self.config.compounds.category == _SBE.WRITEOUT_COMP_CATEGORY_ENUMERATIONS:
if not self.config.destination.format.upper() == _SBE.FORMAT_SDF:
raise NotImplementedError(
"This write-out is not supported for enumerations."
)
else:
def _write_compounds(compounds: List[Compound], resource: str):
# TODO: deal with resolving resources differently (also for writing conformers above)
resource_resolved = resource
for compound in compounds:
if len(compound.get_enumerations()) > 0:
resource_resolved = resolver.resolve(resource, compounds[0])
break
self._make_folder(resource_resolved)
writer = Chem.SDWriter(resource_resolved)
written = 0
for comp in compounds:
for enum in comp:
molecule = enum.get_molecule()
if comp.get_name() is not None and comp.get_name() != "":
molecule.SetProp(_WE.COMPOUND_NAME, comp.get_name())
molecule.SetProp(_WE.RDKIT_NAME, enum.get_index_string())
molecule.SetProp(_WE.INDEX_STRING, enum.get_index_string())
writer.write(molecule)
written += 1
writer.close()
self._logger.log(
f"Wrote {written} enumeration molecules to file {resource_resolved}.",
_LE.DEBUG,
)
if self.config.destination.merge:
_write_compounds(self.data.compounds, resource=resource)
else:
for comp in self.data.compounds:
_write_compounds([comp], resource)
elif self.config.compounds.category == _SBE.WRITEOUT_COMP_CATEGORY_EXTRADATA:
if self.config.destination.format.upper() != _SBE.FORMAT_TXT:
raise ValueError(
f"For writing out extra-data (attached to conformers), only TXT is supported as format."
)
# TODO: Does merging here makes any sense?
for comp in self.data.compounds:
for enum in comp:
for conf in enum:
resource_resolved = resolver.resolve(resource, conf)
self._make_folder(resource_resolved)
with open(resource_resolved, "w") as f:
content = conf.get_extra_data()[self.config.compounds.key]
if isinstance(content, list):
for line in content:
f.write(line.rstrip("\n") + "\n")
elif isinstance(content, str):
f.write(content)
else:
raise ValueError(
"Extra data must be either a string or a list of strings."
)
else:
raise ValueError(f"{self.config.compounds.category} not supported.")
def _write_generic_data(self):
# type and format do not apply here, simply overwrite defaults
self.config.destination.type = _SBE.WRITEOUT_DESTINATION_TYPE_FILE
self.config.destination.format = _SBE.FORMAT_TXT
resource = self._handle_destination_type()
self._make_folder(resource)
if self.config.destination.mode == _SBE.WRITEOUT_DESTINATION_DIR:
# The output path should be a directory only
assert not os.path.isfile(resource)
os.makedirs(resource, exist_ok=True)
# write out all files from that step with the required extension
for idx, file in enumerate(
self.data.generic.get_files_by_extension(self.config.generic.key)
):
if self.config.destination.mode == _SBE.WRITEOUT_DESTINATION_BASE_NAME:
parts = resource.split(".")
resource = parts[0] + f"_{idx}." + parts[1]
file.write(resource, join=False)
elif self.config.destination.mode == _SBE.WRITEOUT_DESTINATION_AUTOMATIC:
# take the original file name from the step (these tend not to be very descriptive)
parts = file.get_file_name().split(".")
file_name = parts[0] + f"_{idx}." + parts[1]
resource = os.path.join("/".join(resource.split("/")[:-1]), file_name)
file.write(resource, join=False)
elif self.config.destination.mode == _SBE.WRITEOUT_DESTINATION_DIR:
assert os.path.isdir(resource)
file.write(resource, join=True, final_writeout=True)
def _write_gromacs_data(self):
"""
Handle writeout from gromacs topology state
"""
self.config.destination.type = _SBE.WRITEOUT_DESTINATION_TYPE_FILE
self.config.destination.format = _SBE.FORMAT_TXT
self.config.destination.type = _SBE.WRITEOUT_DESTINATION_DIR
resource = self._handle_destination_type()
os.makedirs(resource, exist_ok=True)
writeout_keys = map(lambda s: s.strip(), self.config.gmx_state.key.split(","))
for key in writeout_keys:
if key == _SGE.FIELD_KEY_TOPOL:
self.data.gmx_state.write_topol(resource)
elif key == _SGE.FIELD_KEY_NDX:
self.data.gmx_state.write_ndx(resource)
elif key == _SGE.PROPS:
self.data.gmx_state.write_props(resource)
elif key == _SGE.FIELD_KEY_LOG:
self.data.gmx_state.write_log(resource)
elif key == _SGE.FIELD_KEY_TPR:
self.data.gmx_state.write_tpr(resource)
elif key == _SGE.FIELD_KEY_CPT:
self.data.gmx_state.write_cpt(resource)
elif key == _SGE.FIELD_KEY_EDR:
self.data.gmx_state.write_edr(resource)
elif key == _SGE.FIELD_KEY_XTC:
# if we have multiple trajectories, write them out sequentially, with index attached
if len(self.data.gmx_state.trajectories.keys()) > 1:
for k, v in self.data.gmx_state.trajectories.items():
parts = v.get_file_name().split(".")
file_name = parts[0] + "_" + str(k) + "." + parts[1]
self.data.gmx_state.write_trajectory(
resource, file=file_name, index=k
)
else:
self.data.gmx_state.write_trajectory(resource)
elif key == _SGE.FIELD_KEY_STRUCTURE:
if len(self.data.gmx_state.structures.keys()) > 1:
for k, v in self.data.gmx_state.structures.items():
parts = v.get_file_name().split(".")
file_name = parts[0] + "_" + str(k) + "." + parts[1]
self.data.gmx_state.write_structure(
resource, file=file_name, index=k
)
else:
self.data.gmx_state.write_structure(resource)
else:
raise ValueError(
f"Gromacs file of type {key} is not supported for writeout"
)
[docs] def write(self):
if (
self.config.compounds is not None
and self.config.generic is not None
and self.config.gromacs_state is not None
):
raise ValueError("Only specify one type of writeout per block!")
if self.config.compounds is not None:
self._write_compounds()
elif self.config.generic is not None:
self._write_generic_data()
elif self.config.gmx_state is not None:
self._write_gromacs_data()
else:
raise ValueError(
"Either compounds, generic or gromacs data has to be specified."
)
def _writeout_reinvent(self):
def _get_conf_by_comp_name(confs: List[Conformer], comp_name: str) -> Conformer:
# assumes there is at most 1 conformer / compound left at this stage, as is required by REINVENT
for conf in confs:
if conf.get_compound_name() == comp_name:
return conf
return None
dict_result = {_WE.JSON_RESULTS: []}
tags = self._get_selected_tags()
# add names, including those for which no conformer has been obtained
dict_result[_WE.JSON_NAMES] = [comp.get_name() for comp in self.data.compounds]
# do aggregation (might remove conformers)
confs_unrolled = self._apply_aggregation(self.data.compounds)
# add values (derived from molecule tags)
# TODO: if no conformers are left, we need to write out an empty JSON that tells REINVENT that none worked
for tag in tags:
values = []
for comp_name in dict_result[_WE.JSON_NAMES]:
conf = _get_conf_by_comp_name(confs=confs_unrolled, comp_name=comp_name)
if conf is not None:
try:
value = conf.get_molecule().GetProp(tag)
except KeyError:
value = _WE.JSON_NA
else:
value = _WE.JSON_NA
values.append(value.strip())
dict_result[_WE.JSON_RESULTS].append(
{_WE.JSON_VALUES_KEY: tag, _WE.JSON_VALUES: values}
)
# TODO: refactor that part
resource = self._handle_destination_type()
if len(confs_unrolled) > 0:
resolver = RunVariablesResolver()
resource_resolved = resolver.resolve(resource, confs_unrolled[0])
else:
resource_resolved = resource
self._logger.log(
f"No conformers obtained, write-out resource resolving disabled.",
_LE.WARNING,
)
self._make_folder(resource_resolved)
# write-out according to destination type
# TODO: there seems to be an issue here, when multiple write-out blocks are specified and no conformers are
# left: only the first block gets executed and if that's not the REINVENT one, the run will crash
if self.config.destination.type.lower() in (
_SBE.WRITEOUT_DESTINATION_TYPE_REINVENT,
_SBE.WRITEOUT_DESTINATION_TYPE_FILE,
):
with open(resource_resolved, "w") as f:
json.dump(dict_result, f, indent=4)
elif self.config.destination.type.lower() in (
_SBE.WRITEOUT_DESTINATION_TYPE_STDOUT,
_SBE.WRITEOUT_DESTINATION_TYPE_STDERR,
):
json.dump(dict_result, resource_resolved, indent=4)
else:
raise ValueError(
f"Destination type {self.config.destination.type} not supported for this function."
)
def _get_selected_tags(self) -> List[str]:
# this function returns a list of tags (strings) that are to be considered for e.g. tabular write-out
# if the respective configuration field is set to "None", use all tags (over all compounds in a batch)
if self.config.compounds.selected_tags is not None:
if isinstance(self.config.compounds.selected_tags, list):
list_tags = self.config.compounds.selected_tags
elif isinstance(self.config.compounds.selected_tags, str):
list_tags = [self.config.compounds.selected_tags]
else:
raise ValueError(
f'Tag selection "{self.config.compounds.selected_tags}" set to illegal value.'
)
else:
# get all tags for all compounds
list_tags = []
for comp in self.data.compounds:
for enum in comp:
for conf in enum:
list_tags = list_tags + list(conf.get_molecule().GetPropNames())
list_tags = list(set(list_tags))
return list_tags
def _initialize_dict_csv(
self, keys: List[str], nrow: int, fill_value=np.NaN
) -> OrderedDict:
return_dict = OrderedDict()
for key in keys:
return_dict[key] = [fill_value for _ in range(nrow)]
return return_dict
def _apply_aggregation(self, compounds: List[Compound]) -> List[Conformer]:
if (
self.config.compounds.aggregation.mode
== _SBE.WRITEOUT_COMP_AGGREGATION_MODE_ALL
):
return self._unroll_conformers(compounds)
confs_remaining = []
if (
self.config.compounds.aggregation.mode
== _SBE.WRITEOUT_COMP_AGGREGATION_MODE_BESTPERENUMERATION
):
raise NotImplementedError("Best per enumeration is not yet implemented.")
elif (
self.config.compounds.aggregation.mode
== _SBE.WRITEOUT_COMP_AGGREGATION_MODE_BESTPERCOMPOUND
):
for comp in compounds:
unrolled_conformers = self._unroll_conformers([comp])
if len(unrolled_conformers) == 0:
continue
values = []
for conf in unrolled_conformers:
try:
values.append(
float(
conf.get_molecule().GetProp(
self.config.compounds.aggregation.key
)
)
)
except KeyError as e:
self._logger.log(
f"Error {e} for conf {conf.get_index_string()}, setting value to 0.00 in writeout!",
_LE.WARNING,
)
values.append(0.00)
# values = [
# float(
# conf.get_molecule().GetProp(
# self.config.compounds.aggregation.key
# )
# )
# for conf in unrolled_conformers
# ]
index_best = (
values.index(max(values))
if self.config.compounds.aggregation.highest_is_best
else values.index(min(values))
)
confs_remaining.append(unrolled_conformers[index_best])
return confs_remaining
def _unroll_conformers(self, compounds: List[Compound]) -> List[Conformer]:
result = []
for comp in compounds:
for enum in comp:
for conf in enum:
result.append(conf)
return result
def _writeout_tabular(self):
# get all tags of the molecules that are to be considered
tags = self._get_selected_tags()
# remove the compound_name and _Name, as they will be specifically added at the beginning
if _WE.COMPOUND_NAME in tags:
tags.remove(_WE.COMPOUND_NAME)
if _WE.RDKIT_NAME in tags:
tags.remove(_WE.RDKIT_NAME)
# do aggregation (might remove conformers)
confs_unrolled = self._apply_aggregation(self.data.compounds)
# initialize a dictionary with all tags as keys and filled with NA for every position
dict_result = self._initialize_dict_csv(
keys=[
_WE.RDKIT_NAME,
_WE.COMPOUND_NAME,
"original_smiles",
"enumerated_smiles",
]
+ tags,
nrow=len(confs_unrolled),
)
# resolve resource
# TODO: refactor that part
resource = self._handle_destination_type()
resolver = RunVariablesResolver()
if len(confs_unrolled) == 0:
raise ValueError("No conformers found.")
resource_resolved = resolver.resolve(resource, confs_unrolled[0])
self._make_folder(resource_resolved)
# populate the dictionary with the values (if present)
for irow in range(len(confs_unrolled)):
# add the internal Icolos identifier
conf = confs_unrolled[irow]
dict_result[_WE.RDKIT_NAME][irow] = conf.get_index_string()
# add the compound name, if specified
name = conf.get_compound_name()
dict_result[_WE.COMPOUND_NAME][irow] = "" if name is None else name
dict_result["original_smiles"][
irow
] = conf.get_enumeration_object().get_original_smile()
dict_result["enumerated_smiles"][irow] = Chem.rdmolfiles.MolToSmiles(
conf.get_molecule()
)
for tag in tags:
try:
value = conf.get_molecule().GetProp(tag).strip()
except KeyError:
value = np.nan
dict_result[tag][irow] = value
# do the writeout (after sanitation)
df_result = pd.DataFrame.from_dict(dict_result)
df_result = self._sanitize_df_columns(df=df_result)
df_result.to_csv(
path_or_buf=resource_resolved,
sep=",",
na_rep="",
header=True,
index=False,
mode="w",
quoting=None,
)
self._logger.log(
f"Wrote data frame with {len(confs_unrolled)} rows and {len(tags)} columns to file {resource_resolved}.",
_LE.DEBUG,
)
def _sanitize_df_columns(self, df: pd.DataFrame) -> pd.DataFrame:
cols_before = df.columns.to_list()
df.columns = (
df.columns.str.strip()
.str.replace(" ", "_")
.str.replace("(", "")
.str.replace(")", "")
.str.replace("/", "_")
.str.replace("[", "")
.str.replace("]", "")
)
for col_before, col_after in zip(cols_before, df.columns.to_list()):
if col_before != col_after:
self._logger.log(
f"Sanitized column name {col_before} to {col_after}.", _LE.WARNING
)
return df
def _make_folder(self, path):
if isinstance(path, str):
if not os.path.isdir(path):
path = os.path.dirname(path)
Path(path).mkdir(parents=True, exist_ok=True)