import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal
import warnings
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from .model import GNN
from ...base import BaseMolecularPredictor
from ...utils import graph_from_smiles
from ...utils.search import (
suggest_parameter,
ParameterSpec,
ParameterType,
parse_list_params,
)
# Dictionary mapping parameter names to their types and ranges
DEFAULT_GNN_SEARCH_SPACES: Dict[str, ParameterSpec] = {
# Model architecture parameters
"gnn_type": ParameterSpec(
ParameterType.CATEGORICAL, ["gin-virtual", "gcn-virtual", "gin", "gcn"]
),
"norm_layer": ParameterSpec(
ParameterType.CATEGORICAL,
[
"batch_norm",
"layer_norm",
"instance_norm",
"graph_norm",
"size_norm",
"pair_norm",
],
),
"graph_pooling": ParameterSpec(ParameterType.CATEGORICAL, ["mean", "sum", "max"]),
"augmented_feature": ParameterSpec(ParameterType.CATEGORICAL, ["maccs,morgan", "maccs", "morgan", None]),
# Integer-valued parameters
"num_layer": ParameterSpec(ParameterType.INTEGER, (2, 8)),
"hidden_size": ParameterSpec(ParameterType.INTEGER, (64, 512)),
# Float-valued parameters with linear scale
"drop_ratio": ParameterSpec(ParameterType.FLOAT, (0.0, 0.75)),
"scheduler_factor": ParameterSpec(ParameterType.FLOAT, (0.1, 0.5)),
# Float-valued parameters with log scale
"learning_rate": ParameterSpec(ParameterType.LOG_FLOAT, (1e-5, 1e-2)),
"weight_decay": ParameterSpec(ParameterType.LOG_FLOAT, (1e-8, 1e-3)),
}
[docs]
class GNNMolecularPredictor(BaseMolecularPredictor):
"""This predictor implements a GNN model for molecular property prediction tasks.
Parameters
----------
num_task : int, default=1
Number of prediction tasks.
task_type : str, default="regression"
Type of prediction task, either "regression" or "classification".
num_layer : int, default=5
Number of GNN layers.
hidden_size : int, default=300
Dimension of hidden node features.
gnn_type : str, default="gin-virtual"
Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"].
drop_ratio : float, default=0.5
Dropout probability.
norm_layer : str, default="batch_norm"
Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"].
graph_pooling : str, default="sum"
Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"].
augmented_feature : list or None, default=None
Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling.
Examples like ["morgan", "maccs"] or None.
batch_size : int, default=128
Number of samples per batch for training.
epochs : int, default=500
Maximum number of training epochs.
loss_criterion : callable, optional
Loss function for training.
evaluate_criterion : str or callable, optional
Metric for model evaluation.
evaluate_higher_better : bool, optional
Whether higher values of the evaluation metric are better.
learning_rate : float, default=0.001
Learning rate for optimizer.
grad_clip_value : float, optional
Maximum norm of gradients for gradient clipping.
weight_decay : float, default=0.0
L2 regularization strength.
patience : int, default=50
Number of epochs to wait for improvement before early stopping.
use_lr_scheduler : bool, default=False
Whether to use learning rate scheduler.
scheduler_factor : float, default=0.5
Factor by which to reduce learning rate when plateau is reached.
scheduler_patience : int, default=5
Number of epochs with no improvement after which learning rate will be reduced.
verbose : bool, default=False
Whether to print progress information during training.
"""
def __init__(
self,
# Core model parameters
num_task: int = 1,
task_type: str = "regression",
# GNN architecture parameters
num_layer: int = 5,
hidden_size: int = 300,
gnn_type: str = "gin-virtual",
drop_ratio: float = 0.5,
norm_layer: str = "batch_norm",
graph_pooling: str = "sum",
augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None,
# Training parameters
batch_size: int = 128,
epochs: int = 500,
learning_rate: float = 0.001,
weight_decay: float = 0.0,
grad_clip_value: Optional[float] = None,
patience: int = 50,
# Learning rate scheduler parameters
use_lr_scheduler: bool = False,
scheduler_factor: float = 0.5,
scheduler_patience: int = 5,
# Loss and evaluation parameters
loss_criterion: Optional[Callable] = None,
evaluate_criterion: Optional[Union[str, Callable]] = None,
evaluate_higher_better: Optional[bool] = None,
# General parameters
verbose: bool = False,
device: Optional[Union[torch.device, str]] = None,
model_name: str = "GNNMolecularPredictor"
):
super().__init__(
device=device,
model_name=model_name,
num_task=num_task,
task_type=task_type,
)
# Core model parameters
self.num_layer = num_layer
self.hidden_size = hidden_size
self.gnn_type = gnn_type
self.drop_ratio = drop_ratio
self.norm_layer = norm_layer
self.graph_pooling = graph_pooling
self.augmented_feature = augmented_feature
# Training parameters
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.grad_clip_value = grad_clip_value
self.patience = patience
# Learning rate scheduler parameters
self.use_lr_scheduler = use_lr_scheduler
self.scheduler_factor = scheduler_factor
self.scheduler_patience = scheduler_patience
# Loss and evaluation parameters
self.loss_criterion = loss_criterion
self.evaluate_criterion = evaluate_criterion
self.evaluate_higher_better = evaluate_higher_better
# General parameters
self.verbose = verbose
# Training state
self.fitting_loss = list()
self.fitting_epoch = 0
self.model_class = GNN
if self.augmented_feature is not None:
valid_augmented_feature = {"morgan", "maccs"}
invalid_fps = set(self.augmented_feature) - valid_augmented_feature
if invalid_fps:
raise ValueError(
f"Invalid augmented types: {invalid_fps}. "
f"Valid options are: {list(valid_augmented_feature)}"
)
# Setup loss criterion and evaluation
if self.loss_criterion is None:
self.loss_criterion = self._load_default_criterion()
self._setup_evaluation(self.evaluate_criterion, self.evaluate_higher_better)
if self.norm_layer not in ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]:
raise ValueError(f"Invalid norm_layer: {self.norm_layer}. Valid options are: batch_norm, layer_norm, instance_norm, graph_norm, size_norm, pair_norm")
@staticmethod
def _get_param_names() -> List[str]:
"""Get parameter names for the estimator.
Returns
-------
List[str]
List of parameter names that can be used for model configuration.
"""
return [
# Model Hyperparameters
"num_task",
"task_type",
"num_layer",
"hidden_size",
"gnn_type",
"drop_ratio",
"norm_layer",
"graph_pooling",
# Augmented Features
"augmented_feature",
# Training Parameters
"batch_size",
"epochs",
"learning_rate",
"weight_decay",
"patience",
"grad_clip_value",
"loss_criterion",
# Evaluation Parameters
"evaluate_name",
"evaluate_criterion",
"evaluate_higher_better",
# Scheduler Parameters
"use_lr_scheduler",
"scheduler_factor",
"scheduler_patience",
# Other Parameters
"fitting_epoch",
"fitting_loss",
"device",
"verbose"
]
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
if checkpoint is not None:
if "hyperparameters" not in checkpoint:
raise ValueError("Checkpoint missing 'hyperparameters' key")
hyperparameters = checkpoint["hyperparameters"]
return {
"num_task": hyperparameters.get("num_task", self.num_task),
"num_layer": hyperparameters.get("num_layer", self.num_layer),
"hidden_size": hyperparameters.get("hidden_size", self.hidden_size),
"gnn_type": hyperparameters.get("gnn_type", self.gnn_type),
"drop_ratio": hyperparameters.get("drop_ratio", self.drop_ratio),
"norm_layer": hyperparameters.get("norm_layer", self.norm_layer),
"graph_pooling": hyperparameters.get("graph_pooling", self.graph_pooling),
"augmented_feature": hyperparameters.get("augmented_feature", self.augmented_feature)
}
else:
return {
"num_task": self.num_task,
"num_layer": self.num_layer,
"hidden_size": self.hidden_size,
"gnn_type": self.gnn_type,
"drop_ratio": self.drop_ratio,
"norm_layer": self.norm_layer,
"graph_pooling": self.graph_pooling,
"augmented_feature": self.augmented_feature
}
def _convert_to_pytorch_data(self, X, y=None):
"""Convert numpy arrays to PyTorch Geometric data format.
"""
if self.verbose:
iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X))
else:
iterator = enumerate(X)
pyg_graph_list = []
for idx, smiles_or_mol in iterator:
if y is not None:
properties = y[idx]
else:
properties = None
graph = graph_from_smiles(smiles_or_mol, properties, self.augmented_feature)
g = Data()
g.num_nodes = graph["num_nodes"]
g.edge_index = torch.from_numpy(graph["edge_index"])
del graph["num_nodes"]
del graph["edge_index"]
if graph["edge_feat"] is not None:
g.edge_attr = torch.from_numpy(graph["edge_feat"])
del graph["edge_feat"]
if graph["node_feat"] is not None:
g.x = torch.from_numpy(graph["node_feat"])
del graph["node_feat"]
if graph["y"] is not None:
g.y = torch.from_numpy(graph["y"])
del graph["y"]
if graph["morgan"] is not None:
g.morgan = torch.tensor(graph["morgan"], dtype=torch.int8).view(1, -1)
del graph["morgan"]
if graph["maccs"] is not None:
g.maccs = torch.tensor(graph["maccs"], dtype=torch.int8).view(1, -1)
del graph["maccs"]
pyg_graph_list.append(g)
return pyg_graph_list
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
"""Setup optimization components including optimizer and learning rate scheduler.
Returns
-------
Tuple[optim.Optimizer, Optional[Any]]
A tuple containing:
- The configured optimizer
- The learning rate scheduler (if enabled, else None)
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = None
if self.use_lr_scheduler:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=self.scheduler_factor,
patience=self.scheduler_patience,
min_lr=1e-6,
cooldown=0,
eps=1e-8,
)
return optimizer, scheduler
def _get_default_search_space(self):
"""Get the default hyperparameter search space.
"""
return DEFAULT_GNN_SEARCH_SPACES
[docs]
def autofit(
self,
X_train: List[str],
y_train: Optional[Union[List, np.ndarray]],
X_val: Optional[List[str]] = None,
y_val: Optional[Union[List, np.ndarray]] = None,
X_unlbl: Optional[List[str]] = None,
search_parameters: Optional[Dict[str, ParameterSpec]] = None,
n_trials: int = 10,
) -> "GNNMolecularPredictor":
"""Automatically find the best hyperparameters using Optuna optimization."""
import optuna
# Default search parameters
default_search_parameters = self._get_default_search_space()
if search_parameters is None:
search_parameters = default_search_parameters
else:
# Validate search parameter keys
invalid_params = set(search_parameters.keys()) - set(default_search_parameters.keys())
if invalid_params:
raise ValueError(
f"Invalid search parameters: {invalid_params}. "
f"Valid parameters are: {list(default_search_parameters.keys())}"
)
if self.verbose:
all_params = set(self._get_param_names())
searched_params = set(search_parameters.keys())
non_searched_params = all_params - searched_params
print("\nParameter Search Configuration:")
print("-" * 50)
print("\n Parameters being searched:")
for param in sorted(searched_params):
spec = search_parameters[param]
if spec.param_type == ParameterType.CATEGORICAL:
print(f" • {param}: {spec.value_range}")
else:
print(f" • {param}: [{spec.value_range[0]}, {spec.value_range[1]}]")
print("\n Fixed parameters (not being searched):")
for param in sorted(non_searched_params):
value = getattr(self, param, "N/A")
print(f" • {param}: {value}")
print("\n" + "-" * 50)
print(f"\nStarting hyperparameter optimization using {self.evaluate_name} metric")
print(f"Direction: {'maximize' if self.evaluate_higher_better else 'minimize'}")
print(f"Number of trials: {n_trials}")
# Variables to track best state
best_score = float('-inf') if self.evaluate_higher_better else float('inf')
best_state_dict = None
best_trial_params = None
best_loss = None
best_epoch = None
def objective(trial):
nonlocal best_score, best_state_dict, best_trial_params, best_loss, best_epoch
# Define hyperparameters to optimize using the parameter specifications
params = {}
for param_name, param_spec in search_parameters.items():
try:
params[param_name] = suggest_parameter(trial, param_name, param_spec)
except Exception as e:
print(f"Error suggesting parameter {param_name}: {str(e)}")
return float('inf')
# Update model parameters and train
if "augmented_feature" in params:
params['augmented_feature'] = parse_list_params(params['augmented_feature'])
self.set_params(**params)
self.fit(X_train, y_train, X_val, y_val, X_unlbl)
# Get evaluation score
eval_data = (X_val if X_val is not None else X_train)
eval_labels = (y_val if y_val is not None else y_train)
eval_results = self.predict(eval_data)
score = float(self.evaluate_criterion(eval_labels, eval_results['prediction']))
# Update best state if current score is better
is_better = (
score > best_score if self.evaluate_higher_better
else score < best_score
)
if is_better:
best_score = score
best_state_dict = {
'model': self.model.state_dict(),
'architecture': self._get_model_params()
}
best_trial_params = params.copy()
best_loss = self.fitting_loss.copy() # Added .copy() for safety
best_epoch = self.fitting_epoch
if self.verbose:
print(
f"Trial {trial.number}: {self.evaluate_name} = {score:.4f} "
f"({'better' if is_better else 'worse'} than best = {best_score:.4f})"
)
print("Current parameters:")
for param_name, value in params.items():
print(f" {param_name}: {value}")
# Return score (negated if higher is better, since Optuna minimizes)
return -score if self.evaluate_higher_better else score
# Create study with optional output control
optuna.logging.set_verbosity(
optuna.logging.INFO if self.verbose else optuna.logging.WARNING
)
# Create and run study
study = optuna.create_study(
direction="minimize",
study_name=f"{self.model_name}_optimization"
)
study.optimize(
objective,
n_trials=n_trials,
catch=(Exception,),
show_progress_bar=self.verbose
)
if best_state_dict is not None:
self.set_params(**best_trial_params)
# Initialize model with saved architecture parameters
self._initialize_model(self.model_class)
# Load the saved state dict
self.model.load_state_dict(best_state_dict['model'])
self.fitting_loss = best_loss
self.fitting_epoch = best_epoch
self.is_fitted_ = True
if self.verbose:
print(f"\nOptimization completed successfully:")
print(f"Best {self.evaluate_name}: {best_score:.4f}")
eval_data = (X_val if X_val is not None else X_train)
eval_labels = (y_val if y_val is not None else y_train)
eval_results = self.predict(eval_data)
score = float(self.evaluate_criterion(eval_labels, eval_results['prediction']))
print('post score is: ', score)
print("\nBest parameters:")
for param, value in best_trial_params.items():
param_spec = search_parameters[param]
print(f" {param}: {value} (type: {param_spec.param_type.value})")
print("\nOptimization statistics:")
print(f" Number of completed trials: {len(study.trials)}")
print(f" Number of pruned trials: {len(study.get_trials(states=[optuna.trial.TrialState.PRUNED]))}")
print(f" Number of failed trials: {len(study.get_trials(states=[optuna.trial.TrialState.FAIL]))}")
return self
[docs]
def fit(
self,
X_train: List[str],
y_train: Optional[Union[List, np.ndarray]],
X_val: Optional[List[str]] = None,
y_val: Optional[Union[List, np.ndarray]] = None,
X_unlbl: Optional[List[str]] = None,
) -> "GNNMolecularPredictor":
"""Fit the model to the training data with optional validation set.
Parameters
----------
X_train : List[str]
Training set input molecular structures as SMILES strings
y_train : Union[List, np.ndarray]
Training set target values for property prediction
X_val : List[str], optional
Validation set input molecular structures as SMILES strings.
If None, training data will be used for validation
y_val : Union[List, np.ndarray], optional
Validation set target values. Required if X_val is provided
X_unlbl : List[str], optional
Unlabeled set input molecular structures as SMILES strings.
Returns
-------
self : GNNMolecularPredictor
Fitted estimator
"""
if (X_val is None) != (y_val is None):
raise ValueError(
"Both X_val and y_val must be provided for validation. "
f"Got X_val={X_val is not None}, y_val={y_val is not None}"
)
self._initialize_model(self.model_class)
self.model.initialize_parameters()
optimizer, scheduler = self._setup_optimizers()
# Prepare datasets and loaders
X_train, y_train = self._validate_inputs(X_train, y_train)
train_dataset = self._convert_to_pytorch_data(X_train, y_train)
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0
)
if X_val is None or y_val is None:
val_loader = train_loader
warnings.warn(
"No validation set provided. Using training set for validation. "
"This may lead to overfitting.",
UserWarning
)
else:
X_val, y_val = self._validate_inputs(X_val, y_val)
val_dataset = self._convert_to_pytorch_data(X_val, y_val)
val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=0
)
# Initialize training state
self.fitting_loss = []
self.fitting_epoch = 0
best_state_dict = None
best_eval = float('-inf') if self.evaluate_higher_better else float('inf')
cnt_wait = 0
# Calculate total steps for global progress bar
steps_per_epoch = len(train_loader)
total_steps = self.epochs * steps_per_epoch
# Initialize global progress bar
global_pbar = None
if self.verbose:
global_pbar = tqdm(
total=total_steps,
desc="Training Progress",
unit="step",
dynamic_ncols=True
)
for epoch in range(self.epochs):
# Training phase
train_losses = self._train_epoch(train_loader, optimizer, epoch, global_pbar)
self.fitting_loss.append(float(np.mean(train_losses)))
# Validation phase
current_eval = self._evaluation_epoch(val_loader)
if scheduler:
scheduler.step(current_eval)
# Model selection (check if current evaluation is better)
is_better = (
current_eval > best_eval if self.evaluate_higher_better
else current_eval < best_eval
)
if is_better:
self.fitting_epoch = epoch
best_eval = current_eval
best_state_dict = self.model.state_dict()
cnt_wait = 0
if self.verbose:
# Update global progress bar with current metrics
global_pbar.set_postfix({
"Epoch": f"{epoch+1}/{self.epochs}",
"Loss": f"{float(np.mean(train_losses)):.4f}",
f"{self.evaluate_name}": f"{best_eval:.4f}",
"Status": "✓ Best"
})
else:
cnt_wait += 1
if self.verbose:
global_pbar.set_postfix({
"Epoch": f"{epoch+1}/{self.epochs}",
"Loss": f"{float(np.mean(train_losses)):.4f}",
f"{self.evaluate_name}": f"{current_eval:.4f}",
"Wait": f"{cnt_wait}/{self.patience}"
})
if cnt_wait > self.patience:
if self.verbose:
global_pbar.set_postfix({
"Status": "Early Stopped",
"Epoch": f"{epoch+1}/{self.epochs}"
})
global_pbar.close()
break
# Close global progress bar
if global_pbar is not None:
global_pbar.close()
# Restore best model
if best_state_dict is not None:
self.model.load_state_dict(best_state_dict)
else:
warnings.warn(
"No improvement was achieved during training. "
"The model may not be fitted properly.",
UserWarning
)
self.is_fitted_ = True
return self
[docs]
def predict(self, X: List[str]) -> Dict[str, np.ndarray]:
"""Make predictions using the fitted model.
Parameters
----------
X : List[str]
List of SMILES strings to make predictions for
Returns
-------
Dict[str, np.ndarray]
Dictionary containing:
- 'prediction': Model predictions (shape: [n_samples, n_tasks])
"""
self._check_is_fitted()
# Convert to PyTorch Geometric format and create loader
X, _ = self._validate_inputs(X)
dataset = self._convert_to_pytorch_data(X)
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
if self.model is None:
raise RuntimeError("Model not initialized")
# Make predictions
self.model = self.model.to(self.device)
self.model.eval()
predictions = []
with torch.no_grad():
if self.verbose:
iterator = tqdm(loader, desc="Predicting")
else:
iterator = loader
for batch in iterator:
batch = batch.to(self.device)
out = self.model(batch)
predictions.append(out["prediction"].cpu().numpy())
return {
"prediction": np.concatenate(predictions, axis=0),
}
def _evaluation_epoch(
self,
loader: DataLoader,
) -> float:
"""Evaluate the model on given data.
Parameters
----------
loader : DataLoader
DataLoader containing evaluation data
train_losses : List[float]
Training losses from current epoch
Returns
-------
float
Evaluation metric value (adjusted for higher/lower better)
"""
self.model.eval()
y_pred_list = []
y_true_list = []
with torch.no_grad():
for batch in loader:
batch = batch.to(self.device)
out = self.model(batch)
y_pred_list.append(out["prediction"].cpu().numpy())
y_true_list.append(batch.y.cpu().numpy())
y_pred = np.concatenate(y_pred_list, axis=0)
y_true = np.concatenate(y_true_list, axis=0)
# Compute metric
metric_value = float(self.evaluate_criterion(y_true, y_pred))
# Adjust metric value based on higher/lower better
return metric_value
def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
"""Training logic for one epoch.
Args:
train_loader: DataLoader containing training data
optimizer: Optimizer instance for model parameter updates
epoch: Current epoch number
global_pbar: Global progress bar for tracking overall training progress
Returns:
list: List of loss values for each training step
"""
self.model.train()
losses = []
for batch_idx, batch in enumerate(train_loader):
batch = batch.to(self.device)
optimizer.zero_grad()
loss = self.model.compute_loss(batch, self.loss_criterion)
loss.backward()
if self.grad_clip_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value)
optimizer.step()
losses.append(loss.item())
# Update global progress bar
if global_pbar is not None:
global_pbar.update(1)
global_pbar.set_postfix({
"Epoch": f"{epoch+1}/{self.epochs}",
"Batch": f"{batch_idx+1}/{len(train_loader)}",
"Loss": f"{loss.item():.4f}"
})
return losses