Source code for maize.steps.mai.misc.reinvent

"""Interface to REINVENT"""

from collections.abc import Callable
import json
import os
from pathlib import Path
import stat
import sys
import threading
from time import sleep
from typing import TYPE_CHECKING, Annotated, Any, NoReturn, cast
import pytest

import toml
import numpy as np
from numpy.typing import NDArray

from maize.core.node import Node
from maize.core.workflow import Workflow
from maize.core.interface import Input, Output, Parameter, FileParameter, Suffix, Flag
from maize.utilities.chem import IsomerCollection
from maize.utilities.execution import CommandRunner
from maize.utilities.testing import TestRig
from maize.utilities.io import setup_workflow


if TYPE_CHECKING:
    from maize.core.graph import Graph


REINVENT_LOGFILE = Path("reinvent.log")
DEFAULT_PATCHED_CONFIG = Path("config.toml")
TEMP_SMILES_FILE = Path("inp.smi")
TEMP_SCORES_FILE = Path("out.json")
INTERCEPTOR_FILE = Path("./intercept.py")
INTERCEPTOR = f"""#!/usr/bin/env python
from pathlib import Path
import shutil
from time import sleep
import sys

with Path("{TEMP_SMILES_FILE.as_posix()}").open("w") as file:
    file.writelines(sys.stdin.readlines())
while not Path("{TEMP_SCORES_FILE.as_posix()}").exists():
    sleep(0.5)
with Path("{TEMP_SCORES_FILE.as_posix()}").open("r") as file:
    print(file.read())
Path("{TEMP_SCORES_FILE.as_posix()}").unlink()
"""


