Source code for torch_molecule.predictor.grea.modeling_grea

import numpy as np
import warnings
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, List, Callable, Literal

import torch
from torch_geometric.loader import DataLoader

from .model import GREA
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] class GREAMolecularPredictor(GNNMolecularPredictor): """This predictor implements GREA model from the paper "Graph Rationalization with Environment-based Augmentations". During model training, it learns the rationales (explainable subgraphs). Use them for data augmentation. During prediction, the model uses the rationales to make predictions and the rationales themselves can also explain the predictions. References ---------- - Graph Rationalization with Environment-based Augmentations. https://dl.acm.org/doi/10.1145/3534678.3539347 - Code: https://github.com/liugangcode/GREA Parameters ---------- gamma : float, default=0.4 GREA-specific parameter that penalize the size of the rationales (ratio between the number of nodes in the rationales and the number of nodes in the original graph). num_task : int, default=1 Number of prediction tasks. task_type : str, default="regression" Type of prediction task, either "regression" or "classification". num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. gnn_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. graph_pooling : str, default="sum" Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"]. augmented_feature : list or None, default=None Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling. Examples like ["morgan", "maccs"] or None. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. learning_rate : float, default=0.001 Learning rate for optimizer. weight_decay : float, default=0.0 L2 regularization strength. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. patience : int, default=50 Number of epochs to wait for improvement before early stopping. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when plateau is reached. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. loss_criterion : callable, optional Loss function for training. evaluate_criterion : str or callable, optional Metric for model evaluation. evaluate_higher_better : bool, optional Whether higher values of the evaluation metric are better. verbose : bool, default=False Whether to print progress information during training. device : torch.device or str, optional Device to use for computation. model_name : str, default="GREAMolecularPredictor" Name of the model. """ def __init__( self, # GREA-specific parameters gamma: float = 0.4, # Core model parameters num_task: int = 1, task_type: str = "regression", # GNN architecture parameters num_layer: int = 5, hidden_size: int = 300, gnn_type: str = "gin-virtual", drop_ratio: float = 0.5, norm_layer: str = "batch_norm", graph_pooling: str = "sum", augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None, # Training parameters batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 0.0, grad_clip_value: Optional[float] = None, patience: int = 50, # Learning rate scheduler parameters use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, # Loss and evaluation parameters loss_criterion: Optional[Callable] = None, evaluate_criterion: Optional[Union[str, Callable]] = None, evaluate_higher_better: Optional[bool] = None, # General parameters verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "GREAMolecularPredictor", ): super().__init__( num_task=num_task, task_type=task_type, num_layer=num_layer, hidden_size=hidden_size, gnn_type=gnn_type, drop_ratio=drop_ratio, norm_layer=norm_layer, graph_pooling=graph_pooling, augmented_feature=augmented_feature, batch_size=batch_size, epochs=epochs, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_value=grad_clip_value, patience=patience, use_lr_scheduler=use_lr_scheduler, scheduler_factor=scheduler_factor, scheduler_patience=scheduler_patience, loss_criterion=loss_criterion, evaluate_criterion=evaluate_criterion, evaluate_higher_better=evaluate_higher_better, verbose=verbose, device=device, model_name=model_name, ) # GREA-specific parameters self.gamma = gamma self.model_class = GREA @staticmethod def _get_param_names() -> List[str]: return ["gamma"] + GNNMolecularPredictor._get_param_names() def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["gamma"] = ParameterSpec(ParameterType.FLOAT, (0.1, 0.9)) return search_space def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) if checkpoint and "hyperparameters" in checkpoint: base_params["gamma"] = checkpoint["hyperparameters"].get("gamma", self.gamma) else: base_params["gamma"] = self.gamma base_params.pop("graph_pooling", None) return base_params
[docs] def predict(self, X: List[str]) -> Dict[str, Union[np.ndarray, List[List]]]: """Make predictions using the fitted model. Parameters ---------- X : List[str] List of SMILES strings to make predictions for Returns ------- Dict[str, np.ndarray] Dictionary containing: - 'prediction': Model predictions (shape: [n_samples, n_tasks]) - 'variance': Prediction variances (shape: [n_samples, n_tasks]) - 'node_importance': A nested list where the outer list has length n_samples and each inner list has length n_nodes for that molecule """ self._check_is_fitted() # Convert to PyTorch Geometric format and create loader X, _ = self._validate_inputs(X) dataset = self._convert_to_pytorch_data(X) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) # Make predictions self.model = self.model.to(self.device) self.model.eval() predictions = [] variances = [] node_scores = [] with torch.no_grad(): for batch in tqdm(loader, disable=not self.verbose): batch = batch.to(self.device) out = self.model(batch) predictions.append(out["prediction"].cpu().numpy()) variances.append(out["variance"].cpu().numpy()) node_scores.extend(out["score"]) if predictions and variances: return { "prediction": np.concatenate(predictions, axis=0), "variance": np.concatenate(variances, axis=0), "node_importance": node_scores, } else: warnings.warn( "No valid predictions could be made from the input data. Returning empty results." ) return {"prediction": np.array([]), "variance": np.array([]), "node_importance": np.array([])}