Source code for torch_molecule.predictor.irm.modeling_irm

import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, List, Type, Callable, Literal

import torch
from torch_geometric.data import Data

from .model import GNN
from ...utils import graph_from_smiles
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)
[docs] class IRMMolecularPredictor(GNNMolecularPredictor): """This predictor implements a Invariant Risk Minimization model with the GNN. The full name of IRM is Invariant Risk Minimization. References ---------- - Invariant Risk Minimization. https://arxiv.org/abs/1907.02893 - Reference Code: https://github.com/facebookresearch/InvariantRiskMinimization Parameters ---------- IRM_environment : Union[torch.Tensor, np.ndarray, List, str], default="random" Environment assignments for IRM. Can be a list of integers (one per sample), or "random" to assign environments randomly. scale : float, default=1.0 Scaling factor for the IRM penalty term. penalty_weight : float, default=1.0 Weight of the IRM penalty in the loss function. penalty_anneal_iters : int, default=100 Number of iterations for annealing the penalty weight. num_task : int, default=1 Number of prediction tasks. task_type : str, default="regression" Type of prediction task, either "regression" or "classification". num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. gnn_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. graph_pooling : str, default="sum" Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"]. augmented_feature : list or None, default=None Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling. Examples like ["morgan", "maccs"] or None. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. loss_criterion : callable, optional Loss function for training. evaluate_criterion : str or callable, optional Metric for model evaluation. evaluate_higher_better : bool, optional Whether higher values of the evaluation metric are better. learning_rate : float, default=0.001 Learning rate for optimizer. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. weight_decay : float, default=0.0 L2 regularization strength. patience : int, default=50 Number of epochs to wait for improvement before early stopping. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when plateau is reached. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. verbose : bool, default=False Whether to print progress information during training. device : torch.device or str, optional Device to run the model on. """ def __init__( self, # IRM specific parameters IRM_environment: Union[torch.Tensor, np.ndarray, List, str] = "random", scale: float = 1.0, penalty_weight: float = 1.0, penalty_anneal_iters: int = 100, # Core model parameters num_task: int = 1, task_type: str = "regression", # GNN architecture parameters num_layer: int = 5, hidden_size: int = 300, gnn_type: str = "gin-virtual", drop_ratio: float = 0.5, norm_layer: str = "batch_norm", graph_pooling: str = "sum", augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None, # Training parameters batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 0.0, grad_clip_value: Optional[float] = None, patience: int = 50, # Learning rate scheduler parameters use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, # Loss and evaluation parameters loss_criterion: Optional[Callable] = None, evaluate_criterion: Optional[Union[str, Callable]] = None, evaluate_higher_better: Optional[bool] = None, # General parameters verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "IRMMolecularPredictor", ): super().__init__( num_task=num_task, task_type=task_type, num_layer=num_layer, hidden_size=hidden_size, gnn_type=gnn_type, drop_ratio=drop_ratio, norm_layer=norm_layer, graph_pooling=graph_pooling, augmented_feature=augmented_feature, batch_size=batch_size, epochs=epochs, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_value=grad_clip_value, patience=patience, use_lr_scheduler=use_lr_scheduler, scheduler_factor=scheduler_factor, scheduler_patience=scheduler_patience, loss_criterion=loss_criterion, evaluate_criterion=evaluate_criterion, evaluate_higher_better=evaluate_higher_better, verbose=verbose, device=device, model_name=model_name ) self.IRM_environment = IRM_environment self.scale = scale self.penalty_weight = penalty_weight self.penalty_anneal_iters = penalty_anneal_iters self.model_class = GNN @staticmethod def _get_param_names() -> List[str]: return GNNMolecularPredictor._get_param_names() + [ "IRM_environment", "scale", "penalty_weight", "penalty_anneal_iters", ] def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["penalty_weight"] = ParameterSpec(ParameterType.LOG_FLOAT, (1e-10, 1)) search_space["penalty_anneal_iters"] = ParameterSpec(ParameterType.INTEGER, (10, 100)) return search_space def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) return base_params def set_IRM_environment(self, environment: Union[torch.Tensor, np.ndarray, List, str]): if isinstance(environment, str): if environment != "random": raise ValueError("IRM_environment must be 'random' if specified with a string") self.IRM_environment = environment else: if isinstance(environment, np.ndarray) or isinstance(environment, torch.Tensor): self.IRM_environment = environment.reshape(-1).tolist() else: self.IRM_environment = environment if not all(isinstance(item, int) for item in self.IRM_environment): raise ValueError("IRM_environment must be a list of integers") def _convert_to_pytorch_data(self, X, y=None): """Convert numpy arrays to PyTorch Geometric data format. """ if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: if y is not None: properties = y[idx] else: properties = None graph = graph_from_smiles(smiles_or_mol, properties, self.augmented_feature) g = Data() g.num_nodes = graph["num_nodes"] g.edge_index = torch.from_numpy(graph["edge_index"]) del graph["num_nodes"] del graph["edge_index"] if graph["edge_feat"] is not None: g.edge_attr = torch.from_numpy(graph["edge_feat"]) del graph["edge_feat"] if graph["node_feat"] is not None: g.x = torch.from_numpy(graph["node_feat"]) del graph["node_feat"] if graph["y"] is not None: g.y = torch.from_numpy(graph["y"]) del graph["y"] if graph["morgan"] is not None: g.morgan = torch.tensor(graph["morgan"], dtype=torch.int8).view(1, -1) del graph["morgan"] if graph["maccs"] is not None: g.maccs = torch.tensor(graph["maccs"], dtype=torch.int8).view(1, -1) del graph["maccs"] if self.IRM_environment == "random": g.environment = torch.randint(0, 2, (1,)).view(1, 1) elif len(X) != len(self.IRM_environment): raise ValueError("IRM_environment must has the same length as the input, which is {}".format(len(X))) else: if isinstance(self.IRM_environment[idx], int): g.environment = torch.tensor(self.IRM_environment[idx], dtype=torch.int64).view(1, 1) else: raise ValueError("IRM_environment must be a list of integers") pyg_graph_list.append(g) return pyg_graph_list def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None): self.model.train() losses = [] losses_erm = [] penalties = [] for batch_idx, batch in enumerate(train_loader): batch = batch.to(self.device) optimizer.zero_grad() if epoch >= self.penalty_anneal_iters: penalty_weight = self.penalty_weight else: penalty_weight = 1.0 loss, loss_erm, penalty = self.model.compute_loss(batch, self.loss_criterion, self.scale, penalty_weight) loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) losses_erm.append(loss_erm.item()) penalties.append(penalty.item()) if global_pbar is not None: global_pbar.update(1) global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Batch": f"{batch_idx+1}/{len(train_loader)}", "Loss": f"{loss.item():.4f}", "ERM Loss": f"{loss_erm.item():.4f}", "IRM Penalty": f"{penalty.item():.4f}" }) return losses