import logging
import logging.config
import argparse
import pathlib
import pickle
import sys
import pandas as pd
from optunaz.config import LOG_CONFIG
log_conf = LOG_CONFIG
logging.config.dictConfig(log_conf)
logger = logging.getLogger(__name__)
[docs]class ArgsError(Exception):
"Thrown when there is an issue with basic args at inference time"
[docs]class UncertaintyError(Exception):
"Thrown when uncertainty parameters are not set correctly at inference"
[docs]class AuxCovariateMissing(Exception):
"Thrown when a model is trained using Auxiliary (covariate) data which is not supplied at inference"
[docs]class PrecomputedError(Exception):
"Raised when a model is trained with precomputed descriptor not supplied at runtime or due to a missing argument"
[docs]def validate_args(args):
try:
assert not (
args.predict_uncertainty and args.predict_explain
), "Cannot provide both uncertainty and explainability at the same time"
if args.uncertainty_quantile is not None:
assert (
args.predict_uncertainty
), "Must predict with uncertainty to perform uncertainty_quantile"
assert (
0.0 <= args.uncertainty_quantile <= 1.0
), "uncertainty_quantile must range 0.0-1.0"
except AssertionError as e:
raise ArgsError(e)
[docs]def validate_uncertainty(args, model):
if args.predict_uncertainty:
if not hasattr(model.predictor, "predict_uncert"):
raise UncertaintyError("Uncertainty not availble for this model")
[docs]def check_precomp_args(args):
try:
assert (
args.input_precomputed_file is not None
), "Must supply precomputed descriptor parameters"
assert (
args.input_precomputed_input_column
), "Must supply input column for precomputed descriptor"
assert (
args.input_precomputed_response_column
), "Must supply response column for precomputed descriptor"
except AssertionError as e:
raise PrecomputedError(e)
[docs]def set_inference_params(args, desc):
if hasattr(desc.parameters, "descriptor") and hasattr(
desc.parameters.descriptor, "inference_parameters"
): # Scaled precomputed descriptors handled here
desc = desc.parameters.descriptor
if hasattr(desc, "inference_parameters"):
check_precomp_args(args)
desc.inference_parameters(
args.input_precomputed_file,
args.input_precomputed_input_column,
args.input_precomputed_response_column,
)
logging.info("Precomputed descriptor inference params set")
return True
return False
[docs]def validate_set_precomputed(args, model):
descriptor_str = model.descriptor.name
if set_inference_params(args, model.descriptor):
return model
elif hasattr(model.descriptor.parameters, "descriptors"):
n_precomp = 0
for d in model.descriptor.parameters.descriptors:
n_precomp += set_inference_params(args, d)
if n_precomp == 0:
if args.input_precomputed_file is not None:
logging.warning(
f"{descriptor_str} has no Precomputed descriptors... ignoring precomputed descriptor parameters"
)
elif n_precomp > 1:
raise PrecomputedError(
"Inference for > precomputed descriptor not currently available"
)
return model
else:
try:
check_precomp_args(args)
logging.warning(
f"Model was trained using {descriptor_str}... ignoring precomputed descriptor parameters"
)
except PrecomputedError:
pass
return model
[docs]def validate_aux(args, model):
try:
if model.metadata["buildconfig"].get("data").get("aux_column"):
assert (
args.input_aux_column
), "Model was trained with auxiliary data, please provide an input auxiliary column"
if model.aux_transform:
assert (
args.input_aux_column
), "Input auxiliary column required since it appears the model was trained with an auxiliary transform"
except AssertionError as e:
raise AuxCovariateMissing(e)
[docs]def main():
parser = argparse.ArgumentParser(
description="Predict responses for a given OptunaAZ model"
)
# fmt: off
requiredNamed = parser.add_argument_group('required named arguments')
requiredNamed.add_argument("--model-file", type=pathlib.Path, help="Model file name", required=True)
parser.add_argument("--input-smiles-csv-file", type=pathlib.Path, help="Name of input CSV file with Input SMILES")
parser.add_argument("--input-smiles-csv-column", type=str, help="Column name of SMILES column in input CSV file", default="SMILES")
parser.add_argument("--input-aux-column", type=str, help="Column name of auxiliary descriptors in input CSV file", default=None)
parser.add_argument("--input-precomputed-file", type=str, help="Filename of precomputed descriptors input CSV file", default=None)
parser.add_argument("--input-precomputed-input-column", type=str, help="Column name of precomputed descriptors identifier", default=None)
parser.add_argument("--input-precomputed-response-column", type=str, help="Column name of precomputed descriptors response column", default=None)
parser.add_argument("--output-prediction-csv-column", type=str, help="Column name of prediction column in output CSV file", default="Prediction")
parser.add_argument("--output-prediction-csv-file", type=str, help="Name of output CSV file")
parser.add_argument("--predict-uncertainty", action="store_true", help="Predict with uncertainties (model must provide this functionality)")
parser.add_argument("--predict-explain", action="store_true", help="Predict with SHAP or ChemProp explainability")
parser.add_argument("--uncertainty_quantile", type=float, help="Apply uncertainty threshold to predictions", default=None)
# fmt: on
args, leftovers = parser.parse_known_args()
validate_args(args)
with open(args.model_file, "rb") as f:
model = pickle.load(f)
validate_uncertainty(args, model)
model = validate_set_precomputed(args, model)
validate_aux(args, model)
incolumn = args.input_smiles_csv_column
outcolumn = args.output_prediction_csv_column
if args.input_smiles_csv_file is not None:
df = pd.read_csv(args.input_smiles_csv_file, skipinitialspace=True)
elif len(leftovers) > 0:
df = pd.DataFrame({incolumn: leftovers})
else:
logging.info("No SMILES specified, exiting.")
exit(1)
if args.input_aux_column is not None:
aux = df[args.input_aux_column]
else:
aux = None
if args.predict_explain:
df = model.predict_from_smiles(
df[incolumn],
explain=args.predict_explain,
aux=aux,
aux_transform=model.aux_transform,
)
else:
if args.predict_uncertainty:
pred, unc_pred = model.predict_from_smiles(
df[incolumn],
uncert=args.predict_uncertainty,
aux=aux,
aux_transform=model.aux_transform,
)
df[f"{outcolumn}"] = pred
df[f"{outcolumn}_uncert"] = unc_pred
if args.uncertainty_quantile is not None:
uncert_thr = df[f"{outcolumn}_uncert"].quantile(
args.uncertainty_quantile
)
df = df[df[f"{outcolumn}_uncert"] > uncert_thr].sort_values(
f"{outcolumn}_uncert", ascending=False
)
else:
df[outcolumn] = model.predict_from_smiles(
df[incolumn], aux=aux, aux_transform=model.aux_transform
)
if args.output_prediction_csv_file is None:
args.output_prediction_csv_file = sys.stdout
df.to_csv(args.output_prediction_csv_file, index=False, float_format="%g")