Source code for torch_molecule.utils.graph.graph_from_smiles

import numpy as np
from rdkit import Chem
from rdkit.Chem import Crippen
from .features import atom_to_feature_vector, bond_to_feature_vector
from .features import getmaccsfingerprint, getmorganfingerprint
from ..generic.pseudo_tasks import PSEUDOTASK

[docs] def add_fingerprint_feature(mol, feature_type, get_fingerprint_fn): if feature_type is None: return None fingerprint = get_fingerprint_fn(mol) return np.expand_dims(np.array(fingerprint, dtype="int8"), axis=0)
[docs] def get_augmented_property(mol, properties): if mol is None: return None supported_properties = set(PSEUDOTASK.keys()) unsupported = set(properties) - supported_properties if unsupported: raise ValueError(f"Unsupported properties: {unsupported}. Supported properties are: {supported_properties}") augmented_property = [] if 'maccs' in properties: maccs = getmaccsfingerprint(mol) augmented_property.extend(maccs) if 'morgan' in properties: mgf = getmorganfingerprint(mol) augmented_property.extend(mgf) if 'logP' in properties: logp = Crippen.MolLogP(mol) augmented_property.append(logp) return augmented_property
[docs] def graph_from_smiles(smiles_or_mol, properties, augmented_features=None, augmented_properties=None): """ Converts SMILES string or RDKit molecule to graph Data object Parameters ---------- smiles_or_mol : Union[str, rdkit.Chem.rdchem.Mol] SMILES string or RDKit molecule object properties : Any Properties to include in the graph augmented_features : list List of augmented features to include augmented_properties : list, optional List of augmented properties to include Returns ------- dict Graph object dictionary """ # try: if isinstance(smiles_or_mol, str): mol = Chem.MolFromSmiles(smiles_or_mol) else: mol = smiles_or_mol # atoms atom_features_list = [] for atom in mol.GetAtoms(): # print(atom.GetSymbol(), atom_to_feature_vector(atom)[0]) atom_features_list.append(atom_to_feature_vector(atom)) x = np.array(atom_features_list, dtype=np.int64) # bonds num_bond_features = 3 # bond type, bond stereo, is_conjugated if len(mol.GetBonds()) > 0: # mol has bonds edges_list = [] edge_features_list = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_feature = bond_to_feature_vector(bond) # add edges in both directions edges_list.append((i, j)) edge_features_list.append(edge_feature) edges_list.append((j, i)) edge_features_list.append(edge_feature) # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] edge_index = np.array(edges_list, dtype=np.int64).T # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] edge_attr = np.array(edge_features_list, dtype=np.int64) else: # mol has no bonds edge_index = np.empty((2, 0), dtype=np.int64) edge_attr = np.empty((0, num_bond_features), dtype=np.int64) graph = dict() graph["edge_index"] = edge_index graph["edge_feat"] = edge_attr graph["node_feat"] = x graph["num_nodes"] = len(x) # Handle properties and augmented properties props_list = [] if properties is not None: props_list.append(np.array(properties, dtype=np.float32)) if augmented_properties is not None: aug_props = get_augmented_property(mol, augmented_properties) if aug_props: props_list.append(np.array(aug_props, dtype=np.float32)) if props_list: combined_props = np.concatenate(props_list) graph['y'] = combined_props.reshape(1, -1) else: graph['y'] = np.full((1, 1), np.nan, dtype=np.float32) # Handle augmented features if augmented_features is not None: graph['morgan'] = add_fingerprint_feature( mol, 'morgan' if 'morgan' in augmented_features else None, getmorganfingerprint ) graph['maccs'] = add_fingerprint_feature( mol, 'maccs' if 'maccs' in augmented_features else None, getmaccsfingerprint ) else: graph['morgan'] = None graph['maccs'] = None return graph
# except Exception as e: # print(f"Error: {e} during converting {smiles_or_mol} to graph") # return None