Source code for torch_molecule.base.base

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Type, Any, ClassVar, Union, Tuple, Callable, Literal
import torch
import os
import numpy as np
from ..utils.checkpoint import LocalCheckpointManager, HuggingFaceCheckpointManager
from ..utils.checker import MolecularInputChecker

[docs] @dataclass class BaseModel(ABC): """Base class for molecular models with shared functionality. This abstract class provides common methods and utilities for molecular models, including model initialization, saving/loading, and parameter management. """ device: Optional[torch.device] = field(default=None) model_name: str = field(default="BaseModel") model_class: Optional[Type[torch.nn.Module]] = field(default=None, init=False) # used for model initialization model: Optional[torch.nn.Module] = field(default=None, init=False) # initialized model is_fitted_: bool = field(default=False, init=False) def __post_init__(self): """Initialize common device settings after instance creation. Sets the device to CUDA if available, otherwise CPU, when no device is specified. Converts string device specifications to torch.device objects. """ if self.device is None: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") elif isinstance(self.device, str): self.device = torch.device(self.device) @abstractmethod def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]: """Set up optimizers for model training. Returns ------- Tuple[torch.optim.Optimizer, Optional[Any]] Tuple containing the primary optimizer and an optional secondary optimizer or scheduler. """ pass @abstractmethod def _train_epoch(self, train_loader, optimizer): """Train the model for one epoch. Parameters ---------- train_loader : torch.utils.data.DataLoader DataLoader containing training batches optimizer : torch.optim.Optimizer Optimizer to use for parameter updates Returns ------- dict Training metrics for the epoch """ pass @abstractmethod def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: """Get model parameters used for model initialization. Parameters ---------- checkpoint : Optional[Dict], default=None Optional dictionary containing model checkpoint data Returns ------- Dict[str, Any] Dictionary of parameters to initialize the model """ pass @staticmethod def _get_param_names(self) -> List[str]: """Get parameter names in the modeling class. Returns ------- List[str] List of parameter names that can be configured """ # return ["model_name", "model_class", "is_fitted_"] return ["model_name", "is_fitted_"]
[docs] def get_params(self, deep: bool = True) -> Dict[str, Any]: """Get parameters for this estimator. Parameters ---------- deep : bool, default=True If True, will return the parameters for this estimator and contained subobjects that are estimators. Returns ------- Dict[str, Any] Dictionary of parameter names mapped to their values """ out = {} for key in self._get_param_names(): value = getattr(self, key) if deep and hasattr(value, "get_params"): deep_items = value.get_params().items() out.update((key + "__" + k, val) for k, val in deep_items) out[key] = value return out
[docs] def set_params(self, **params) -> "BaseModel": """Set parameters for this estimator. Parameters ---------- **params Parameter names mapped to their values Returns ------- BaseModel Self instance for method chaining Raises ------ ValueError If an invalid parameter is provided """ valid_params = self.get_params(deep=True) for key, value in params.items(): if key not in valid_params: raise ValueError(f"Invalid parameter {key} for model {self}") setattr(self, key, value) return self
def _initialize_model( self, model_class: Type[torch.nn.Module], checkpoint: Optional[Dict] = None ) -> torch.nn.Module: """Initialize the model with parameters or a checkpoint. Parameters ---------- model_class : Type[torch.nn.Module] PyTorch module class to instantiate checkpoint : Optional[Dict], default=None Optional dictionary containing model checkpoint data Returns ------- torch.nn.Module Initialized PyTorch model """ model_params = self._get_model_params(checkpoint) self.model = model_class(**model_params) self.model = self.model.to(self.device) if checkpoint is not None: self.model.load_state_dict(checkpoint["model_state_dict"]) return self.model def _validate_inputs( self, X: List[str], y: Optional[Union[List, np.ndarray]] = None, num_task: int = 0, num_pretask: int = 0, return_rdkit_mol: bool = True ) -> Tuple[Union[List[str], List["Chem.Mol"]], Optional[np.ndarray]]: """Validate molecular inputs and targets. Parameters ---------- X : List[str] List of SMILES strings representing molecules y : Optional[Union[List, np.ndarray]], default=None Optional target values for supervised learning num_task : int, default=0 Number of prediction tasks num_pretask : int, default=0 Number of pre-training tasks return_rdkit_mol : bool, default=True Whether to return RDKit Mol objects instead of SMILES Returns ------- Tuple[Union[List[str], List["Chem.Mol"]], Optional[np.ndarray]] Tuple of validated inputs and targets """ return MolecularInputChecker.validate_inputs(X, y, num_task, num_pretask, return_rdkit_mol)
[docs] def save_to_local(self, path: str) -> None: """Save model to local disk. Parameters ---------- path : str File path to save the model Raises ------ ValueError If the model is not fitted """ if not self.is_fitted_: raise ValueError("Model must be fitted before saving to local disk.") LocalCheckpointManager.save_model_to_local(self, path)
[docs] def load_from_local(self, path: str) -> None: """Load model from local disk. Parameters ---------- path : str File path to load the model from """ LocalCheckpointManager.load_model_from_local(self, path)
[docs] def save_to_hf( self, repo_id: str, task_id: str = "default", metadata_dict: Optional[Dict[str, Any]] = None, metrics: Optional[Dict[str, float]] = None, commit_message: str = "Update model", hf_token: Optional[str] = None, private: bool = False, config_filename: Optional[str] = 'config.json', ) -> None: """Save model to Hugging Face Hub. Parameters ---------- repo_id : str Hugging Face repository ID task_id : str, default="default" Task identifier for the model metadata_dict : Optional[Dict[str, Any]], default=None Optional metadata to store with the model metrics : Optional[Dict[str, float]], default=None Optional performance metrics to store with the model commit_message : str, default="Update model" Git commit message hf_token : Optional[str], default=None Hugging Face authentication token private : bool, default=False Whether the repository should be private config_filename : Optional[str], default='config.json' Name of the configuration file to save to the repository Raises ------ ValueError If the model is not fitted """ if not self.is_fitted_: raise ValueError("Model must be fitted before saving to Hugging Face Hub.") HuggingFaceCheckpointManager.push_to_huggingface( model_instance=self, repo_id=repo_id, task_id=task_id, metadata_dict=metadata_dict, metrics=metrics, commit_message=commit_message, token=hf_token, private=private, config_filename=config_filename, )
[docs] def load_from_hf(self, repo_id: str, local_cache: Optional[str] = None, config_filename: Optional[str] = 'config.json') -> None: """Load model from Hugging Face Hub. Parameters ---------- repo_id : str Hugging Face repository ID local_cache : str, default=None Local path to save the model config_filename : str, default='config.json' Name of the configuration file to load from the repository """ if local_cache is None: local_cache = 'model.pt' HuggingFaceCheckpointManager.load_model_from_hf(self, repo_id, local_cache, config_filename=config_filename)
[docs] def save(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwargs) -> None: """Automatic save to either local disk or Hugging Face Hub. Parameters ---------- path : Optional[str], default=None File path for local saving (required if repo_id is None) repo_id : Optional[str], default=None Hugging Face repository ID for remote saving **kwargs Additional arguments passed to save_to_hf Raises ------ ValueError If path is None when repo_id is None """ # if both path and repo_id are None, raise an error if path is None and repo_id is None: raise ValueError("path must be provided if repo_id is not given.") if repo_id is not None: self.save_to_hf(repo_id=repo_id, **kwargs) if path is not None: self.save_to_local(path)
[docs] def load(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwargs) -> None: """Automatic load from either local disk or Hugging Face Hub. Parameters ---------- path : Optional[str], default=None File path for local loading. repo_id : Optional[str], default=None Hugging Face repository ID for remote loading. If path is provided, repo_id is ignored. **kwargs Additional arguments passed to load_from_hf Raises ------ FileNotFoundError If no local file is found and no repo_id is provided """ if path is not None: if os.path.exists(path): self.load_from_local(path) else: raise FileNotFoundError(f"No local file found at '{path}'.") else: if repo_id is None: raise ValueError("repo_id must be provided if path is not given.") self.load_from_hf(repo_id, **kwargs)
def _check_is_fitted(self) -> None: """Check if the model is fitted. Raises ------ AttributeError If the model is not fitted """ if not self.is_fitted_: raise AttributeError("This model is not fitted yet. Call 'fit' before using it.") def __str__(self, N_CHAR_MAX: int = 700) -> str: """Return a string representation of the model. Parameters ---------- N_CHAR_MAX : int, default=700 Maximum number of characters in the string representation Returns ------- str String representation of the model """ attributes = { name: value for name, value in sorted(self.__dict__.items()) if not name.startswith("_") and not callable(value) } attributes = {k: v for k, v in attributes.items() if k != "fitting_loss"} def format_value(v): """Helper to format values for representation.""" if isinstance(v, (float, np.float32, np.float64)): return f"{v:.3g}" elif isinstance(v, (list, tuple, np.ndarray)) and len(v) > 6: return f"{v[:3]}...{v[-3:]}" elif isinstance(v, str) and len(v) > 50: return f"'{v[:25]}...{v[-22:]}'" elif isinstance(v, dict) and len(v) > 6: return f"{{{', '.join(f'{k}: {v}' for k, v in list(v.items())[:3])}...}}" elif isinstance(v, torch.nn.Module): return f"{v.__class__.__name__}(...)" return repr(v) class_name = self.__class__.__name__ important_attrs = ["model_name", "is_fitted_", "task_type", "num_task"] attributes_str = [f"{attr}={format_value(attributes.pop(attr))}" for attr in important_attrs if attr in attributes] attributes_str += [f"{k}={format_value(v)}" for k, v in sorted(attributes.items())] content = ",\n ".join(attributes_str) repr_str = f"{class_name}(\n {content}\n)" if len(repr_str) > N_CHAR_MAX: repr_str = "\n".join([repr_str[:N_CHAR_MAX//2], "...", repr_str[-N_CHAR_MAX//2:]]) return repr_str