[docs] def expose_reinvent(graph_type: type["Graph"]) -> Callable[[], None]: """ Converts a subgraph with smiles input and score output to a REINVENT-compatible workflow. The subgraph must have a single input 'inp' of type `Input[list[str]]` and a single output 'out' of type `Output[NDArray[np.float32]]`. Parameters ---------- graph_type The subgraph to convert Returns ------- Callable[[], None] Runnable workflow """ from maize.steps.io import LoadData, Return def wrapped() -> None: flow = Workflow() smi = flow.add(LoadData[list[str]], name="smiles") core = flow.add(graph_type, name="core") sco = flow.add(Return[NDArray[np.float32]]) assert hasattr(core, "inp") assert hasattr(core, "out") # Small hack to allow us to use help, but at the # same time allow reading in all input from stdin if all(flag not in sys.argv for flag in ("-h", "--help")): smiles = sys.stdin.readlines() flow.nodes smi.data.set(smiles) flow.connect_all((smi.out, core.inp), (core.out, sco.inp)) flow.map(*core.all_parameters.values()) setup_workflow(flow) # 1 is stdout, 2 is stderr if (scores := sco.get()) is not None: print(json.dumps(list(scores))) return wrapped
def _exception_handler(args: Any, /) -> NoReturn: raise args.exc_type(args.exc_value).with_traceback(args.exc_traceback) def _write_interceptor(file: Path, contents: str) -> None: """Creates a SMILES interceptor script""" with file.open("w") as script: script.write(contents) os.chmod(file, mode=stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR) def _patch_config( path: Path, weight: float = 1.0, low: float = 0.0, high: float = 1.0, k: float = 0.5, reverse: bool = False, min_epochs: int = 5, max_epochs: int = 10, batch_size: int = 128, maize_backend: bool = False, ) -> Path: """Patch the REINVENT config to allow interception of SMILES.""" score_conf = { "name": "maize", "weight": weight, "params": { "executable": "./intercept.py", "args": "", }, "transform": { "low": low, "high": high, "k": k, "type": "reverse_sigmoid" if reverse else "sigmoid", }, } with path.open() as file: if path.suffix == ".json": conf = json.load(file) elif path.suffix == ".toml": conf = toml.load(file) else: raise IOError(f"Unable to read REINVENT config '{path.as_posix()}'") conf["stage"][0]["scoring"]["component"]["ExternalProcess"] = {"endpoint": [score_conf]} conf["stage"][0]["min_steps"] = min_epochs conf["stage"][0]["max_steps"] = max_epochs conf["parameters"]["batch_size"] = batch_size patched_file = DEFAULT_PATCHED_CONFIG.with_suffix(path.suffix) with patched_file.open("w") as out: if path.suffix == ".json": json.dump(conf, out) elif path.suffix == ".toml": toml.dump(conf, out) return patched_file class read_log: def __init__(self, logfile: Path) -> None: self.logfile = logfile self._new_lines = 0 def __call__(self) -> str: """Read the ReInvent logfile and format""" with self.logfile.open() as log: lines = log.readlines() all_lines = len(lines) msg = "\n" msg += "---------------- STDOUT ----------------\n" msg += "".join(lines[-(all_lines - self._new_lines) :]) + "\n" msg += "---------------- STDOUT ----------------\n" self._new_lines = all_lines return msg
[docs] class ReInvent(Node): """ Runs REINVENT in a staged learning context. This node works by starting a REINVENT process with a special 'intercepting' external process to score the sampled SMILES. This interceptor simply accepts the SMILES on standard input and writes them to a location known by the node. The node then reads these SMILES and sends them to the output. The node then waits for scores to be received on the input, and writes them to a location known by the interceptor. The interceptor waits for the scores to be written and then reads them in to pass them to REINVENT on standard output. REINVENT can then perform its likelihood update and the cycle repeats until REINVENT exits or the maximum number of iterations is reached. """ required_callables = ["reinvent"] """ Requires REINVENT to be installed in a separate python environment and ideally be specified as an interpreter - script pair. """ inp: Input[NDArray[np.float32]] = Input(optional=True) """Raw score input for the likelihood update""" out: Output[list[str]] = Output() """SMILES string output""" configuration: FileParameter[Annotated[Path, Suffix("toml", "json")]] = FileParameter() """ReInvent configuration file""" min_epoch: Parameter[int] = Parameter(default=5) """Minimum number of epochs to run""" max_epoch: Parameter[int] = Parameter(default=50) """Minimum number of epochs to run""" weight: Parameter[float] = Parameter(default=1.0) """Weight of the maize scoring component""" low: Parameter[float] = Parameter(default=0.0) """Low threshold for the sigmoid score transformation""" high: Parameter[float] = Parameter(default=1.0) """High threshold for the sigmoid score transformation""" k: Parameter[float] = Parameter(default=0.5) """Slope for the sigmoid score transformation""" reverse: Flag = Flag(default=False) """Whether to use a reverse sigmoid score transform""" batch_size: Parameter[int] = Parameter(default=128) """ReInvent batch size""" maize_backend: Flag = Flag(default=False) """Whether to use a special maize backend in Reinvent to enable weighted scores""" def _handle_smiles(self, worker: threading.Thread) -> None: self.logger.debug("Waiting for SMILES from Reinvent") while not TEMP_SMILES_FILE.exists(): sleep(0.5) if not worker.is_alive(): self.logger.debug("Reinvent has completed, exiting...") return with TEMP_SMILES_FILE.open("r") as file: smiles = [smi.strip() for smi in file.readlines()] self.logger.debug("Sending SMILES") self.out.send(smiles) TEMP_SMILES_FILE.unlink() def run(self) -> None: # Create the interceptor fake external process _write_interceptor(INTERCEPTOR_FILE, INTERCEPTOR) self.max_steps = self.max_epoch.value config = _patch_config( self.configuration.filepath, weight=self.weight.value, low=self.low.value, high=self.high.value, k=self.k.value, reverse=self.reverse.value, min_epochs=self.min_epoch.value, max_epochs=self.max_epoch.value, batch_size=self.batch_size.value, maize_backend=self.maize_backend.value, ) command = ( f"{self.runnable['reinvent']} " f"--log-filename {REINVENT_LOGFILE.as_posix()} " f"-f {config.suffix.strip('.')} {config.as_posix()}" ) # This allows us to keep track of the most recent REINVENT logs readlog = read_log(REINVENT_LOGFILE) # Have to instantiate the executor in this process, as doing so in the subthread # will cause a python error due to the use of signals outside of the main thread. cmd = CommandRunner(working_dir=self.work_dir, rm_config=self.config.batch_config) threading.excepthook = _exception_handler # Start REINVENT in a separate thread (subprocess # starts a separate GIL-independent process) worker = threading.Thread(target=lambda: cmd.run(command)) worker.start() self.logger.debug("Starting REINVENT worker with TID %s", worker.native_id) # Get the first set of SMILES self._handle_smiles(worker=worker) # Point to the tensorboard log directory if the user wants it tb_logs = list(self.work_dir.glob("tb_logs*")) if len(tb_logs) > 0: self.logger.info( "Tensorboard logs can be found at %s", tb_logs[-1].absolute().as_posix(), ) epoch = 0 for _ in range(self.max_epoch.value): # Reinvent may terminate early self.logger.debug("Checking if Reinvent is still running") if not worker.is_alive(): break self.logger.info("ReInvent output: %s", readlog()) self.logger.debug("Waiting for scores") scores = self.inp.receive() with TEMP_SCORES_FILE.open("w") as file: scores_data = {"version": 4, "payload": {"predictions": list(scores)}} self.logger.debug("Writing '%s'", scores_data) json.dump(scores_data, file) self._handle_smiles(worker=worker) self.logger.info("Sent new batch of SMILES, epoch %s", epoch) epoch += 1 self.logger.info("Loop complete, stopping worker with TID %s", worker.native_id) worker.join(timeout=5)
[docs] class ReinventEntry(Node): """ Specialized entrypoint for the REINVENT - Maize interface. Reads a JSON file containing generated SMILES and additional metadata, and outputs this information to be used in more generic workflows. Examples -------- .. literalinclude:: ../../docs/reinvent-interface-example.yml :language: yaml :linenos: """ data: FileParameter[Annotated[Path, Suffix("json")]] = FileParameter() """JSON input from Maize REINVENT scoring component""" out: Output[list[str]] = Output() """SMILES output""" out_metadata: Output[dict[str, Any]] = Output(optional=True) """Any additional metadata passed on by REINVENT""" def run(self) -> None: file = self.data.filepath with file.open() as inp: data = json.loads(inp.read()) self.logger.debug("Received %s smiles", len(data["smiles"])) self.out.send(data["smiles"]) self.out_metadata.send(data["metadata"])
[docs] class ReinventExit(Node): """ Specialized exitpoint for the REINVENT - Maize interface. Creates a JSON file containing scores and relevances, which can be read in by the Maize scoring component in REINVENT. """ inp: Input[list[IsomerCollection]] = Input() """Scored molecule input""" data: FileParameter[Annotated[Path, Suffix("json")]] = FileParameter(exist_required=False) """JSON output for Maize REINVENT scoring component""" def run(self) -> None: mols = self.inp.receive() self.logger.debug("Received %s mols", len(mols)) # Add per-mol relevances if we have them relevance = [1.0 for _ in mols] if all(iso.has_tag("relevance") for mol in mols for iso in mol.molecules): relevance = [ max(float(cast(float, iso.get_tag("relevance"))) for iso in mol.molecules) for mol in mols ] scores = [mol.best_score for mol in mols] self.logger.info("Sending %s scores: %s", len(mols), scores) data = dict(scores=scores, relevances=relevance) with self.data.filepath.open("w") as out: out.write(json.dumps(data))
@pytest.fixture def reinvent_config(shared_datadir: Path) -> Path: return shared_datadir / "input-intercept.toml" @pytest.fixture def prior(shared_datadir: Path) -> Path: return shared_datadir / "random.prior.new" @pytest.fixture def agent(shared_datadir: Path) -> Path: return shared_datadir / "random.prior.new" @pytest.fixture def patch_config(prior: Path, agent: Path, reinvent_config: Path, tmp_path: Path) -> Path: with reinvent_config.open() as conf: data = toml.load(conf) data["parameters"]["prior_file"] = prior.absolute().as_posix() data["parameters"]["agent_file"] = agent.absolute().as_posix() new_config_file = tmp_path / "conf.toml" with new_config_file.open("w") as conf: toml.dump(data, conf) return new_config_file def test_reinvent(temp_working_dir: Any, test_config: Any, patch_config: Any) -> None: n_epochs, n_batch = 5, 8 scores = [np.random.rand(n_batch) for _ in range(n_epochs)] rig = TestRig(ReInvent, config=test_config) params = { "configuration": patch_config, "min_epoch": 3, "max_epoch": n_epochs, "batch_size": n_batch, } res = rig.setup_run(parameters=params, inputs={"inp": scores}) data = res["out"].flush(timeout=20) assert len(data) == n_epochs assert 1 < len(data[0]) <= n_batch