Source code for torch_molecule.utils.graph.graph_from_smiles
import numpy as np
from rdkit import Chem
from rdkit.Chem import Crippen
from .features import atom_to_feature_vector, bond_to_feature_vector
from .features import getmaccsfingerprint, getmorganfingerprint
from ..generic.pseudo_tasks import PSEUDOTASK
[docs]
def add_fingerprint_feature(mol, feature_type, get_fingerprint_fn):
    if feature_type is None:
        return None
    fingerprint = get_fingerprint_fn(mol)
    return np.expand_dims(np.array(fingerprint, dtype="int8"), axis=0) 
    
[docs]
def get_augmented_property(mol, properties):
    if mol is None:
        return None
            
    supported_properties = set(PSEUDOTASK.keys())
    unsupported = set(properties) - supported_properties
    if unsupported:
        raise ValueError(f"Unsupported properties: {unsupported}. Supported properties are: {supported_properties}")
    
    augmented_property = []
    if 'maccs' in properties:
        maccs = getmaccsfingerprint(mol)
        augmented_property.extend(maccs)
    if 'morgan' in properties:
        mgf = getmorganfingerprint(mol)
        augmented_property.extend(mgf)
    if 'logP' in properties:
        logp = Crippen.MolLogP(mol)
        augmented_property.append(logp)
    return augmented_property 
[docs]
def graph_from_smiles(smiles_or_mol, properties, augmented_features=None, augmented_properties=None):
    """
    Converts SMILES string or RDKit molecule to graph Data object
    
    Parameters
    ----------
    smiles_or_mol : Union[str, rdkit.Chem.rdchem.Mol]
        SMILES string or RDKit molecule object
    properties : Any
        Properties to include in the graph
    augmented_features : list
        List of augmented features to include
    augmented_properties : list, optional
        List of augmented properties to include
        
    Returns
    -------
    dict
        Graph object dictionary
    """
    # try:
    if isinstance(smiles_or_mol, str):
        mol = Chem.MolFromSmiles(smiles_or_mol)
    else:
        mol = smiles_or_mol
        
    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        # print(atom.GetSymbol(), atom_to_feature_vector(atom)[0])
        atom_features_list.append(atom_to_feature_vector(atom))
    x = np.array(atom_features_list, dtype=np.int64)
    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = bond_to_feature_vector(bond)
            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)
        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype=np.int64).T
        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype=np.int64)
    else:  # mol has no bonds
        edge_index = np.empty((2, 0), dtype=np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype=np.int64)
    graph = dict()
    graph["edge_index"] = edge_index
    graph["edge_feat"] = edge_attr
    graph["node_feat"] = x
    graph["num_nodes"] = len(x)
    # Handle properties and augmented properties
    props_list = []            
    if properties is not None:
        props_list.append(np.array(properties, dtype=np.float32))
    if augmented_properties is not None:
        aug_props = get_augmented_property(mol, augmented_properties)
        if aug_props:
            props_list.append(np.array(aug_props, dtype=np.float32))
    if props_list:
        combined_props = np.concatenate(props_list)
        graph['y'] = combined_props.reshape(1, -1)
    else:
        graph['y'] = np.full((1, 1), np.nan, dtype=np.float32)
    # Handle augmented features
    if augmented_features is not None:
        graph['morgan'] = add_fingerprint_feature(
            mol,
            'morgan' if 'morgan' in augmented_features else None,
            getmorganfingerprint
        )
        graph['maccs'] = add_fingerprint_feature(
            mol,
            'maccs' if 'maccs' in augmented_features else None,
            getmaccsfingerprint
        )
    else:
        graph['morgan'] = None
        graph['maccs'] = None
    return graph     
    
    # except Exception as e:
    #     print(f"Error: {e} during converting {smiles_or_mol} to graph")
    #     return None