Source code for torch_molecule.nn.mlp
import torch.nn as nn
[docs]
class MLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_bn=True,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
linear_layer = nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias)
if use_bn:
self.bn1 = nn.BatchNorm1d(hidden_features)
else:
self.bn1 = nn.Identity()
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
return x