import logging
import os
from dataclasses import dataclass
from typing import List, Dict
import requests
from apischema import serialize
from optunaz.config.build_from_opt import remove_algo_hash
from optuna import Study
from optuna.trial import FrozenTrial
from optunaz.config.build_from_opt import buildconfig_from_trial
from optunaz.config.buildconfig import BuildConfig
from optunaz.config.optconfig import OptimizationConfig
from optunaz.evaluate import calibration_analysis
logger = logging.getLogger(__name__)
[docs]@dataclass
class TrackingData:
"""Dataclass defining internal tracking format"""
trial_number: int
trial_value: float
scoring: str
trial_state: str
all_cv_test_scores: Dict[str, List[float]]
buildconfig: BuildConfig
algorith_hash: str
def __post_init__(self):
self.buildconfig.metadata = None # Metadata is not essential - drop.
self.buildconfig.settings.n_trials = None # Drop.
[docs]def removeprefix(line: str, prefix: str) -> str:
# Starting from Python 3.9, str has method removeprefix().
# We target Python 3.7+, so here is this function.
if line.startswith(prefix):
return line[len(prefix) :]
[docs]def round_scores(test_scores):
return {k: [round(v, ndigits=3) for v in vs] for k, vs in test_scores.items()}
[docs]@dataclass
class InternalTrackingCallback:
"""Callback to track (log) Optimization progress using internal tracking format"""
optconfig: OptimizationConfig
trial_number_offset: int
def __call__(self, study: Study, trial: FrozenTrial) -> None:
trial = remove_algo_hash(trial)
try:
buildconfig = buildconfig_from_trial(study, trial)
if hasattr(trial, "values") and trial.values is not None:
trial_value = round(trial.values[0], ndigits=3)
elif hasattr(trial, "value") and trial.value is not None:
trial_value = round(trial.value, ndigits=3)
else:
trial_value = float("nan")
data = TrackingData(
trial_number=trial.number + self.trial_number_offset,
trial_value=trial_value,
scoring=self.optconfig.settings.scoring,
trial_state=trial.state.name,
all_cv_test_scores=round_scores(trial.user_attrs["test_scores"]),
buildconfig=buildconfig,
algorith_hash=trial.user_attrs["alg_hash"],
)
json_data = serialize(data)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": get_authorization_header(),
}
url = self.optconfig.settings.tracking_rest_endpoint
try:
response = requests.post(url, json=json_data, headers=headers)
except Exception as e:
logger.warning(f"Failed to report progress to {url}: {e}")
except Exception as e:
logger.warning(f"Failed to report progress: {e}")
[docs]@dataclass
class Datapoint:
smiles: str
expected: float
predicted: float
[docs]@dataclass
class Calpoint:
bin_edges: float
frac_true: float
frac_pred: float
[docs]@dataclass
class BuildTrackingData:
"""Dataclass defining internal Build tracking format"""
response_column_name: str
test_scores: Dict[str, float] | str
test_points: List[Datapoint]
cal_points: List[Calpoint] | None
[docs]def track_build(qptuna_model, buildconfig: BuildConfig, test_scores):
test_smiles = qptuna_model.predictor.test_smiles_
test_aux = qptuna_model.predictor.test_aux_
expected = qptuna_model.predictor.test_y_
if test_smiles is None or len(test_smiles) < 1:
logger.warning("No test set.")
return
rounded_test_scores = (
{k: round(v, ndigits=3) for k, v in test_scores.items()}
if test_scores is not None
else ""
)
predicted = qptuna_model.predict_from_smiles(test_smiles, aux=test_aux)
if qptuna_model.transform is not None:
expected = qptuna_model.transform.reverse_transform(expected)
test_points = [
Datapoint(
smiles=smi,
expected=round(expval.item(), ndigits=3), # item() converts numpy to float.
predicted=round(predval.item(), ndigits=3),
)
for smi, expval, predval in zip(test_smiles, expected, predicted)
]
try:
cal_points = [
Calpoint(
bin_edges=round(bin_edges.item(), ndigits=3),
frac_true=round(frac_true.item(), ndigits=3),
frac_pred=round(frac_pred.item(), ndigits=3),
)
for bin_edges, frac_true, frac_pred in calibration_analysis(
expected, predicted
)
]
except ValueError:
cal_points = ""
data = BuildTrackingData(
response_column_name=buildconfig.data.response_column,
test_scores=rounded_test_scores,
test_points=test_points,
cal_points=cal_points,
)
json_data = serialize(data)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": get_authorization_header(),
}
url = buildconfig.settings.tracking_rest_endpoint
try:
response = requests.post(url, json=json_data, headers=headers)
except Exception as e:
logger.warning(f"Failed to report build results {json_data} to {url}: {e}")