Source code for torch_molecule.generator.defog.modeling_defog

import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from .model import GraphTransformer
from .utils import PlaceHolder, to_dense, compute_dataset_info, DistributionNodes
from .flow_matching import NoiseDistribution, TimeDistorter, p_xt_g_x1, RateMatrixDesigner, sample_discrete_feature_noise, sample_discrete_features

from .extra_features import ExtraFeatures

from ...base import BaseMolecularGenerator
from ...utils import graph_from_smiles, graph_to_smiles

[docs] class DeFoGMolecularGenerator(BaseMolecularGenerator): """ This generator implements the Discrete Flow Graph Model (DeFoG) for (multi-conditional and unconditional) molecular generation. It uses a flow matching approach with a graph transformer backbone. References ---------- - DeFoG: A Discrete Flow Model for Graph Generation. ICML 2025. (https://openreview.net/forum?id=KPRIwWhqAZ) - Implementation based on: https://github.com/manuelmlmadeira/DeFoG Parameters ---------- num_layer : int, default=6 Number of transformer layers hidden_mlp_dims : Dict[str, int], default={'X': 256, 'E': 128, 'y': 128} if None Hidden dimensions for MLP layers in X (node dim), E (edge dim), and y (property dim) components hidden_dims : Dict[str, Any], default={'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128} if None Hidden dimensions for transformer components including attention heads and feed-forward layers Keys: 'dx' (node dim), 'de' (edge dim), 'dy' (property dim), 'n_head' (number of attention heads), 'dim_ffX' (feed-forward dim for node features), 'dim_ffE' (feed-forward dim for edge features), 'dim_ffy' (feed-forward dim for property features) transition : str, default='marginal' Transition type for flow matching. Options: 'marginal', 'absorbing', 'uniform', 'absorbfirst', 'argmax', 'edge_marginal', 'node_marginal' time_distortion : str, default="polydec" Time distortion schedule used during training/sampling. Options: 'identity', 'cos', 'revcos', 'polyinc', 'polydec' lambda_train : List[float], default=[5.0, 1.0] if None Loss weights: [edge_loss_weight, property_loss_weight] extra_features_type : str, default='rrwp' Extra feature type. Options: 'rrwp', 'rrwp_double', 'rrwp_only', 'rrwp_comp', 'cycles', 'eigenvalues', 'all' rrwp_steps : int, default=16 Number of steps for (R)RWP features sample_steps : int, default=500 Number of sampling steps during generation rdb : str, default="general" Rate matrix design (R^db) method. Options: 'general', 'marginal', 'column', 'entry' rdb_crit : str, default="p_x1_g_xt" Criterion used by 'column'/'entry' designs. Options: 'max_marginal', 'x_t', 'abs_state', 'p_x1_g_xt', 'x_1', 'p_xt_g_x1', 'xhat_t', 'first' eta : float, default=0 Stochasticity scaling for R^db (higher = more stochastic) omega : float, default=0.1 Target-guidance scaling for R^tg (higher = stronger guidance) guidance_weight : float, default=0.2 Classifier-free guidance weight during sampling batch_size : int, default=128 Batch size for training epochs : int, default=1000 Number of training epochs learning_rate : float, default=0.0002 Learning rate for optimization grad_clip_value : Optional[float], default=1.0 Value for gradient clipping (None = no clipping) weight_decay : float, default=0.0 Weight decay for optimization use_lr_scheduler : bool, default=False Whether to use learning rate scheduler scheduler_factor : float, default=0.5 Factor by which to reduce learning rate on plateau scheduler_patience : int, default=10 Number of epochs with no improvement after which learning rate will be reduced task_type : List[str], default=[] List specifying type of each task ('regression' or 'classification') verbose : bool, default=False Whether to display progress bars and logs device : Optional[Union[torch.device, str]], default=None Device to run the model on (CPU or GPU) model_name : str, default="DeFoGMolecularGenerator" Name identifier for the model """ def __init__( self, num_layer: int = 6, hidden_mlp_dims: Optional[Dict[str, int]] = None, hidden_dims: Optional[Dict[str, Any]] = None, transition: str = 'marginal', time_distortion: str = "polydec", lambda_train: Optional[List[float]] = None, extra_features_type: str = 'rrwp', rrwp_steps: int = 16, sample_steps: int = 500, rdb: str = "general", rdb_crit: str = "p_x1_g_xt", eta: float = 0, omega: float = 0.1, guidance_weight: float = 0.2, batch_size: int = 128, epochs: int = 1000, learning_rate: float = 0.0002, grad_clip_value: Optional[float] = 1.0, weight_decay: float = 0.0, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 10, task_type: Optional[List[str]] = None, verbose: bool = False, *, device: Optional[Union[torch.device, str]] = None, model_name: str = "DeFoGMolecularGenerator", ): super().__init__(device=device, model_name=model_name) # Defaults for dict/list parameters if hidden_mlp_dims is None: hidden_mlp_dims = {'X': 256, 'E': 128, 'y': 128} if hidden_dims is None: hidden_dims = {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128} if lambda_train is None: lambda_train = [5.0, 1.0] # Model hyperparameters self.num_layer = num_layer self.hidden_mlp_dims = hidden_mlp_dims self.hidden_dims = hidden_dims # Flow Matching & Training parameters self.transition = transition self.time_distortion = time_distortion self.lambda_train = lambda_train self.extra_features_type = extra_features_type self.rrwp_steps = rrwp_steps # Sampling parameters self.sample_steps = sample_steps self.rdb = rdb self.rdb_crit = rdb_crit self.eta = eta self.omega = omega self.guidance_weight = guidance_weight # Standard training parameters (mirror GraphDiT style) self.batch_size = batch_size self.epochs = epochs self.learning_rate = learning_rate self.grad_clip_value = grad_clip_value self.weight_decay = weight_decay self.use_lr_scheduler = use_lr_scheduler self.scheduler_factor = scheduler_factor self.scheduler_patience = scheduler_patience # Task and verbosity if task_type is None: self.task_type = list() else: self.task_type = task_type self.verbose = verbose # Book-keeping attributes self.fitting_loss = list() self.fitting_epoch = 0 self.dataset_info = dict() self.model_class = GraphTransformer # Dimensions, to be set up later self.max_node = None self.input_dim_X = None self.input_dim_E = None self.input_dim_y = len(self.task_type) @staticmethod def _get_param_names() -> List[str]: """Get parameter names for the estimator. Returns ------- List[str] List of parameter names that can be used for model configuration. """ return [ # Model Hyperparameters "max_node", "num_layer", "hidden_mlp_dims", "hidden_dims", "input_dim_X", "input_dim_E", "input_dim_y", "task_type", # Flow Matching parameters "transition", "time_distortion", "lambda_train", "extra_features_type", "rrwp_steps", "sample_steps", "rdb", "rdb_crit", "eta", "omega", "guidance_weight", "dataset_info", # Training Parameters "batch_size", "epochs", "learning_rate", "grad_clip_value", "weight_decay", "use_lr_scheduler", "scheduler_factor", "scheduler_patience", # Other Parameters "fitting_epoch", "fitting_loss", "device", "verbose", "model_name" ] def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: """Get model parameters for initialization.""" params = [ "num_layer", "hidden_mlp_dims", "hidden_dims", "input_dim_X", "input_dim_E", "input_dim_y", "task_type" ] if checkpoint is not None: if "hyperparameters" not in checkpoint: raise ValueError("Checkpoint missing 'hyperparameters' key") return {k: checkpoint["hyperparameters"][k] for k in params} return {k: getattr(self, k) for k in params} def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]: """Setup optimization components including optimizer and learning rate scheduler.""" optimizer = torch.optim.Adam( self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) scheduler = None if self.use_lr_scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=self.scheduler_factor, patience=self.scheduler_patience, min_lr=1e-6, cooldown=0, eps=1e-8, ) return optimizer, scheduler def _convert_to_pytorch_data(self, X, y=None): """Convert numpy arrays to PyTorch Geometric data format.""" if self.verbose: iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X)) else: iterator = enumerate(X) pyg_graph_list = [] for idx, smiles_or_mol in iterator: # For unconditional models, properties should be None properties = y[idx] if y is not None and len(self.task_type) > 0 else None graph = graph_from_smiles(smiles_or_mol, properties) g = Data() node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1) if node_type.numel() <= 1: continue valid_mask = node_type >= 0 if not valid_mask.all(): # Get valid nodes and adjust edge indices valid_indices = torch.where(valid_mask)[0] index_map = -torch.ones(node_type.size(0), dtype=torch.long) index_map[valid_indices] = torch.arange(valid_indices.size(0)) edge_index = torch.from_numpy(graph["edge_index"]) valid_edges_mask = valid_mask[edge_index[0]] & valid_mask[edge_index[1]] valid_edge_index = edge_index[:, valid_edges_mask] remapped_edge_index = index_map[valid_edge_index] edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1 valid_edge_attr = edge_attr[valid_edges_mask] node_type = node_type[valid_mask] g.edge_index = remapped_edge_index g.edge_attr = valid_edge_attr.long() else: g.edge_index = torch.from_numpy(graph["edge_index"]) edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1 g.edge_attr = edge_attr.long() node_type[node_type == 118] = 117 # Remap '*' g.x = node_type.long().squeeze(-1) # Only set g.y for conditional models if len(self.task_type) > 0: g.y = torch.from_numpy(graph["y"]) pyg_graph_list.append(g) return pyg_graph_list def _setup_flow_params(self, X: Union[List, Dict]): """Sets up parameters and distributions required for the flow matching process.""" if isinstance(X, dict): self.dataset_info = X["hyperparameters"]["dataset_info"] else: self.dataset_info = compute_dataset_info(X) class DeFoGDatasetInfos: def __init__(self, info, task_type): self.info = info self.output_dims = {'X': len(info['x_margins']), 'E': len(info['e_margins']), 'y': len(task_type)} self.input_dims = self.output_dims.copy() self.nodes_dist = info['num_nodes_dist'] self.node_types = info['x_margins'] self.edge_types = info['e_margins'] self.atom_decoder = info['atom_decoder'] defog_infos = DeFoGDatasetInfos(self.dataset_info, self.task_type) self.noise_dist = NoiseDistribution(self.transition, defog_infos) self.limit_dist = self.noise_dist.get_limit_dist() self.limit_dist.to_device(self.device) self.time_distorter = TimeDistorter(train_distortion=self.time_distortion, sample_distortion=self.time_distortion, alpha=1, beta=1) self.input_dim_X = self.noise_dist.x_num_classes self.input_dim_E = self.noise_dist.e_num_classes self.max_node = self.dataset_info['max_node'] def _initialize_model( self, model_class: Type[torch.nn.Module], checkpoint: Optional[Dict] = None ) -> torch.nn.Module: """Initialize the model with parameters or a checkpoint.""" if checkpoint is not None: # When loading from checkpoint, we need to set up flow params first self._setup_flow_params(checkpoint) # Initialize ExtraFeatures class DeFoGDatasetInfosForFeatures: def __init__(self, max_n_nodes): self.max_n_nodes = max_n_nodes feature_infos = DeFoGDatasetInfosForFeatures(self.max_node) self.extra_features_computer = ExtraFeatures( self.extra_features_type, self.rrwp_steps, feature_infos ) # Temp calculation to get feature dimensions # This is a bit of a hack, but necessary to get the dims before full model init dummy_x = torch.zeros(1, self.max_node, self.input_dim_X, device=self.device) dummy_e = torch.zeros(1, self.max_node, self.max_node, self.input_dim_E, device=self.device) dummy_y = torch.zeros(1, self.input_dim_y, device=self.device) dummy_mask = torch.ones(1, self.max_node, device=self.device).bool() dummy_t = torch.zeros(1, 1, device=self.device) dummy_noisy_data = {'X_t': dummy_x, 'E_t': dummy_e, 'y_t': dummy_y, 't': dummy_t, 'node_mask': dummy_mask} extra_data = self._compute_extra_data(dummy_noisy_data) self.input_dim_X_extra = extra_data.X.shape[-1] self.input_dim_E_extra = extra_data.E.shape[-1] self.input_dim_y_extra = extra_data.y.shape[-1] model_params = { 'n_layers': self.num_layer, 'input_dims': { 'X': self.input_dim_X + self.input_dim_X_extra, 'E': self.input_dim_E + self.input_dim_E_extra, 'y': self.input_dim_y + self.input_dim_y_extra, }, 'hidden_mlp_dims': self.hidden_mlp_dims, 'hidden_dims': self.hidden_dims, 'output_dims': {'X': self.input_dim_X, 'E': self.input_dim_E, 'y': self.input_dim_y}, 'act_fn_in': nn.ReLU(), 'act_fn_out': nn.ReLU(), } self.model = model_class(**model_params) self.model = self.model.to(self.device) if checkpoint is not None: self.model.load_state_dict(checkpoint["model_state_dict"]) self.rate_matrix_designer = RateMatrixDesigner( rdb=self.rdb, rdb_crit=self.rdb_crit, eta=self.eta, omega=self.omega, limit_dist=self.limit_dist )
[docs] def fit(self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]] = None) -> "DeFoGMolecularGenerator": num_task = len(self.task_type) # For unconditional models, ensure y_train is None if num_task == 0 and y_train is not None: print("Warning: y_train provided for unconditional model. Ignoring y_train.") y_train = None X_train, y_train = self._validate_inputs(X_train, y_train, num_task=num_task) self._setup_flow_params(X_train) self._initialize_model(self.model_class) optimizer, scheduler = self._setup_optimizers() train_dataset = self._convert_to_pytorch_data(X_train, y_train) train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) # Calculate total steps for global progress bar total_steps = self.epochs * len(train_loader) global_progress = tqdm(total=total_steps, desc="Training Progress", leave=True) if self.verbose else None self.fitting_loss = [] for epoch in range(self.epochs): train_losses = self._train_epoch(train_loader, optimizer, epoch, global_progress) avg_loss = np.mean(train_losses) self.fitting_loss.append(avg_loss) if scheduler: scheduler.step(avg_loss) if global_progress: global_progress.close() self.is_fitted_ = True return self
def _train_epoch(self, train_loader, optimizer, epoch, global_progress=None): self.model.train() losses = [] active_index = self.dataset_info["active_index"] for batched_data in train_loader: batched_data = batched_data.to(self.device) optimizer.zero_grad() data_x = F.one_hot(batched_data.x, num_classes=118).float()[:, active_index] data_edge_attr = F.one_hot(batched_data.edge_attr, num_classes=5).float() dense_data, node_mask = to_dense(data_x, batched_data.edge_index, data_edge_attr, batched_data.batch) dense_data = dense_data.mask(node_mask) X, E = dense_data.X, dense_data.E # Process y data if hasattr(batched_data, "y") and len(self.task_type) > 0: # Conditional model y = batched_data.y if torch.rand(1) < 0.1: # 10% probability to use unconditional training y = torch.ones_like(y, device=self.device) * -1 else: # Unconditional model # Ensure y has correct dimensions batch_size = X.size(0) y = torch.zeros(batch_size, self.input_dim_y, device=self.device) noisy_data = self.apply_noise(X, E, y, node_mask) extra_data = self._compute_extra_data(noisy_data) pred = self.model( torch.cat((noisy_data["X_t"], extra_data.X), dim=2).float(), torch.cat((noisy_data["E_t"], extra_data.E), dim=3).float(), torch.cat((noisy_data["y_t"], extra_data.y), dim=1).float(), node_mask ) loss_X = F.cross_entropy(pred.X.transpose(1, 2), X.argmax(-1), reduction='none') loss_E = F.cross_entropy(pred.E.transpose(1, 3), E.argmax(-1), reduction='none') masked_loss_X = (loss_X * node_mask).sum() / node_mask.sum() e_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2) masked_loss_E = (loss_E * e_mask).sum() / e_mask.sum() if len(self.task_type) > 0 and y.shape[-1] > 0: # For conditional models, compute loss on y predictions # Use MSE loss for regression tasks, similar to original DeFoG loss_y = F.mse_loss(pred.y, y, reduction='mean') else: loss_y = torch.tensor(0.0, device=self.device) loss = masked_loss_X + self.lambda_train[0] * masked_loss_E + self.lambda_train[1] * loss_y loss.backward() if self.grad_clip_value is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value) optimizer.step() losses.append(loss.item()) # Update global progress bar if global_progress: global_progress.set_postfix({ "Epoch": f"{epoch+1}", "Loss": f"{loss.item():.4f}", "Loss_X": f"{masked_loss_X.item():.4f}", "Loss_E": f"{masked_loss_E.item():.4f}", "Loss_y": f"{loss_y.item():.4f}" }) global_progress.update(1) return losses def apply_noise(self, X, E, y, node_mask): bs = X.size(0) t_float = self.time_distorter.train_ft(bs, self.device) X_1_label = torch.argmax(X, dim=-1) E_1_label = torch.argmax(E, dim=-1) prob_X_t, prob_E_t = p_xt_g_x1(X_1_label, E_1_label, t_float, self.limit_dist) sampled_t = sample_discrete_features(prob_X_t, prob_E_t, node_mask=node_mask) X_t = F.one_hot(sampled_t.X, num_classes=self.input_dim_X) E_t = F.one_hot(sampled_t.E, num_classes=self.input_dim_E) z_t = PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask) return {"t": t_float, "X_t": z_t.X, "E_t": z_t.E, "y_t": z_t.y, "node_mask": node_mask} def _compute_extra_data(self, noisy_data): """Computes extra features for the model, like time embedding and RRWP.""" extra_features = self.extra_features_computer(noisy_data) t = noisy_data['t'] extra_y = torch.cat((extra_features.y, t), dim=1) return PlaceHolder(X=extra_features.X, E=extra_features.E, y=extra_y)
[docs] @torch.no_grad() def generate(self, labels: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, num_nodes: Optional[Union[List[List], np.ndarray, torch.Tensor]] = None, batch_size: int = 32) -> List[str]: """Generate molecules with specified properties and optional node counts. Parameters ---------- labels : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None Target properties for the generated molecules. Can be provided as: - A list of lists for multiple properties - A numpy array of shape (batch_size, n_properties) - A torch tensor of shape (batch_size, n_properties) For single label (properties values), can also be provided as 1D array/tensor. If None, generates unconditional samples. num_nodes : Optional[Union[List[List], np.ndarray, torch.Tensor]], default=None Number of nodes for each molecule in the batch. If None, samples from the training distribution. Can be provided as: - A list of lists - A numpy array of shape (batch_size, 1) - A torch tensor of shape (batch_size, 1) batch_size : int, default=32 Number of molecules to generate. Only used if labels is None. Returns ------- List[str] List of generated molecules in SMILES format. """ if not self.is_fitted_: raise ValueError("Model must be fitted before generating molecules.") if self.input_dim_X is None or self.input_dim_E is None or self.max_node is None: raise ValueError(f"Model may not be fitted correctly as one of below attributes is not set: input_dim_X={self.input_dim_X}, input_dim_E={self.input_dim_E}, max_node={self.max_node}") if len(self.task_type) > 0 and labels is None: raise ValueError(f"labels must be provided if task_type is not empty: {self.task_type}") if labels is not None and num_nodes is not None: assert len(labels) == len(num_nodes), "labels and num_nodes must have the same batch size" if labels is not None: if num_nodes is not None: assert len(labels) == len(num_nodes), "labels and num_nodes must have the same batch size" batch_size = len(labels) elif num_nodes is not None: batch_size = len(num_nodes) # Convert properties to 2D tensor if needed if isinstance(labels, list): labels = torch.tensor(labels) elif isinstance(labels, np.ndarray): labels = torch.from_numpy(labels) if labels is not None and labels.dim() == 1: labels = labels.unsqueeze(-1) if num_nodes is None: num_nodes_dist = self.dataset_info["num_nodes_dist"] num_nodes = num_nodes_dist.sample_n(batch_size, self.device) elif isinstance(num_nodes, list): num_nodes = torch.tensor(num_nodes).to(self.device) elif isinstance(num_nodes, np.ndarray): num_nodes = torch.from_numpy(num_nodes).to(self.device) if num_nodes.dim() == 1: num_nodes = num_nodes.unsqueeze(-1) assert num_nodes.size(0) == batch_size arange = ( torch.arange(self.max_node).to(self.device) .unsqueeze(0) .expand(batch_size, -1) ) node_mask = arange < num_nodes self.model.eval() z_T = sample_discrete_feature_noise(self.limit_dist, node_mask) X, E = z_T.X, z_T.E if len(self.task_type) > 0 and labels is not None: y = labels.to(self.device).float() else: # unconditional y = torch.zeros(batch_size, self.input_dim_y, device=self.device) for t_int in tqdm(reversed(range(0, self.sample_steps)), desc="Generating", total=self.sample_steps, disable=not self.verbose): t_array = t_int * torch.ones((batch_size, 1), device=self.device) s_array = (t_int + 1) * torch.ones((batch_size, 1), device=self.device) t_norm = t_array / self.sample_steps s_norm = s_array / self.sample_steps t_norm_distorted = self.time_distorter.sample_ft(t_norm, self.time_distortion) s_norm_distorted = self.time_distorter.sample_ft(s_norm, self.time_distortion) sampled_s, _ = self._sample_step(t_norm_distorted, s_norm_distorted, X, E, y, node_mask) X, E = sampled_s.X, sampled_s.E sampled_s = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True) X_final, E_final, _ = self.noise_dist.ignore_virtual_classes(sampled_s.X, sampled_s.E, sampled_s.y) molecule_list = [] for i in range(batch_size): n = num_nodes[i].item() atom_types = X_final[i, :n].cpu() edge_types = E_final[i, :n, :n].cpu() molecule_list.append([atom_types, edge_types]) return graph_to_smiles(molecule_list, self.dataset_info["atom_decoder"])
def _sample_step(self, t, s, X_t, E_t, y_t, node_mask): dt = (s - t)[0] def get_rates(y_in): noisy_data = {"X_t": X_t, "E_t": E_t, "y_t": y_in, "t": t, "node_mask": node_mask} extra_data = self._compute_extra_data(noisy_data) pred = self.model( torch.cat((noisy_data["X_t"], extra_data.X), dim=2).float(), torch.cat((noisy_data["E_t"], extra_data.E), dim=3).float(), torch.cat((noisy_data["y_t"], extra_data.y), dim=1).float(), node_mask ) pred_X = F.softmax(pred.X, dim=-1) pred_E = F.softmax(pred.E, dim=-1) return self.rate_matrix_designer.compute_graph_rate_matrix(t, node_mask, (X_t, E_t), (pred_X, pred_E)) R_t_X, R_t_E = get_rates(y_t) is_conditional = len(self.task_type) > 0 and y_t.shape[-1] > 0 and y_t.sum() != 0 if is_conditional and self.guidance_weight > 0: # Use -1 as unconditional signal, similar to original DeFoG uncond_y = torch.ones_like(y_t, device=self.device) * -1 R_t_X_uncond, R_t_E_uncond = get_rates(uncond_y) # Use logarithmic interpolation R_t_X = torch.exp( torch.log(R_t_X_uncond + 1e-6) * (1 - self.guidance_weight) + torch.log(R_t_X + 1e-6) * self.guidance_weight ) R_t_E = torch.exp( torch.log(R_t_E_uncond + 1e-6) * (1 - self.guidance_weight) + torch.log(R_t_E + 1e-6) * self.guidance_weight ) step_probs_X, step_probs_E = self._compute_step_probs(R_t_X, R_t_E, X_t, E_t, dt) if s[0] >= 1.0: noisy_data = {"X_t": X_t, "E_t": E_t, "y_t": y_t, "t": t, "node_mask": node_mask} extra_data = self._compute_extra_data(noisy_data) pred = self.model( torch.cat((noisy_data["X_t"], extra_data.X), dim=2).float(), torch.cat((noisy_data["E_t"], extra_data.E), dim=3).float(), torch.cat((noisy_data["y_t"], extra_data.y), dim=1).float(), node_mask ) step_probs_X, step_probs_E = F.softmax(pred.X, dim=-1), F.softmax(pred.E, dim=-1) sampled_s = sample_discrete_features(step_probs_X, step_probs_E, node_mask=node_mask) X_s = F.one_hot(sampled_s.X, num_classes=self.input_dim_X).float() E_s = F.one_hot(sampled_s.E, num_classes=self.input_dim_E).float() if is_conditional: y_to_save = y_t else: y_to_save = torch.zeros([y_t.shape[0], 0], device=self.device) out_one_hot = PlaceHolder(X=X_s, E=E_s, y=y_to_save).mask(node_mask).type_as(y_t) out_discrete = PlaceHolder(X=X_s, E=E_s, y=y_to_save).mask(node_mask, collapse=True).type_as(y_t) return out_one_hot, out_discrete def _compute_step_probs(self, R_t_X, R_t_E, X_t, E_t, dt): step_probs_X = (R_t_X * dt).clamp(0, 1) step_probs_E = (R_t_E * dt).clamp(0, 1) step_probs_X.scatter_(-1, X_t.argmax(-1, keepdim=True), 0.0) step_probs_E.scatter_(-1, E_t.argmax(-1, keepdim=True), 0.0) diag_X = (1.0 - step_probs_X.sum(dim=-1, keepdim=True)).clamp(min=0.0) diag_E = (1.0 - step_probs_E.sum(dim=-1, keepdim=True)).clamp(min=0.0) step_probs_X.scatter_(-1, X_t.argmax(-1, keepdim=True), diag_X) step_probs_E.scatter_(-1, E_t.argmax(-1, keepdim=True), diag_E) return step_probs_X, step_probs_E