import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Type
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from .transformer import Transformer
from .utils import PlaceHolder, to_dense, compute_dataset_info
from .diffusion import NoiseScheduleDiscrete, MarginalTransition, sample_discrete_features, sample_discrete_feature_noise, reverse_diffusion
from ...base import BaseMolecularGenerator
from ...utils import graph_from_smiles, graph_to_smiles
[docs]
@dataclass
class GraphDITMolecularGenerator(BaseMolecularGenerator):
"""
This generator implements the graph diffusion transformer for (multi-conditional and unconditional) molecular generation.
References
----------
- Graph Diffusion Transformers for Multi-Conditional Molecular Generation. NeurIPS 2024.
https://openreview.net/forum?id=cfrDLD1wfO
- Implementation: https://github.com/liugangcode/Graph-DiT
Parameters
----------
num_layer : int, default=6
Number of transformer layers
hidden_size : int, default=1152
Dimension of hidden layers
dropout : float, default=0.0
Dropout rate for transformer layers
drop_condition : float, default=0.0
Dropout rate for condition embedding
num_head : int, default=16
Number of attention heads in transformer
mlp_ratio : float, default=4
Ratio of MLP hidden dimension to transformer hidden dimension
task_type : List[str], default=[]
List specifying type of each task ('regression' or 'classification')
timesteps : int, default=500
Number of diffusion timesteps
batch_size : int, default=128
Batch size for training
epochs : int, default=10000
Number of training epochs
learning_rate : float, default=0.0002
Learning rate for optimization
grad_clip_value : Optional[float], default=None
Value for gradient clipping (None = no clipping)
weight_decay : float, default=0.0
Weight decay for optimization
lw_X : float, default=1
Loss weight for node reconstruction
lw_E : float, default=5
Loss weight for edge reconstruction
guide_scale : float, default=2.0
Scale factor for classifier-free guidance during sampling
use_lr_scheduler : bool, default=False
Whether to use learning rate scheduler
scheduler_factor : float, default=0.5
Factor by which to reduce learning rate on plateau
scheduler_patience : int, default=5
Number of epochs with no improvement after which learning rate will be reduced
verbose : bool, default=False
Whether to display progress bars and logs
"""
# Model parameters
num_layer: int = 6
hidden_size: int = 1152
dropout: float = 0.
drop_condition: float = 0.
num_head: int = 16
mlp_ratio: float = 4
task_type: List[str] = field(default_factory=list)
# Diffusion parameters
timesteps: int = 500
# Training parameters
batch_size: int = 128
epochs: int = 10000
learning_rate: float = 0.0002
grad_clip_value: Optional[float] = None
weight_decay: float = 0.0
lw_X: float = 1
lw_E: float = 5
# Sampling parameters
guide_scale: float = 2.
# Scheduler parameters
use_lr_scheduler: bool = False
scheduler_factor: float = 0.5
scheduler_patience: int = 5
verbose: bool = False
# Attributes
model_name: str = "GraphDITMolecularGenerator"
fitting_loss: List[float] = field(default_factory=list, init=False)
fitting_epoch: int = field(default=0, init=False)
dataset_info: Dict[str, Any] = field(default_factory=dict, init=False)
model_class: Type[Transformer] = field(default=Transformer, init=False)
def __post_init__(self):
"""Initialize the model after dataclass initialization."""
super().__post_init__()
self.max_node = None
self.input_dim_X = None
self.input_dim_E = None
self.input_dim_y = len(self.task_type)
@staticmethod
def _get_param_names() -> List[str]:
"""Get parameter names for the estimator.
Returns
-------
List[str]
List of parameter names that can be used for model configuration.
"""
return [
# Model Hyperparameters
"max_node", "hidden_size", "num_layer", "num_head",
"mlp_ratio", "dropout", "drop_condition", "input_dim_X", "input_dim_E", "input_dim_y",
"task_type",
# Diffusion parameters
"timesteps", "dataset_info",
# Training Parameters
"batch_size", "epochs", "learning_rate", "grad_clip_value",
"weight_decay", "lw_X", "lw_E",
# Scheduler Parameters
"use_lr_scheduler", "scheduler_factor", "scheduler_patience",
# Sampling Parameters
"guide_scale",
# Other Parameters
"fitting_epoch", "fitting_loss", "device", "verbose", "model_name"
]
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
params = ["max_node", "hidden_size", "num_layer", "num_head", "mlp_ratio",
"dropout", "drop_condition", "input_dim_X", "input_dim_E", "input_dim_y", "task_type"]
if checkpoint is not None:
if "hyperparameters" not in checkpoint:
raise ValueError("Checkpoint missing 'hyperparameters' key")
return {k: checkpoint["hyperparameters"][k] for k in params}
return {k: getattr(self, k) for k in params}
def _convert_to_pytorch_data(self, X, y=None):
"""Convert numpy arrays to PyTorch Geometric data format.
"""
if self.verbose:
iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X))
else:
iterator = enumerate(X)
pyg_graph_list = []
for idx, smiles_or_mol in iterator:
if y is not None:
properties = y[idx]
else:
properties = None
graph = graph_from_smiles(smiles_or_mol, properties)
g = Data()
# No H, first heavy atom has type 0
node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1)
# Filter out invalid node types (< 0)
valid_mask = node_type >= 0
if not valid_mask.all():
# Get valid nodes and adjust edge indices
valid_indices = torch.where(valid_mask)[0]
index_map = -torch.ones(node_type.size(0), dtype=torch.long)
index_map[valid_indices] = torch.arange(valid_indices.size(0))
# Filter edges that connect to invalid nodes
edge_index = torch.from_numpy(graph["edge_index"])
valid_edges_mask = valid_mask[edge_index[0]] & valid_mask[edge_index[1]]
valid_edge_index = edge_index[:, valid_edges_mask]
# Remap edge indices to account for removed nodes
remapped_edge_index = index_map[valid_edge_index]
# Filter edge attributes
edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1
valid_edge_attr = edge_attr[valid_edges_mask]
# Update node and edge data
node_type = node_type[valid_mask]
g.edge_index = remapped_edge_index
g.edge_attr = valid_edge_attr.long().squeeze(-1)
else:
# No invalid nodes, proceed normally
g.edge_index = torch.from_numpy(graph["edge_index"])
edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1
g.edge_attr = edge_attr.long().squeeze(-1)
# * is encoded as "misc" which is 119 - 1 and should be 117
node_type[node_type == 118] = 117
g.x = node_type.long().squeeze(-1)
del graph["node_feat"]
del graph["edge_index"]
del graph["edge_feat"]
g.y = torch.from_numpy(graph["y"])
del graph["y"]
pyg_graph_list.append(g)
return pyg_graph_list
def _setup_diffusion_params(self, X: Union[List, Dict]) -> None:
# Extract dataset info from X if it's a dict (from checkpoint), otherwise compute it
if isinstance(X, dict):
dataset_info = X["hyperparameters"]["dataset_info"]
timesteps = X["hyperparameters"]["timesteps"]
max_node = X["hyperparameters"]["max_node"]
else:
assert isinstance(X, list)
dataset_info = compute_dataset_info(X)
timesteps = self.timesteps
max_node = dataset_info["max_node"]
self.input_dim_X = dataset_info["x_margins"].shape[0]
self.input_dim_E = dataset_info["e_margins"].shape[0]
self.dataset_info = dataset_info
self.timesteps = timesteps
self.max_node = max_node
x_limit = dataset_info["x_margins"].to(self.device)
e_limit = dataset_info["e_margins"].to(self.device)
xe_conditions = dataset_info["xe_conditions"].to(self.device)
ex_conditions = dataset_info["ex_conditions"].to(self.device)
self.transition_model = MarginalTransition(x_limit, e_limit, xe_conditions, ex_conditions, self.max_node)
self.limit_dist = PlaceHolder(X=x_limit, E=e_limit, y=None)
self.noise_schedule = NoiseScheduleDiscrete(timesteps=self.timesteps).to(self.device)
def _initialize_model(
self,
model_class: Type[torch.nn.Module],
checkpoint: Optional[Dict] = None
) -> torch.nn.Module:
"""Initialize the model with parameters or a checkpoint."""
model_params = self._get_model_params(checkpoint)
self.model = model_class(**model_params)
self.model = self.model.to(self.device)
if checkpoint is not None:
self._setup_diffusion_params(checkpoint)
self.model.load_state_dict(checkpoint["model_state_dict"])
return self.model
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
"""Setup optimization components including optimizer and learning rate scheduler.
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = None
if self.use_lr_scheduler:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=self.scheduler_factor,
patience=self.scheduler_patience,
min_lr=1e-6,
cooldown=0,
eps=1e-8,
)
return optimizer, scheduler
[docs]
def fit(
self,
X_train: List[str],
y_train: Optional[Union[List, np.ndarray]] = None,
) -> "GraphDITMolecularGenerator":
num_task = len(self.task_type)
X_train, y_train = self._validate_inputs(X_train, y_train, num_task=num_task)
self._setup_diffusion_params(X_train)
self._initialize_model(self.model_class)
self.model.initialize_parameters()
optimizer, scheduler = self._setup_optimizers()
train_dataset = self._convert_to_pytorch_data(X_train, y_train)
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0
)
self.fitting_loss = []
self.fitting_epoch = 0
for epoch in range(self.epochs):
train_losses = self._train_epoch(train_loader, optimizer, epoch)
self.fitting_loss.append(np.mean(train_losses).item())
if scheduler:
scheduler.step(np.mean(train_losses).item())
self.fitting_epoch = epoch
self.is_fitted_ = True
return self
def _train_epoch(self, train_loader, optimizer, epoch):
self.model.train()
losses = []
iterator = (
tqdm(train_loader, desc="Training", leave=False)
if self.verbose
else train_loader
)
active_index = self.dataset_info["active_index"]
for step, batched_data in enumerate(iterator):
batched_data = batched_data.to(self.device)
optimizer.zero_grad()
data_x = F.one_hot(batched_data.x, num_classes=118).float()[:, active_index]
data_edge_attr = F.one_hot(batched_data.edge_attr, num_classes=5).float()
dense_data, node_mask = to_dense(data_x, batched_data.edge_index, data_edge_attr, batched_data.batch, self.max_node)
dense_data = dense_data.mask(node_mask)
X, E = dense_data.X, dense_data.E
noisy_data = self.apply_noise(X, E, batched_data.y, node_mask)
loss, loss_X, loss_E = self.model.compute_loss(noisy_data, true_X=X, true_E=E, lw_X=self.lw_X, lw_E=self.lw_E)
loss.backward()
if self.grad_clip_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value)
optimizer.step()
losses.append(loss.item())
if self.verbose:
iterator.set_postfix({"Epoch": epoch, "Loss": f"{loss.item():.4f}", "Loss_X": f"{loss_X.item():.4f}", "Loss_E": f"{loss_E.item():.4f}"})
return losses
def apply_noise(self, X, E, y, node_mask) -> Dict[str, Any]:
t_int = torch.randint(0, self.timesteps + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1)
s_int = t_int - 1
t_float = t_int / self.timesteps
s_float = s_int / self.timesteps
# beta_t and alpha_s_bar are used for denoising/loss computation
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
bs, n, _ = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X
probX = prob_all[:, :, :self.input_dim_X]
probE = prob_all[:, :, self.input_dim_X:].reshape(bs, n, n, -1)
# check whether X_all/prob_all/probX/probE contain nan
sampled_t = sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)
X_t = F.one_hot(sampled_t.X, num_classes=self.input_dim_X)
E_t = F.one_hot(sampled_t.E, num_classes=self.input_dim_E)
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
z_t = PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)
noisy_data = {'t': t_float * self.timesteps, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask}
return noisy_data
[docs]
@torch.no_grad()
def generate(self, labels: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, num_nodes: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, batch_size: int = 32) -> List[str]:
"""Generate molecules with specified properties and optional node counts.
Parameters
----------
labels : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None
Target properties for the generated molecules. Can be provided as:
- A list of lists for multiple properties
- A numpy array of shape (batch_size, n_properties)
- A torch tensor of shape (batch_size, n_properties)
For single label (properties values), can also be provided as 1D array/tensor.
If None, generates unconditional samples.
num_nodes : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None
Number of nodes for each molecule in the batch. If None, samples from
the training distribution. Can be provided as:
- A list of lists
- A numpy array of shape (batch_size, 1)
- A torch tensor of shape (batch_size, 1)
batch_size : int, default=32
Number of molecules to generate. Only used if labels is None.
Returns
-------
List[str]
List of generated molecules in SMILES format.
"""
if not self.is_fitted_:
raise ValueError("Model must be fitted before generating molecules.")
if self.input_dim_X is None or self.input_dim_E is None or self.max_node is None:
raise ValueError(f"Model may not be fitted correctly as one of below attributes is not set: input_dim_X={self.input_dim_X}, input_dim_E={self.input_dim_E}, max_node={self.max_node}")
if len(self.task_type) > 0 and labels is None:
raise ValueError(f"labels must be provided if task_type is not empty: {self.task_type}")
if labels is not None and num_nodes is not None:
assert len(labels) == len(num_nodes), "labels and num_nodes must have the same batch size"
if labels is not None:
if num_nodes is not None:
assert len(labels) == len(num_nodes), "labels and num_nodes must have the same batch size"
batch_size = len(labels)
elif num_nodes is not None:
batch_size = len(num_nodes)
# Convert properties to 2D tensor if needed
if isinstance(labels, list):
labels = torch.tensor(labels)
elif isinstance(labels, np.ndarray):
labels = torch.from_numpy(labels)
if labels is not None and labels.dim() == 1:
labels = labels.unsqueeze(-1)
if num_nodes is None:
num_nodes_dist = self.dataset_info["num_nodes_dist"]
num_nodes = num_nodes_dist.sample_n(batch_size, self.device)
elif isinstance(num_nodes, list):
num_nodes = torch.tensor(num_nodes).to(self.device)
elif isinstance(num_nodes, np.ndarray):
num_nodes = torch.from_numpy(num_nodes).to(self.device)
if num_nodes.dim() == 1:
num_nodes = num_nodes.unsqueeze(-1)
assert num_nodes.size(0) == batch_size
arange = (
torch.arange(self.max_node).to(self.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
node_mask = arange < num_nodes
if not hasattr(self, 'limit_dist') or self.limit_dist is None:
raise ValueError("Limit distribution not found. Please call setup_diffusion_params first.")
if not hasattr(self, 'dataset_info') or self.dataset_info is None:
raise ValueError("Dataset info not found. Please call setup_diffusion_params first.")
z_T = sample_discrete_feature_noise(
limit_dist=self.limit_dist, node_mask=node_mask
)
X, E = z_T.X, z_T.E
assert (E == torch.transpose(E, 1, 2)).all()
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
if labels is not None:
y = labels.to(self.device).float()
else:
y = None
for s_int in reversed(range(0, self.timesteps)):
s_array = s_int * torch.ones((batch_size, 1)).float().to(self.device)
t_array = s_array + 1
s_norm = s_array / self.timesteps
t_norm = t_array / self.timesteps
# Sample z_s
sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
for i in range(batch_size):
n = num_nodes[i][0].item()
atom_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
smiles_list = graph_to_smiles(molecule_list, self.dataset_info["atom_decoder"])
return smiles_list
def sample_p_zs_given_zt(
self, s, t, X_t, E_t, properties, node_mask
):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
if last_step, return the graph prediction as well"""
bs, n, _ = X_t.shape
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
# Neural net predictions
noisy_data = {
"X_t": X_t,
"E_t": E_t,
"y_t": properties,
"t": t,
"node_mask": node_mask,
}
def get_prob(noisy_data, unconditioned=False):
pred = self.model(noisy_data, unconditioned=unconditioned)
# Normalize predictions
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
device = pred_X.device
# Retrieve transitions matrix
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
Qt = self.transition_model.get_Qt(beta_t, device)
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
# raise ValueError('stop here')
unnormalized_probX_all = reverse_diffusion(
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
)
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.input_dim_X]
unnormalized_prob_E = unnormalized_probX_all[
:, :, self.input_dim_X :
].reshape(bs, n * n, -1)
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(
unnormalized_prob_X, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = unnormalized_prob_E / torch.sum(
unnormalized_prob_E, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
return prob_X, prob_E
prob_X, prob_E = get_prob(noisy_data)
### Guidance
if self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True)
prob_X = (
uncon_prob_X
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** self.guide_scale
)
prob_E = (
uncon_prob_E
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** self.guide_scale
)
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
sampled_s = sample_discrete_features(prob_X, prob_E, node_mask=node_mask)
X_s = F.one_hot(sampled_s.X, num_classes=self.input_dim_X).to(self.device).float()
E_s = F.one_hot(sampled_s.E, num_classes=self.input_dim_E).to(self.device).float()
assert (E_s == torch.transpose(E_s, 1, 2)).all()
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
out_one_hot = PlaceHolder(X=X_s, E=E_s, y=properties)
return out_one_hot.mask(node_mask)