Source code for torch_molecule.datasets.load_local_csv
import os
import numpy as np
import gzip
import csv
from typing import List, Tuple
from .constant import SMILESDataset
current_file_path = os.path.dirname(os.path.abspath(__file__))
def _load_from_local_csv(
filename: str,
input_cols: List[str],
target_cols: List[str]
) -> Tuple[List[List[str]], np.ndarray]:
"""
Generic function to load data from local CSV.gz file within torch_molecule package.
Args:
filename (str): Name of the CSV.gz file in the data directory
input_cols (List[str]): List of input column names (e.g., SMILES)
target_cols (List[str]): List of target column names
Returns:
SMILESDataset:
Dataset object with data (SMILES strings) and target (property values) attributes
- data: List[str]
- target: np.ndarray
"""
data_path = os.path.join(current_file_path, "data", filename)
if not os.path.exists(data_path):
raise FileNotFoundError(f"Dataset not found at {data_path}")
# print(f"Loading dataset from {data_path}")
input_data = []
property_data = []
with gzip.open(data_path, 'rt', encoding='utf-8') as f:
csv_reader = csv.DictReader(f)
# Check if required columns exist
all_required_cols = input_cols + target_cols
missing_cols = [col for col in all_required_cols if col not in csv_reader.fieldnames]
if missing_cols:
raise ValueError(f"Required columns {missing_cols} not found in dataset. Available columns: {list(csv_reader.fieldnames)}")
# Read data
for row in csv_reader:
# Get input data (e.g., SMILES)
input_row = [row[col] for col in input_cols]
input_data.append(input_row)
# Get target data
property_row = []
for col in target_cols:
try:
# Convert to float, handle empty strings or None
value = row[col]
if value == '' or value is None:
property_row.append(np.nan)
else:
property_row.append(float(value))
except (ValueError, TypeError):
property_row.append(np.nan)
property_data.append(property_row)
property_numpy = np.array(property_data)
return input_data, property_numpy
[docs]
def load_gasperm(
target_cols: List[str] = ["CH4", "CO2", "H2", "N2", "O2"],
) -> SMILESDataset:
"""
Load gas permeability dataset from local CSV.gz file within torch_molecule package.
Args:
target_cols (List[str]): List of target column names. Default is ["CH4", "CO2", "H2", "N2", "O2"]
Returns:
SMILESDataset:
Dataset object with data (SMILES strings) and target (property values) attributes
- data: List[str]
- target: np.ndarray
"""
input_cols = ["SMILES"]
filename = "polymer_gas_permeability.csv.gz"
input_data, property_numpy = _load_from_local_csv(filename, input_cols, target_cols)
smiles_list = [row[0] for row in input_data]
return SMILESDataset(data=smiles_list, target=property_numpy)