Source code for torch_molecule.utils.checkpoint

import os
import json
import datetime
import tempfile
import torch
import warnings
import numpy as np

# Import typing dependencies if needed
from typing import Optional, Dict, Any

HF_METADATA = {
    "tags": [
        "torch_molecule",
        "molecular-property-prediction",
    ],
    "library_name": "torch_molecule", 
}
from ..utils.format import sanitize_config

[docs] class LocalCheckpointManager: """Handles saving and loading of models to and from local paths."""
[docs] @staticmethod def save_model_to_local(model_instance, path: str) -> None: """Save model weights and configuration to a local file.""" if not model_instance.is_fitted_: raise ValueError("Model must be fitted before saving.") if not path.endswith((".pt", ".pth")): raise ValueError("Save path should end with '.pt' or '.pth'.") os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) model_name = os.path.splitext(os.path.basename(path))[0] save_dict = { "model_state_dict": model_instance.model.state_dict(), "hyperparameters": model_instance.get_params(), "model_name": model_name, "date_saved": datetime.datetime.now().isoformat(), "version": getattr(model_instance, "__version__", "1.0.0"), } torch.save(save_dict, path)
[docs] @staticmethod def load_model_from_local(model_instance, path: str) -> None: """Load model weights and configuration from a local file.""" if not os.path.exists(path): raise FileNotFoundError(f"No model file found at '{path}'.") try: checkpoint = torch.load(path, map_location=model_instance.device, weights_only=False) except Exception as e: raise ValueError(f"Error loading model from {path}: {str(e)}") verbose = model_instance.get_params().get("verbose", False) required_keys = {"model_state_dict", "hyperparameters", "model_name"} if not all(key in checkpoint for key in required_keys): missing_keys = required_keys - set(checkpoint.keys()) raise ValueError(f"Checkpoint missing required keys: {missing_keys}") parameter_status = [] for key, new_value in checkpoint["hyperparameters"].items(): if hasattr(model_instance, key): old_value = getattr(model_instance, key) is_changed = (old_value != new_value) parameter_status.append({ "Parameter": key, "Old Value": old_value, "New Value": new_value, "Status": "Changed" if is_changed else "Unchanged", }) if is_changed: setattr(model_instance, key, new_value) if parameter_status and verbose: print("\nHyperparameter Status:") print("-" * 80) print(f"{'Parameter':<20} {'Old Value':<20} {'New Value':<20} {'Status':<10}") print("-" * 80) # Sort so that changed parameters appear first parameter_status.sort(key=lambda x: (x["Status"] != "Changed", x["Parameter"])) for param in parameter_status: color = "\033[91m" if param["Status"] == "Changed" else "\033[92m" print( f"{param['Parameter']:<20} " f"{str(param['Old Value']):<20} " f"{str(param['New Value']):<20} " f"{color}{param['Status']}\033[0m" ) print("-" * 80) changes_count = sum(1 for p in parameter_status if p["Status"] == "Changed") print( f"\nSummary: {changes_count} parameters changed, " f"{len(parameter_status) - changes_count} unchanged" ) # Reinitialize model_instance._initialize_model(model_instance.model_class, checkpoint) model_instance.model_name = checkpoint["model_name"] model_instance.is_fitted_ = True model_instance.model = model_instance.model.to(model_instance.device) print(f"Model successfully loaded from local path: {path}")
[docs] class HuggingFaceCheckpointManager: """Handles saving and loading of models to and from the Hugging Face Hub."""
[docs] @staticmethod def load_model_from_hf(model_instance, repo_id: str, path: str, config_filename: str = "config.json") -> None: """Load model from Hugging Face Hub, saving locally to `path` first.""" try: from huggingface_hub import hf_hub_download except ImportError: raise ImportError( "huggingface_hub package is required to load from Hugging Face Hub. " "Install it with: pip install huggingface_hub" ) try: os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) model_name = os.path.splitext(os.path.basename(path))[0] downloaded_path = hf_hub_download( repo_id=repo_id, filename=f"{model_name}.pt", local_dir=os.path.dirname(path), ) hf_hub_download( repo_id=repo_id, filename=config_filename, local_dir=os.path.dirname(path), ) checkpoint = torch.load(downloaded_path, map_location=model_instance.device, weights_only=False) required_keys = {"model_state_dict", "hyperparameters", "model_name"} if not all(key in checkpoint for key in required_keys): missing_keys = required_keys - set(checkpoint.keys()) raise ValueError(f"Checkpoint missing required keys: {missing_keys}") parameter_status = [] for key, new_value in checkpoint["hyperparameters"].items(): if hasattr(model_instance, key): old_value = getattr(model_instance, key) is_changed = (old_value != new_value) parameter_status.append({ "Parameter": key, "Old Value": old_value, "New Value": new_value, "Status": "Changed" if is_changed else "Unchanged", }) if is_changed: setattr(model_instance, key, new_value) if parameter_status: print("\nHyperparameter Status:") print("-" * 80) print(f"{'Parameter':<20} {'Old Value':<20} {'New Value':<20} {'Status':<10}") print("-" * 80) parameter_status.sort(key=lambda x: (x["Status"] != "Changed", x["Parameter"])) for param in parameter_status: color = "\033[91m" if param["Status"] == "Changed" else "\033[92m" print( f"{param['Parameter']:<20} " f"{str(param['Old Value']):<20} " f"{str(param['New Value']):<20} " f"{color}{param['Status']}\033[0m" ) print("-" * 80) changes_count = sum(1 for p in parameter_status if p["Status"] == "Changed") print( f"\nSummary: {changes_count} parameters changed, " f"{len(parameter_status) - changes_count} unchanged" ) model_instance._initialize_model(model_instance.model_class, checkpoint) model_instance.model_name = checkpoint["model_name"] model_instance.is_fitted_ = True model_instance.model = model_instance.model.to(model_instance.device) print(f"Model successfully loaded from repository {repo_id}") except Exception as e: if os.path.exists(path): os.remove(path) # Clean up partial downloads raise RuntimeError(f"Failed to download or load model from repository '{repo_id}': {str(e)}")
[docs] @staticmethod def push_to_huggingface( model_instance, repo_id: str, task_id: str = "default", metadata_dict: Optional[Dict[str, Any]] = None, metrics: Optional[Dict[str, float]] = None, commit_message: str = "Update model", token: Optional[str] = None, private: bool = False, config_filename: str = "config.json", ) -> None: """Push a task-specific model checkpoint to Hugging Face Hub.""" try: from huggingface_hub import HfApi, create_repo, metadata_update from .hf import merge_task_configs, get_existing_repo_data, create_model_card except ImportError: raise ImportError( "huggingface_hub package is required to push to Hugging Face Hub. " "Install it with: pip install huggingface_hub" ) if not isinstance(repo_id, str) or "/" not in repo_id: raise ValueError("repo_id must be in format '<username>/<model_name>'.") if not task_id: raise ValueError("task_id must be provided.") if not model_instance.is_fitted_: raise ValueError("Model must be fitted before pushing to Hugging Face Hub.") try: api = HfApi(token=token) repo_exists, existing_config, existing_readme = get_existing_repo_data(repo_id, token) create_repo(repo_id, private=private, token=token, exist_ok=True) with tempfile.TemporaryDirectory() as tmp_dir: local_path = os.path.join(tmp_dir, f"{model_instance.model_name}.pt") # Use the local manager to save a copy LocalCheckpointManager.save_model_to_local(model_instance, local_path) # Prepare task-specific config task_config = { **({"task_type": model_instance.task_type} if hasattr(model_instance, "task_type") else {}), "config": sanitize_config(model_instance.get_params(deep=True)), "metrics": metrics or {}, } # Merge with any existing config base_config = ( { "model_name": model_instance.model_name, "framework": "torch_molecule", "date_created": existing_config.get( "date_created", datetime.datetime.now().isoformat() ), } if not repo_exists else existing_config ) num_params = sum(p.numel() for p in model_instance.model.parameters()) final_config = merge_task_configs( task_id=task_id, existing_config=base_config, new_task_config=task_config, num_params=num_params, ) # Save config file config_path = os.path.join(tmp_dir, config_filename) with open(config_path, "w") as f: json.dump(final_config, f, indent=2) # Create model card readme_content = create_model_card( model_class=model_instance.__class__.__name__, model_name=model_instance.model_name, tasks_config=final_config.get("tasks", {}), model_config=final_config, repo_id=repo_id, existing_readme=existing_readme, ) readme_path = os.path.join(tmp_dir, "README.md") with open(readme_path, "w") as f: f.write(readme_content) # Upload everything api.upload_folder( repo_id=repo_id, folder_path=tmp_dir, commit_message=f"{commit_message} - Task: {task_id}", ) # Update or add repository metadata metadata_dict = HF_METADATA if metadata_dict is None else metadata_dict metadata_update(repo_id=repo_id, metadata=metadata_dict, token=token, overwrite=True) # Print summary task_info = final_config["tasks"][task_id] print(f"Successfully pushed model for task {task_id} to {repo_id}") print(f"Task version: {task_info['current_version']}") if metrics: print("Metrics:") for metric, value in metrics.items(): print(f" {metric}: {value:.4f}") except Exception as e: raise RuntimeError(f"Failed to push to Hugging Face Hub: {str(e)}")