Neural Network Components

Multi-Head Attention with Node Masking

class torch_molecule.nn.attention.AttentionWithNodeMask(dim, num_head=8, qkv_bias=False, qk_norm=False, attn_drop=0.0, proj_drop=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]

Bases: Module

fast_attn: Final[bool]
forward(x, node_mask)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Embedding methods for different input features

class torch_molecule.nn.embedder.CategoricalEmbedder(num_classes, hidden_size, dropout_prob)[source]

Bases: Module

Embeds categorical conditions (e.g., data source labels) into vector representations. Supports label dropout for classifier-free guidance.

Parameters:
  • num_classes (int) – Number of distinct label categories.

  • hidden_size (int) – Size of the embedding vectors.

  • dropout_prob (float) – Probability of label dropout.

forward(labels, train, force_drop_ids=None)[source]

Forward pass for categorical embedding with optional label dropout.

Parameters:
  • labels (torch.Tensor) – Tensor of categorical labels.

  • train (bool) – Whether the model is in training mode.

  • force_drop_ids (torch.Tensor or None, optional) – Explicit mask for which labels to drop.

Returns:

Embedded label representations, with optional noise added during training.

Return type:

torch.Tensor

token_drop(labels, force_drop_ids=None)[source]

Drops labels to enable classifier-free guidance.

Parameters:
  • labels (torch.Tensor) – Tensor of integer labels.

  • force_drop_ids (torch.Tensor or None, optional) – Boolean mask to force specific labels to be dropped.

Returns:

Labels with some entries replaced by a dropout token.

Return type:

torch.Tensor

class torch_molecule.nn.embedder.ClusterContinuousEmbedder(input_size, hidden_size, dropout_prob)[source]

Bases: Module

Embeds continuous input features into vector representations using a multilayer perceptron (MLP). Supports optional embedding dropout for classifier-free guidance.

Parameters:
  • input_size (int) – The size of the input features.

  • hidden_size (int) – The size of the output embedding vectors.

  • dropout_prob (float) – Probability of embedding dropout, used for classifier-free guidance.

forward(labels, train, force_drop_ids=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torch_molecule.nn.embedder.TimestepEmbedder(hidden_size, frequency_embedding_size=256)[source]

Bases: Module

Embeds scalar timesteps into vector representations using a sinusoidal embedding followed by a multilayer perceptron (MLP).

Parameters:
  • hidden_size (int) – Output dimension of the MLP embedding.

  • frequency_embedding_size (int, optional) – Size of the input frequency embedding, by default 256.

forward(t)[source]

Forward pass for timestep embedding.

Parameters:

t (torch.Tensor) – 1D tensor of scalar timesteps.

Returns:

The final embedded representation of shape (N, hidden_size).

Return type:

torch.Tensor

static timestep_embedding(t, dim, max_period=10000)[source]

Create sinusoidal timestep embeddings.

Parameters:
  • t – a 1-D Tensor of N indices, one per batch element. These may be fractional.

  • dim – the dimension of the output.

  • max_period – controls the minimum frequency of the embeddings.

Returns:

an (N, D) Tensor of positional embeddings.

Layers for Graph Neural Networks

class torch_molecule.nn.gnn.AtomEncoder(hidden_size)[source]

Bases: Module

Encodes atom features into a fixed-size vector representation.

This module converts categorical atom features into embeddings and combines them to create a unified atom representation.

Parameters:

hidden_size (int) – Dimensionality of the output atom embedding vectors.

Notes

Each atom feature is embedded separately using an Embedding layer, then these embeddings are summed to produce the final representation. The embedding weights are initialized using Xavier uniform initialization with max_norm=1 constraint.

forward(x)[source]

Transform atom features into embeddings.

Parameters:

x (torch.Tensor) – Tensor of shape [num_atoms, num_features] containing categorical atom features.

Returns:

Atom embeddings of shape [num_atoms, hidden_size].

Return type:

torch.Tensor

class torch_molecule.nn.gnn.BondEncoder(hidden_size)[source]

Bases: Module

Encodes bond features into a fixed-size vector representation.

This module converts categorical bond features into embeddings and combines them to create a unified bond representation.

Parameters:

hidden_size (int) – Dimensionality of the output bond embedding vectors.

Notes

Each bond feature is embedded separately using an Embedding layer, then these embeddings are summed to produce the final representation. The embedding weights are initialized using Xavier uniform initialization with max_norm=1 constraint.

forward(edge_attr)[source]

Transform bond features into embeddings.

Parameters:

edge_attr (torch.Tensor) – Tensor of shape [num_bonds, num_features] containing categorical bond features.

Returns:

Bond embeddings of shape [num_bonds, hidden_size].

Return type:

torch.Tensor

class torch_molecule.nn.gnn.GCNConv(hidden_size, output_size=None)[source]

Bases: MessagePassing

forward(x, edge_index, edge_attr)[source]

Runs the forward pass of the module.

message(x_j, edge_attr, norm)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

update(aggr_out)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

class torch_molecule.nn.gnn.GINConv(hidden_size, output_size=None)[source]

Bases: MessagePassing

forward(x, edge_index, edge_attr)[source]

Runs the forward pass of the module.

message(x_j, edge_attr)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

update(aggr_out)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

class torch_molecule.nn.gnn.GNN_node(num_layer, hidden_size, drop_ratio=0.5, JK='last', residual=False, gnn_name='gin', norm_layer='batch_norm', encode_atom=True)[source]

Bases: Module

Output:

node representations

forward(*args)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torch_molecule.nn.gnn.GNN_node_Virtualnode(num_layer, hidden_size, drop_ratio=0.5, JK='last', residual=False, gnn_name='gin', norm_layer='batch_norm', encode_atom=True)[source]

Bases: Module

Output:

node representations

forward(*args)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Multi-Layer Perceptron

class torch_molecule.nn.mlp.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, bias=True, drop=0.0, use_bn=True)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.