Source code for torch_molecule.predictor.ssr.modeling_ssr

from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal, Type

import torch
from torch_geometric.data import Data

from .model import SSR
from ...utils import graph_from_smiles
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] class SSRMolecularPredictor(GNNMolecularPredictor): """This predictor implements a SizeShiftReg model with the GNN. References ---------- - Paper: SizeShiftReg: a Regularization Method for Improving Size-Generalization in Graph Neural Networks. https://arxiv.org/abs/2207.07888 - Reference Code: https://github.com/DavideBuffelli/SizeShiftReg/tree/main Parameters ---------- coarse_ratios : List[float], default=[0.8, 0.9] List of ratios for graph coarsening. Each ratio determines the percentage of nodes to keep in the coarsened graph. cmd_coeff : float, default=0.1 Weight for CMD (Central Moment Discrepancy) loss. Controls the strength of the size-shift regularization. fine_grained : bool, default=True Whether to use fine-grained CMD. When True, matches distributions at a more detailed level. n_moments : int, default=5 Number of moments to match in the CMD calculation. Higher values capture more complex distribution characteristics. coarse_pool : str, default='mean' Pooling method for coarsened graphs. num_task : int, default=1 Number of prediction tasks. task_type : str, default="regression" Type of prediction task, either "regression" or "classification". num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. gnn_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. graph_pooling : str, default="sum" Method for aggregating node features to graph-level representations. One of ["sum", "mean", "max"]. augmented_feature : list or None, default=None Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling. Examples like ["morgan", "maccs"] or None. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. loss_criterion : callable, optional Loss function for training. evaluate_criterion : str or callable, optional Metric for model evaluation. evaluate_higher_better : bool, optional Whether higher values of the evaluation metric are better. learning_rate : float, default=0.001 Learning rate for optimizer. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. weight_decay : float, default=0.0 L2 regularization strength. patience : int, default=50 Number of epochs to wait for improvement before early stopping. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when plateau is reached. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. verbose : bool, default=False Whether to print progress information during training. device : torch.device or str, optional Device to use for computations. model_name : str, default="SSRMolecularPredictor" Name of the model. """ def __init__( self, # SSR-specific parameters coarse_ratios: List[float] = [0.8, 0.9], cmd_coeff: float = 0.1, fine_grained: bool = True, n_moments: int = 5, coarse_pool: str = 'mean', # Core model parameters num_task: int = 1, task_type: str = "regression", # GNN architecture parameters num_layer: int = 5, hidden_size: int = 300, gnn_type: str = "gin-virtual", drop_ratio: float = 0.5, norm_layer: str = "batch_norm", graph_pooling: str = "sum", augmented_feature: Optional[list[Literal["morgan", "maccs"]]] = None, # Training parameters batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 0.0, grad_clip_value: Optional[float] = None, patience: int = 50, # Learning rate scheduler parameters use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, # Loss and evaluation parameters loss_criterion: Optional[Callable] = None, evaluate_criterion: Optional[Union[str, Callable]] = None, evaluate_higher_better: Optional[bool] = None, # General parameters verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "SSRMolecularPredictor", ): super().__init__( num_task=num_task, task_type=task_type, num_layer=num_layer, hidden_size=hidden_size, gnn_type=gnn_type, drop_ratio=drop_ratio, norm_layer=norm_layer, graph_pooling=graph_pooling, augmented_feature=augmented_feature, batch_size=batch_size, epochs=epochs, learning_rate=learning_rate, weight_decay=weight_decay, grad_clip_value=grad_clip_value, patience=patience, use_lr_scheduler=use_lr_scheduler, scheduler_factor=scheduler_factor, scheduler_patience=scheduler_patience, loss_criterion=loss_criterion, evaluate_criterion=evaluate_criterion, evaluate_higher_better=evaluate_higher_better, verbose=verbose, device=device, model_name=model_name, ) self.coarse_ratios = coarse_ratios self.cmd_coeff = cmd_coeff self.fine_grained = fine_grained self.n_moments = n_moments self.coarse_pool = coarse_pool self.model_class = SSR @staticmethod def _get_param_names() -> List[str]: return GNNMolecularPredictor._get_param_names() + [ "coarse_ratios", "cmd_coeff", "fine_grained", "n_moments", "coarse_pool", ] def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["cmd_coeff"] = ParameterSpec(ParameterType.FLOAT, (0.01, 1.0)) search_space["n_moments"] = ParameterSpec(ParameterType.INTEGER, (1, 10)) return search_space def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) return base_params def _convert_to_pytorch_data(self, X, y=None): """Convert SMILES to PyTorch Geometric data with coarsened versions, preserving edge attributes.""" if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: if y is not None: properties = y[idx] else: properties = None # Convert SMILES to graph graph = graph_from_smiles(smiles_or_mol, properties, self.augmented_feature) g = Data() g.num_nodes = graph["num_nodes"] g.edge_index = torch.from_numpy(graph["edge_index"]) # Standard attributes if graph["edge_feat"] is not None: g.edge_attr = torch.from_numpy(graph["edge_feat"]) if graph["node_feat"] is not None: g.x = torch.from_numpy(graph["node_feat"]) if graph["y"] is not None: g.y = torch.from_numpy(graph["y"]) if graph.get("morgan") is not None: g.morgan = torch.tensor(graph["morgan"], dtype=torch.int8).view(1, -1) if graph.get("maccs") is not None: g.maccs = torch.tensor(graph["maccs"], dtype=torch.int8).view(1, -1) # Add coarsened versions for ratio in self.coarse_ratios: num_clusters = max(1, int(g.num_nodes * ratio)) coarse_ratio_postfix = str(int(ratio*100)) # Get coarsened graph with edge attributes coarse_edge_index, coarse_edge_attr, clusters = self.spectral_graph_coarsening(g, num_clusters) # Add attributes to graph setattr(g, f"coarsened_edge_index_{coarse_ratio_postfix}", coarse_edge_index) if hasattr(g, 'edge_attr'): setattr(g, f"coarsened_edge_attr_{coarse_ratio_postfix}", coarse_edge_attr) setattr(g, f"num_coarse_nodes_{coarse_ratio_postfix}", torch.tensor(num_clusters)) setattr(g, f"clusters_{coarse_ratio_postfix}", clusters) pyg_graph_list.append(g) return pyg_graph_list def spectral_graph_coarsening(self, graph, num_clusters): """Coarsen graph based on spectral clustering while preserving edge attributes""" # Extract node features and edge information edge_index = graph.edge_index edge_attr = graph.edge_attr if hasattr(graph, 'edge_attr') else None num_nodes = graph.num_nodes edge_attr = edge_attr.float() # Convert to adjacency matrix adj = torch.zeros((num_nodes, num_nodes), device=edge_index.device) adj[edge_index[0], edge_index[1]] = 1 # Calculate Laplacian degree = adj.sum(dim=1) degree_mat = torch.diag(degree) laplacian = degree_mat - adj # Compute eigenvectors if num_clusters < num_nodes - 1: eigvals, eigvecs = torch.linalg.eigh(laplacian) # Use smallest non-zero eigenvalues indices = torch.argsort(eigvals)[1:num_clusters+1] # Skip first eigenvector fiedler_vectors = eigvecs[:, indices] # Use k-means for clustering from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=num_clusters, random_state=42) clusters = kmeans.fit_predict(fiedler_vectors.cpu().numpy()) clusters = torch.tensor(clusters, device=edge_index.device) else: clusters = torch.arange(num_nodes, device=edge_index.device) # Create coarsened edge index and attributes coarse_edge_index = [] coarse_edge_attr = [] # Store mapping from (cluster_i, cluster_j) to list of edge indices cluster_edges = {} # Group edges by their clusters for e_idx in range(edge_index.shape[1]): src, dst = edge_index[0, e_idx], edge_index[1, e_idx] c_src, c_dst = clusters[src], clusters[dst] # Skip self-loops in coarsened graph if c_src == c_dst: continue # Sort clusters to avoid duplicates if c_src > c_dst: c_src, c_dst = c_dst, c_src key = (c_src.item(), c_dst.item()) if key not in cluster_edges: cluster_edges[key] = [] if edge_attr is not None: cluster_edges[key].append(edge_attr[e_idx]) # Create coarsened edges for (c_src, c_dst), edge_attrs in cluster_edges.items(): coarse_edge_index.append([c_src, c_dst]) coarse_edge_index.append([c_dst, c_src]) # Add both directions if edge_attr is not None: # Aggregate edge attributes using mean mean_attr = torch.stack(edge_attrs).mean(dim=0) coarse_edge_attr.append(mean_attr) coarse_edge_attr.append(mean_attr) # Same for both directions # Convert to tensors if coarse_edge_index: coarse_edge_index = torch.tensor(coarse_edge_index, dtype=torch.long, device=edge_index.device).t() if edge_attr is not None: coarse_edge_attr = torch.stack(coarse_edge_attr) else: coarse_edge_index = torch.zeros((2, 0), dtype=torch.long, device=edge_index.device) if edge_attr is not None: coarse_edge_attr = torch.zeros((0, edge_attr.size(1)), device=edge_attr.device) return coarse_edge_index, coarse_edge_attr, clusters def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None): self.model.train() losses = [] pred_losses = [] ssr_losses = [] for batch_idx, batch in enumerate(train_loader): batch = batch.to(self.device) optimizer.zero_grad() total_loss, pred_loss, ssr_loss = self.model.compute_loss(batch, self.loss_criterion, self.coarse_ratios, self.cmd_coeff, self.fine_grained, self.n_moments) total_loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(total_loss.item()) pred_losses.append(pred_loss.item()) ssr_losses.append(ssr_loss.item()) if global_pbar is not None: global_pbar.update(1) global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Batch": f"{batch_idx+1}/{len(train_loader)}", "Total loss": f"{total_loss.item():.4f}", "Pred loss": f"{pred_loss.item():.4f}", "SSR loss": f"{ssr_loss.item():.4f}" }) return losses