from tqdm import tqdm
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Tuple, List, Type
import numpy as np
import torch
from .jtnn_vae import JTNNVAE
from .jtnn.mol_tree import MolTree
from .jtnn.vocab import Vocab
from .jtnn.datautils import MolTreeFolder
from ...base import BaseMolecularGenerator
[docs]
@dataclass
class JTVAEMolecularGenerator(BaseMolecularGenerator):
"""
JT-VAE-based molecular generator. Implemented for unconditional moleculargeneration.
References
----------
- Junction Tree Variational Autoencoder for Molecular Graph Generation. ICML 2018. https://arxiv.org/pdf/1802.04364
- Code: https://github.com/kamikaze0923/jtvae
Parameters
----------
hidden_size : int, default=450
Dimension of hidden layers in the model.
latent_size : int, default=56
Dimension of the latent space.
depthT : int, default=20
Depth of the tree encoder.
depthG : int, default=3
Depth of the graph decoder.
batch_size : int, default=32
Number of samples per batch during training.
epochs : int, default=20
Number of epochs to train the model.
learning_rate : float, default=0.003
Initial learning rate for the optimizer.
weight_decay : float, default=0.0
L2 regularization factor.
grad_norm_clip : Optional[float], default=None
Maximum norm for gradient clipping. None means no clipping.
beta : float, default=0.0
Initial KL divergence weight for VAE training.
step_beta : float, default=0.002
Step size for KL annealing.
max_beta : float, default=1.0
Maximum value for KL weight.
warmup : int, default=40000
Number of steps for KL annealing warmup.
use_lr_scheduler : bool, default=True
Whether to use learning rate scheduling.
anneal_rate : float, default=0.9
Learning rate annealing factor.
anneal_iter : int, default=40000
Number of iterations between learning rate updates.
kl_anneal_iter : int, default=2000
Number of iterations between KL weight updates.
verbose : bool, default=False
Whether to print detailed training information.
"""
# Model parameters
hidden_size: int = 450
latent_size: int = 56
depthT: int = 20
depthG: int = 3
# Training parameters
batch_size: int = 32
epochs: int = 20
learning_rate: float = 0.003
weight_decay: float = 0.0
grad_norm_clip: Optional[float] = None
# KL annealing parameters
beta: float = 0.0
step_beta: float = 0.002
max_beta: float = 1.0
warmup: int = 40000
# Learning rate scheduling
use_lr_scheduler: bool = True
anneal_rate: float = 0.9
anneal_iter: int = 40000
kl_anneal_iter: int = 2000
# Other parameters
verbose: bool = False
# attributes
model_name: str = "JTVAEMolecularGenerator"
fitting_loss: List[float] = field(default_factory=list, init=False)
fitting_epoch: int = field(default=0, init=False)
model_class: Type[JTNNVAE] = field(default=JTNNVAE, init=False)
def __post_init__(self):
super().__post_init__()
self.vocab = None
@staticmethod
def _get_param_names() -> List[str]:
return [
"hidden_size", "latent_size", "depthT", "depthG",
"batch_size", "epochs", "learning_rate", "weight_decay",
"beta", "step_beta", "max_beta", "warmup",
"use_lr_scheduler", "anneal_rate", "anneal_iter", "kl_anneal_iter",
"verbose", 'model_name', 'vocab'
]
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
params = [
"vocab", "hidden_size", "latent_size", "depthT", "depthG"
]
if checkpoint is not None:
if "hyperparameters" not in checkpoint:
raise ValueError("Checkpoint missing 'hyperparameters' key")
return {k: checkpoint["hyperparameters"][k] for k in params}
return {k: getattr(self, k) for k in params}
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
"""Setup optimization components including optimizer and learning rate scheduler.
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = None
if self.use_lr_scheduler:
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.anneal_rate)
return optimizer, scheduler
def _extract_vocab(self, X):
cset = set()
for ii, smiles in enumerate(X):
try:
mol = MolTree(smiles)
for c in mol.nodes:
cset.add(c.smiles)
except Exception as e:
print(f'Error {e} in extracting vocab for smiles: {smiles}')
pass
vocab = list(cset)
return vocab
def _convert_to_tensor(self, X):
all_data = []
for smiles in X:
try:
mol_tree = MolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
for node in mol_tree.nodes:
if node.label not in node.cands:
node.cands.append(node.label)
del mol_tree.mol
for node in mol_tree.nodes:
del node.mol
except Exception as e:
print(f'Error {e} in tensorizing smiles: {smiles}')
mol_tree = None
all_data.append(mol_tree)
return all_data
[docs]
def fit(
self,
X_train: List[str],
) -> "JTVAEMolecularGenerator":
X_train, _ = self._validate_inputs(X_train, None, num_task=0, return_rdkit_mol=False)
vocab = self._extract_vocab(X_train)
vocab = Vocab(vocab)
self.vocab = vocab
self._initialize_model(self.model_class)
self.model.initialize_parameters()
optimizer, scheduler = self._setup_optimizers()
train_dataset = self._convert_to_tensor(X_train)
train_dataset = list(filter(lambda x: x is not None, train_dataset))
train_loader = MolTreeFolder(train_dataset, vocab, self.batch_size, num_workers=0)
step_len = len(X_train) // self.batch_size
self.fitting_loss = []
self.fitting_epoch = 0
total_step = 0
for epoch in range(self.epochs):
train_losses, total_step = self._train_epoch(train_loader, optimizer, scheduler, epoch, total_step, step_len)
self.fitting_loss.append(np.mean(train_losses).item())
self.fitting_epoch = epoch
self.is_fitted_ = True
return self
def _train_epoch(self, train_loader, optimizer, scheduler, epoch, total_step, step_len):
self.model.train()
losses = []
iterator = (
tqdm(train_loader, desc="Training", leave=False, total=step_len)
if self.verbose
else train_loader
)
for step, batched_data in enumerate(iterator):
total_step += 1
optimizer.zero_grad()
word_loss, topo_loss, assm_loss, kl_div = self.model.compute_loss(batched_data)
total_loss = word_loss + topo_loss + assm_loss + kl_div * self.beta
total_loss.backward()
if self.grad_norm_clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clip)
optimizer.step()
losses.append(total_loss.item())
if total_step % self.anneal_iter == 0:
scheduler.step()
if total_step % self.kl_anneal_iter == 0 and total_step >= self.warmup:
self.beta = min(self.max_beta, self.beta + self.step_beta)
if self.verbose:
iterator.set_postfix({"Epoch": epoch, "Loss": f"{total_loss.item():.4f}, word_loss: {word_loss.item():.4f}, topo_loss: {topo_loss.item():.4f}, assm_loss: {assm_loss.item():.4f}, kl_div: {kl_div.item():.4f}"})
return losses, total_step
[docs]
def generate(
self,
batch_size: int = 32
) -> List[str]:
"""Generate molecules using JT-VAE.
Parameters
----------
batch_size : int
Number of molecules to generate.
Returns
-------
List[str]
Generated molecules as SMILES strings.
"""
if not self.is_fitted_:
raise ValueError("Model must be fitted before generating molecules.")
return self.model.sample_prior(self.device)