Source code for torch_molecule.nn.mlp

import torch.nn as nn

[docs] class MLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., use_bn=True, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features linear_layer = nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias) if use_bn: self.bn1 = nn.BatchNorm1d(hidden_features) else: self.bn1 = nn.Identity() self.act = act_layer() self.drop1 = nn.Dropout(drop) self.fc2 = linear_layer(hidden_features, out_features, bias=bias)
[docs] def forward(self, x): x = self.fc1(x) x = self.bn1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) return x