Source code for torch_molecule.encoder.graphmae.modeling_graphmae

import numpy as np
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import Optional, Union, Dict, Any, Tuple, List, Literal, Type

import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from .model import GNN
from .dataloader import DataLoaderMaskingPred
from ..constant import GNN_ENCODER_MODELS, GNN_ENCODER_READOUTS, GNN_ENCODER_PARAMS
from ...base import BaseMolecularEncoder
from ...utils import graph_from_smiles

ALLOWABLE_ENCODER_MODELS = GNN_ENCODER_MODELS
ALLOWABLE_ENCODER_READOUTS = GNN_ENCODER_READOUTS

[docs] @dataclass class GraphMAEMolecularEncoder(BaseMolecularEncoder): """GraphMAE: Self-Supervised Masked Graph Autoencoders References ---------- - Paper: https://arxiv.org/abs/2205.10803 - Code: https://github.com/THUDM/GraphMAE/tree/main/chem Parameters ---------- mask_rate : float, default=0.15 Fraction of nodes to mask during training. mask_edge : bool, default=False Whether to mask edges in addition to nodes. predictor_type : str, default="gin" Type of predictor network to use for reconstruction. Options: ["gin", "gcn", "linear"] num_layer : int, default=5 Number of message passing layers in the GNN. hidden_size : int, default=300 Dimension of hidden node representations. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization to use. Options: ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"] encoder_type : str, default="gin-virtual" Type of GNN encoder to use. Options: ["gin-virtual", "gcn-virtual", "gin", "gcn"] readout : str, default="sum" Pooling method to use for graph-level representations. Options: ["sum", "mean", "max"] batch_size : int, default=128 Batch size for training and inference. epochs : int, default=500 Number of training epochs. learning_rate : float, default=0.001 Learning rate for optimizer. grad_clip_value : Optional[float], default=None Maximum norm of gradients for gradient clipping. No clipping if None. weight_decay : float, default=0.0 L2 regularization factor. use_lr_scheduler : bool, default=False Whether to use a learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when using scheduler. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. verbose : bool, default=False Whether to display progress bars and logs. model_name : str, default="GraphMAEMolecularEncoder" Name of the model. Examples -------- >>> from torch_molecule import GraphMAEMolecularEncoder >>> encoder = GraphMAEMolecularEncoder(hidden_size=128, epochs=100) >>> encoder.fit(["CC(=O)OC1=CC=CC=C1C(=O)O", "CCO", "C1=CC=CC=C1"]) >>> representations = encoder.encode(["CCO"]) """ # Task related parameters mask_rate: float = 0.15 mask_edge: bool = False # whether to mask edges predictor_type: str = "gin" # one of ["gin", "gcn", "linear"] # Model parameters num_layer: int = 5 hidden_size: int = 300 drop_ratio: float = 0.5 norm_layer: str = "batch_norm" # one of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"] encoder_type: str = "gin-virtual" # one of ["gin-virtual", "gcn-virtual", "gin", "gcn"] readout: str = "sum" # one of ["sum", "mean", "max"] # Training parameters batch_size: int = 128 epochs: int = 500 learning_rate: float = 0.001 grad_clip_value: Optional[float] = None weight_decay: float = 0.0 # Scheduler parameters use_lr_scheduler: bool = False scheduler_factor: float = 0.5 # if use_lr_scheduler is True scheduler_patience: int = 5 # if use_lr_scheduler is True # Other parameters verbose: bool = False model_name: str = "GraphMAEMolecularEncoder" # Non-init fields fitting_loss: List[float] = field(default_factory=list, init=False) fitting_epoch: int = field(default=0, init=False) model_class: Type[GNN] = field(default=GNN, init=False) def __post_init__(self): """Initialize the model after dataclass initialization.""" super().__post_init__() if self.encoder_type not in ALLOWABLE_ENCODER_MODELS: raise ValueError(f"Invalid encoder_model: {self.encoder_type}. Currently only {ALLOWABLE_ENCODER_MODELS} are supported.") if self.readout not in ALLOWABLE_ENCODER_READOUTS: raise ValueError(f"Invalid encoder_readout: {self.readout}. Currently only {ALLOWABLE_ENCODER_READOUTS} are supported.") @staticmethod def _get_param_names() -> List[str]: """Get parameter names for the estimator. Returns ------- List[str] List of parameter names that can be used for model configuration. """ return ["mask_rate", "mask_edge", "predictor_type"] + GNN_ENCODER_PARAMS.copy() def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: params = { "num_layer": self.num_layer, "hidden_size": self.hidden_size, "drop_ratio": self.drop_ratio, "norm_layer": self.norm_layer, "readout": self.readout, "encoder_type": self.encoder_type, "predictor_type": self.predictor_type, "mask_edge": self.mask_edge } if checkpoint is not None: if "hyperparameters" not in checkpoint: raise ValueError("Checkpoint missing 'hyperparameters' key") hyperparameters = checkpoint["hyperparameters"] params = {k: hyperparameters.get(k, v) for k, v in params.items()} return params def _convert_to_pytorch_data(self, X): """Convert numpy arrays to PyTorch Geometric data format. """ if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: graph = graph_from_smiles(smiles_or_mol, None) g = Data() # g.num_nodes = graph["num_nodes"] g.edge_index = torch.from_numpy(graph["edge_index"]) del graph["num_nodes"] del graph["edge_index"] if graph["edge_feat"] is not None: g.edge_attr = torch.from_numpy(graph["edge_feat"]) del graph["edge_feat"] if graph["node_feat"] is not None: g.x = torch.from_numpy(graph["node_feat"]) del graph["node_feat"] pyg_graph_list.append(g) return pyg_graph_list def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]: """Setup optimization components including optimizer and learning rate scheduler. """ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) scheduler = None if self.use_lr_scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=self.scheduler_factor, patience=self.scheduler_patience, min_lr=1e-6, cooldown=0, eps=1e-8, ) return optimizer, scheduler
[docs] def fit( self, X_train: List[str], ) -> "GraphMAEMolecularEncoder": """Fit the model to the training data with optional validation set. Parameters ---------- X_train : List[str] Training set input molecular structures as SMILES strings Returns ------- self : GraphMAEMolecularEncoder Fitted estimator """ self._initialize_model(self.model_class) self.model.initialize_parameters() optimizer, scheduler = self._setup_optimizers() # Prepare datasets and loaders X_train, _ = self._validate_inputs(X_train, return_rdkit_mol=True) train_dataset = self._convert_to_pytorch_data(X_train) train_loader = DataLoaderMaskingPred( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers = 0, mask_rate=self.mask_rate, mask_edge=self.mask_edge) self.fitting_loss = [] for epoch in range(self.epochs): # Training phase train_losses = self._train_epoch(train_loader, optimizer, epoch) self.fitting_loss.append(np.mean(train_losses)) if scheduler: scheduler.step(np.mean(train_losses)) self.fitting_epoch = epoch self.is_fitted_ = True return self
def _train_epoch(self, train_loader, optimizer, epoch): """Training logic for one epoch. Args: train_loader: DataLoader containing training data optimizer: Optimizer instance for model parameter updates Returns: list: List of loss values for each training step """ self.model.train() losses = [] iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) for batch in iterator: batch = batch.to(self.device) optimizer.zero_grad() loss_atom, loss_edge = self.model.compute_loss(batch) loss = loss_atom + loss_edge loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) if self.verbose: iterator.set_postfix({"Epoch": f"{epoch}", "Loss": f"{loss.item():.4f}", "Loss_atom": f"{loss_atom.item():.4f}", "Loss_edge": f"{loss_edge.item():.4f}"}) return losses
[docs] def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union[np.ndarray, torch.Tensor]: """Encode molecules into vector representations. Parameters ---------- X : List[str] List of SMILES strings return_type : Literal["np", "pt"], default="pt" Return type of the representations Returns ------- representations : ndarray or torch.Tensor Molecular representations """ self._check_is_fitted() # Convert to PyTorch Geometric format and create loader X, _ = self._validate_inputs(X, return_rdkit_mol=True) dataset = self._convert_to_pytorch_data(X) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) if self.model is None: raise RuntimeError("Model not initialized") # Generate encodings self.model = self.model.to(self.device) self.model.eval() encodings = [] with torch.no_grad(): for batch in tqdm(loader, disable=not self.verbose): batch = batch.to(self.device) out = self.model(batch) encodings.append(out["graph"].cpu()) # Concatenate and convert to requested format encodings = torch.cat(encodings, dim=0) return encodings if return_type == "pt" else encodings.numpy()