Source code for optunaz.schemagen

import inspect
import json
from textwrap import dedent
from typing import Any, Optional

import apischema
from apischema import schema
from apischema.json_schema import deserialization_schema, JsonSchemaVersion
from apischema.schemas import Schema

from optunaz.config.optconfig import Algorithm
from optunaz.config.optconfig import OptimizationConfig
from optunaz.descriptors import MolDescriptor, RdkitDescriptor
from optunaz.utils.schema import (
    replacekey,
    addsibling,
    delsibling,
    copytitle,
    replaceenum,
    addtitles,
)

from optunaz.utils.preprocessing.splitter import Splitter
from optunaz.utils.preprocessing.deduplicator import Deduplicator


[docs]def doctitle(doc: str) -> tuple[None, None] | tuple[str, str]: """Returns the first line of the docstring as the title.""" if not doc: return None, None fulltitle, sep, rest = doc.partition("\n") # https://stackoverflow.com/a/11833277 title = fulltitle.rstrip(".") description = dedent(rest).strip() return title, description
[docs]def type_base_schema(tp: Any) -> Optional[Schema]: """Adds title and description from docstrings. See https://wyfo.github.io/apischema/0.16/json_schema/#base-schema """ # Add only to Algorithms and Descriptors, ignore all the rest. # Otherwise, it starts adding description to classes like 'int' and 'list'. # Could be enough to only ignore built-in classes, # and add descriptions to all classes under optunaz/qputna package. if not inspect.isclass(tp): return None if not issubclass(tp, (Algorithm, MolDescriptor, Splitter, Deduplicator)): return None if not hasattr(tp, "__doc__"): return None title, description = doctitle(tp.__doc__) return schema( title=title, description=description, )
apischema.settings.base_schema.type = type_base_schema
[docs]def patch_schema_generic(schema): # Adds titles to all fields, even without schema(title=...). # TODO(alex): instead of calling addtitles, add proper titles to all fields. addtitles(schema) # Replace singleton enums with const. # For some reason, this was not needed in AZDock. A mystery. schema = replaceenum(schema) # Replace "anyOf" with "oneOf". schema = replacekey(schema) # Add "type": "object" to any elements that contain "oneOf": [...]. schema = addsibling(schema) # Delete "type": "string" for "enum". schema = delsibling(schema, {"enum": "type"}) # Delete most of the stuff for "const". schema = delsibling(schema, {"const": "type"}) schema = delsibling(schema, {"const": "default"}) schema = delsibling(schema, {"const": "title"}) # Copy title from $refs into oneOf. schema = copytitle(schema, schema) return schema
[docs]def patch_schema_optunaz(schema): ( schema.get("$defs", {}) .get("MolData", {}) .get("properties", {}) .get("file_path", {}) )["format"] = "uri" # Dataset ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("save_intermediate_files", {}) )["const"] = True ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("intermediate_training_dataset_file", {}) )["const"] = "{{run.path}}/intermediate_training_dataset_file.csv" ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("intermediate_test_dataset_file", {}) )["const"] = "{{run.path}}/intermediate_test_dataset_file.csv" ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .pop("test_dataset_file", None) ) # ( # schema.get("$defs", {}) # .get("Dataset", {}) # .get("properties", {}) # .get("training_dataset_file", {}) # )["format"] = "file" ( schema.get("$defs", {}) .get("MolData", {}) .get("properties", {}) .get("file_path", {}) )["format"] = "uri" # Root OptimizationConfig ( schema.get("$defs", {}) .get("OptimizationConfig", {}) .get("properties", {}) .pop("mode", None) ) ( schema.get("$defs", {}) .get("OptimizationConfig", {}) .get("properties", {}) .pop("visualization", None) ) # ( # schema.get("$defs", {}) # .get("OptimizationConfig", {}) # .get("properties", {}) # )["mode"] = { # "$ref": "#/$defs/ModelMode", # "title": "Mode mode: regression or classification", # "default": "regression" # } (schema.get("$defs", {}).get("OptimizationConfig", {}).get("properties", {}))[ "slurmOptions" ] = { "title": "Slurm Options", "type": "object", "format": "slurm-summary", "description": "Resources and limits options - like CPU, Memory, time limits and reservations.", "properties": { "reservation": { "title": "Reinvent reservation.", "description": "Run the job on the nodes reserved for Reinvent.", "type": "boolean", "default": False, }, "cpusPerTask": { "title": "Number of CPU cores", "description": "How many cores to request for the job." " Rule of thumb: use as many as the number of folds in cross-validation.", "type": "number", "default": 5, }, "timeout": { "title": "Time limit for the job", "description": ( "How much time the job will have to finish." " Acceptable time formats include" " 'minutes', 'minutes:seconds', 'hours:minutes:seconds'," " 'days-hours', 'days-hours:minutes'" " and 'days-hours:minutes:seconds'." ), "type": "string", "pattern": ( "^(" "([0-9]+)" # minutes "|([0-9]+:[0-9]+)" # minutes:seconds "|([0-9]+:[0-9]+:[0-9]+)" # hours:minutes:seconds "|([0-9]+[-][0-9]+)" # days-hours "|([0-9]+[-][0-9]+:[0-9]+)" # days-hours:minutes "|([0-9]+[-][0-9]+:[0-9]+:[0-9]+)" # days-hours:minutes:seconds ")$" ), "default": "100:0:0", }, "memPerCpu": { "title": "Memory allocation per CPU", "description": "How much memory should be allocated per core used. Default units are megabytes. Different units can be specified using the suffix [K|M|G|T]: e.g. 28G = 28 Gigabytes.", "type": "string", "pattern": "[0-9]+[K|M|G|T]?", "default": "4G", }, "nodes": {"const": 1}, "ntasks": {"const": 1}, }, } ( schema.get("$defs", {}) .get("Settings", {}) .get("properties", {}) .pop("mode", None) ) (schema.get("$defs", {}).get("Settings", {}).get("properties", {}))["n_jobs"] = { "const": -1 } (schema.get("$defs", {}).get("Settings", {}).get("properties", {}))[ "track_to_mlflow" ] = {"const": False} (schema.get("$defs", {}).get("Settings", {}).get("properties", {}))[ "optuna_storage" ] = {"const": "sqlite:///{{run.path}}/optuna_storage.sqlite"} ( schema.get("$defs", {}) .get("Settings", {}) .get("properties", {}) .pop("shuffle", None) ) ( schema.get("$defs", {}) .get("Settings", {}) .get("properties", {}) .pop("direction", None) ) ( schema.get("$defs", {}) .get("Settings", {}) .get("properties", {}) .pop("scoring", None) ) ( schema.get("$defs", {}) .get("Settings", {}) .get("properties", {}) .pop("tracking_rest_endpoint", None) ) (schema.get("$defs", {}).get("Settings", {}))["format"] = "collapsed" ( schema.get("$defs", {}) .get("ScaledDescriptorParameters", {}) .get("properties", {}) .pop("scaler", None) ) ( schema.get("$defs", {}) .get("PhyschemDescriptors", {}) .get("properties", {}) .pop("parameters", {}) ) ( schema.get("$defs", {}) .get("JazzyDescriptors", {}) .get("properties", {}) .pop("parameters", {}) ) ( schema.get("$defs", {}) .get("SmilesAndSideInfoFromFileParams", {}) .get("properties", {}) .get("file", {}) )["format"] = "uri" ( schema.get("$defs", {}) .get("Stratified", {}) .get("properties", {}) .pop("bins", {}) ) ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("training_dataset_file", {}) )[ "format" ] = "qsartuna-file" # this format emphasized csv and sdf support with basic validation and automated column handling ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("response_type", {}) )[ "const" ] = None # hide response_type on UI, not implemented yet ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("input_column", {}) )[ "format" ] = "single-string-select" # for QSARtuna model this field is string, but in GUI we generate values for it based on qsartuna-file ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get("response_column", {}) )[ "format" ] = "single-string-select" # for QSARtuna model this field is string, but in GUI we generate values for it based on qsartuna-file ( schema.get("$defs", {}) .get("MolDescriptor", {}) .get("anyOf", []) .remove({"$ref": "#/$defs/PrecomputedDescriptorFromFile"}) ) ( schema.get("$defs", {}) .get("RandomForestRegressorParams", {}) .get("properties", {}) .get("max_features", {}) .get("items", {}) .update(default="auto") ) # Make GUI elements dependency for PTR. # First, copy "threshold" and "std" into "dependencies" object, # then delete them from the original place. ( schema.get("$defs", {}) .get("Dataset", {}) .update( dependencies={ "probabilistic_threshold_representation": { "oneOf": [ { "properties": { "probabilistic_threshold_representation": { "enum": [False] } } }, { "properties": { "probabilistic_threshold_representation": { "enum": [True] }, "probabilistic_threshold_representation_threshold": ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get( "probabilistic_threshold_representation_threshold", {}, ) ), "probabilistic_threshold_representation_std": ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get( "probabilistic_threshold_representation_std", {}, ) ), }, "required": [ "probabilistic_threshold_representation_threshold", "probabilistic_threshold_representation_std", ], }, ] }, } ) ) ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .pop("probabilistic_threshold_representation_std", None) ) ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .pop("probabilistic_threshold_representation_threshold", None) ) # Make GUI elements dependency for log transform. # First, copy "log" and "negate" into "dependencies" object, # then delete them from the original place. ( schema.get("$defs", {}) .get("Dataset", {}) .update( dependencies={ "log_transform": { "oneOf": [ {"properties": {"log_transform": {"enum": [False]}}}, { "properties": { "log_transform": {"enum": [True]}, "log_transform_base": ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get( "log_transform_base", {}, ) ), "log_transform_negative": ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get( "log_transform_negative", {}, ) ), "log_transform_unit_conversion": ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .get( "log_transform_unit_conversion", {}, ) ), }, "required": [ "log_transform_base", "log_transform_negative", "log_transform_unit_conversion", ], }, ] }, } ) ) ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .pop("log_transform_base", None) ) ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .pop("log_transform_negative", None) ) ( schema.get("$defs", {}) .get("Dataset", {}) .get("properties", {}) .pop("log_transform_unit_conversion", None) ) for alg in Algorithm.__subclasses__(): ( schema.get("$defs", {}) .get(alg.__name__, {}) .update(format="select-with-description") ) for descriptor in RdkitDescriptor.__subclasses__(): ( schema.get("$defs", {}) .get(descriptor.__name__, {}) .update(format="select-with-description") ) for descriptor in MolDescriptor.__subclasses__(): ( schema.get("$defs", {}) .get(descriptor.__name__, {}) .update(format="select-with-description") ) return schema
[docs]def main(): schema = deserialization_schema( OptimizationConfig, all_refs=True, version=JsonSchemaVersion.DRAFT_2019_09 ) schema = patch_schema_optunaz(schema) schema = patch_schema_generic(schema) print(json.dumps(schema, indent=2))
if __name__ == "__main__": main()