Source code for torch_molecule.generator.gdss.modeling_gdss

import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Type
from dataclasses import dataclass, field

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from .sde import load_sde
from .solver import get_pc_sampler
from .model import GDSSModel, get_sde_loss_fn
from .utils import compute_dataset_info, to_dense, quantize_mol

from ...base import BaseMolecularGenerator
from ...utils import graph_from_smiles, graph_to_smiles

[docs] class GDSSMolecularGenerator(BaseMolecularGenerator): """This generator implements "Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations" References ---------- - Paper: https://arxiv.org/abs/2202.02514 - Official Implementation: https://github.com/harryjo97/GDSS Parameters ---------- num_layer : int, default=3 Number of layers in the score networks. hidden_size_adj : float, default=8 Hidden dimension size for the adjacency in the adjacency score network. hidden_size : int, default=16 Hidden dimension size latent representation. attention_dim : int, default=16 Dimension of attention layers. num_head : int, default=4 Number of attention heads. sde_type_x : str, default='VE' SDE type for node features. One of 'VP', 'VE', 'subVP'. sde_beta_min_x : float, default=0.1 Minimum noise level for node features. sde_beta_max_x : float, default=1 Maximum noise level for node features. sde_num_scales_x : int, default=1000 Number of noise scales for node features. sde_type_adj : str, default='VE' SDE type for adjacency matrix. One of 'VP', 'VE', 'subVP'. sde_beta_min_adj : float, default=0.1 Minimum noise level for adjacency matrix. sde_beta_max_adj : float, default=1 Maximum noise level for adjacency matrix. sde_num_scales_adj : int, default=1000 Number of noise scales for adjacency matrix. batch_size : int, default=128 Batch size for training. epochs : int, default=500 Number of training epochs. learning_rate : float, default=0.005 Learning rate for optimizer. grad_clip_value : Optional[float], default=1 Value for gradient clipping. None means no clipping. weight_decay : float, default=1e-4 Weight decay for optimizer. use_loss_reduce_mean : bool, default=False Whether to use mean reduction for loss calculation. use_lr_scheduler : bool, default=False Whether to use learning rate scheduler. scheduler_factor : float, default=0.5 Factor by which to reduce learning rate when using scheduler (only used if use_lr_scheduler is True). scheduler_patience : int, default=5 Number of epochs with no improvement after which learning rate will be reduced (only used if use_lr_scheduler is True). sampler_predictor : str, default='Reverse' Predictor method for sampling. One of 'Euler', 'Reverse'. sampler_corrector : str, default='Langevin' Corrector method for sampling. One of 'Langevin', 'None'. sampler_snr : float, default=0.2 Signal-to-noise ratio for corrector. sampler_scale_eps : float, default=0.7 Scale factor for noise level in corrector. sampler_n_steps : int, default=1 Number of corrector steps per predictor step. sampler_probability_flow : bool, default=False Whether to use probability flow ODE for sampling. sampler_noise_removal : bool, default=True Whether to remove noise in the final step of sampling. verbose : bool, default=False Whether to display progress bars and logs. device : Optional[Union[torch.device, str]], optional Device to use for computation (cuda/cpu) model_name : str, optional Name of the model, defaults to "GDSSMolecularGenerator" """ def __init__( self, *, num_layer: int = 3, hidden_size_adj: float = 8, hidden_size: int = 16, attention_dim: int = 16, num_head: int = 4, sde_type_x: str = 'VE', sde_beta_min_x: float = 0.1, sde_beta_max_x: float = 1, sde_num_scales_x: int = 1000, sde_type_adj: str = 'VE', sde_beta_min_adj: float = 0.1, sde_beta_max_adj: float = 1, sde_num_scales_adj: int = 1000, batch_size: int = 128, epochs: int = 500, learning_rate: float = 0.005, grad_clip_value: Optional[float] = 1.0, weight_decay: float = 1e-4, use_loss_reduce_mean: bool = False, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, sampler_predictor: str = 'Reverse', sampler_corrector: str = 'Langevin', sampler_snr: float = 0.2, sampler_scale_eps: float = 0.7, sampler_n_steps: int = 1, sampler_probability_flow: bool = False, sampler_noise_removal: bool = True, verbose: bool = False, device: Optional[Union[torch.device, str]] = None, model_name: str = "GDSSMolecularGenerator" ): super().__init__( device=device, model_name=model_name, ) self.num_layer = num_layer self.hidden_size_adj = hidden_size_adj self.hidden_size = hidden_size self.attention_dim = attention_dim self.num_head = num_head self.sde_type_x = sde_type_x self.sde_beta_min_x = sde_beta_min_x self.sde_beta_max_x = sde_beta_max_x self.sde_num_scales_x = sde_num_scales_x self.sde_type_adj = sde_type_adj self.sde_beta_min_adj = sde_beta_min_adj self.sde_beta_max_adj = sde_beta_max_adj self.sde_num_scales_adj = sde_num_scales_adj self.batch_size = batch_size self.epochs = epochs self.learning_rate = learning_rate self.grad_clip_value = grad_clip_value self.weight_decay = weight_decay self.use_loss_reduce_mean = use_loss_reduce_mean self.use_lr_scheduler = use_lr_scheduler self.scheduler_factor = scheduler_factor self.scheduler_patience = scheduler_patience self.sampler_predictor = sampler_predictor self.sampler_corrector = sampler_corrector self.sampler_snr = sampler_snr self.sampler_scale_eps = sampler_scale_eps self.sampler_n_steps = sampler_n_steps self.sampler_probability_flow = sampler_probability_flow self.sampler_noise_removal = sampler_noise_removal self.verbose = verbose self.fitting_loss = list() self.fitting_epoch = 0 self.model_class = GDSSModel self.max_node = None self.input_dim_X = None self.input_dim_adj = None self.conv = 'GCN' self.sampler_eps = 1e-4 self.train_eps = 1e-5 self.dataset_info = None self.loss_fn = None self.sampling_fn = None @staticmethod def _get_param_names() -> List[str]: """Get parameter names for the estimator. Returns ------- List[str] List of parameter names that can be used for model configuration. """ return [ # Model Hyperparameters "max_node", "num_layer", "num_head", "input_dim_X", "input_dim_adj", "hidden_size_adj", "hidden_size", "attention_dim", # Diffusion parameters "dataset_info", "sde_type_x", "sde_beta_min_x", "sde_beta_max_x", "sde_num_scales_x", "sde_type_adj", "sde_beta_min_adj", "sde_beta_max_adj", "sde_num_scales_adj", # Training Parameters "batch_size", "epochs", "learning_rate", "grad_clip_value", "weight_decay", "use_loss_reduce_mean", "train_eps", "conv", # Scheduler Parameters "use_lr_scheduler", "scheduler_factor", "scheduler_patience", # Sampling Parameters "sampler_predictor", "sampler_corrector", "sampler_snr", "sampler_scale_eps", "sampler_n_steps", "sampler_probability_flow", "sampler_noise_removal", "sampler_eps", # Other Parameters "fitting_epoch", "fitting_loss", "device", "verbose", "model_name" ] def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: params = ["input_dim_X", "max_node", "hidden_size", "num_layer", "input_dim_adj", "hidden_size_adj", "attention_dim", "num_head"] if checkpoint is not None: if "hyperparameters" not in checkpoint: raise ValueError("Checkpoint missing 'hyperparameters' key") return {k: checkpoint["hyperparameters"][k] for k in params} return {k: getattr(self, k) for k in params} def _convert_to_pytorch_data(self, X, y=None): """Convert numpy arrays to PyTorch Geometric data format. """ if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: if y is not None: properties = y[idx] else: properties = None graph = graph_from_smiles(smiles_or_mol, properties) g = Data() # No H, first heavy atom has type 0 node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1) if node_type.numel() <= 1: continue # Filter out invalid node types (< 0) valid_mask = node_type >= 0 if not valid_mask.all(): # Get valid nodes and adjust edge indices valid_indices = torch.where(valid_mask)[0] index_map = -torch.ones(node_type.size(0), dtype=torch.long) index_map[valid_indices] = torch.arange(valid_indices.size(0)) # Filter edges that connect to invalid nodes edge_index = torch.from_numpy(graph["edge_index"]) valid_edges_mask = valid_mask[edge_index[0]] & valid_mask[edge_index[1]] valid_edge_index = edge_index[:, valid_edges_mask] # Remap edge indices to account for removed nodes remapped_edge_index = index_map[valid_edge_index] # Filter edge attributes edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1 valid_edge_attr = edge_attr[valid_edges_mask] # Update node and edge data node_type = node_type[valid_mask] g.edge_index = remapped_edge_index g.edge_attr = valid_edge_attr.long().squeeze(-1) else: # No invalid nodes, proceed normally g.edge_index = torch.from_numpy(graph["edge_index"]) edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1 g.edge_attr = edge_attr.long().squeeze(-1) # * is encoded as "misc" which is 119 - 1 and should be 117 node_type[node_type == 118] = 117 g.x = node_type.long().squeeze(-1) del graph["node_feat"] del graph["edge_index"] del graph["edge_feat"] g.y = torch.from_numpy(graph["y"]) del graph["y"] pyg_graph_list.append(g) return pyg_graph_list def _setup_diffusion_params(self, X: Union[List, Dict]) -> None: # Extract dataset info from X if it's a dict (from checkpoint), otherwise compute it if isinstance(X, dict): dataset_info = X["hyperparameters"]["dataset_info"] max_node = X["hyperparameters"]["max_node"] else: assert isinstance(X, list) dataset_info = compute_dataset_info(X) max_node = dataset_info["max_node"] self.input_dim_X = dataset_info["x_margins"].shape[0] self.input_dim_adj = dataset_info["e_margins"].shape[0] self.dataset_info = dataset_info self.max_node = max_node x_sde = load_sde(self.sde_type_x, self.sde_beta_min_x, self.sde_beta_max_x, self.sde_num_scales_x) adj_sde = load_sde(self.sde_type_adj, self.sde_beta_min_adj, self.sde_beta_max_adj, self.sde_num_scales_adj) self.loss_fn = get_sde_loss_fn( x_sde, adj_sde, train=True, reduce_mean=self.use_loss_reduce_mean, continuous=True, likelihood_weighting=False, eps=self.train_eps, ) self.sampling_fn = get_pc_sampler( sde_x=x_sde, sde_adj=adj_sde, predictor=self.sampler_predictor, corrector=self.sampler_corrector, snr=self.sampler_snr, scale_eps=self.sampler_scale_eps, n_steps=self.sampler_n_steps, probability_flow=self.sampler_probability_flow, continuous=True, denoise=self.sampler_noise_removal, eps=self.sampler_eps, device=self.device, ) def _initialize_model( self, model_class: Type[torch.nn.Module], checkpoint: Optional[Dict] = None ) -> torch.nn.Module: """Initialize the model with parameters or a checkpoint.""" model_params = self._get_model_params(checkpoint) self.model = model_class(**model_params) self.model = self.model.to(self.device) if checkpoint is not None: self._setup_diffusion_params(checkpoint) self.model.load_state_dict(checkpoint["model_state_dict"]) return self.model def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]: """Setup optimization components including optimizer and learning rate scheduler. """ optimizer = torch.optim.Adam( self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) scheduler = None if self.use_lr_scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=self.scheduler_factor, patience=self.scheduler_patience, min_lr=1e-6, cooldown=0, eps=1e-8, ) return optimizer, scheduler
[docs] def fit( self, X_train: List[str], ) -> "GDSSMolecularGenerator": """Fit the model to the training data. Parameters ---------- X_train : List[str] List of training data in SMILES format. Returns ------- self : GDSSMolecularGenerator The fitted model. """ X_train, _ = self._validate_inputs(X_train) self._setup_diffusion_params(X_train) self._initialize_model(self.model_class) self.model.initialize_parameters() optimizer, scheduler = self._setup_optimizers() train_dataset = self._convert_to_pytorch_data(X_train) train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0 ) # Calculate total steps for global progress bar steps_per_epoch = len(train_loader) total_steps = self.epochs * steps_per_epoch # Initialize global progress bar global_pbar = None if self.verbose: global_pbar = tqdm( total=total_steps, desc="GDSS Training Progress", unit="step", dynamic_ncols=True, leave=True ) self.fitting_loss = [] self.fitting_epoch = 0 try: for epoch in range(self.epochs): train_losses = self._train_epoch(train_loader, optimizer, epoch, global_pbar) epoch_loss = np.mean(train_losses).item() self.fitting_loss.append(epoch_loss) if scheduler: scheduler.step(epoch_loss) # Update global progress bar with epoch summary if global_pbar is not None: global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Avg Loss": f"{epoch_loss:.4f}" }) self.fitting_epoch = epoch finally: # Ensure progress bar is closed if global_pbar is not None: global_pbar.close() self.is_fitted_ = True return self
def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None): """Training logic for one epoch. Args: train_loader: DataLoader containing training data optimizer: Optimizer instance for model parameter updates epoch: Current epoch number global_pbar: Global progress bar for tracking overall training progress Returns: list: List of loss values for each training step """ self.model.train() losses = [] active_index = self.dataset_info["active_index"] for step, batched_data in enumerate(train_loader): batched_data = batched_data.to(self.device) optimizer.zero_grad() data_x = F.one_hot(batched_data.x, num_classes=118).float()[:, active_index] data_edge_attr = batched_data.edge_attr.float() X, E, node_mask = to_dense(data_x, batched_data.edge_index, data_edge_attr, batched_data.batch, self.max_node) loss_x, loss_adj = self.model.compute_loss(x=X, adj=E, flags=node_mask, loss_fn=self.loss_fn) loss = loss_x + loss_adj loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) # Update global progress bar if global_pbar is not None: global_pbar.update(1) global_pbar.set_postfix({ "Epoch": f"{epoch+1}/{self.epochs}", "Step": f"{step+1}/{len(train_loader)}", "Loss": f"{loss.item():.4f}", "Node Loss": f"{loss_x.item():.4f}", "Adj Loss": f"{loss_adj.item():.4f}" }) return losses
[docs] @torch.no_grad() def generate(self, num_nodes: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, batch_size: int = 32) -> List[str]: """Randomly generate molecules with specified node counts. Parameters ---------- num_nodes : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None Number of nodes for each molecule in the batch. If None, samples from the training distribution. Can be provided as: - A list of lists - A numpy array of shape (batch_size, 1) - A torch tensor of shape (batch_size, 1) batch_size : int, default=32 Number of molecules to generate. Returns ------- List[str] List of generated molecules in SMILES format. """ if not self.is_fitted_: raise ValueError("Model must be fitted before generating molecules.") if num_nodes is not None: batch_size = len(num_nodes) if num_nodes is None: num_nodes_dist = self.dataset_info["num_nodes_dist"] num_nodes = num_nodes_dist.sample_n(batch_size, self.device) elif isinstance(num_nodes, list): num_nodes = torch.tensor(num_nodes).to(self.device) elif isinstance(num_nodes, np.ndarray): num_nodes = torch.from_numpy(num_nodes).to(self.device) if num_nodes.dim() == 1: num_nodes = num_nodes.unsqueeze(-1) assert num_nodes.size(0) == batch_size arange = ( torch.arange(self.max_node).to(self.device) .unsqueeze(0) .expand(batch_size, -1) ) node_mask = arange < num_nodes if not hasattr(self, 'dataset_info') or self.dataset_info is None: raise ValueError("Dataset info not found. Please call setup_diffusion_params first.") shape_x = ( batch_size, self.max_node, self.input_dim_X, ) shape_adj = (batch_size, self.max_node, self.max_node) X, E, _ = self.sampling_fn(self.model.score_network_x, self.model.score_network_a, shape_x, shape_adj, node_mask) E = quantize_mol(E) X = X.argmax(dim=-1) molecule_list = [] for i in range(batch_size): n = num_nodes[i][0].item() atom_types = X[i, :n].cpu() edge_types = E[i, :n, :n].cpu() molecule_list.append([atom_types, edge_types]) smiles_list = graph_to_smiles(molecule_list, self.dataset_info["atom_decoder"]) return smiles_list