Source code for torch_molecule.utils.format
import torch
import numpy as np
import types
import inspect
[docs]
def serialize_config(obj):
"""Helper function to make config JSON serializable.
Handles special cases like lambda functions, torch modules, and numpy arrays.
Parameters
----------
obj : Any
The object to serialize
Returns
-------
Any
JSON serializable representation of the object
"""
# Handle None
if obj is None:
return None
# Handle lambda and regular functions
if isinstance(obj, (types.LambdaType, types.FunctionType, types.MethodType)):
# For lambda functions, try to get the source code
try:
if isinstance(obj, types.LambdaType):
source = inspect.getsource(obj).strip()
return {"_type": "lambda", "source": source}
else:
return {"_type": "function", "name": obj.__name__}
except (IOError, TypeError):
# Fallback for built-in functions or when source isn't available
return {"_type": "function", "name": str(obj)}
# Handle PyTorch modules and optimizers
elif isinstance(obj, (torch.nn.Module, torch.optim.Optimizer)):
return {
"_type": "torch_class",
"class_name": obj.__class__.__name__,
"module": obj.__class__.__module__
}
# Handle PyTorch tensors
elif isinstance(obj, torch.Tensor):
# If it's a single number wrapped in a torch tensor, just return the number
if obj.numel() == 1:
return obj.item()
elif obj.numel() < 1000:
return {
"_type": "torch_tensor",
"data": obj.detach().cpu().tolist(),
"shape": list(obj.shape)
}
return {
"_type": "torch_tensor",
"shape": list(obj.shape)
}
# Handle numpy arrays
elif isinstance(obj, (np.ndarray, np.generic)):
# If it's a single number wrapped in a numpy array, just return the number
if obj.size == 1:
return obj.item()
elif obj.size < 1000:
return {
"_type": "numpy_array",
"data": obj.tolist(),
"shape": list(obj.shape)
}
return {
"_type": "numpy_array",
"shape": list(obj.shape)
}
# Handle sets and frozensets
elif isinstance(obj, (set, frozenset)):
return {
"_type": "set",
"data": list(obj)
}
# Handle custom objects with __dict__
elif hasattr(obj, '__dict__'):
return {
"_type": "custom_class",
"class_name": obj.__class__.__name__,
"module": obj.__class__.__module__
}
# Handle basic types that are JSON serializable
elif isinstance(obj, (str, int, float, bool)):
return obj
# Handle any other types by converting to string
return {
"_type": "unknown",
"repr": str(obj)
}
[docs]
def sanitize_config(config_dict):
"""Recursively sanitize config dictionary for JSON serialization.
Handles nested structures and special cases.
Parameters
----------
config_dict : dict
Configuration dictionary to sanitize
Returns
-------
dict
Sanitized configuration dictionary that is JSON serializable
"""
if not isinstance(config_dict, dict):
return serialize_config(config_dict)
clean_dict = {}
for key, value in config_dict.items():
# Skip private attributes and callable objects stored as attributes
if isinstance(key, str) and (key.startswith('_') or callable(value)):
continue
# Handle nested dictionaries
if isinstance(value, dict):
clean_dict[key] = sanitize_config(value)
# Handle lists and tuples
elif isinstance(value, (list, tuple)):
clean_dict[key] = [sanitize_config(v) for v in value]
# Handle all other types
else:
clean_dict[key] = serialize_config(value)
return clean_dict