Source code for torch_molecule.generator.molgpt.modeling_molgpt

import re
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Type
from dataclasses import dataclass, field

import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.cuda.amp import GradScaler

from .gpt import GPT
from .dataset import SmilesDataset

from ...base import BaseMolecularGenerator

[docs] @dataclass class MolGPTMolecularGenerator(BaseMolecularGenerator): """ This generator implements the molecular GPT model for generating molecules. The model uses a GPT-like architecture to learn the distribution of SMILES strings and generate new molecules. It supports conditional generation based on properties and/or molecular scaffolds. References ---------- - MolGPT: Molecular Generation Using a Transformer-Decoder Model. Journal of Chemical Information and Modeling. https://pubs.acs.org/doi/10.1021/acs.jcim.1c00600 - Code: https://github.com/devalab/molgpt Parameters ---------- num_layer : int, default=8 Number of transformer layers in the model. num_head : int, default=8 Number of attention heads in each transformer layer. hidden_size : int, default=256 Dimension of the hidden representations. max_len : int, default=128 Maximum length of SMILES strings. num_task : int, default=0 Number of property prediction tasks for conditional generation. O for unconditional generation. use_scaffold : bool, default=False Whether to use scaffold conditioning. use_lstm : bool, default=False Whether to use LSTM for encoding scaffold. lstm_layers : int, default=0 Number of LSTM layers if use_lstm is True. batch_size : int, default=64 Batch size for training. epochs : int, default=1000 Number of training epochs. learning_rate : float, default=3e-4 Learning rate for optimizer. adamw_betas : Tuple[float, float], default=(0.9, 0.95) Beta parameters for AdamW optimizer. weight_decay : float, default=0.1 Weight decay for optimizer. grad_norm_clip : float, default=1.0 Gradient norm clipping value. verbose : bool, default=False Whether to display progress bars during training. """ # Model parameters num_layer: int = 8 num_head: int = 8 hidden_size: int = 256 max_len: int = 128 # Conditioning parameters num_task: int = 0 use_scaffold: bool = False use_lstm: bool = False lstm_layers: int = 0 # Training parameters batch_size: int = 64 epochs: int = 1000 learning_rate: float = 3e-4 adamw_betas: Tuple[float, float] = (0.9, 0.95) weight_decay: float = 0.1 grad_norm_clip: float = 1.0 verbose: bool = False # Attributes model_name: str = "MolGPTMolecularGenerator" fitting_loss: List[float] = field(default_factory=list, init=False) fitting_epoch: int = field(default=0, init=False) model_class: Type[GPT] = field(default=GPT, init=False) def __post_init__(self): """Initialize the model after dataclass initialization.""" super().__post_init__() self.vocab_size = None self.token_to_id = None self.id_to_token = None self.scaffold_maxlen = None self.pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" self.regex = re.compile(self.pattern) @staticmethod def _get_param_names() -> List[str]: """Get parameter names for the estimator. Returns ------- List[str] List of parameter names that can be used for model configuration. """ return [ # Model Hyperparameters "num_layer", "num_head", "hidden_size", # Conditioning parameters "num_task", "use_scaffold", "use_lstm", "lstm_layers", # post-initialization parameters "vocab_size", "token_to_id", "id_to_token", "scaffold_maxlen", "max_len", # Training Parameters "batch_size", "epochs", "learning_rate", "weight_decay", "grad_norm_clip", # Other Parameters "verbose", "model_name", "fitting_epoch", "fitting_loss", "device" ] def _initialize_model( self, model_class: Type[torch.nn.Module], checkpoint: Optional[Dict] = None ) -> torch.nn.Module: """Initialize the model with parameters or a checkpoint. Parameters ---------- model_class : Type[torch.nn.Module] PyTorch module class to instantiate checkpoint : Optional[Dict], default=None Optional dictionary containing model checkpoint data Returns ------- torch.nn.Module Initialized PyTorch model """ model_params = self._get_model_params(checkpoint) self.model = model_class(**model_params) self.model = self.model.to(self.device) if checkpoint is not None: self.model.load_state_dict(checkpoint["model_state_dict"]) # get other params in checkpoint["hyperparameters"] but NOT in model_params other_params = {k: checkpoint["hyperparameters"][k] for k in checkpoint["hyperparameters"] if k not in model_params} # set other params in self for k, v in other_params.items(): setattr(self, k, v) return self.model def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: params = [ "vocab_size", "max_len", "num_task", "num_layer", "num_head", "hidden_size", "use_scaffold", "scaffold_maxlen", "use_lstm", "lstm_layers", ] if checkpoint is not None: if "hyperparameters" not in checkpoint: raise ValueError("Checkpoint missing 'hyperparameters' key") return {k: checkpoint["hyperparameters"][k] for k in params} return {k: getattr(self, k) for k in params} def _setup_optimizers(self): train_config = { "learning_rate": self.learning_rate, "weight_decay": self.weight_decay, "betas": self.adamw_betas } return self.model.configure_optimizers(train_config)
[docs] def fit(self, X_train, y_train=None, X_scaffold=None): """ Train the MolGPT model on SMILES strings. Parameters ---------- X_train : List[str] List of SMILES strings for training y_train : Optional[List[float]] Optional list of property values for conditional generation X_scaffold : Optional[List[str]] Optional list of scaffold SMILES strings for conditional generation Returns ------- self : MolGPTGenerator The fitted model """ X_train, y_train = self._validate_inputs(X_train, y_train, num_task=self.num_task, return_rdkit_mol=False) # Calculate max length for padding lens = [len(self.regex.findall(i.strip())) for i in X_train] max_len = max(lens) if X_scaffold is not None: assert len(X_scaffold) == len(X_train), "X_scaffold and X_train must have the same length" assert self.use_scaffold, "use_scaffold must be True" X_scaffold, _ = self._validate_inputs(X_scaffold, num_task=self.num_task, return_rdkit_mol=False) scaffold_maxlen = max([len(self.regex.findall(i.strip())) for i in X_scaffold]) else: scaffold_maxlen = 0 self.scaffold_maxlen = scaffold_maxlen self.max_len = max_len # Create dataset train_dataset = SmilesDataset( X_train, self.regex, max_len, properties=y_train, scaffolds=X_scaffold, scaffold_maxlen=scaffold_maxlen, ) # Save vocabulary self.vocab_size = train_dataset.vocab_size self.token_to_id = train_dataset.stoi self.id_to_token = train_dataset.itos # Initialize model self._initialize_model(self.model_class) self.model.initialize_parameters() optimizer = self._setup_optimizers() # Create data loader train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0 ) # Training loop self.fitting_loss = [] self.fitting_epoch = 0 scaler = GradScaler() for epoch in range(self.epochs): train_losses = self._train_epoch(train_loader, optimizer, epoch, scaler) self.fitting_loss.append(np.mean(train_losses)) self.fitting_epoch = epoch self.is_fitted_ = True return self
def _train_epoch(self, train_loader, optimizer, epoch, scaler): self.model.train() losses = [] iterator = ( tqdm(train_loader, desc="Training", leave=False) if self.verbose else train_loader ) for step, (x, y, prop, scaffold) in enumerate(iterator): x = x.to(self.device) y = y.to(self.device) prop = prop.to(self.device) if prop.numel() > 0 else None scaffold = scaffold.to(self.device) if scaffold.numel() > 0 else None optimizer.zero_grad() loss = self.model.compute_loss(x, targets=y, prop=prop, scaffold=scaffold) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clip) scaler.step(optimizer) scaler.update() losses.append(loss.item()) if self.verbose: iterator.set_postfix({"Epoch": epoch, "Loss": f"{loss.item():.4f}"}) return losses @torch.no_grad() def sample(self, x, steps, temperature=1.0, top_k=None, prop=None, scaffold=None): """ Sample from the model given a context. Parameters ---------- x : torch.Tensor Context tensor of shape (batch_size, seq_len) steps : int Number of steps to sample temperature : float Sampling temperature top_k : int Top-k sampling parameter prop : torch.Tensor Property conditioning tensor scaffold : torch.Tensor Scaffold conditioning tensor Returns ------- torch.Tensor Generated sequences """ model = self.model model.eval() for k in range(steps): # Get block size from model max_len = model.get_max_len() # Crop context if needed x_cond = x if x.size(1) <= max_len else x[:, -max_len:] # Forward pass logits, _ = model(x_cond, prop=prop, scaffold=scaffold) # Get logits for the next token and apply temperature logits = logits[:, -1, :] / temperature # Apply top-k sampling if specified if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') # Apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # Sample from the distribution next_token = torch.multinomial(probs, num_samples=1) # Append to the sequence x = torch.cat((x, next_token), dim=1) return x
[docs] def generate(self, n_samples=10, properties=None, scaffolds=None, max_len=None, temperature=1.0, top_k=10, starting_token='C'): """ Generate molecules using the trained model. Parameters ---------- n_samples : int, default=10 Number of molecules to generate properties : Optional[List[List[float]]] Property values for conditional generation scaffolds : Optional[List[str]] Scaffold SMILES for conditional generation max_len : Optional[int] Maximum length of generated SMILES temperature : float, default=1.0 Sampling temperature top_k : int, default=10 Top-k sampling parameter starting_token : Optional[str] Starting token for generation (default is 'C') Returns ------- List[str] List of generated SMILES strings """ if not self.is_fitted_: raise ValueError("Model must be fitted before generating molecules") if max_len is None: max_len = self.max_len # Prepare property conditioning if provided if properties is not None: if len(properties) != n_samples: raise ValueError(f"Number of property values ({len(properties)}) must match n_samples ({n_samples})") prop_tensor = torch.tensor(properties, dtype=torch.float).to(self.device) else: prop_tensor = None # Prepare scaffold conditioning if provided if scaffolds is not None: if len(scaffolds) != n_samples: raise ValueError(f"Number of scaffolds ({len(scaffolds)}) must match n_samples ({n_samples})") # Tokenize scaffolds regex = re.compile(self.pattern) scaffold_tokens = [] for scaffold in scaffolds: tokens = regex.findall(scaffold.strip()) # Pad with '<' if needed tokens += ['<'] * (self.scaffold_maxlen - len(tokens)) # Convert to indices scaffold_tokens.append([self.token_to_id.get(t, 0) for t in tokens]) scaffold_tensor = torch.tensor(scaffold_tokens, dtype=torch.long).to(self.device) else: scaffold_tensor = None if starting_token in self.token_to_id: start_token = self.token_to_id[starting_token] else: warnings.warn(f"Starting token {starting_token} not found in vocabulary, using first token instead") start_token = 0 context = torch.tensor([[start_token]] * n_samples, dtype=torch.long).to(self.device) # Sample from the model generated = self.sample( context, steps=max_len-1, # -1 because we already have one token temperature=temperature, top_k=top_k, prop=prop_tensor, scaffold=scaffold_tensor ) # Convert to SMILES strings smiles_list = [] for i in range(n_samples): # Convert indices to tokens tokens = [self.id_to_token[idx.item()] for idx in generated[i]] # Join tokens and remove padding smiles = ''.join(tokens).replace('<', '') smiles_list.append(smiles) return smiles_list