Molecular Property Prediction Models¶
The predictor models inherit from the torch_molecule.base.predictor.BaseMolecularPredictor
class and share common methods for model training, evaluation, prediction and persistence.
Training and Prediction
fit(X, y, **kwargs)
: Train the model on given data, where X contains SMILES strings and y contains target valuesautofit(X, y, search_parameters, n_trials=50, **kwargs)
: Automatically search for optimal hyperparameters using Optuna and train the modelpredict(X, **kwargs)
: Make predictions on new SMILES strings and return a dictionary containing predictions and optional uncertainty estimates
Model Persistence
inherited from torch_molecule.base.base.BaseModel
save_to_local(path)
: Save the trained model to a local fileload_from_local(path)
: Load a trained model from a local filesave_to_hf(repo_id)
: Push the model to Hugging Face Hubload_from_hf(repo_id, local_cache)
: Load a model from Hugging Face Hub and save it to a local filesave(path, repo_id)
: Save the model to either local storage or Hugging Faceload(path, repo_id)
: Load a model from either local storage or Hugging Face
Modeling Molecules as Graphs with Graph Neural Networks¶
Graph Neural Networks
- class torch_molecule.predictor.gnn.modeling_gnn.GNNMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'GNNMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False)[source]¶
Bases:
BaseMolecularPredictor
This predictor implements a GNN model for molecular property prediction tasks.
- Parameters:
num_task (int, default=1) – Number of prediction tasks.
task_type (str, default="regression") – Type of prediction task, either “regression” or “classification”.
num_layer (int, default=5) – Number of GNN layers.
hidden_size (int, default=300) – Dimension of hidden node features.
gnn_type (str, default="gin-virtual") – Type of GNN architecture to use. One of [“gin-virtual”, “gcn-virtual”, “gin”, “gcn”].
drop_ratio (float, default=0.5) – Dropout probability.
norm_layer (str, default="batch_norm") – Type of normalization layer to use. One of [“batch_norm”, “layer_norm”, “instance_norm”, “graph_norm”, “size_norm”, “pair_norm”].
graph_pooling (str, default="sum") – Method for aggregating node features to graph-level representations. One of [“sum”, “mean”, “max”].
augmented_feature (list, default=["morgan", "maccs"]) – Additional molecular fingerprints to use as features. It will be concatenated with the graph representation after pooling.
batch_size (int, default=128) – Number of samples per batch for training.
epochs (int, default=500) – Maximum number of training epochs.
loss_criterion (callable, optional) – Loss function for training.
evaluate_criterion (str or callable, optional) – Metric for model evaluation.
evaluate_higher_better (bool, optional) – Whether higher values of the evaluation metric are better.
learning_rate (float, default=0.001) – Learning rate for optimizer.
grad_clip_value (float, optional) – Maximum norm of gradients for gradient clipping.
weight_decay (float, default=0.0) – L2 regularization strength.
patience (int, default=50) – Number of epochs to wait for improvement before early stopping.
use_lr_scheduler (bool, default=False) – Whether to use learning rate scheduler.
scheduler_factor (float, default=0.5) – Factor by which to reduce learning rate when plateau is reached.
scheduler_patience (int, default=5) – Number of epochs with no improvement after which learning rate will be reduced.
verbose (bool, default=False) – Whether to print progress information during training.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor [source]¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) GNNMolecularPredictor [source]¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
X_unlbl (List[str], optional) – Unlabeled set input molecular structures as SMILES strings.
- Returns:
self – Fitted estimator
- Return type:
Graph Rationalization with Environment-based Data Augmentation
- class torch_molecule.predictor.grea.modeling_grea.GREAMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'GREAMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, gamma: float = 0.4)[source]¶
Bases:
GNNMolecularPredictor
This predictor implements a Graph Rationalization model called GREA.
The full name of GREA is Graph Rationalization with Environment-based Augmentations. During model training, it learns the rationales (explainable subgraphs) and use them for molecular property prediction tasks.
References
Graph Rationalization with Environment-based Augmentations. https://dl.acm.org/doi/10.1145/3534678.3539347
- Parameters:
gamma (float) – GREA-specific parameter. Default is 0.4.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor ¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) GNNMolecularPredictor ¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
X_unlbl (List[str], optional) – Unlabeled set input molecular structures as SMILES strings.
- Returns:
self – Fitted estimator
- Return type:
- predict(X: List[str]) Dict[str, ndarray | List[List]] [source]¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
’variance’: Prediction variances (shape: [n_samples, n_tasks])
’node_importance’: A nested list where the outer list has length n_samples and each inner list has length n_nodes for that molecule
- Return type:
Dict[str, np.ndarray]
Semi-Supervised Graph Imbalanced Regression Models
- class torch_molecule.predictor.sgir.modeling_sgir.SGIRMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'SGIRMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, gamma: float = 0.4, num_anchor: int = 10, warmup_epoch: int = 20, labeling_interval: int = 5, augmentation_interval: int = 5, top_quantile: float = 0.1, label_logscale: bool = False, lw_aug: float = 1)[source]¶
Bases:
GREAMolecularPredictor
This predictor implements SGIR for semi-supervised graph imbalanced regression.
It trains the GREA model based on pseudo-labeling and data augmentation.
References
Semi-Supervised Graph Imbalanced Regression. https://dl.acm.org/doi/10.1145/3580305.3599497
- Parameters:
num_anchor (int, default=10) – Number of anchor points used to split the label space during pseudo-labeling
warmup_epoch (int, default=20) – Number of epochs to train before starting pseudo-labeling and data augmentation
labeling_interval (int, default=5) – Interval (in epochs) between pseudo-labeling steps
augmentation_interval (int, default=5) – Interval (in epochs) between data augmentation steps
top_quantile (float, default=0.1) – Quantile threshold for selecting high confidence predictions during pseudo-labeling
label_logscale (bool, default=False) – Whether to use log scale for the label space during pseudo-labeling and data augmentation
lw_aug (float, default=1) – Weight for the data augmentation loss
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor ¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) SGIRMolecularPredictor [source]¶
Fit the model to training data with optional validation set.
- predict(X: List[str]) Dict[str, ndarray | List[List]] ¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
’variance’: Prediction variances (shape: [n_samples, n_tasks])
’node_importance’: A nested list where the outer list has length n_samples and each inner list has length n_nodes for that molecule
- Return type:
Dict[str, np.ndarray]
Discovering Invariant Rationales
- class torch_molecule.predictor.dir.modeling_dir.DIRMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'DIRMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, causal_ratio: float = 0.8, lw_invariant: float = 0.0001)[source]¶
Bases:
GNNMolecularPredictor
This predictor implements the DIR for molecular property prediction tasks.
The full name of DIR is Discovering Invariant Rationales.
References
Discovering Invariant Rationales for Graph Neural Networks. https://openreview.net/forum?id=hGXij5rfiHw
- Parameters:
causal_ratio (float, default=0.8) – The ratio of causal edges to keep during training. A higher ratio means more edges are considered causal/important for the prediction. This controls the sparsity of the learned rationales.
lw_invariant (float, default=1e-4) – The weight of the invariance loss term. This loss encourages the model to learn rationales that are invariant across different environments/perturbations. A higher value puts more emphasis on learning invariant features.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor ¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) GNNMolecularPredictor ¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
X_unlbl (List[str], optional) – Unlabeled set input molecular structures as SMILES strings.
- Returns:
self – Fitted estimator
- Return type:
- predict(X: List[str]) Dict[str, ndarray | List[List]] [source]¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
- Return type:
Dict[str, np.ndarray]
Invariant Risk Minimization with GNNs
- class torch_molecule.predictor.irm.modeling_irm.IRMMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'IRMMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, IRM_environment: ~torch.Tensor | ~numpy.ndarray | ~typing.List | str = 'random', scale: float = 1.0, penalty_weight: float = 1.0, penalty_anneal_iters: int = 100)[source]¶
Bases:
GNNMolecularPredictor
This predictor implements a Invariant Risk Minimization model with the GNN.
The full name of IRM is Invariant Risk Minimization.
References
Invariant Risk Minimization. https://arxiv.org/abs/1907.02893
Reference Code: https://github.com/facebookresearch/InvariantRiskMinimization
- Parameters:
IRM_environment (Union[torch.Tensor, np.ndarray, List, str], default="random") – Environment assignments for IRM. Can be a list of integers (one per sample), or “random” to assign environments randomly.
scale (float, default=1.0) – Scaling factor for the IRM penalty term.
penalty_weight (float, default=1.0) – Weight of the IRM penalty in the loss function.
penalty_anneal_iters (int, default=100) – Number of iterations for annealing the penalty weight.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor ¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) GNNMolecularPredictor ¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
X_unlbl (List[str], optional) – Unlabeled set input molecular structures as SMILES strings.
- Returns:
self – Fitted estimator
- Return type:
- predict(X: List[str]) Dict[str, ndarray] ¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
- Return type:
Dict[str, np.ndarray]
Relational Pooling for Graph Representations with GNNs
- class torch_molecule.predictor.rpgnn.modeling_rpgnn.RPGNNMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'RPGNNMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, num_perm: int = 3, fixed_size: int = 10, num_node_feature: int = 9)[source]¶
Bases:
GNNMolecularPredictor
This predictor implements a GNN model based on Relational pooling.
The full name of RPGNN is Relational Pooling for Graph Representations.
References
Relational Pooling for Graph Representations. https://arxiv.org/abs/1903.02541
Reference Code: https://github.com/PurdueMINDS/RelationalPooling/tree/master?tab=readme-ov-file
- Parameters:
num_perm (int, default=3) – TODO: Add description
fixed_size (int, default=10) – TODO: Add description
num_node_feature (int, default=9) – TODO: Add description
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor ¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) GNNMolecularPredictor ¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
X_unlbl (List[str], optional) – Unlabeled set input molecular structures as SMILES strings.
- Returns:
self – Fitted estimator
- Return type:
- predict(X: List[str]) Dict[str, ndarray] ¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
- Return type:
Dict[str, np.ndarray]
SizeShiftReg: a Regularization Method for Improving Size-Generalization in GNNs
- class torch_molecule.predictor.ssr.modeling_ssr.SSRMolecularPredictor(device: ~torch.device | None = None, model_name: str = 'SSRMolecularPredictor', num_task: int = 1, task_type: str = 'regression', num_layer: int = 5, hidden_size: int = 300, gnn_type: str = 'gin-virtual', drop_ratio: float = 0.5, norm_layer: str = 'batch_norm', graph_pooling: str = 'sum', augmented_feature: list[~typing.Literal['morgan', 'maccs']] | None = <factory>, batch_size: int = 128, epochs: int = 500, loss_criterion: ~typing.Callable | None = None, evaluate_criterion: str | ~typing.Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, grad_clip_value: float | None = None, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, coarse_ratios: ~typing.List[float] = <factory>, cmd_coeff: float = 0.1, fine_grained: bool = True, n_moments: int = 5, coarse_pool: str = 'mean')[source]¶
Bases:
GNNMolecularPredictor
This predictor implements a SizeShiftReg model with the GNN.
References
Paper: SizeShiftReg: a Regularization Method for Improving Size-Generalization in Graph Neural Networks. https://arxiv.org/abs/2207.07888
Reference Code: https://github.com/DavideBuffelli/SizeShiftReg/tree/main
- Parameters:
coarse_ratios (List[float], default=[0.8, 0.9]) – List of ratios for graph coarsening. Each ratio determines the percentage of nodes to keep in the coarsened graph.
cmd_coeff (float, default=0.1) – Weight for CMD (Central Moment Discrepancy) loss. Controls the strength of the size-shift regularization.
fine_grained (bool, default=True) – Whether to use fine-grained CMD. When True, matches distributions at a more detailed level.
n_moments (int, default=5) – Number of moments to match in the CMD calculation. Higher values capture more complex distribution characteristics.
coarse_pool (str, default='mean') – Pooling method for coarsened graphs. Determines how node features are aggregated during coarsening.
model_name (str, default="SSRMolecularPredictor") – Name of the model.
model_class (Type[SSR], default=SSR) – The model class to use for prediction.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) GNNMolecularPredictor ¶
Automatically find the best hyperparameters using Optuna optimization.
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, X_unlbl: List[str] | None = None) GNNMolecularPredictor ¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
X_unlbl (List[str], optional) – Unlabeled set input molecular structures as SMILES strings.
- Returns:
self – Fitted estimator
- Return type:
- predict(X: List[str]) Dict[str, ndarray] ¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
- Return type:
Dict[str, np.ndarray]
Modeling Molecules as Sequences with RNNs¶
LSTM models based on SMILES
- class torch_molecule.predictor.lstm.modeling_lstm.LSTMMolecularPredictor(device: device | None = None, model_name: str = 'LSTMMolecularPredictor', num_task: int = 1, task_type: str = 'regression', input_dim: int = 54, output_dim: int = 15, LSTMunits: int = 60, max_input_len: int = 200, batch_size: int = 128, epochs: int = 500, loss_criterion: Callable | None = None, evaluate_criterion: str | Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.001, weight_decay: float = 0.0, patience: int = 50, use_lr_scheduler: bool = False, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False)[source]¶
Bases:
BaseMolecularPredictor
This predictor implements a LSTM model for molecular property prediction tasks.
References
Predicting Polymers’ Glass Transition Temperature by a Chemical Language Processing Model. https://www.semanticscholar.org/reader/f43ed533b2520567be2d8c24f6396f4e63e96430
- Parameters:
num_task (int, default=1) – Number of prediction tasks.
task_type (str, default="regression") – Type of prediction task, either “regression” or “classification”.
input_dim (int, default=54) – Size of vocabulary for SMILES tokenization.
output_dim (int, default=15) – Dimension of embedding vectors.
LSTMunits (int, default=60) – Number of hidden units in LSTM layers.
max_input_len (int, default=200) – Maximum length of input sequences. Longer sequences will be truncated.
batch_size (int, default=128) – Number of samples per batch for training.
epochs (int, default=500) – Maximum number of training epochs.
loss_criterion (callable, optional) – Loss function for training. Defaults to MSELoss for regression.
evaluate_criterion (str or callable, optional) – Metric for model evaluation.
evaluate_higher_better (bool, optional) – Whether higher values of the evaluation metric are better.
learning_rate (float, default=0.001) – Learning rate for optimizer.
weight_decay (float, default=0.0) – L2 regularization strength.
patience (int, default=50) – Number of epochs to wait for improvement before early stopping.
use_lr_scheduler (bool, default=False) – Whether to use learning rate scheduler.
scheduler_factor (float, default=0.5) – Factor by which to reduce learning rate when plateau is reached.
scheduler_patience (int, default=5) – Number of epochs with no improvement after which learning rate will be reduced.
verbose (bool, default=False) – Whether to print progress information during training.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) LSTMMolecularPredictor [source]¶
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None) LSTMMolecularPredictor [source]¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
Transformer models based on SMILES
- class torch_molecule.predictor.smiles_transformer.modeling_transformer.SMILESTransformerMolecularPredictor(device: device | None = None, model_name: str = 'SMILESTransformerMolecularPredictor', num_task: int = 1, task_type: str = 'regression', input_dim: int = 54, output_dim: int = 15, LSTMunits: int = 60, max_input_len: int = 200, batch_size: int = 64, epochs: int = 200, loss_criterion: Callable | None = None, evaluate_criterion: str | Callable | None = None, evaluate_higher_better: bool | None = None, learning_rate: float = 0.0001, weight_decay: float = 0.0, patience: int = 20, use_lr_scheduler: bool = True, scheduler_factor: float = 0.5, scheduler_patience: int = 5, verbose: bool = False, hidden_size: int = 128, n_heads: int = 4, num_layers: int = 3, dim_feedforward: int | None = 256, dropout: float = 0.1)[source]¶
Bases:
LSTMMolecularPredictor
This predictor implements a Transformer model for SMILES-based molecular property predictions.
Notes
This implementation uses a transformer encoder architecture to learn representations of molecular structures from SMILES strings.
- Parameters:
num_task (int, default=1) – Number of prediction tasks.
task_type (str, default="regression") – Type of prediction task, either “regression” or “classification”.
input_dim (int, default=54) – Size of vocabulary for SMILES tokenization.
hidden_size (int, default=128) – Dimension of embedding vectors.
n_heads (int, default=4) – Number of attention heads in transformer layers.
num_layers (int, default=3) – Number of transformer encoder layers.
dim_feedforward (int, default=256) – Dimension of the feedforward network in transformer layers.
max_input_len (int, default=200) – Maximum length of input sequences. Shorter sequences will be padded.
dropout (float, default=0.1) – Dropout rate for transformer layers.
batch_size (int, default=64) – Number of samples per batch for training.
epochs (int, default=200) – Maximum number of training epochs.
loss_criterion (callable, optional) – Loss function for training. Defaults to MSELoss for regression.
evaluate_criterion (str or callable, optional) – Metric for model evaluation.
evaluate_higher_better (bool, optional) – Whether higher values of the evaluation metric are better.
learning_rate (float, default=0.0001) – Learning rate for optimizer.
weight_decay (float, default=0.0) – L2 regularization strength.
patience (int, default=20) – Number of epochs to wait for improvement before early stopping.
use_lr_scheduler (bool, default=True) – Whether to use learning rate scheduler.
scheduler_factor (float, default=0.5) – Factor by which to reduce learning rate when plateau is reached.
scheduler_patience (int, default=5) – Number of epochs with no improvement after which learning rate will be reduced.
verbose (bool, default=False) – Whether to print progress information during training.
- autofit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None, search_parameters: Dict[str, ParameterSpec] | None = None, n_trials: int = 10) LSTMMolecularPredictor ¶
- fit(X_train: List[str], y_train: List | ndarray | None, X_val: List[str] | None = None, y_val: List | ndarray | None = None) LSTMMolecularPredictor ¶
Fit the model to the training data with optional validation set.
- Parameters:
X_train (List[str]) – Training set input molecular structures as SMILES strings
y_train (Union[List, np.ndarray]) – Training set target values for property prediction
X_val (List[str], optional) – Validation set input molecular structures as SMILES strings. If None, training data will be used for validation
y_val (Union[List, np.ndarray], optional) – Validation set target values. Required if X_val is provided
- predict(X: List[str]) Dict[str, ndarray] ¶
Make predictions using the fitted model.
- Parameters:
X (List[str]) – List of SMILES strings to make predictions for
- Returns:
- Dictionary containing:
’prediction’: Model predictions (shape: [n_samples, n_tasks])
- Return type:
Dict[str, np.ndarray]