Base Modules¶
Base¶
- class torch_molecule.base.base.BaseModel(device: device | None = None, model_name: str = 'BaseModel')[source]¶
Bases:
ABC
Base class for molecular models with shared functionality.
This abstract class provides common methods and utilities for molecular models, including model initialization, saving/loading, and parameter management.
- device: device | None = None¶
- get_params(deep: bool = True) Dict[str, Any] [source]¶
Get parameters for this estimator.
- Parameters:
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
Dictionary of parameter names mapped to their values
- Return type:
Dict[str, Any]
- is_fitted_: bool = False¶
- load(path: str | None = None, repo_id: str | None = None, **kwargs) None [source]¶
Automatic load from either local disk or Hugging Face Hub.
- Parameters:
path (Optional[str], default=None) – File path for local loading.
repo_id (Optional[str], default=None) – Hugging Face repository ID for remote loading. If path is provided, repo_id is ignored.
**kwargs – Additional arguments passed to load_from_hf
- Raises:
FileNotFoundError – If no local file is found and no repo_id is provided
- load_from_hf(repo_id: str, local_cache: str | None = None, config_filename: str | None = 'config.json') None [source]¶
Load model from Hugging Face Hub.
- Parameters:
repo_id (str) – Hugging Face repository ID
local_cache (str, default=None) – Local path to save the model
config_filename (str, default='config.json') – Name of the configuration file to load from the repository
- load_from_local(path: str) None [source]¶
Load model from local disk.
- Parameters:
path (str) – File path to load the model from
- model: Module | None = None¶
- model_class: Type[Module] | None = None¶
- model_name: str = 'BaseModel'¶
- save(path: str | None = None, repo_id: str | None = None, **kwargs) None [source]¶
Automatic save to either local disk or Hugging Face Hub.
- Parameters:
path (Optional[str], default=None) – File path for local saving (required if repo_id is None)
repo_id (Optional[str], default=None) – Hugging Face repository ID for remote saving
**kwargs – Additional arguments passed to save_to_hf
- Raises:
ValueError – If path is None when repo_id is None
- save_to_hf(repo_id: str, task_id: str = 'default', metadata_dict: Dict[str, Any] | None = None, metrics: Dict[str, float] | None = None, commit_message: str = 'Update model', hf_token: str | None = None, private: bool = False, config_filename: str | None = 'config.json') None [source]¶
Save model to Hugging Face Hub.
- Parameters:
repo_id (str) – Hugging Face repository ID
task_id (str, default="default") – Task identifier for the model
metadata_dict (Optional[Dict[str, Any]], default=None) – Optional metadata to store with the model
metrics (Optional[Dict[str, float]], default=None) – Optional performance metrics to store with the model
commit_message (str, default="Update model") – Git commit message
hf_token (Optional[str], default=None) – Hugging Face authentication token
private (bool, default=False) – Whether the repository should be private
config_filename (Optional[str], default='config.json') – Name of the configuration file to save to the repository
- Raises:
ValueError – If the model is not fitted
Base Predictor¶
- class torch_molecule.base.predictor.BaseMolecularPredictor(device: device | None = None, model_name: str = 'BaseMolecularPredictor', num_task: int = 0, task_type: str = None)[source]¶
Bases:
BaseModel
,ABC
Base class for molecular discovery estimators.
- DEFAULT_METRICS: ClassVar[Dict] = {'classification': {'default': ('roc_auc', <function roc_auc_score>, True)}, 'regression': {'default': ('mae', <function mean_absolute_error>, False)}}¶
- abstractmethod autofit(X_train, y_train, X_val=None, y_val=None, search_parameters: dict | None = None, n_trials: int = 10) BaseMolecularPredictor [source]¶
- abstractmethod fit(X_train, y_train, X_val=None, y_val=None, search_parameters: dict | None = None, n_trials: int = 10) BaseMolecularPredictor [source]¶
- model_name: str = 'BaseMolecularPredictor'¶
- num_task: int = 0¶
- task_type: str = None¶
Base Generator¶
- class torch_molecule.base.generator.BaseMolecularGenerator(device: device | None = None, model_name: str = 'BaseMolecularGenerator')[source]¶
Bases:
BaseModel
,ABC
Base class for molecular generation.
- abstractmethod fit(X: List[str], y: ndarray | None = None) BaseMolecularGenerator [source]¶
- abstractmethod generate(n_samples: int, **kwargs) List[str] [source]¶
Generate molecular structures.
- model_name: str = 'BaseMolecularGenerator'¶
Base Encoder¶
- class torch_molecule.base.encoder.BaseMolecularEncoder(device: device | None = None, model_name: str = 'BaseMolecularEncoder')[source]¶
Bases:
BaseModel
,ABC
Base class for molecular representation learning.
- abstractmethod encode(X: List[str], return_type: Literal['np', 'pt'] = 'pt') ndarray | Tensor [source]¶
- abstractmethod fit(X: List[str], y: ndarray | None = None) BaseMolecularEncoder [source]¶
- model_name: str = 'BaseMolecularEncoder'¶