Source code for torch_molecule.predictor.irm.modeling_irm

import os
import numpy as np
import warnings
import datetime
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, List, Type
from dataclasses import dataclass, field

import torch
from torch_geometric.data import Data

from .model import GNN
from ...utils import graph_from_smiles
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] @dataclass class IRMMolecularPredictor(GNNMolecularPredictor): """This predictor implements a Invariant Risk Minimization model with the GNN. The full name of IRM is Invariant Risk Minimization. References ---------- - Invariant Risk Minimization. https://arxiv.org/abs/1907.02893 - Reference Code: https://github.com/facebookresearch/InvariantRiskMinimization Parameters ---------- IRM_environment : Union[torch.Tensor, np.ndarray, List, str], default="random" Environment assignments for IRM. Can be a list of integers (one per sample), or "random" to assign environments randomly. scale : float, default=1.0 Scaling factor for the IRM penalty term. penalty_weight : float, default=1.0 Weight of the IRM penalty in the loss function. penalty_anneal_iters : int, default=100 Number of iterations for annealing the penalty weight. """ IRM_environment: Union[torch.Tensor, np.ndarray, List, str] = "random" scale: float = 1.0 penalty_weight: float = 1.0 penalty_anneal_iters: int = 100 # Other Non-init fields model_name: str = "IRMMolecularPredictor" model_class: Type[GNN] = field(default=GNN, init=False) def __post_init__(self): super().__post_init__() @staticmethod def _get_param_names() -> List[str]: return GNNMolecularPredictor._get_param_names() + [ "IRM_environment", "scale", "penalty_weight", "penalty_anneal_iters", ] def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["penalty_weight"] = ParameterSpec(ParameterType.LOG_FLOAT, (1e-10, 1)) search_space["penalty_anneal_iters"] = ParameterSpec(ParameterType.INTEGER, (10, 100)) return search_space def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) return base_params def set_IRM_environment(self, environment: Union[torch.Tensor, np.ndarray, List, str]): if isinstance(environment, str): if environment != "random": raise ValueError("IRM_environment must be 'random' if specified with a string") self.IRM_environment = environment else: if isinstance(environment, np.ndarray) or isinstance(environment, torch.Tensor): self.IRM_environment = environment.reshape(-1).tolist() else: self.IRM_environment = environment if not all(isinstance(item, int) for item in self.IRM_environment): raise ValueError("IRM_environment must be a list of integers") def _convert_to_pytorch_data(self, X, y=None): """Convert numpy arrays to PyTorch Geometric data format. """ if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: if y is not None: properties = y[idx] else: properties = None graph = graph_from_smiles(smiles_or_mol, properties, self.augmented_feature) g = Data() g.num_nodes = graph["num_nodes"] g.edge_index = torch.from_numpy(graph["edge_index"]) del graph["num_nodes"] del graph["edge_index"] if graph["edge_feat"] is not None: g.edge_attr = torch.from_numpy(graph["edge_feat"]) del graph["edge_feat"] if graph["node_feat"] is not None: g.x = torch.from_numpy(graph["node_feat"]) del graph["node_feat"] if graph["y"] is not None: g.y = torch.from_numpy(graph["y"]) del graph["y"] if graph["morgan"] is not None: g.morgan = torch.tensor(graph["morgan"], dtype=torch.int8).view(1, -1) del graph["morgan"] if graph["maccs"] is not None: g.maccs = torch.tensor(graph["maccs"], dtype=torch.int8).view(1, -1) del graph["maccs"] if self.IRM_environment == "random": g.environment = torch.randint(0, 2, (1,)).view(1, 1) elif len(X) != len(self.IRM_environment): raise ValueError("IRM_environment must has the same length as the input, which is {}".format(len(X))) else: if isinstance(self.IRM_environment[idx], int): g.environment = torch.tensor(self.IRM_environment[idx], dtype=torch.int64).view(1, 1) else: raise ValueError("IRM_environment must be a list of integers") pyg_graph_list.append(g) return pyg_graph_list def _train_epoch(self, train_loader, optimizer, epoch): self.model.train() losses = [] losses_erm = [] penalties = [] iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) for step, batch in enumerate(iterator): batch = batch.to(self.device) optimizer.zero_grad() if epoch >= self.penalty_anneal_iters: penalty_weight = self.penalty_weight else: penalty_weight = 1.0 loss, loss_erm, penalty = self.model.compute_loss(batch, self.loss_criterion, self.scale, penalty_weight) loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) losses_erm.append(loss_erm.item()) penalties.append(penalty.item()) if self.verbose: iterator.set_postfix({"Epoch": epoch, "Total Loss": f"{loss.item():.4f}", "ERM Loss": f"{loss_erm.item():.4f}", "IRM Penalty": f"{penalty.item():.4f}"}) return losses