Source code for torch_molecule.utils.search

from enum import Enum
from typing import Dict, Any, Union, List, Tuple, NamedTuple

[docs] class ParameterType(Enum): """Enum defining types of hyperparameters for optimization. Each type corresponds to a specific Optuna suggest method and parameter behavior. """ CATEGORICAL = "categorical" # Uses suggest_categorical for discrete choices INTEGER = "integer" # Uses suggest_int for whole numbers FLOAT = "float" # Uses suggest_float for continuous values LOG_FLOAT = "log_float" # Uses suggest_float with log=True for exponential scale
[docs] class ParameterSpec(NamedTuple): """Specification for a hyperparameter including its type and valid range/options.""" param_type: ParameterType value_range: Union[Tuple[Any, Any], List[Any]]
[docs] def suggest_parameter(trial: Any, param_name: str, param_spec: ParameterSpec) -> Any: """Suggest a parameter value using the appropriate Optuna suggest method. Parameters ---------- trial : optuna.Trial The Optuna trial object param_name : str Name of the parameter param_spec : ParameterSpec Specification of the parameter type and range Returns ------- Any The suggested parameter value Raises ------ ValueError If the parameter type is not recognized """ if param_spec.param_type == ParameterType.CATEGORICAL: return trial.suggest_categorical(param_name, param_spec.value_range) elif param_spec.param_type == ParameterType.INTEGER: min_val, max_val = param_spec.value_range return trial.suggest_int(param_name, min_val, max_val) elif param_spec.param_type == ParameterType.FLOAT: min_val, max_val = param_spec.value_range return trial.suggest_float(param_name, min_val, max_val) elif param_spec.param_type == ParameterType.LOG_FLOAT: min_val, max_val = param_spec.value_range return trial.suggest_float(param_name, min_val, max_val, log=True) else: raise ValueError(f"Unknown parameter type: {param_spec.param_type}")
[docs] def parse_list_params(params_str): if params_str is None: return None return params_str.split(',')