from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Type
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import numpy as np
from .transformer import Transformer
from .utils import PlaceHolder, to_dense, compute_dataset_info
from .diffusion import NoiseScheduleDiscrete, MarginalTransition, sample_discrete_features, sample_discrete_feature_noise, reverse_diffusion
from ...base import BaseMolecularGenerator
from ...utils import graph_from_smiles, graph_to_smiles
[docs]
class GraphDITMolecularGenerator(BaseMolecularGenerator):
"""
This generator implements the graph diffusion transformer for (multi-conditional and unconditional) molecular generation.
References
----------
- Graph Diffusion Transformers for Multi-Conditional Molecular Generation. NeurIPS 2024.
https://openreview.net/forum?id=cfrDLD1wfO
- Implementation: https://github.com/liugangcode/Graph-DiT
Parameters
----------
num_layer : int, default=6
Number of transformer layers
hidden_size : int, default=1152
Dimension of hidden layers
dropout : float, default=0.0
Dropout rate for transformer layers
drop_condition : float, default=0.0
Dropout rate for condition embedding
num_head : int, default=16
Number of attention heads in transformer
mlp_ratio : float, default=4
Ratio of MLP hidden dimension to transformer hidden dimension
task_type : List[str], default=[]
List specifying type of each task ('regression' or 'classification')
timesteps : int, default=500
Number of diffusion timesteps
batch_size : int, default=128
Batch size for training
epochs : int, default=10000
Number of training epochs
learning_rate : float, default=0.0002
Learning rate for optimization
grad_clip_value : Optional[float], default=None
Value for gradient clipping (None = no clipping)
weight_decay : float, default=0.0
Weight decay for optimization
lw_X : float, default=1
Loss weight for node reconstruction
lw_E : float, default=5
Loss weight for edge reconstruction
guide_scale : float, default=2.0
Scale factor for classifier-free guidance during sampling
use_lr_scheduler : bool, default=False
Whether to use learning rate scheduler
scheduler_factor : float, default=0.5
Factor by which to reduce learning rate on plateau
scheduler_patience : int, default=5
Number of epochs with no improvement after which learning rate will be reduced
verbose : bool, default=False
Whether to display progress bars and logs
device : Optional[Union[torch.device, str]], default=None
Device to run the model on (CPU or GPU)
model_name : str, default="GraphDITMolecularGenerator"
Name identifier for the model
"""
def __init__(
self,
num_layer: int = 6,
hidden_size: int = 1152,
dropout: float = 0.,
drop_condition: float = 0.,
num_head: int = 16,
mlp_ratio: float = 4,
task_type: Optional[List[str]] = None,
timesteps: int = 500,
batch_size: int = 128,
epochs: int = 10000,
learning_rate: float = 0.0002,
grad_clip_value: Optional[float] = None,
weight_decay: float = 0.0,
lw_X: float = 1,
lw_E: float = 5,
guide_scale: float = 2.,
use_lr_scheduler: bool = False,
scheduler_factor: float = 0.5,
scheduler_patience: int = 5,
verbose: bool = False,
*,
device: Optional[Union[torch.device, str]] = None,
model_name: str = "GraphDITMolecularGenerator"
):
super().__init__(device=device, model_name=model_name)
self.num_layer = num_layer
self.hidden_size = hidden_size
self.dropout = dropout
self.drop_condition = drop_condition
self.num_head = num_head
self.mlp_ratio = mlp_ratio
if task_type is None:
self.task_type = list()
else:
self.task_type = task_type
self.timesteps = timesteps
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.grad_clip_value = grad_clip_value
self.weight_decay = weight_decay
self.lw_X = lw_X
self.lw_E = lw_E
self.guide_scale = guide_scale
self.use_lr_scheduler = use_lr_scheduler
self.scheduler_factor = scheduler_factor
self.scheduler_patience = scheduler_patience
self.verbose = verbose
self.fitting_loss = list()
self.fitting_epoch = 0
self.dataset_info = dict()
self.model_class = Transformer
self.max_node = None
self.input_dim_X = None
self.input_dim_E = None
self.input_dim_y = len(self.task_type)
@staticmethod
def _get_param_names() -> List[str]:
"""Get parameter names for the estimator.
Returns
-------
List[str]
List of parameter names that can be used for model configuration.
"""
return [
# Model Hyperparameters
"max_node", "hidden_size", "num_layer", "num_head",
"mlp_ratio", "dropout", "drop_condition", "input_dim_X", "input_dim_E", "input_dim_y",
"task_type",
# Diffusion parameters
"timesteps", "dataset_info",
# Training Parameters
"batch_size", "epochs", "learning_rate", "grad_clip_value",
"weight_decay", "lw_X", "lw_E",
# Scheduler Parameters
"use_lr_scheduler", "scheduler_factor", "scheduler_patience",
# Sampling Parameters
"guide_scale",
# Other Parameters
"fitting_epoch", "fitting_loss", "device", "verbose", "model_name"
]
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
params = ["max_node", "hidden_size", "num_layer", "num_head", "mlp_ratio",
"dropout", "drop_condition", "input_dim_X", "input_dim_E", "input_dim_y", "task_type"]
if checkpoint is not None:
if "hyperparameters" not in checkpoint:
raise ValueError("Checkpoint missing 'hyperparameters' key")
return {k: checkpoint["hyperparameters"][k] for k in params}
return {k: getattr(self, k) for k in params}
def _convert_to_pytorch_data(self, X, y=None):
"""Convert numpy arrays to PyTorch Geometric data format.
"""
if self.verbose:
iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X))
else:
iterator = enumerate(X)
pyg_graph_list = []
for idx, smiles_or_mol in iterator:
if y is not None:
properties = y[idx]
else:
properties = None
graph = graph_from_smiles(smiles_or_mol, properties)
g = Data()
# No H, first heavy atom has type 0
node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1)
if node_type.numel() <= 1:
continue
# Filter out invalid node types (< 0)
valid_mask = node_type >= 0
if not valid_mask.all():
# Get valid nodes and adjust edge indices
valid_indices = torch.where(valid_mask)[0]
index_map = -torch.ones(node_type.size(0), dtype=torch.long)
index_map[valid_indices] = torch.arange(valid_indices.size(0))
# Filter edges that connect to invalid nodes
edge_index = torch.from_numpy(graph["edge_index"])
valid_edges_mask = valid_mask[edge_index[0]] & valid_mask[edge_index[1]]
valid_edge_index = edge_index[:, valid_edges_mask]
# Remap edge indices to account for removed nodes
remapped_edge_index = index_map[valid_edge_index]
# Filter edge attributes
edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1
valid_edge_attr = edge_attr[valid_edges_mask]
# Update node and edge data
node_type = node_type[valid_mask]
g.edge_index = remapped_edge_index
g.edge_attr = valid_edge_attr.long().squeeze(-1)
else:
# No invalid nodes, proceed normally
g.edge_index = torch.from_numpy(graph["edge_index"])
edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1
g.edge_attr = edge_attr.long().squeeze(-1)
# * is encoded as "misc" which is 119 - 1 and should be 117
node_type[node_type == 118] = 117
g.x = node_type.long().squeeze(-1)
del graph["node_feat"]
del graph["edge_index"]
del graph["edge_feat"]
g.y = torch.from_numpy(graph["y"])
del graph["y"]
pyg_graph_list.append(g)
return pyg_graph_list
def _setup_diffusion_params(self, X: Union[List, Dict]) -> None:
# Extract dataset info from X if it's a dict (from checkpoint), otherwise compute it
if isinstance(X, dict):
dataset_info = X["hyperparameters"]["dataset_info"]
timesteps = X["hyperparameters"]["timesteps"]
max_node = X["hyperparameters"]["max_node"]
else:
assert isinstance(X, list)
dataset_info = compute_dataset_info(X)
timesteps = self.timesteps
max_node = dataset_info["max_node"]
self.input_dim_X = dataset_info["x_margins"].shape[0]
self.input_dim_E = dataset_info["e_margins"].shape[0]
self.dataset_info = dataset_info
self.timesteps = timesteps
self.max_node = max_node
x_limit = dataset_info["x_margins"].to(self.device)
e_limit = dataset_info["e_margins"].to(self.device)
xe_conditions = dataset_info["xe_conditions"].to(self.device)
ex_conditions = dataset_info["ex_conditions"].to(self.device)
self.transition_model = MarginalTransition(x_limit, e_limit, xe_conditions, ex_conditions, self.max_node)
self.limit_dist = PlaceHolder(X=x_limit, E=e_limit, y=None)
self.noise_schedule = NoiseScheduleDiscrete(timesteps=self.timesteps).to(self.device)
def _initialize_model(
self,
model_class: Type[torch.nn.Module],
checkpoint: Optional[Dict] = None
) -> torch.nn.Module:
"""Initialize the model with parameters or a checkpoint."""
model_params = self._get_model_params(checkpoint)
self.model = model_class(**model_params)
self.model = self.model.to(self.device)
if checkpoint is not None:
self._setup_diffusion_params(checkpoint)
self.model.load_state_dict(checkpoint["model_state_dict"])
return self.model
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
"""Setup optimization components including optimizer and learning rate scheduler.
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = None
if self.use_lr_scheduler:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=self.scheduler_factor,
patience=self.scheduler_patience,
min_lr=1e-6,
cooldown=0,
eps=1e-8,
)
return optimizer, scheduler
[docs]
def fit(
self,
X_train: List[str],
y_train: Optional[Union[List, np.ndarray]] = None,
) -> "GraphDITMolecularGenerator":
num_task = len(self.task_type)
X_train, y_train = self._validate_inputs(X_train, y_train, num_task=num_task)
self._setup_diffusion_params(X_train)
self._initialize_model(self.model_class)
self.model.initialize_parameters()
optimizer, scheduler = self._setup_optimizers()
train_dataset = self._convert_to_pytorch_data(X_train, y_train)
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0
)
self.fitting_loss = []
self.fitting_epoch = 0
# Calculate total steps for progress tracking
total_steps = self.epochs * len(train_loader)
# Initialize global progress bar
global_pbar = tqdm(total=total_steps, desc="Training Progress", disable=not self.verbose)
for epoch in range(self.epochs):
train_losses = self._train_epoch(train_loader, optimizer, epoch, global_pbar)
self.fitting_loss.append(np.mean(train_losses).item())
if scheduler:
scheduler.step(np.mean(train_losses).item())
global_pbar.close()
self.fitting_epoch = epoch
self.is_fitted_ = True
return self
def _train_epoch(self, train_loader, optimizer, epoch, global_pbar=None):
self.model.train()
losses = []
# Remove the local tqdm iterator since we're using global progress bar
active_index = self.dataset_info["active_index"]
for step, batched_data in enumerate(train_loader):
batched_data = batched_data.to(self.device)
optimizer.zero_grad()
data_x = F.one_hot(batched_data.x, num_classes=118).float()[:, active_index]
data_edge_attr = F.one_hot(batched_data.edge_attr, num_classes=5).float()
dense_data, node_mask = to_dense(data_x, batched_data.edge_index, data_edge_attr, batched_data.batch, self.max_node)
dense_data = dense_data.mask(node_mask)
X, E = dense_data.X, dense_data.E
noisy_data = self.apply_noise(X, E, batched_data.y, node_mask)
loss, loss_X, loss_E = self.model.compute_loss(noisy_data, true_X=X, true_E=E, lw_X=self.lw_X, lw_E=self.lw_E)
loss.backward()
if self.grad_clip_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value)
optimizer.step()
losses.append(loss.item())
# Update global progress bar
if global_pbar is not None:
global_pbar.set_postfix({
"Epoch": f"{epoch+1}/{self.epochs}",
"Step": f"{step+1}/{len(train_loader)}",
"Loss": f"{loss.item():.4f}",
"Loss_X": f"{loss_X.item():.4f}",
"Loss_E": f"{loss_E.item():.4f}"
})
global_pbar.update(1)
return losses
def apply_noise(self, X, E, y, node_mask) -> Dict[str, Any]:
t_int = torch.randint(0, self.timesteps + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1)
s_int = t_int - 1
t_float = t_int / self.timesteps
s_float = s_int / self.timesteps
# beta_t and alpha_s_bar are used for denoising/loss computation
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
bs, n, _ = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X
probX = prob_all[:, :, :self.input_dim_X]
probE = prob_all[:, :, self.input_dim_X:].reshape(bs, n, n, -1)
# check whether X_all/prob_all/probX/probE contain nan
sampled_t = sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)
X_t = F.one_hot(sampled_t.X, num_classes=self.input_dim_X)
E_t = F.one_hot(sampled_t.E, num_classes=self.input_dim_E)
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
z_t = PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)
noisy_data = {'t': t_float * self.timesteps, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask}
return noisy_data
[docs]
@torch.no_grad()
def generate(self, labels: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, num_nodes: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, batch_size: int = 32) -> List[str]:
"""Generate molecules with specified properties and optional node counts.
Parameters
----------
labels : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None
Target properties for the generated molecules. Can be provided as:
- A list of lists for multiple properties
- A numpy array of shape (batch_size, n_properties)
- A torch tensor of shape (batch_size, n_properties)
For single label (properties values), can also be provided as 1D array/tensor.
If None, generates unconditional samples.
num_nodes : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None
Number of nodes for each molecule in the batch. If None, samples from
the training distribution. Can be provided as:
- A list of lists
- A numpy array of shape (batch_size, 1)
- A torch tensor of shape (batch_size, 1)
batch_size : int, default=32
Number of molecules to generate. Only used if labels is None.
Returns
-------
List[str]
List of generated molecules in SMILES format.
"""
if not self.is_fitted_:
raise ValueError("Model must be fitted before generating molecules.")
if self.input_dim_X is None or self.input_dim_E is None or self.max_node is None:
raise ValueError(f"Model may not be fitted correctly as one of below attributes is not set: input_dim_X={self.input_dim_X}, input_dim_E={self.input_dim_E}, max_node={self.max_node}")
if len(self.task_type) > 0 and labels is None:
raise ValueError(f"labels must be provided if task_type is not empty: {self.task_type}")
if labels is not None and num_nodes is not None:
assert len(labels) == len(num_nodes), "labels and num_nodes must have the same batch size"
if labels is not None:
if num_nodes is not None:
assert len(labels) == len(num_nodes), "labels and num_nodes must have the same batch size"
batch_size = len(labels)
elif num_nodes is not None:
batch_size = len(num_nodes)
# Convert properties to 2D tensor if needed
if isinstance(labels, list):
labels = torch.tensor(labels)
elif isinstance(labels, np.ndarray):
labels = torch.from_numpy(labels)
if labels is not None and labels.dim() == 1:
labels = labels.unsqueeze(-1)
if num_nodes is None:
num_nodes_dist = self.dataset_info["num_nodes_dist"]
num_nodes = num_nodes_dist.sample_n(batch_size, self.device)
elif isinstance(num_nodes, list):
num_nodes = torch.tensor(num_nodes).to(self.device)
elif isinstance(num_nodes, np.ndarray):
num_nodes = torch.from_numpy(num_nodes).to(self.device)
if num_nodes.dim() == 1:
num_nodes = num_nodes.unsqueeze(-1)
assert num_nodes.size(0) == batch_size
arange = (
torch.arange(self.max_node).to(self.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
node_mask = arange < num_nodes
if not hasattr(self, 'limit_dist') or self.limit_dist is None:
raise ValueError("Limit distribution not found. Please call setup_diffusion_params first.")
if not hasattr(self, 'dataset_info') or self.dataset_info is None:
raise ValueError("Dataset info not found. Please call setup_diffusion_params first.")
z_T = sample_discrete_feature_noise(
limit_dist=self.limit_dist, node_mask=node_mask
)
X, E = z_T.X, z_T.E
assert (E == torch.transpose(E, 1, 2)).all()
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
if labels is not None:
y = labels.to(self.device).float()
else:
y = None
self.model.eval()
for s_int in reversed(range(0, self.timesteps)):
s_array = s_int * torch.ones((batch_size, 1)).float().to(self.device)
t_array = s_array + 1
s_norm = s_array / self.timesteps
t_norm = t_array / self.timesteps
# Sample z_s
sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
for i in range(batch_size):
n = num_nodes[i][0].item()
atom_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
smiles_list = graph_to_smiles(molecule_list, self.dataset_info["atom_decoder"])
return smiles_list
def sample_p_zs_given_zt(
self, s, t, X_t, E_t, properties, node_mask
):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
if last_step, return the graph prediction as well"""
bs, n, _ = X_t.shape
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
# Neural net predictions
noisy_data = {
"X_t": X_t,
"E_t": E_t,
"y_t": properties,
"t": t,
"node_mask": node_mask,
}
def get_prob(noisy_data, unconditioned=False):
pred = self.model(noisy_data, unconditioned=unconditioned)
# Normalize predictions
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
device = pred_X.device
# Retrieve transitions matrix
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
Qt = self.transition_model.get_Qt(beta_t, device)
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
# raise ValueError('stop here')
unnormalized_probX_all = reverse_diffusion(
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
)
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.input_dim_X]
unnormalized_prob_E = unnormalized_probX_all[
:, :, self.input_dim_X :
].reshape(bs, n * n, -1)
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(
unnormalized_prob_X, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = unnormalized_prob_E / torch.sum(
unnormalized_prob_E, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
return prob_X, prob_E
prob_X, prob_E = get_prob(noisy_data)
### Guidance
if self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True)
prob_X = (
uncon_prob_X
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** self.guide_scale
)
prob_E = (
uncon_prob_E
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** self.guide_scale
)
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
sampled_s = sample_discrete_features(prob_X, prob_E, node_mask=node_mask)
X_s = F.one_hot(sampled_s.X, num_classes=self.input_dim_X).to(self.device).float()
E_s = F.one_hot(sampled_s.E, num_classes=self.input_dim_E).to(self.device).float()
assert (E_s == torch.transpose(E_s, 1, 2)).all()
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
out_one_hot = PlaceHolder(X=X_s, E=E_s, y=properties)
return out_one_hot.mask(node_mask)