Source code for torch_molecule.utils.generic.weights
importtorch
[docs]definit_weights(net,init_type='xavier',init_gain=0.02,verbose=False):"""Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. """definit_func(m):# define the initialization functionclassname=m.__class__.__name__ifhasattr(m,'weight')and(classname.find('Conv')!=-1orclassname.find('Linear')!=-1):ifinit_type=='normal':torch.nn.init.normal_(m.weight.data,0.0,init_gain)elifinit_type=='xavier':torch.nn.init.xavier_normal_(m.weight.data,gain=init_gain)elifinit_type=='kaiming':torch.nn.init.kaiming_normal_(m.weight.data,a=0,mode='fan_in')elifinit_type=='orthogonal':torch.nn.init.orthogonal_(m.weight.data)else:raiseNotImplementedError('initialization method [%s] is not implemented'%init_type)ifhasattr(m,'bias')andm.biasisnotNone:torch.nn.init.constant_(m.bias.data,0.0)elifclassname.find('BatchNorm2d')!=-1:# BatchNorm Layer's weight is not a matrixtorch.nn.init.normal_(m.weight.data,1.0,init_gain)torch.nn.init.constant_(m.bias.data,0.0)ifverbose:print('initialize network with %s'%init_type)net.apply(init_func)# apply the initialization function