Source code for torch_molecule.predictor.bfgnn.modeling_bfgnn
import os
import numpy as np
import warnings
import datetime
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, List, Type
from dataclasses import dataclass, field
import torch
from torch_geometric.data import Data
from .model import BFGNN
from ...utils import graph_from_smiles
from ..gnn.modeling_gnn import GNNMolecularPredictor
from ...utils.search import (
ParameterSpec,
ParameterType,
)
[docs]
@dataclass
class BFGNNMolecularPredictor(GNNMolecularPredictor):
"""
This predictor implements algorithm alignment of Bellman-Ford algorithm with GNN.
References
----------
- Graph neural networks extrapolate out-of-distribution for shortest paths.
https://arxiv.org/abs/2503.19173
:param l1_penalty: Weight for the L1 penalty
:type l1_penalty: float, default=1e-3
"""
l1_penalty: float = 1e-3
# Other Non-init fields
model_name: str = "BFGNNMolecularPredictor"
model_class: Type[BFGNN] = field(default=BFGNN, init=False)
def __post_init__(self):
super().__post_init__()
@staticmethod
def _get_param_names() -> List[str]:
return GNNMolecularPredictor._get_param_names() + [
"l1_penalty",
]
def _get_default_search_space(self):
search_space = super()._get_default_search_space().copy()
search_space["l1_penalty"] = ParameterSpec(ParameterType.LOG_FLOAT, (1e-6, 1))
return search_space
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
base_params = super()._get_model_params(checkpoint)
return base_params
def _train_epoch(self, train_loader, optimizer, epoch):
self.model.train()
losses = []
iterator = (
tqdm(train_loader, desc="Training", leave=False)
if self.verbose
else train_loader
)
for step, batch in enumerate(iterator):
batch = batch.to(self.device)
optimizer.zero_grad()
loss = self.model.compute_loss(batch, self.loss_criterion, self.l1_penalty)
loss.backward()
if self.grad_clip_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value)
optimizer.step()
if self.verbose:
iterator.set_postfix({"Epoch": epoch, "Total Loss": f"{loss.item():.4f}"})
losses.append(loss.item())
return losses