Source code for torch_molecule.encoder.supervised.modeling_supervised

import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Literal, Type
from dataclasses import dataclass, field

import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from .model import GNN
from ..constant import GNN_ENCODER_MODELS, GNN_ENCODER_READOUTS, GNN_ENCODER_PARAMS
from ...base import BaseMolecularEncoder
from ...utils import graph_from_smiles
from ...utils import PSEUDOTASK

ALLOWABLE_ENCODER_MODELS = GNN_ENCODER_MODELS
ALLOWABLE_ENCODER_READOUTS = GNN_ENCODER_READOUTS

[docs] @dataclass class SupervisedMolecularEncoder(BaseMolecularEncoder): """This encoder implements a GNN model for supervised molecular representation learning with user-defined or predefined fingerprint/calculated property tasks. Parameters ---------- num_task : int, optional Number of user-defined tasks for supervised pretraining. If it is specified, user must provide y_train in the fit function. predefined_task : List[str], optional List of predefined tasks to use. Must be from the supported task list ["morgan", "maccs", "logP"]. If None and num_task is None, all predefined tasks will be used. encoder_type : str, default="gin-virtual" Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"]. readout : str, default="sum" Method for aggregating node features to obtain graph-level representations. One of ["sum", "mean", "max"]. num_layer : int, default=5 Number of GNN layers. hidden_size : int, default=300 Dimension of hidden node features. drop_ratio : float, default=0.5 Dropout probability. norm_layer : str, default="batch_norm" Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"]. batch_size : int, default=128 Number of samples per batch for training. epochs : int, default=500 Maximum number of training epochs. learning_rate : float, default=0.001 Learning rate for optimizer. grad_clip_value : float, optional Maximum norm of gradients for gradient clipping. weight_decay : float, default=0.0 L2 regularization strength. use_lr_scheduler : bool, default=False Whether to use a learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce the learning rate when plateau is detected. scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced. verbose : bool, default=False Whether to print progress information during training. """ # pretraining task num_task: Optional[int] = None predefined_task: Optional[List[str]] = None # Model parameters encoder_type: str = "gin-virtual" readout: str = "sum" num_layer: int = 5 hidden_size: int = 300 drop_ratio: float = 0.5 norm_layer: str = "batch_norm" # Training parameters batch_size: int = 128 epochs: int = 500 learning_rate: float = 0.001 grad_clip_value: Optional[float] = None weight_decay: float = 0.0 # Scheduler parameters use_lr_scheduler: bool = False scheduler_factor: float = 0.5 scheduler_patience: int = 5 # Other parameters verbose: bool = False model_name: str = "SupervisedMolecularEncoder" # Non-init fields fitting_loss: List[float] = field(default_factory=list, init=False) fitting_epoch: int = field(default=0, init=False) model_class: Type[GNN] = field(default=GNN, init=False) def __post_init__(self): """Initialize the model after dataclass initialization.""" super().__post_init__() self.num_pretask = None if self.encoder_type not in ALLOWABLE_ENCODER_MODELS: raise ValueError(f"Invalid encoder: {self.encoder_type}. Currently only {ALLOWABLE_ENCODER_MODELS} are supported.") if self.readout not in ALLOWABLE_ENCODER_READOUTS: raise ValueError(f"Invalid readout: {self.readout}. Currently only {ALLOWABLE_ENCODER_READOUTS} are supported.") if self.predefined_task is not None: for task in self.predefined_task: if task not in PSEUDOTASK.keys(): raise ValueError(f"Invalid predefined_task: {task}. Currently only {PSEUDOTASK.keys()} are supported.") # Calculate number of predefined tasks if any are specified num_pretask = 0 if self.predefined_task is not None: num_pretask = sum(PSEUDOTASK[task][0] for task in self.predefined_task) elif self.predefined_task is None and self.num_task is None: # Use all predefined tasks if none specified self.predefined_task = list(PSEUDOTASK.keys()) num_pretask = sum(task[0] for task in PSEUDOTASK.values()) self.num_pretask = num_pretask self.num_task = (self.num_task or 0) + num_pretask if self.verbose: if self.predefined_task is None: print(f"Using {self.num_task} user-defined tasks.") elif self.num_task == num_pretask: print(f"Using {num_pretask} predefined tasks from: {self.predefined_task}") else: print(f"Using {num_pretask} predefined tasks and {self.num_task - num_pretask} user-defined tasks.") @staticmethod def _get_param_names() -> List[str]: """Get parameter names for the estimator. Returns ------- List[str] List of parameter names that can be used for model configuration. """ return ["num_task", "predefined_task", "num_pretask"] + GNN_ENCODER_PARAMS.copy() def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: params = { "num_layer": self.num_layer, "hidden_size": self.hidden_size, "num_task": self.num_task, "encoder_type": self.encoder_type, "drop_ratio": self.drop_ratio, "norm_layer": self.norm_layer, "readout": self.readout, } if checkpoint is not None: if "hyperparameters" not in checkpoint: raise ValueError("Checkpoint missing 'hyperparameters' key") hyperparameters = checkpoint["hyperparameters"] params = {k: hyperparameters.get(k, v) for k, v in params.items()} return params def _convert_to_pytorch_data(self, X, y=None): """Convert numpy arrays to PyTorch Geometric data format. """ if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: if y is not None: properties = y[idx] else: properties = None graph = graph_from_smiles(smiles_or_mol, properties, augmented_properties = self.predefined_task) g = Data() g.num_nodes = graph["num_nodes"] g.edge_index = torch.from_numpy(graph["edge_index"]) del graph["num_nodes"] del graph["edge_index"] if graph["edge_feat"] is not None: g.edge_attr = torch.from_numpy(graph["edge_feat"]) del graph["edge_feat"] if graph["node_feat"] is not None: g.x = torch.from_numpy(graph["node_feat"]) del graph["node_feat"] if graph["y"] is not None: g.y = torch.from_numpy(graph["y"]) del graph["y"] if graph["morgan"] is not None: g.morgan = torch.tensor(graph["morgan"], dtype=torch.int8).view(1, -1) del graph["morgan"] if graph["maccs"] is not None: g.maccs = torch.tensor(graph["maccs"], dtype=torch.int8).view(1, -1) del graph["maccs"] pyg_graph_list.append(g) return pyg_graph_list def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]: """Setup optimization components including optimizer and learning rate scheduler. """ optimizer = torch.optim.Adam( self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) scheduler = None if self.use_lr_scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=self.scheduler_factor, patience=self.scheduler_patience, min_lr=1e-6, cooldown=0, eps=1e-8, ) return optimizer, scheduler
[docs] def fit( self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]] = None, ) -> "SupervisedMolecularEncoder": """Fit the model to the training data with optional validation set. Parameters ---------- X_train : List[str] Training set input molecular structures as SMILES strings y_train : Union[List, np.ndarray] Training set target values for representation learning Returns ------- self : SupervisedMolecularEncoder Fitted estimator """ user_defined_task = self.num_task - self.num_pretask if user_defined_task > 0: if y_train is None: raise ValueError("User-defined tasks require target values but y_train is None.") if y_train.shape[1] != user_defined_task: raise ValueError(f"Number of user-defined tasks ({user_defined_task}) must match the number of target values in y_train ({y_train.shape[1]}).") self._initialize_model(self.model_class) self.model.initialize_parameters() optimizer, scheduler = self._setup_optimizers() # Prepare datasets and loaders X_train, y_train = self._validate_inputs(X_train, y_train, return_rdkit_mol=True, num_task=self.num_task, num_pretask=self.num_pretask) train_dataset = self._convert_to_pytorch_data(X_train, y_train) train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0 ) self.fitting_loss = [] if user_defined_task > 0: is_class_user = self._inspect_task_types(y_train, return_type="pt") else: is_class_user = torch.tensor([], dtype=torch.bool) if self.predefined_task is not None: is_class_predefined = torch.cat([torch.full((PSEUDOTASK[task][0],), PSEUDOTASK[task][1] == "classification", dtype=torch.bool) for task in self.predefined_task]) else: is_class_predefined = torch.tensor([], dtype=torch.bool) is_class = torch.cat([is_class_user, is_class_predefined]) for epoch in range(self.epochs): # Training phase train_losses = self._train_epoch(train_loader, optimizer, is_class, epoch) self.fitting_loss.append(np.mean(train_losses).item()) if scheduler: scheduler.step(np.mean(train_losses).item()) self.fitting_epoch = epoch self.is_fitted_ = True return self
def _train_epoch(self, train_loader, optimizer, is_class, epoch): """Training logic for one epoch. Args: train_loader: DataLoader containing training data optimizer: Optimizer instance for model parameter updates Returns: list: List of loss values for each training step """ self.model.train() losses = [] iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) for batch in iterator: batch = batch.to(self.device) optimizer.zero_grad() loss = self.model.compute_loss(batch, is_class) loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) # Update progress bar if using tqdm if self.verbose: iterator.set_postfix({"Epoch": f"{epoch}", "Loss": f"{loss.item():.4f}"}) return losses
[docs] def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union[np.ndarray, torch.Tensor]: """Encode molecules into vector representations. Parameters ---------- X : List[str] List of SMILES strings return_type : Literal["np", "pt"], default="pt" Return type of the representations Returns ------- representations : ndarray or torch.Tensor Molecular representations """ self._check_is_fitted() # Convert to PyTorch Geometric format and create loader X, _ = self._validate_inputs(X, return_rdkit_mol=True) dataset = self._convert_to_pytorch_data(X) loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) if self.model is None: raise RuntimeError("Model not initialized") # Generate encodings self.model = self.model.to(self.device) self.model.eval() encodings = [] with torch.no_grad(): for batch in tqdm(loader, disable=not self.verbose): batch = batch.to(self.device) out = self.model(batch) encodings.append(out["graph"].cpu()) # Concatenate and convert to requested format encodings = torch.cat(encodings, dim=0) return encodings if return_type == "pt" else encodings.numpy()