Source code for torch_molecule.predictor.dir.modeling_dir

import numpy as np
import warnings
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal

import torch
from torch_geometric.loader import DataLoader

from .model import DIR
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] class DIRMolecularPredictor(GNNMolecularPredictor): """This predictor implements the DIR for molecular property prediction tasks. The full name of DIR is Discovering Invariant Rationales. References ---------- - Discovering Invariant Rationales for Graph Neural Networks. https://openreview.net/forum?id=hGXij5rfiHw - Code: https://github.com/Wuyxin/DIR-GNN Parameters ---------- causal_ratio : float, default=0.8 The ratio of causal edges to keep during training. A higher ratio means more edges are considered causal/important for the prediction. This controls the sparsity of the learned rationales. lw_invariant : float, default=1e-4 The weight of the invariance loss term. This loss encourages the model to learn rationales that are invariant across different environments/perturbations. A higher value puts more emphasis on learning invariant features. num_task : int, default=1 Number of prediction tasks. task_type : str, default="regression" Type of prediction task, either "regression" or "classification". num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. gnn_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. graph_pooling : str, default="sum" Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"]. augmented_feature : list or None, default=None Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling. Examples like ["morgan", "maccs"] or None. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. learning_rate : float, default=0.001 Learning rate for optimizer. weight_decay : float, default=0.0 L2 regularization strength. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. patience : int, default=50 Number of epochs to wait for improvement before early stopping. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when plateau is reached. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. loss_criterion : callable, optional Loss function for training. evaluate_criterion : str or callable, optional Metric for model evaluation. evaluate_higher_better : bool, optional Whether higher values of the evaluation metric are better. verbose : bool, default=False Whether to print progress information during training. device : torch.device or str, optional Device to use for computation. model_name : str, default="DIRMolecularPredictor" Name of the model. """ def __init__( self, # DIR-specific parameters causal_ratio: float = 0.8, lw_invariant: float = 1e-4, # Core model parameters num_task: int = 1, task_type: str = "regression", # GNN architecture parameters num_layer: int = 5, hidden_size: int = 300, gnn_type: str = "gin-virtual", drop_ratio: float = 0.5, norm_layer: str = "batch_norm", graph_pooling: str = "sum", augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None, # Training parameters batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 0.0, grad_clip_value: Optional[float] = None, patience: int = 50, # Learning rate scheduler parameters use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, # Loss and evaluation parameters loss_criterion: Optional[Callable] = None, evaluate_criterion: Optional[Union[str, Callable]] = None, evaluate_higher_better: Optional[bool] = None, # General parameters verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "DIRMolecularPredictor", ): super().__init__( num_task=num_task, task_type=task_type, num_layer=num_layer, hidden_size=hidden_size, gnn_type=gnn_type, drop_ratio=drop_ratio, norm_layer=norm_layer, graph_pooling=graph_pooling, augmented_feature=augmented_feature, batch_size=batch_size, epochs=epochs, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_value=grad_clip_value, patience=patience, use_lr_scheduler=use_lr_scheduler, scheduler_factor=scheduler_factor, scheduler_patience=scheduler_patience, loss_criterion=loss_criterion, evaluate_criterion=evaluate_criterion, evaluate_higher_better=evaluate_higher_better, verbose=verbose, device=device, model_name=model_name, ) # DIR-specific parameters self.causal_ratio = causal_ratio self.lw_invariant = lw_invariant self.model_class = DIR @staticmethod def _get_param_names() -> List[str]: return ["causal_ratio", "lw_invariant"] + GNNMolecularPredictor._get_param_names() def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["causal_ratio"] = ParameterSpec(ParameterType.FLOAT, (0.1, 0.9)) search_space["lw_invariant"] = ParameterSpec(ParameterType.FLOAT, (1e-5, 1e-2)) return search_space def _setup_optimizers(self) -> Tuple[Dict[str, torch.optim.Optimizer], Optional[Any]]: model_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) conf_optimizer = torch.optim.Adam(self.model.conf_lin.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.grad_clip_value is not None: for group in model_optimizer.param_groups: group.setdefault("max_norm", self.grad_clip_value) group.setdefault("norm_type", 2.0) scheduler = None if self.use_lr_scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( model_optimizer, mode="min", factor=self.scheduler_factor, patience=self.scheduler_patience, min_lr=1e-6, cooldown=0, eps=1e-8, ) optimizer = {"model": model_optimizer, "conf": conf_optimizer} return optimizer, scheduler def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) if checkpoint and "hyperparameters" in checkpoint: base_params["causal_ratio"] = checkpoint["hyperparameters"]["causal_ratio"] else: base_params["causal_ratio"] = self.causal_ratio return base_params def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None): self.model.train() losses = [] alpha_prime = self.lw_invariant * (epoch ** 1.6) conf_opt = optimizer["conf"] model_optimizer = optimizer["model"] for batch_idx, batch in enumerate(train_loader): batch = batch.to(self.device) # Forward pass and loss computation causal_loss, conf_loss, env_loss = self.model.compute_loss(batch, self.loss_criterion, alpha_prime) conf_opt.zero_grad() conf_loss.backward() conf_opt.step() model_optimizer.zero_grad() (causal_loss + env_loss).backward() model_optimizer.step() loss = causal_loss + env_loss + conf_loss losses.append(loss.item()) # Update progress bar if using tqdm if global_pbar is not None: global_pbar.update(1) global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Batch": f"{batch_idx+1}/{len(train_loader)}", "Loss": f"{loss.item():.4f}" }) return losses
[docs] def predict(self, X: List[str]) -> Dict[str, Union[np.ndarray, List[List]]]: self._check_is_fitted() # Convert to PyTorch Geometric format and create loader X, _ = self._validate_inputs(X) dataset = self._convert_to_pytorch_data(X) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) # Make predictions self.model = self.model.to(self.device) self.model.eval() predictions = [] with torch.no_grad(): for batch in tqdm(loader, disable=not self.verbose): batch = batch.to(self.device) out = self.model(batch) predictions.append(out["prediction"].cpu().numpy()) if predictions: return { "prediction": np.concatenate(predictions, axis=0), } else: warnings.warn( "No valid predictions could be made from the input data. Returning empty results." ) return {"prediction": np.array([])}