Source code for torch_molecule.utils.checker
import warnings
import numpy as np
from rdkit import Chem
from typing import Optional, Union, List, Tuple
[docs]
class MolecularInputChecker:
"""
Class for validating input data used in molecular models.
"""
[docs]
@staticmethod
def validate_smiles(
smiles: str,
idx: int
) -> Tuple[bool, Optional[str], Optional[Chem.Mol]]:
"""Validate a single SMILES string at a given index.
Parameters
----------
smiles : str
The SMILES string to validate
idx : int
The index of the SMILES string in the original list
Returns
-------
Tuple[bool, Optional[str], Optional[Chem.Mol]]
A tuple containing:
- A boolean indicating whether the SMILES string is valid
- A string describing the error if the SMILES is invalid, or None if valid
- The RDKit Mol object if valid, or None if invalid
"""
if not smiles or not smiles.strip():
return False, f"Empty SMILES at index {idx}", None
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return False, f"Invalid SMILES structure at index {idx}: {smiles}", None
return True, None, mol
except Exception as e:
return False, f"RDKit error at index {idx}: {str(e)}", None
[docs]
@staticmethod
def validate_inputs(
X: List[str],
y: Optional[Union[List, np.ndarray]] = None,
num_task: int = 0,
num_pretask: int = 0,
return_rdkit_mol: bool = True
) -> Tuple[Union[List[str], List["Chem.Mol"]], Optional[np.ndarray]]:
"""Validate a list of SMILES strings, and optionally validate a target array.
Parameters
----------
X : List[str]
List of SMILES strings
y : Optional[Union[List, np.ndarray]], optional
Optional target values, by default None
num_task : int, optional
Total number of tasks; used to check dimensions of y, by default 0
num_pretask : int, optional
Number of (pseudo)-tasks that are predefined in the modeling;
used to check dimensions of y. Preliminarily used in supervised pretraining,
by default 0
return_rdkit_mol : bool, optional
If True, convert SMILES to RDKit Mol objects, by default True
Returns
-------
Tuple[Union[List[str], List["Chem.Mol"]], Optional[np.ndarray]]
A tuple containing:
- The original or converted SMILES (RDKit Mol objects if return_rdkit_mol=True)
- The target array as a numpy array, or None if y was not provided
Raises
------
ValueError
If SMILES or target dimensions are invalid
"""
if not isinstance(X, list):
raise ValueError("X must be a list of SMILES strings.")
if not all(isinstance(s, str) for s in X):
raise ValueError("All elements in X must be strings.")
invalid_smiles = []
rdkit_mols = []
for i, smiles in enumerate(X):
is_valid, error_msg, mol = MolecularInputChecker.validate_smiles(smiles, i)
if not is_valid:
invalid_smiles.append(error_msg)
else:
rdkit_mols.append(mol)
if invalid_smiles:
raise ValueError("Invalid SMILES found:\n" + "\n".join(invalid_smiles))
if y is not None:
try:
y = np.asarray(y, dtype=np.float32)
except Exception as e:
raise ValueError(f"Could not convert y to numpy array: {str(e)}")
if len(y.shape) == 1:
if num_task - num_pretask != 1:
raise ValueError(
f"1D target array provided but num_task is {num_task - num_pretask}. "
"For multiple tasks, y must be 2D."
)
y = y.reshape(-1, 1)
if len(y.shape) != 2:
raise ValueError(
"y must be 1D (single task) or 2D (multiple tasks). "
f"Got shape {y.shape}."
)
if y.shape[0] != len(X):
raise ValueError(
f"Number of samples in y ({y.shape[0]}) must match length of X ({len(X)})."
)
if y.shape[1] != num_task - num_pretask:
raise ValueError(
f"Second dimension of y ({y.shape[1]}) must match num_task ({num_task - num_pretask})."
)
inf_mask = np.isinf(y)
if np.any(inf_mask):
inf_indices = np.where(inf_mask)
warnings.warn(
f"Infinite values found in y at indices: {list(zip(*inf_indices))}. "
"Converting to NaN.",
RuntimeWarning,
)
y = y.astype(float)
y[inf_mask] = np.nan
# nan_mask = np.isnan(y)
# if np.any(nan_mask):
# nan_counts = np.sum(nan_mask, axis=0)
# nan_percentages = (nan_counts / len(X)) * 100
# task_warnings = []
# for task_idx, (count, percentage) in enumerate(zip(nan_counts, nan_percentages)):
# if count > 0:
# task_warnings.append(f"Task {task_idx}: {count} NaNs ({percentage:.1f}%)")
# warnings.warn(
# "NaN values present in y:\n"
# + "\n".join(task_warnings)
# + "\nSamples with NaN will be ignored or cause issues unless handled.",
# RuntimeWarning,
# )
return rdkit_mols if return_rdkit_mol else X, y