Source code for torch_molecule.predictor.sgir.modeling_sgir

import warnings
import numpy as np
from tqdm import tqdm
from typing import Optional, Union, List, Callable, Literal

import torch
from torch_geometric.loader import DataLoader

from .strategy import build_selection_dataset, build_augmentation_dataset
from ..grea.modeling_grea import GREAMolecularPredictor
from ..grea.model import GREA

from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] class SGIRMolecularPredictor(GREAMolecularPredictor): """This predictor implements SGIR for semi-supervised graph imbalanced regression. It trains the GREA model based on pseudo-labeling and data augmentation. References ---------- - Semi-Supervised Graph Imbalanced Regression. https://dl.acm.org/doi/10.1145/3580305.3599497 - Code: https://github.com/liugangcode/SGIR Parameters ---------- num_anchor : int, default=10 Number of anchor points used to split the label space during pseudo-labeling. warmup_epoch : int, default=20 Number of epochs to train before starting pseudo-labeling and data augmentation. labeling_interval : int, default=5 Interval (in epochs) between two pseudo-labeling steps. It controls the update frequency of pseudo-labeling. augmentation_interval : int, default=5 Interval (in epochs) between two data augmentation steps. It controls the update frequency of data augmentation. top_quantile : float, default=0.1 Quantile threshold for selecting high confidence predictions during pseudo-labeling. label_logscale : bool, default=False Whether to use log scale for the label space during pseudo-labeling and data augmentation. lw_aug : float, default=1 Loss weight for the augmented data. gamma : float, default=0.4 GREA-specific parameter that penalize the size of the rationales (ratio between the number of nodes in the rationales and the number of nodes in the original graph). num_task : int, default=1 Number of prediction tasks. SGIR currently only supports regression tasks with 1 task. task_type : str, default="regression" Type of prediction task, either "regression" or "classification". num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. gnn_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. graph_pooling : str, default="sum" Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"]. augmented_feature : list or None, default=None Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling. Examples like ["morgan", "maccs"] or None. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. loss_criterion : callable, optional Loss function for training. evaluate_criterion : str or callable, optional Metric for model evaluation. evaluate_higher_better : bool, optional Whether higher values of the evaluation metric are better. learning_rate : float, default=0.001 Learning rate for optimizer. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. weight_decay : float, default=0.0 L2 regularization strength. patience : int, default=50 Number of epochs to wait for improvement before early stopping. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when plateau is reached. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. verbose : bool, default=False Whether to print progress information during training. device : torch.device or str, optional Device to run the model on. If None, will auto-detect GPU or use CPU. model_name : str, default="SGIRMolecularPredictor" Name identifier for the model. """ def __init__( self, # SGIR-specific parameters num_anchor: int = 10, warmup_epoch: int = 20, labeling_interval: int = 5, augmentation_interval: int = 5, top_quantile: float = 0.1, label_logscale: bool = False, lw_aug: float = 1, gamma: float = 0.4, # Core model parameters num_task: int = 1, task_type: str = "regression", # GNN architecture parameters num_layer: int = 5, hidden_size: int = 300, gnn_type: str = "gin-virtual", drop_ratio: float = 0.5, norm_layer: str = "batch_norm", graph_pooling: str = "sum", augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None, # Training parameters batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 0.0, grad_clip_value: Optional[float] = None, patience: int = 50, # Learning rate scheduler parameters use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, # Loss and evaluation parameters loss_criterion: Optional[Callable] = None, evaluate_criterion: Optional[Union[str, Callable]] = None, evaluate_higher_better: Optional[bool] = None, # General parameters verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "SGIRMolecularPredictor", ): super().__init__( gamma=gamma, num_task=num_task, task_type=task_type, num_layer=num_layer, hidden_size=hidden_size, gnn_type=gnn_type, drop_ratio=drop_ratio, norm_layer=norm_layer, graph_pooling=graph_pooling, augmented_feature=augmented_feature, batch_size=batch_size, epochs=epochs, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_value=grad_clip_value, patience=patience, use_lr_scheduler=use_lr_scheduler, scheduler_factor=scheduler_factor, scheduler_patience=scheduler_patience, loss_criterion=loss_criterion, evaluate_criterion=evaluate_criterion, evaluate_higher_better=evaluate_higher_better, verbose=verbose, device=device, model_name=model_name, ) # SGIR-specific parameters self.num_anchor = num_anchor self.warmup_epoch = warmup_epoch self.labeling_interval = labeling_interval self.augmentation_interval = augmentation_interval self.top_quantile = top_quantile self.label_logscale = label_logscale self.lw_aug = lw_aug if self.task_type != "regression" or self.num_task != 1: raise ValueError("SGIR only supports regression tasks with 1 task") @staticmethod def _get_param_names(): grea_params = [ "num_anchor", "warmup_epoch", "labeling_interval", "augmentation_interval", "top_quantile", "label_logscale", "lw_aug" ] return grea_params + GREAMolecularPredictor._get_param_names() def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["num_anchor"] = ParameterSpec(ParameterType.INTEGER, (10, 100)) search_space["labeling_interval"] = ParameterSpec(ParameterType.INTEGER, (10, 20)) search_space["augmentation_interval"] = ParameterSpec(ParameterType.INTEGER, (10, 20)) search_space["top_quantile"] = ParameterSpec(ParameterType.LOG_FLOAT, (0.01, 0.5)) search_space["lw_aug"] = ParameterSpec(ParameterType.FLOAT, (0.1, 1)) return search_space
[docs] def fit( self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]], X_val: Optional[List[str]] = None, y_val: Optional[Union[List, np.ndarray]] = None, X_unlbl: Optional[List[str]] = None, ) -> "SGIRMolecularPredictor": """Fit the model to training data with optional validation set. """ if (X_val is None) != (y_val is None): raise ValueError("X_val and y_val must both be provided for validation") if X_unlbl is None: raise ValueError("X_unlbl (unlabeled SMILES strings) must be provided in SGIR") if len(X_unlbl) == 0: raise ValueError("X_unlbl (unlabeled SMILES strings) must not be empty") # Initialize model and optimization self._initialize_model(self.model_class) self.model.initialize_parameters() optimizer, scheduler = self._setup_optimizers() # Prepare datasets X_train, y_train = self._validate_inputs(X_train, y_train) train_dataset = self._convert_to_pytorch_data(X_train, y_train) train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0 ) X_unlbl, _ = self._validate_inputs(X_unlbl, None) unlbl_dataset = self._convert_to_pytorch_data(X_unlbl) if X_val is None: val_loader = train_loader warnings.warn( "No validation set provided. Using training set for validation.", UserWarning ) else: X_val, y_val = self._validate_inputs(X_val, y_val) val_dataset = self._convert_to_pytorch_data(X_val, y_val) val_loader = DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0 ) # Training loop augmented_dataset = None self.fitting_loss = [] self.fitting_epoch = 0 best_state_dict = None best_eval = float('-inf') if self.evaluate_higher_better else float('inf') cnt_wait = 0 # Calculate total steps for global progress bar steps_per_epoch = len(train_loader) total_steps = self.epochs * steps_per_epoch # Initialize global progress bar global_pbar = None if self.verbose: global_pbar = tqdm( total=total_steps, desc="SGIR Training Progress", unit="step", dynamic_ncols=True, leave=True ) self.model.train() try: for epoch in range(self.epochs): # Training phase train_losses = self._train_epoch(train_loader, augmented_dataset, optimizer, epoch, global_pbar) # Update datasets after warmup if epoch > self.warmup_epoch: if epoch % self.labeling_interval == 0: train_loader = build_selection_dataset( self.model, train_dataset, unlbl_dataset, self.batch_size, self.num_anchor, self.top_quantile, self.device, self.label_logscale ) if epoch % self.augmentation_interval == 0: augmented_dataset = build_augmentation_dataset( self.model, train_dataset, unlbl_dataset, self.batch_size, self.num_anchor, self.device, self.label_logscale ) self.fitting_loss.append(np.mean(train_losses)) # Validation and model selection current_eval = self._evaluation_epoch(val_loader) if scheduler: scheduler.step(current_eval) is_better = ( current_eval > best_eval if self.evaluate_higher_better else current_eval < best_eval ) if is_better: self.fitting_epoch = epoch best_eval = current_eval best_state_dict = self.model.state_dict() cnt_wait = 0 if self.verbose and global_pbar: global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Loss": f"{np.mean(train_losses):.4f}", f"{self.evaluate_name}": f"{best_eval:.4f}", "Status": "✓ Best" }) else: cnt_wait += 1 if self.verbose and global_pbar: global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Loss": f"{np.mean(train_losses):.4f}", f"{self.evaluate_name}": f"{current_eval:.4f}", "Wait": f"{cnt_wait}/{self.patience}" }) if cnt_wait > self.patience: if self.verbose: if global_pbar: global_pbar.set_postfix({ "Status": "Early Stopped", "Epoch": f"{epoch+1}/{self.epochs}" }) break finally: # Ensure progress bar is closed if global_pbar is not None: global_pbar.close() # Restore best model if best_state_dict is not None: self.model.load_state_dict(best_state_dict) else: warnings.warn( "No improvement achieved during training.", UserWarning ) self.is_fitted_ = True return self
def _train_epoch(self, train_loader, augmented_dataset, optimizer, epoch, global_pbar=None): """Training logic for one epoch. Args: train_loader: DataLoader containing training data augmented_dataset: Augmented dataset for SGIR training optimizer: Optimizer instance for model parameter updates epoch: Current epoch number global_pbar: Global progress bar for tracking overall training progress Returns: list: List of loss values for each training step """ losses = [] if augmented_dataset is not None and self.lw_aug != 0: aug_reps = augmented_dataset['representations'] aug_targets = augmented_dataset['labels'] random_inds = torch.randperm(aug_reps.size(0)) aug_reps = aug_reps[random_inds] aug_targets = aug_targets[random_inds] num_step = len(train_loader) aug_batch_size = aug_reps.size(0) // max(1, num_step) aug_inputs = list(torch.split(aug_reps, aug_batch_size)) aug_outputs = list(torch.split(aug_targets, aug_batch_size)) else: aug_inputs = None aug_outputs = None for batch_idx, batch in enumerate(train_loader): batch = batch.to(self.device) optimizer.zero_grad() # augmentation loss if aug_inputs is not None and aug_outputs is not None and aug_inputs[batch_idx].size(0) != 1: self.model._disable_batchnorm_tracking(self.model) pred_aug = self.model.predictor(aug_inputs[batch_idx]) self.model._enable_batchnorm_tracking(self.model) targets_aug = aug_outputs[batch_idx] Laug = self.loss_criterion(pred_aug.view(targets_aug.size()).to(torch.float32), targets_aug).mean() else: Laug = torch.tensor(0.) Lx = self.model.compute_loss(batch, self.loss_criterion) loss = Lx + Laug * self.lw_aug loss.backward() # Compute gradient norm if gradient clipping is enabled if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) # Update global progress bar if global_pbar is not None: global_pbar.update(1) global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Batch": f"{batch_idx+1}/{len(train_loader)}", "Total Loss": f"{loss.item():.4f}", "(Pseudo)Labeled Loss": f"{Lx.item():.4f}", "Aug Loss": f"{Laug.item():.4f}" }) return losses