import numpy as np
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Literal, Type
from dataclasses import dataclass, field
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from .model import GNN
from ..constant import GNN_ENCODER_MODELS, GNN_ENCODER_READOUTS, GNN_ENCODER_PARAMS
from ...base import BaseMolecularEncoder
from ...utils import graph_from_smiles
ALLOWABLE_ENCODER_MODELS = GNN_ENCODER_MODELS
ALLOWABLE_ENCODER_READOUTS = GNN_ENCODER_READOUTS
[docs]
@dataclass
class InfoGraphMolecularEncoder(BaseMolecularEncoder):
"""This encoder implements a InfoGraph for molecular representation learning.
InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization (ICLR 2020)
References
----------
- Paper: https://arxiv.org/abs/1908.01000
- Code: https://github.com/sunfanyunn/InfoGraph/tree/master/unsupervised
Parameters
----------
lw_prior : float, default=0.
Weight for prior loss term.
embedding_dim : int, default=160
Dimension of final graph embedding. Must be divisible by num_layer.
num_layer : int, default=5
Number of GNN layers.
drop_ratio : float, default=0.5
Dropout probability.
norm_layer : str, default="batch_norm"
Type of normalization layer to use. One of ["batch_norm", "layer_norm", "instance_norm", "graph_norm", "size_norm", "pair_norm"].
encoder_type : str, default="gin-virtual"
Type of GNN architecture to use. One of ["gin-virtual", "gcn-virtual", "gin", "gcn"].
readout : str, default="sum"
Method for aggregating node features to obtain graph-level representations. One of ["sum", "mean", "max"].
batch_size : int, default=128
Number of samples per batch for training.
epochs : int, default=500
Maximum number of training epochs.
learning_rate : float, default=0.001
Learning rate for optimizer.
grad_clip_value : float, optional
Maximum norm of gradients for gradient clipping.
weight_decay : float, default=0.0
L2 regularization strength.
use_lr_scheduler : bool, default=False
Whether to use a learning rate scheduler.
scheduler_factor : float, default=0.5
Factor by which to reduce the learning rate when plateau is detected.
scheduler_patience : int, default=5
Number of epochs with no improvement after which learning rate will be reduced.
verbose : bool, default=False
Whether to print progress information during training.
model_name : str, default="InfographMolecularEncoder"
Name of the encoder model.
"""
# Task related parameters
lw_prior : float = 0.
embedding_dim: int = 160
# Model parameters
num_layer: int = 5
drop_ratio: float = 0.5
norm_layer: str = "batch_norm"
encoder_type: str = "gin-virtual"
readout: str = "sum"
# Training parameters
batch_size: int = 128
epochs: int = 500
learning_rate: float = 0.001
grad_clip_value: Optional[float] = None
weight_decay: float = 0.0
# Scheduler parameters
use_lr_scheduler: bool = False
scheduler_factor: float = 0.5
scheduler_patience: int = 5
# Other parameters
verbose: bool = False
model_name: str = "InfographMolecularEncoder"
# Non-init fields
fitting_loss: List[float] = field(default_factory=list, init=False)
fitting_epoch: int = field(default=0, init=False)
model_class: Type[GNN] = field(default=GNN, init=False)
def __post_init__(self):
"""Initialize the model after dataclass initialization."""
super().__post_init__()
if self.encoder_type not in ALLOWABLE_ENCODER_MODELS:
raise ValueError(f"Invalid encoder_model: {self.encoder_type}. Currently only {ALLOWABLE_ENCODER_MODELS} are supported.")
if self.readout not in ALLOWABLE_ENCODER_READOUTS:
raise ValueError(f"Invalid encoder_readout: {self.readout}. Currently only {ALLOWABLE_ENCODER_READOUTS} are supported.")
if self.lw_prior == 0:
self.use_prior = False
else:
self.use_prior = True
assert self.embedding_dim % self.num_layer == 0, "embedding_dim must be divisible by num_layer for InfographMolecularEncoder"
@staticmethod
def _get_param_names() -> List[str]:
"""Get parameter names for the estimator.
Returns
-------
List[str]
List of parameter names that can be used for model configuration.
"""
params = GNN_ENCODER_PARAMS.copy()
params.remove("hidden_size")
params = params + ["embedding_dim", 'use_prior', 'lw_prior']
return params
def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]:
params = [
"num_layer", "embedding_dim", "drop_ratio", "norm_layer", "encoder_type", "readout", "use_prior",
]
if checkpoint is not None:
if "hyperparameters" not in checkpoint:
raise ValueError("Checkpoint missing 'hyperparameters' key")
return {k: checkpoint["hyperparameters"][k] for k in params}
return {k: getattr(self, k) for k in params}
def _convert_to_pytorch_data(self, X):
"""Convert numpy arrays to PyTorch Geometric data format.
"""
if self.verbose:
iterator = tqdm(enumerate(X), desc="Converting molecules to graphs", total=len(X))
else:
iterator = enumerate(X)
pyg_graph_list = []
for idx, smiles_or_mol in iterator:
graph = graph_from_smiles(smiles_or_mol, None)
g = Data()
g.num_nodes = graph["num_nodes"]
g.edge_index = torch.from_numpy(graph["edge_index"])
del graph["num_nodes"]
del graph["edge_index"]
if graph["edge_feat"] is not None:
g.edge_attr = torch.from_numpy(graph["edge_feat"])
del graph["edge_feat"]
if graph["node_feat"] is not None:
g.x = torch.from_numpy(graph["node_feat"])
del graph["node_feat"]
pyg_graph_list.append(g)
return pyg_graph_list
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
"""Setup optimization components including optimizer and learning rate scheduler.
"""
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
scheduler = None
if self.use_lr_scheduler:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=self.scheduler_factor,
patience=self.scheduler_patience,
min_lr=1e-6,
cooldown=0,
eps=1e-8,
)
return optimizer, scheduler
[docs]
def fit(
self,
X_train: List[str],
) -> "InfographMolecularEncoder":
"""Fit the model to the training data with optional validation set.
Parameters
----------
X_train : List[str]
Training set input molecular structures as SMILES strings
Returns
-------
self : InfographMolecularEncoder
Fitted estimator
"""
self._initialize_model(self.model_class)
self.model.initialize_parameters()
optimizer, scheduler = self._setup_optimizers()
# Prepare datasets and loaders
X_train, _ = self._validate_inputs(X_train, return_rdkit_mol=True)
train_dataset = self._convert_to_pytorch_data(X_train)
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0
)
self.fitting_loss = []
for epoch in range(self.epochs):
# Training phase
train_losses = self._train_epoch(train_loader, optimizer, epoch)
self.fitting_loss.append(np.mean(train_losses).item())
if scheduler:
scheduler.step(np.mean(train_losses).item())
self.fitting_epoch = epoch
self.is_fitted_ = True
return self
def _train_epoch(self, train_loader, optimizer, epoch):
"""Training logic for one epoch.
Args:
train_loader: DataLoader containing training data
optimizer: Optimizer instance for model parameter updates
Returns:
list: List of loss values for each training step
"""
self.model.train()
losses = []
iterator = (
tqdm(train_loader, desc="Training", leave=False)
if self.verbose
else train_loader
)
for batch in iterator:
batch = batch.to(self.device)
optimizer.zero_grad()
local_global_loss, prior_loss = self.model.compute_loss(batch, self.lw_prior)
loss = local_global_loss + prior_loss
loss.backward()
if self.grad_clip_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_value)
optimizer.step()
losses.append(loss.item())
if self.verbose:
iterator.set_postfix({"Epoch": f"{epoch}", "Loss": f"{loss.item():.4f}", "Local/Global": f"{local_global_loss.item():.4f}", "Prior": f"{prior_loss.item():.4f}"})
return losses
[docs]
def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union[np.ndarray, torch.Tensor]:
"""Encode molecules into vector representations.
Parameters
----------
X : List[str]
List of SMILES strings
return_type : Literal["np", "pt"], default="pt"
Return type of the representations
Returns
-------
representations : ndarray or torch.Tensor
Molecular representations
"""
self._check_is_fitted()
# Convert to PyTorch Geometric format and create loader
X, _ = self._validate_inputs(X, return_rdkit_mol=True)
dataset = self._convert_to_pytorch_data(X)
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
if self.model is None:
raise RuntimeError("Model not initialized")
# Generate encodings
self.model = self.model.to(self.device)
self.model.eval()
encodings = []
with torch.no_grad():
for batch in tqdm(loader, disable=not self.verbose):
batch = batch.to(self.device)
out = self.model(batch)
encodings.append(out["graph"].cpu())
# Concatenate and convert to requested format
encodings = torch.cat(encodings, dim=0)
return encodings if return_type == "pt" else encodings.numpy()