Source code for torch_molecule.predictor.rpgnn.modeling_rpgnn

import os
import numpy as np
import warnings
import datetime
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal, Type
from dataclasses import dataclass, field

import torch
from torch_geometric.loader import DataLoader

from .model import RPGNN
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] class RPGNNMolecularPredictor(GNNMolecularPredictor): """This predictor implements a GNN model based on Relational pooling. The full name of RPGNN is Relational Pooling for Graph Representations. References ---------- - Relational Pooling for Graph Representations. https://arxiv.org/abs/1903.02541 - Reference Code: https://github.com/PurdueMINDS/RelationalPooling/tree/master?tab=readme-ov-file Parameters ---------- num_task : int, default=1 Number of prediction tasks. task_type : str, default="regression" Type of prediction task, either "regression" or "classification". num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. gnn_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. graph_pooling : str, default="sum" Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"]. augmented_feature : list or None, default=None Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling. Examples like ["morgan", "maccs"] or None. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. loss_criterion : callable, optional Loss function for training. evaluate_criterion : str or callable, optional Metric for model evaluation. evaluate_higher_better : bool, optional Whether higher values of the evaluation metric are better. learning_rate : float, default=0.001 Learning rate for optimizer. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. weight_decay : float, default=0.0 L2 regularization strength. patience : int, default=50 Number of epochs to wait for improvement before early stopping. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when plateau is reached. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. verbose : bool, default=False Whether to print progress information during training. device : torch.device or str, optional Device to run the model on. If None, will auto-detect GPU or use CPU. model_name : str, default="RPGNNMolecularPredictor" Name identifier for the model. num_perm : int, default=3 Number of random permutations to use for relational pooling. fixed_size : int, default=10 Maximum number of nodes to consider in the graph. num_node_feature : int, default=9 Dimension of the input node features. This should match the number of atomic features used to represent each node in the molecular graph (e.g., atomic number, degree, hybridization, etc.). """ def __init__( self, # RPGNN-specific parameters num_perm: int = 3, fixed_size: int = 10, num_node_feature: int = 9, # Core model parameters num_task: int = 1, task_type: str = "regression", # GNN architecture parameters num_layer: int = 5, hidden_size: int = 300, gnn_type: str = "gin-virtual", drop_ratio: float = 0.5, norm_layer: str = "batch_norm", graph_pooling: str = "sum", augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None, # Training parameters batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 0.0, grad_clip_value: Optional[float] = None, patience: int = 50, # Learning rate scheduler parameters use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, # Loss and evaluation parameters loss_criterion: Optional[Callable] = None, evaluate_criterion: Optional[Union[str, Callable]] = None, evaluate_higher_better: Optional[bool] = None, # General parameters verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "RPGNNMolecularPredictor", ): super().__init__( device=device, model_name=model_name, num_task=num_task, task_type=task_type, num_layer=num_layer, hidden_size=hidden_size, gnn_type=gnn_type, drop_ratio=drop_ratio, norm_layer=norm_layer, graph_pooling=graph_pooling, augmented_feature=augmented_feature, batch_size=batch_size, epochs=epochs, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_value=grad_clip_value, patience=patience, use_lr_scheduler=use_lr_scheduler, scheduler_factor=scheduler_factor, scheduler_patience=scheduler_patience, loss_criterion=loss_criterion, evaluate_criterion=evaluate_criterion, evaluate_higher_better=evaluate_higher_better, verbose=verbose ) self.num_perm = num_perm self.fixed_size = fixed_size self.num_node_feature = num_node_feature self.model_class = RPGNN @staticmethod def _get_param_names() -> List[str]: return ["num_perm", "fixed_size", "num_node_feature"] + GNNMolecularPredictor._get_param_names() def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["num_perm"] = ParameterSpec(ParameterType.INTEGER, (1, 10)) search_space["fixed_size"] = ParameterSpec(ParameterType.INTEGER, (1, 10)) return search_space def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) if checkpoint and "hyperparameters" in checkpoint: base_params["num_perm"] = checkpoint["hyperparameters"]["num_perm"] base_params["fixed_size"] = checkpoint["hyperparameters"]["fixed_size"] base_params["num_node_feature"] = checkpoint["hyperparameters"]["num_node_feature"] else: base_params["num_perm"] = self.num_perm base_params["fixed_size"] = self.fixed_size base_params["num_node_feature"] = self.num_node_feature base_params.pop("graph_pooling", None) return base_params