Source code for torch_molecule.predictor.ssr.modeling_ssr

import os
import numpy as np
import warnings
import datetime
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable, Literal, Type
from dataclasses import dataclass, field

import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from .model import SSR
from ...utils import graph_from_smiles
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
    ParameterSpec,
    ParameterType,
)

[docs] @dataclass class SSRMolecularPredictor(GNNMolecularPredictor): """This predictor implements a SizeShiftReg model with the GNN. References ---------- - Paper: SizeShiftReg: a Regularization Method for Improving Size-Generalization in Graph Neural Networks. https://arxiv.org/abs/2207.07888 - Reference Code: https://github.com/DavideBuffelli/SizeShiftReg/tree/main Parameters ---------- coarse_ratios : List[float], default=[0.8, 0.9] List of ratios for graph coarsening. Each ratio determines the percentage of nodes to keep in the coarsened graph. cmd_coeff : float, default=0.1 Weight for CMD (Central Moment Discrepancy) loss. Controls the strength of the size-shift regularization. fine_grained : bool, default=True Whether to use fine-grained CMD. When True, matches distributions at a more detailed level. n_moments : int, default=5 Number of moments to match in the CMD calculation. Higher values capture more complex distribution characteristics. coarse_pool : str, default='mean' Pooling method for coarsened graphs. Determines how node features are aggregated during coarsening. model_name : str, default="SSRMolecularPredictor" Name of the model. model_class : Type[SSR], default=SSR The model class to use for prediction. """ # SSR-specific parameters coarse_ratios: List[float] = field(default_factory=lambda: [0.8, 0.9]) cmd_coeff: float = field(default=0.1) fine_grained: bool = field(default=True) n_moments: int = field(default=5) coarse_pool: str = field(default='mean') # Other Non-init fields model_name: str = "SSRMolecularPredictor" model_class: Type[SSR] = field(default=SSR, init=False) def __post_init__(self): super().__post_init__() @staticmethod def _get_param_names() -> List[str]: return GNNMolecularPredictor._get_param_names() + [ "coarse_ratios", "cmd_coeff", "fine_grained", "n_moments", "coarse_pool", ] def _get_default_search_space(self): search_space = super()._get_default_search_space().copy() search_space["cmd_coeff"] = ParameterSpec(ParameterType.FLOAT, (0.01, 1.0)) search_space["n_moments"] = ParameterSpec(ParameterType.INTEGER, (1, 10)) return search_space def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: base_params = super()._get_model_params(checkpoint) return base_params def _convert_to_pytorch_data(self, X, y=None): """Convert SMILES to PyTorch Geometric data with coarsened versions, preserving edge attributes.""" if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: if y is not None: properties = y[idx] else: properties = None # Convert SMILES to graph graph = graph_from_smiles(smiles_or_mol, properties, self.augmented_feature) g = Data() g.num_nodes = graph["num_nodes"] g.edge_index = torch.from_numpy(graph["edge_index"]) # Standard attributes if graph["edge_feat"] is not None: g.edge_attr = torch.from_numpy(graph["edge_feat"]) if graph["node_feat"] is not None: g.x = torch.from_numpy(graph["node_feat"]) if graph["y"] is not None: g.y = torch.from_numpy(graph["y"]) if graph.get("morgan") is not None: g.morgan = torch.tensor(graph["morgan"], dtype=torch.int8).view(1, -1) if graph.get("maccs") is not None: g.maccs = torch.tensor(graph["maccs"], dtype=torch.int8).view(1, -1) # Add coarsened versions for ratio in self.coarse_ratios: num_clusters = max(1, int(g.num_nodes * ratio)) coarse_ratio_postfix = str(int(ratio*100)) # Get coarsened graph with edge attributes coarse_edge_index, coarse_edge_attr, clusters = self.spectral_graph_coarsening(g, num_clusters) # Add attributes to graph setattr(g, f"coarsened_edge_index_{coarse_ratio_postfix}", coarse_edge_index) if hasattr(g, 'edge_attr'): setattr(g, f"coarsened_edge_attr_{coarse_ratio_postfix}", coarse_edge_attr) setattr(g, f"num_coarse_nodes_{coarse_ratio_postfix}", torch.tensor(num_clusters)) setattr(g, f"clusters_{coarse_ratio_postfix}", clusters) if self.verbose: print(f"Processed molecule {idx}: {g.num_nodes} nodes, coarsened versions added") pyg_graph_list.append(g) return pyg_graph_list def spectral_graph_coarsening(self, graph, num_clusters): """Coarsen graph based on spectral clustering while preserving edge attributes""" # Extract node features and edge information edge_index = graph.edge_index edge_attr = graph.edge_attr if hasattr(graph, 'edge_attr') else None num_nodes = graph.num_nodes edge_attr = edge_attr.float() # Convert to adjacency matrix adj = torch.zeros((num_nodes, num_nodes), device=edge_index.device) adj[edge_index[0], edge_index[1]] = 1 # Calculate Laplacian degree = adj.sum(dim=1) degree_mat = torch.diag(degree) laplacian = degree_mat - adj # Compute eigenvectors if num_clusters < num_nodes - 1: eigvals, eigvecs = torch.linalg.eigh(laplacian) # Use smallest non-zero eigenvalues indices = torch.argsort(eigvals)[1:num_clusters+1] # Skip first eigenvector fiedler_vectors = eigvecs[:, indices] # Use k-means for clustering from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=num_clusters, random_state=42) clusters = kmeans.fit_predict(fiedler_vectors.cpu().numpy()) clusters = torch.tensor(clusters, device=edge_index.device) else: clusters = torch.arange(num_nodes, device=edge_index.device) # Create coarsened edge index and attributes coarse_edge_index = [] coarse_edge_attr = [] # Store mapping from (cluster_i, cluster_j) to list of edge indices cluster_edges = {} # Group edges by their clusters for e_idx in range(edge_index.shape[1]): src, dst = edge_index[0, e_idx], edge_index[1, e_idx] c_src, c_dst = clusters[src], clusters[dst] # Skip self-loops in coarsened graph if c_src == c_dst: continue # Sort clusters to avoid duplicates if c_src > c_dst: c_src, c_dst = c_dst, c_src key = (c_src.item(), c_dst.item()) if key not in cluster_edges: cluster_edges[key] = [] if edge_attr is not None: cluster_edges[key].append(edge_attr[e_idx]) # Create coarsened edges for (c_src, c_dst), edge_attrs in cluster_edges.items(): coarse_edge_index.append([c_src, c_dst]) coarse_edge_index.append([c_dst, c_src]) # Add both directions if edge_attr is not None: # Aggregate edge attributes using mean mean_attr = torch.stack(edge_attrs).mean(dim=0) coarse_edge_attr.append(mean_attr) coarse_edge_attr.append(mean_attr) # Same for both directions # Convert to tensors if coarse_edge_index: coarse_edge_index = torch.tensor(coarse_edge_index, dtype=torch.long, device=edge_index.device).t() if edge_attr is not None: coarse_edge_attr = torch.stack(coarse_edge_attr) else: coarse_edge_index = torch.zeros((2, 0), dtype=torch.long, device=edge_index.device) if edge_attr is not None: coarse_edge_attr = torch.zeros((0, edge_attr.size(1)), device=edge_attr.device) return coarse_edge_index, coarse_edge_attr, clusters def _train_epoch(self, train_loader, optimizer, epoch): """Training logic for one epoch with SSR.""" self.model.train() losses = [] pred_losses = [] ssr_losses = [] iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) for batch in iterator: batch = batch.to(self.device) optimizer.zero_grad() # Forward pass and loss computation total_loss, pred_loss, ssr_loss = self.model.compute_loss(batch, self.loss_criterion, self.coarse_ratios, self.cmd_coeff, self.fine_grained, self.n_moments) # Backward pass total_loss.backward() # Compute gradient norm if gradient clipping is enabled if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(total_loss.item()) pred_losses.append(pred_loss.item()) ssr_losses.append(ssr_loss.item()) # Update progress bar if using tqdm if self.verbose: iterator.set_postfix({ "Epoch": epoch, "Total loss": f"{total_loss.item():.4f}", "Pred loss": f"{pred_loss.item():.4f}", "ssr_loss": f"{ssr_loss.item():.4f}" }) # Return all loss components for logging return losses