Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for torch_molecule.base.generator
from dataclasses import dataclass , field
from abc import ABC , abstractmethod
from typing import Optional , ClassVar , Union , List , Dict , Any , Tuple , Callable , Type , Literal
import torch
import numpy as np
from .base import BaseModel
[docs]
@dataclass
class BaseMolecularGenerator ( BaseModel , ABC ):
"""Base class for molecular generation."""
model_name : str = field ( default = "BaseMolecularGenerator" )
[docs]
@abstractmethod
def fit ( self , X : List [ str ], y : Optional [ np . ndarray ] = None ) -> "BaseMolecularGenerator" :
pass
[docs]
@abstractmethod
def generate ( self , n_samples : int , ** kwargs ) -> List [ str ]:
"""Generate molecular structures.
"""
pass