Source code for torch_molecule.predictor.sgir.modeling_sgir

import os
import numpy as np
import warnings
import datetime
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal
from dataclasses import dataclass

import torch
from torch_geometric.loader import DataLoader

from .strategy import build_selection_dataset, build_augmentation_dataset
from ..grea.modeling_grea import GREAMolecularPredictor
from ..grea.model import GREA

from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] @dataclass class SGIRMolecularPredictor(GREAMolecularPredictor): """ This predictor implements SGIR for semi-supervised graph imbalanced regression. It trains the GREA model based on pseudo-labeling and data augmentation. References ---------- - Semi-Supervised Graph Imbalanced Regression. https://dl.acm.org/doi/10.1145/3580305.3599497 - Code: https://github.com/liugangcode/SGIR :param num_anchor: Number of anchor points used to split the label space during pseudo-labeling :type num_anchor: int, default=10 :param warmup_epoch: Number of epochs to train before starting pseudo-labeling and data augmentation :type warmup_epoch: int, default=20 :param labeling_interval: Interval (in epochs) between pseudo-labeling steps :type labeling_interval: int, default=5 :param augmentation_interval: Interval (in epochs) between data augmentation steps :type augmentation_interval: int, default=5 :param top_quantile: Quantile threshold for selecting high confidence predictions during pseudo-labeling :type top_quantile: float, default=0.1 :param label_logscale: Whether to use log scale for the label space during pseudo-labeling and data augmentation :type label_logscale: bool, default=False :param lw_aug: Weight for the data augmentation loss :type lw_aug: float, default=1 """ # SGIR-specific parameters num_anchor: int = 10 warmup_epoch: int = 20 labeling_interval: int = 5 augmentation_interval: int = 5 top_quantile: float = 0.1 label_logscale: bool = False lw_aug: float = 1 # Override parent defaults task_type: str = "regression" model_name: str = "SGIRMolecularPredictor" def __post_init__(self): super().__post_init__() if self.task_type != "regression" or self.num_task != 1: raise ValueError("SGIR only supports regression tasks with 1 task") @staticmethod def _get_param_names(): grea_params = [ "num_anchor", "warmup_epoch", "labeling_interval", "augmentation_interval", "top_quantile", "label_logscale", "lw_aug" ] return grea_params + GREAMolecularPredictor._get_param_names() def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["num_anchor"] = ParameterSpec(ParameterType.INTEGER, (10, 100)) search_space["labeling_interval"] = ParameterSpec(ParameterType.INTEGER, (10, 20)) search_space["augmentation_interval"] = ParameterSpec(ParameterType.INTEGER, (10, 20)) search_space["top_quantile"] = ParameterSpec(ParameterType.LOG_FLOAT, (0.01, 0.5)) search_space["lw_aug"] = ParameterSpec(ParameterType.FLOAT, (0.1, 1)) return search_space
[docs] def fit( self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]], X_val: Optional[List[str]] = None, y_val: Optional[Union[List, np.ndarray]] = None, X_unlbl: Optional[List[str]] = None, ) -> "SGIRMolecularPredictor": """Fit the model to training data with optional validation set. """ if (X_val is None) != (y_val is None): raise ValueError("X_val and y_val must both be provided for validation") if X_unlbl is None: raise ValueError("X_unlbl (unlabeled SMILES strings) must be provided in SGIR") if len(X_unlbl) == 0: raise ValueError("X_unlbl (unlabeled SMILES strings) must not be empty") # Initialize model and optimization self._initialize_model(self.model_class) self.model.initialize_parameters() optimizer, scheduler = self._setup_optimizers() # Prepare datasets X_train, y_train = self._validate_inputs(X_train, y_train) train_dataset = self._convert_to_pytorch_data(X_train, y_train) train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0 ) X_unlbl, _ = self._validate_inputs(X_unlbl, None) unlbl_dataset = self._convert_to_pytorch_data(X_unlbl) if X_val is None: val_loader = train_loader warnings.warn( "No validation set provided. Using training set for validation.", UserWarning ) else: X_val, y_val = self._validate_inputs(X_val, y_val) val_dataset = self._convert_to_pytorch_data(X_val, y_val) val_loader = DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0 ) # Training loop augmented_dataset = None self.fitting_loss = [] self.fitting_epoch = 0 best_state_dict = None best_eval = float('-inf') if self.evaluate_higher_better else float('inf') cnt_wait = 0 self.model.train() for epoch in range(self.epochs): # Training phase train_losses = self._train_epoch(train_loader, augmented_dataset, optimizer, epoch) # Update datasets after warmup if epoch > self.warmup_epoch: if epoch % self.labeling_interval == 0: train_loader = build_selection_dataset( self.model, train_dataset, unlbl_dataset, self.batch_size, self.num_anchor, self.top_quantile, self.device, self.label_logscale ) if epoch % self.augmentation_interval == 0: augmented_dataset = build_augmentation_dataset( self.model, train_dataset, unlbl_dataset, self.batch_size, self.num_anchor, self.device, self.label_logscale ) self.fitting_loss.append(np.mean(train_losses)) # Validation and model selection current_eval = self._evaluation_epoch(val_loader) if scheduler: scheduler.step(current_eval) is_better = ( current_eval > best_eval if self.evaluate_higher_better else current_eval < best_eval ) if is_better: self.fitting_epoch = epoch best_eval = current_eval best_state_dict = self.model.state_dict() cnt_wait = 0 else: cnt_wait += 1 if cnt_wait > self.patience: if self.verbose: print(f"Early stopping at epoch {epoch}") break if self.verbose and epoch % 10 == 0: print( f"Epoch {epoch}: Loss = {np.mean(train_losses):.4f}, " f"{self.evaluate_name} = {current_eval:.4f}, " f"Best {self.evaluate_name} = {best_eval:.4f}" ) # Restore best model if best_state_dict is not None: self.model.load_state_dict(best_state_dict) else: warnings.warn( "No improvement achieved during training.", UserWarning ) self.is_fitted_ = True return self
def _train_epoch(self, train_loader, augmented_dataset, optimizer, epoch): losses = [] if augmented_dataset is not None and self.lw_aug != 0: aug_reps = augmented_dataset['representations'] aug_targets = augmented_dataset['labels'] random_inds = torch.randperm(aug_reps.size(0)) aug_reps = aug_reps[random_inds] aug_targets = aug_targets[random_inds] num_step = len(train_loader) aug_batch_size = aug_reps.size(0) // max(1, num_step) aug_inputs = list(torch.split(aug_reps, aug_batch_size)) aug_outputs = list(torch.split(aug_targets, aug_batch_size)) else: aug_inputs = None aug_outputs = None iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) for batch_idx, batch in enumerate(iterator): batch = batch.to(self.device) optimizer.zero_grad() # augmentation loss if aug_inputs is not None and aug_outputs is not None and aug_inputs[batch_idx].size(0) != 1: self.model._disable_batchnorm_tracking(self.model) pred_aug = self.model.predictor(aug_inputs[batch_idx]) self.model._enable_batchnorm_tracking(self.model) targets_aug = aug_outputs[batch_idx] Laug = self.loss_criterion(pred_aug.view(targets_aug.size()).to(torch.float32), targets_aug).mean() else: Laug = torch.tensor(0.) Lx = self.model.compute_loss(batch, self.loss_criterion) loss = Lx + Laug * self.lw_aug loss.backward() # Compute gradient norm if gradient clipping is enabled if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) # Update progress bar if using tqdm if self.verbose: iterator.set_postfix({"Epoch": epoch, "Total Loss": f"{loss.item():.4f}", "Lbls Loss": f"{Lx.item():.4f}", "Aug Loss": f"{Laug.item():.4f}",}) return losses