Source code for torch_molecule.nn.gnn

import torch
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_geometric.nn.norm import GraphNorm, PairNorm, DiffGroupNorm, InstanceNorm, LayerNorm, GraphSizeNorm
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool

from ..utils import get_atom_feature_dims, get_bond_feature_dims 

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

[docs] class AtomEncoder(torch.nn.Module): """Encodes atom features into a fixed-size vector representation. This module converts categorical atom features into embeddings and combines them to create a unified atom representation. Parameters ---------- hidden_size : int Dimensionality of the output atom embedding vectors. Notes ----- Each atom feature is embedded separately using an Embedding layer, then these embeddings are summed to produce the final representation. The embedding weights are initialized using Xavier uniform initialization with max_norm=1 constraint. """ def __init__(self, hidden_size): super(AtomEncoder, self).__init__() self.atom_embedding_list = torch.nn.ModuleList() for i, dim in enumerate(full_atom_feature_dims): emb = torch.nn.Embedding(dim, hidden_size, max_norm=1) torch.nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb)
[docs] def forward(self, x): """Transform atom features into embeddings. Parameters ---------- x : torch.Tensor Tensor of shape [num_atoms, num_features] containing categorical atom features. Returns ------- torch.Tensor Atom embeddings of shape [num_atoms, hidden_size]. """ x_embedding = 0 for i in range(x.shape[1]): x_embedding += self.atom_embedding_list[i](x[:,i]) return x_embedding
[docs] class BondEncoder(torch.nn.Module): """Encodes bond features into a fixed-size vector representation. This module converts categorical bond features into embeddings and combines them to create a unified bond representation. Parameters ---------- hidden_size : int Dimensionality of the output bond embedding vectors. Notes ----- Each bond feature is embedded separately using an Embedding layer, then these embeddings are summed to produce the final representation. The embedding weights are initialized using Xavier uniform initialization with max_norm=1 constraint. """ def __init__(self, hidden_size): super(BondEncoder, self).__init__() self.bond_embedding_list = torch.nn.ModuleList() for i, dim in enumerate(full_bond_feature_dims): emb = torch.nn.Embedding(dim, hidden_size, max_norm=1) torch.nn.init.xavier_uniform_(emb.weight.data) self.bond_embedding_list.append(emb)
[docs] def forward(self, edge_attr): """Transform bond features into embeddings. Parameters ---------- edge_attr : torch.Tensor Tensor of shape [num_bonds, num_features] containing categorical bond features. Returns ------- torch.Tensor Bond embeddings of shape [num_bonds, hidden_size]. """ bond_embedding = 0 for i in range(edge_attr.shape[1]): bond_embedding += self.bond_embedding_list[i](edge_attr[:,i]) return bond_embedding
[docs] class GINConv(MessagePassing): def __init__(self, hidden_size, output_size=None): ''' hidden_size (int): node embedding dimensionality ''' super(GINConv, self).__init__(aggr = "add") if output_size is None: output_size = hidden_size self.mlp = torch.nn.Sequential(torch.nn.Linear(hidden_size, 2*hidden_size), torch.nn.BatchNorm1d(2*hidden_size), torch.nn.ReLU(), torch.nn.Linear(2*hidden_size, output_size)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(hidden_size = hidden_size)
[docs] def forward(self, x, edge_index, edge_attr): edge_embedding = self.bond_encoder(edge_attr) out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) return out
[docs] def message(self, x_j, edge_attr): return F.relu(x_j + edge_attr)
[docs] def update(self, aggr_out): return aggr_out
[docs] class GCNConv(MessagePassing): def __init__(self, hidden_size, output_size=None): super(GCNConv, self).__init__(aggr='add') if output_size is None: output_size = hidden_size self.linear = torch.nn.Linear(hidden_size, output_size) self.root_emb = torch.nn.Embedding(1, output_size) self.bond_encoder = BondEncoder(hidden_size = output_size)
[docs] def forward(self, x, edge_index, edge_attr): x = self.linear(x) edge_embedding = self.bond_encoder(edge_attr) row, col = edge_index #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) deg = degree(row, x.size(0), dtype = x.dtype) + 1 deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
[docs] def message(self, x_j, edge_attr, norm): return norm.view(-1, 1) * F.relu(x_j + edge_attr)
[docs] def update(self, aggr_out): return aggr_out
### GNN to generate node embedding
[docs] class GNN_node(torch.nn.Module): """ Output: node representations """ def __init__(self, num_layer, hidden_size, drop_ratio = 0.5, JK = "last", residual = False, gnn_name = 'gin', norm_layer = 'batch_norm', encode_atom = True): ''' hidden_size (int): node embedding dimensionality num_layer (int): number of GNN message passing layers ''' super(GNN_node, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK ### add residual connection or not self.residual = residual self.norm_layer = norm_layer self.encode_atom = encode_atom if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(hidden_size) self.bond_encoder = BondEncoder(hidden_size) ###List of GNNs self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for layer in range(num_layer): if gnn_name == 'gin': self.convs.append(GINConv(hidden_size)) elif gnn_name == 'gcn': self.convs.append(GCNConv(hidden_size)) else: raise ValueError('Undefined GNN type called {}'.format(gnn_name)) if norm_layer.split('_')[0] == 'batch': if norm_layer.split('_')[-1] == 'notrack': self.batch_norms.append(torch.nn.BatchNorm1d(hidden_size, track_running_stats=False, affine=False)) else: self.batch_norms.append(torch.nn.BatchNorm1d(hidden_size)) elif norm_layer.split('_')[0] == 'instance': self.batch_norms.append(InstanceNorm(hidden_size)) elif norm_layer.split('_')[0] == 'layer': self.batch_norms.append(LayerNorm(hidden_size)) elif norm_layer.split('_')[0] == 'graph': self.batch_norms.append(GraphNorm(hidden_size)) elif norm_layer.split('_')[0] == 'size': self.batch_norms.append(GraphSizeNorm()) elif norm_layer.split('_')[0] == 'pair': self.batch_norms.append(PairNorm(hidden_size)) elif norm_layer.split('_')[0] == 'group': self.batch_norms.append(DiffGroupNorm(hidden_size, groups=4)) else: raise ValueError('Undefined normalization layer called {}'.format(norm_layer)) if norm_layer.split('_')[1] == 'size': self.graph_size_norm = GraphSizeNorm() # def forward(self, batched_data): # x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
[docs] def forward(self, *args): if len(args) == 1: # Case 1: batched_data input batched_data = args[0] x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch elif len(args) == 4: # Case 2: separate inputs x, edge_index, edge_attr, batch = args else: raise ValueError("forward expects either 1 batched_data argument or 4 separate arguments (x, edge_index, edge_attr, batch)") if self.encode_atom: h_list = [self.atom_encoder(x)] else: h_list = [x] for layer in range(self.num_layer): h = self.convs[layer](h_list[layer], edge_index, edge_attr) if self.norm_layer.split('_')[0] == 'batch': h = self.batch_norms[layer](h) else: h = self.batch_norms[layer](h, batch) if self.norm_layer.split('_')[1] == 'size': h = self.graph_size_norm(h, batch) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.relu(h) h = F.dropout(h, self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer + 1): node_representation += h_list[layer] return node_representation, h_list
### Virtual GNN to generate node embedding
[docs] class GNN_node_Virtualnode(torch.nn.Module): """ Output: node representations """ def __init__(self, num_layer, hidden_size, drop_ratio = 0.5, JK = "last", residual = False, gnn_name = 'gin', norm_layer = 'batch_norm', encode_atom = True): ''' hidden_size (int): node embedding dimensionality ''' super(GNN_node_Virtualnode, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK ### add residual connection or not self.residual = residual self.norm_layer = norm_layer self.encode_atom = encode_atom if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(hidden_size) self.bond_encoder = BondEncoder(hidden_size) ### set the initial virtual node embedding to 0. self.virtualnode_embedding = torch.nn.Embedding(1, hidden_size) torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) ### List of GNNs self.convs = torch.nn.ModuleList() ### batch norms applied to node embeddings self.batch_norms = torch.nn.ModuleList() ### List of MLPs to transform virtual node at every layer self.mlp_virtualnode_list = torch.nn.ModuleList() for layer in range(num_layer): if gnn_name == 'gin': self.convs.append(GINConv(hidden_size)) elif gnn_name == 'gcn': self.convs.append(GCNConv(hidden_size)) else: raise ValueError('Undefined GNN type called {}'.format(gnn_name)) if norm_layer.split('_')[0] == 'batch': if norm_layer.split('_')[-1] == 'notrack': self.batch_norms.append(torch.nn.BatchNorm1d(hidden_size, track_running_stats=False, affine=False)) else: self.batch_norms.append(torch.nn.BatchNorm1d(hidden_size)) elif norm_layer.split('_')[0] == 'instance': self.batch_norms.append(InstanceNorm(hidden_size)) elif norm_layer.split('_')[0] == 'layer': self.batch_norms.append(LayerNorm(hidden_size)) elif norm_layer.split('_')[0] == 'graph': self.batch_norms.append(GraphNorm(hidden_size)) elif norm_layer.split('_')[0] == 'size': self.batch_norms.append(GraphSizeNorm()) elif norm_layer.split('_')[0] == 'pair': self.batch_norms.append(PairNorm(hidden_size)) elif norm_layer.split('_')[0] == 'group': self.batch_norms.append(DiffGroupNorm(hidden_size, groups=4)) else: raise ValueError('Undefined normalization layer called {}'.format(norm_layer)) if norm_layer.split('_')[1] == 'size': self.graph_size_norm = GraphSizeNorm() for layer in range(num_layer - 1): self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(hidden_size, 2*hidden_size), torch.nn.BatchNorm1d(2*hidden_size), torch.nn.ReLU(), \ torch.nn.Linear(2*hidden_size, hidden_size), torch.nn.BatchNorm1d(hidden_size), torch.nn.ReLU()))
[docs] def forward(self, *args): if len(args) == 1: # Case 1: batched_data input batched_data = args[0] x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch elif len(args) == 4: # Case 2: separate inputs x, edge_index, edge_attr, batch = args else: raise ValueError("forward expects either 1 batched_data argument or 4 separate arguments (x, edge_index, edge_attr, batch)") ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) if self.encode_atom: h_list = [self.atom_encoder(x)] else: h_list = [x] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) if self.norm_layer.split('_')[0] == 'batch': h = self.batch_norms[layer](h) else: h = self.batch_norms[layer](h, batch) if self.norm_layer.split('_')[1] == 'size': h = self.graph_size_norm(h, batch) if layer == self.num_layer - 1: h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.relu(h) h = F.dropout(h, self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) else: virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer + 1): node_representation += h_list[layer] return node_representation, h_list