Source code for torch_molecule.generator.graph_ga.modeling_graph_ga

import numpy as np
import random
import joblib
from joblib import delayed
from rdkit import Chem
from tqdm import tqdm
from typing import Optional, Union, Dict, Any, Tuple, List, Callable
from dataclasses import dataclass, field
from warnings import warn

from .crossover import crossover
from .mutate import mutate
from .oracle import Oracle

from ...base import BaseMolecularGenerator

[docs] @dataclass class GraphGAMolecularGenerator(BaseMolecularGenerator): """This generator implements the Graph Genetic Algorithm for molecular generation. References ---------- - A Graph-Based Genetic Algorithm and Its Application to the Multiobjective Evolution of Median Molecules. Journal of Chemical Information and Computer Sciences. https://pubs.acs.org/doi/10.1021/ci034290p - Implementation: https://github.com/wenhao-gao/mol_opt Parameters ---------- num_task : int, default=0 Number of properties to condition on. Set to 0 for unconditional generation. population_size : int, default=100 Size of the population in each iteration. offspring_size : int, default=50 Number of offspring molecules to generate in each iteration. mutation_rate : float, default=0.0067 Probability of mutation occurring during reproduction. n_jobs : int, default=1 Number of parallel jobs to run. -1 means using all processors. iteration : int, default=5 Number of iterations for each target label (or random sample) to run the genetic algorithm. verbose : bool, default=False Whether to display progress bars and logs. """ # GA parameters num_task: int = 0 population_size: int = 100 offspring_size: int = 50 mutation_rate: float = 0.0067 n_jobs: int = 1 iteration: int = 5 # Other parameters verbose: bool = False model_name: str = "GraphGAMolecularGenerator" model_class = None def __post_init__(self): super().__post_init__() @staticmethod def _get_param_names() -> List[str]: return [ "num_task", "population_size", "offspring_size", "mutation_rate", "n_jobs", "iteration", "verbose" ] def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]: raise NotImplementedError("GraphGA does not support getting model parameters") def save_to_local(self, path: str): joblib.dump(self.oracle, path) if self.verbose: print(f"Saved oracle to {path}") def load_from_local(self): raise NotImplementedError( "GraphGA does not support loading from local. " "If you want to load the oracles saved through save_to_local, " "you need to manually load the oracle from the path with joblib.load(path) " "and pass it to the fit function." ) def save_to_hf(self, repo_id: str, task_id: str = "default"): raise NotImplementedError("GraphGA does not support pushing to huggingface") def load_from_hf(self, repo_id: str, task_id: str = "default"): raise NotImplementedError("GraphGA does not support loading from huggingface") def _setup_optimizers(self): raise NotImplementedError("GraphGA does not support setting up optimizers") def _train_epoch(self, train_loader, optimizer): raise NotImplementedError("GraphGA does not support training epochs")
[docs] def fit( self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]] = None, oracle: Optional[List[Callable]] = None ) -> "GraphGAMolecularGenerator": """Fit the model to the training data. Parameters ---------- X_train : List[str] Training data, which will be used as the initial population. y_train : Optional[Union[List, np.ndarray]] Training labels for conditional generation (num_task is not 0). oracle : Optional[Callable] Oracle used to score the generated molecules. If not provided, default oracles based on ``sklearn.ensemble.RandomForestRegressor`` are trained on the X_train and y_train. For a customized oracle, it should be a Callable object, i.e., ``oracle(X, y)``. Please properly wrap your oracle to take two inputs: - a list of ``rdkit.Chem.rdchem.Mol`` objects and - a (1, num_task) numpy array of target values that all the molecules in the list target to achieve. Take care of NaN values if any. Scores for different tasks should be aggregated, i.e., mean or sum. The return should be a list of scores (float). Smaller scores mean closer to the target goal. Oracles are not needed for unconditional generation. Returns ------- self : GraphGAMolecularGenerator Fitted model. """ self.y_train = None if oracle is not None: self.oracle = oracle else: X_train, y_train = self._validate_inputs(X_train, y_train, num_task=self.num_task, return_rdkit_mol=False) if y_train is not None: warn("No oracles provided but y_train is provided, using default oracles (RandomForestRegressor)", UserWarning) self.oracle = Oracle(num_task=self.num_task) self.oracle.fit(X_train, y_train) self.y_train = y_train else: assert self.num_task == 0, "No oracles or y_train provided but num_task is not 0" self.oracle = None self.X_train = X_train self.is_fitted_ = True return self
def _make_mating_pool(self, population_mol, population_scores, offspring_size: int): """Create mating pool where smaller scores have higher selection probabilities.""" max_score = max(population_scores) # Invert scores so that smaller scores become larger probabilities inverted_scores = [max_score - s + 1e-6 for s in population_scores] # Add small constant to avoid zeros sum_scores = sum(inverted_scores) population_probs = [p / sum_scores for p in inverted_scores] mating_pool = np.random.choice(population_mol, p=population_probs, size=offspring_size, replace=True) return mating_pool def _reproduce(self, mating_pool, mutation_rate): """Create new molecule through crossover and mutation.""" parent_a = random.choice(mating_pool) parent_b = random.choice(mating_pool) new_child = crossover(parent_a, parent_b) if new_child is not None: new_child = mutate(new_child, mutation_rate) return new_child def _sanitize_molecules(self, population_mol): """Sanitize molecules by removing duplicates and invalid molecules.""" new_mol_list = [] smiles_set = set() for mol in population_mol: if mol is not None: try: smiles = Chem.MolToSmiles(mol) if smiles is not None and smiles not in smiles_set: smiles_set.add(smiles) new_mol_list.append(mol) except ValueError: pass return new_mol_list def _get_score(self, mol_list, label): if label is None: return [1.0] * len(mol_list) # For unconditional generation return self.oracle(mol_list, label)
[docs] def generate( self, labels: Optional[Union[List[List], np.ndarray]] = None, num_samples: int = 32 ) -> List[str]: """Generate molecules using genetic algorithm optimization.""" if not self.is_fitted_: raise RuntimeError("Model must be fitted before generating") all_generated_mols = [] if labels is not None: try: labels = np.array(labels).reshape(-1, self.num_task) except: raise ValueError(f"labels must be convertible to a numpy array with shape (-1, {self.num_task})") # Prepare all inputs for parallel processing parallel_inputs = [] for i in range(labels.shape[0]): label = labels[i:i+1] # Keep as 2D array # Initialize population based on similarity to target label if self.y_train is not None: population_mol = self._initialize_population_for_label(label) else: population_idx = np.random.choice(len(self.X_train), min(self.population_size, len(self.X_train))) population_smiles = [self.X_train[idx] for idx in population_idx] population_mol = [Chem.MolFromSmiles(s) for s in population_smiles] parallel_inputs.append((population_mol, label)) # Run GA for all labels in parallel with tqdm progress bar if self.verbose: results = joblib.Parallel(n_jobs=self.n_jobs)( delayed(self._run_generation)(pop_mol, lbl) for pop_mol, lbl in tqdm(parallel_inputs, desc="Generating molecules") ) else: results = joblib.Parallel(n_jobs=self.n_jobs)( delayed(self._run_generation)(pop_mol, lbl) for pop_mol, lbl in parallel_inputs ) # Convert results to SMILES in the original order all_generated_mols = [Chem.MolToSmiles(mol) for mol in results] else: # Prepare all inputs for parallel processing parallel_inputs = [] for _ in range(num_samples): population_idx = np.random.choice(len(self.X_train), min(self.population_size, len(self.X_train))) population_smiles = [self.X_train[idx] for idx in population_idx] population_mol = [Chem.MolFromSmiles(s) for s in population_smiles] parallel_inputs.append((population_mol, None)) # Run GA for all samples in parallel with tqdm progress bar if self.verbose: results = joblib.Parallel(n_jobs=self.n_jobs)( delayed(self._run_generation)(pop_mol, lbl) for pop_mol, lbl in tqdm(parallel_inputs, desc="Generating molecules", total=num_samples) ) else: results = joblib.Parallel(n_jobs=self.n_jobs)( delayed(self._run_generation)(pop_mol, lbl) for pop_mol, lbl in parallel_inputs ) # Convert results to SMILES all_generated_mols = [Chem.MolToSmiles(mol) for mol in results] return all_generated_mols
def _initialize_population_for_label(self, label): """Initialize population based on similarity to target label.""" similarities = [] for i in range(len(self.X_train)): sample_label = self.y_train[i] similarity = -np.nansum((sample_label - label[0])**2) similarities.append((i, similarity)) if similarities: similarities.sort(key=lambda x: x[1], reverse=True) top_indices = [x[0] for x in similarities[:self.population_size]] else: top_indices = np.random.choice(len(self.X_train), min(self.population_size, len(self.X_train))) population_smiles = [self.X_train[i] for i in top_indices] return [Chem.MolFromSmiles(s) for s in population_smiles] def _run_generation(self, population_mol, label): """Run the genetic algorithm for a specific population and label.""" for generation_idx in range(self.iteration): if label is not None: population_scores = self._get_score(population_mol, label) else: population_scores = [1.0] * len(population_mol) # For unconditional generation mating_pool = self._make_mating_pool(population_mol, population_scores, self.offspring_size) # Create offspring sequentially (parallelization is at the higher level now) offspring_mol = [] for _ in range(self.offspring_size): offspring = self._reproduce(mating_pool, self.mutation_rate) offspring_mol.append(offspring) population_mol += offspring_mol population_mol = self._sanitize_molecules(population_mol) # Re-score the expanded population if label is not None: population_scores = self._get_score(population_mol, label) else: population_scores = [1.0] * len(population_mol) # Select top molecules for next generation population_tuples = list(zip(population_scores, population_mol)) population_tuples = sorted(population_tuples, key=lambda x: x[0], reverse=False) # lower score is better population_tuples = population_tuples[:self.population_size] population_mol = [t[1] for t in population_tuples] # Return the best molecule return population_mol[0]