Source code for torch_molecule.predictor.dir.modeling_dir

import os
import numpy as np
import warnings
import datetime
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal, Type
from dataclasses import dataclass, field

import torch
from torch_geometric.loader import DataLoader

from .model import DIR
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] @dataclass class DIRMolecularPredictor(GNNMolecularPredictor): """This predictor implements the DIR for molecular property prediction tasks. The full name of DIR is Discovering Invariant Rationales. References ---------- - Discovering Invariant Rationales for Graph Neural Networks. https://openreview.net/forum?id=hGXij5rfiHw - Code: https://github.com/Wuyxin/DIR-GNN Parameters ---------- causal_ratio : float, default=0.8 The ratio of causal edges to keep during training. A higher ratio means more edges are considered causal/important for the prediction. This controls the sparsity of the learned rationales. lw_invariant : float, default=1e-4 The weight of the invariance loss term. This loss encourages the model to learn rationales that are invariant across different environments/perturbations. A higher value puts more emphasis on learning invariant features. """ # DIR-specific parameters causal_ratio: float = 0.8 lw_invariant: float = 1e-4 # Override parent defaults model_name: str = "DIRMolecularPredictor" model_class: Type[DIR] = field(default=DIR, init=False) def __post_init__(self): super().__post_init__() @staticmethod def _get_param_names() -> List[str]: return ["causal_ratio", "lw_invariant"] + GNNMolecularPredictor._get_param_names() def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["causal_ratio"] = ParameterSpec(ParameterType.FLOAT, (0.1, 0.9)) search_space["lw_invariant"] = ParameterSpec(ParameterType.FLOAT, (1e-5, 1e-2)) return search_space def _setup_optimizers(self) -> Tuple[Dict[str, torch.optim.Optimizer], Optional[Any]]: model_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) conf_optimizer = torch.optim.Adam(self.model.conf_lin.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.grad_clip_value is not None: for group in model_optimizer.param_groups: group.setdefault("max_norm", self.grad_clip_value) group.setdefault("norm_type", 2.0) scheduler = None if self.use_lr_scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( model_optimizer, mode="min", factor=self.scheduler_factor, patience=self.scheduler_patience, min_lr=1e-6, cooldown=0, eps=1e-8, ) optimizer = {"model": model_optimizer, "conf": conf_optimizer} return optimizer, scheduler def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) if checkpoint and "hyperparameters" in checkpoint: base_params["causal_ratio"] = checkpoint["hyperparameters"]["causal_ratio"] else: base_params["causal_ratio"] = self.causal_ratio return base_params def _train_epoch(self, train_loader, optimizer, epoch): self.model.train() losses = [] iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) alpha_prime = self.lw_invariant * (epoch ** 1.6) conf_opt = optimizer["conf"] model_optimizer = optimizer["model"] for batch in iterator: batch = batch.to(self.device) # Forward pass and loss computation causal_loss, conf_loss, env_loss = self.model.compute_loss(batch, self.loss_criterion, alpha_prime) conf_opt.zero_grad() conf_loss.backward() conf_opt.step() model_optimizer.zero_grad() (causal_loss + env_loss).backward() model_optimizer.step() loss = causal_loss + env_loss + conf_loss losses.append(loss.item()) # Update progress bar if using tqdm if self.verbose: iterator.set_postfix({"Epoch": epoch, "Causal Loss": f"{causal_loss.item():.4f}", "Conf Loss": f"{conf_loss.item():.4f}", "Env Loss": f"{env_loss.item():.4f}", "Total Loss": f"{loss.item():.4f}"}) return losses
[docs] def predict(self, X: List[str]) -> Dict[str, Union[np.ndarray, List[List]]]: self._check_is_fitted() # Convert to PyTorch Geometric format and create loader X, _ = self._validate_inputs(X) dataset = self._convert_to_pytorch_data(X) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) # Make predictions self.model = self.model.to(self.device) self.model.eval() predictions = [] with torch.no_grad(): for batch in tqdm(loader, disable=not self.verbose): batch = batch.to(self.device) out = self.model(batch) predictions.append(out["prediction"].cpu().numpy()) if predictions: return { "prediction": np.concatenate(predictions, axis=0), } else: warnings.warn( "No valid predictions could be made from the input data. Returning empty results." ) return {"prediction": np.array([])}