[Enhancement] Refactor Dataset (#184)

# Summary
I have introduced a new base `class Dataset` in `ml.py` which all datasets should subclass. It stores the dataset as a polars DataFrame with the column names and number of columns determined by the subclass. It implements generic methods such as `add_row`, `at`, `limit` and dataset saving. It also details abstract methods required by the subclasses. These include `X`, `y` and `generate_dataset`.

There are two subclasses that currently exist. `RuleBasedDataset` for the MLRR models and `EnviFormerDataset` for the enviFormer models.

# Old Dataset to New RuleBasedDataset Functionality Translation

- [x] \_\_init\_\_
    - self.columns and self.num_labels moved to base Dataset class
    - self.data moved to base class with name self.df along with initialising from list or from another DataFrame
    - struct_features, triggered and observed remain the same
- [x] \_block\_indices
    - function moved to base Dataset class
- [x] structure_id
    - stays in RuleBasedDataset, now requires an index for the row of interest
- [x] add_row
    - moved to base Dataset class, now calls add_rows so one or more rows can be added at a time
- [x] times_triggered
    - stays in RuleBasedDataset, now does a look up using polars df.filter
- [x] struct_features (see init)
- [x] triggered (see init)
- [x] observed (see init)
- [x] at
    - removed in favour of indexing with getitem
- [x] limit
    - removed in favour of indexing with getitem
- [x] classification_dataset
    - stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] generate_dataset
    - stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] X
    - moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] trig
    - stays in RuleBasedDataset, functionally the same but uses polars
- [x] y
    - moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] \_\_get_item\_\_
    - moved to base dataset, now passes item to the dataframe for polars to handle
- [x] to_arff
    - stays in RuleBasedDataset, functionally the same but uses polars
- [x] \_\_repr\_\_
    - moved to base dataset
- [x] \_\_iter\_\_
    - moved to base Dataset, now uses polars iter_rows

# Base Dataset class Features
The following functions are available in the base Dataset class

- init - Create the dataset from a list of columns and data in format list of list. Or can create a dataset from a polars Dataframe, this is essential for recreating itself during indexing. Can create an empty dataset by just passing column names.
- add_rows - Add rows to the Dataset, we check that the new data length is the same but it is presumed that the column order matches the existing dataframe
- add_row - Add one row, see add_rows
- block_indices - Returns the column indices that start with the given prefix
- columns - Property, returns dataframe.columns
- shape - Property, returns dataframe.shape
- X - Abstract method to be implemented by the subclasses, it should represent the input to a ML model
- y - Abstract method to be implemented by the subclasses, it should represent the target for a ML model
- generate_dataset - Abstract and static method to be implemented by the subclasses, should return an initialised subclass of Dataset
- iter - returns the iterable from dataframe.iter_rows()
- getitem - passes the item argument to the dataframe. If the result of indexing the dataframe is another dataframe, the new dataframe is  packaged into a new Dataset of the same subclass. If the result of indexing is something else (int, float, polar Series) return the result.
- save - Pickle and save the dataframe to the given path
- load - Static method to load the dataset from the given path
- to_numpy - returns the dataframe as a numpy array. Required for compatibility with training of the ECC model
- repr - return a representation of the dataset
- len - return the length of the dataframe
- iter_rows - Return dataframe.iterrows with arguments passed through. Mainly used to get the named iterable which returns rows of the dataframe as dict of column names: column values instead of tuple of column values.
- filter - pass to dataframe.filter and recreates self with the result
- select - pass to dataframe.select and recreates self with the result
- with_columns - pass to dataframe.with_columns and recreates self with the result
- sort - pass to dataframe.sort and recreates self with the result
- item - pass to dataframe.item
- fill_nan - fill the dataframe nan's with value
- height - Property, returns the height (number of rows) of the dataframe

- [x] App domain
- [x] MACCS alternatives

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#184
Reviewed-by: jebus <lorsbach@envipath.com>
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-11-07 08:46:17 +13:00
committed by jebus
parent 98d62e1d1f
commit e26d5a21e3
10 changed files with 754 additions and 513 deletions

View File

@ -5,11 +5,14 @@ import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
from abc import ABC, abstractmethod
import networkx as nx
import numpy as np
from envipy_plugins import Descriptor
from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
@ -26,70 +29,281 @@ if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
class Dataset:
def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns
self.num_labels: int = num_labels
if data is None:
self.data: List[List[str | int | float]] = list()
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
self.df = data
else:
self.data = data
# Build either an empty dataframe with columns or fill it with list of list data
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def add_rows(self, rows: List[List[str | int | float]]):
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
self.df.extend(new_rows)
def _block_indices(self, prefix) -> Tuple[int, int]:
def add_row(self, row: List[str | int | float]):
"""See add_rows"""
self.add_rows([row])
def block_indices(self, prefix) -> List[int]:
"""Find the indexes in column labels that has the prefix"""
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return indices
return min(indices), max(indices)
@property
def columns(self) -> List[str]:
"""Use the polars dataframe columns"""
return self.df.columns
def structure_id(self):
return self.data[0][0]
@property
def shape(self):
return self.df.shape
def add_row(self, row: List[str | int | float]):
if len(self.columns) != len(row):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
self.data.append(row)
@abstractmethod
def X(self, **kwargs):
pass
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f"trig_{rule_uuid}")
@abstractmethod
def y(self, **kwargs):
pass
times_triggered = 0
for row in self.data:
if row[idx] == 1:
times_triggered += 1
return times_triggered
def struct_features(self) -> Tuple[int, int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
return self._observed
def at(self, position: int) -> Dataset:
return Dataset(self.columns, self.num_labels, [self.data[position]])
def limit(self, limit: int) -> Dataset:
return Dataset(self.columns, self.num_labels, self.data[:limit])
@staticmethod
@abstractmethod
def generate_dataset(reactions, *args, **kwargs):
pass
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
"""Use polars iter_rows for iterating over the dataset"""
return self.df.iter_rows()
def __getitem__(self, item):
"""Item is passed to polars allowing for advanced indexing.
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path | str"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
)
def __len__(self):
return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
def filter(self, *predicates, **constraints):
return self.__class__(data=self.df.filter(*predicates, **constraints))
def select(self, *exprs, **named_exprs):
return self.__class__(data=self.df.select(*exprs, **named_exprs))
def with_columns(self, *exprs, **name_exprs):
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
multithreaded=multithreaded, maintain_order=maintain_order))
def item(self, row=None, column=None):
return self.df.item(row, column)
def fill_nan(self, value):
return self.__class__(data=self.df.fill_nan(value))
@property
def height(self):
return self.df.height
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> List[int]:
return self._struct_features
def triggered(self) -> List[int]:
return self._triggered
def observed(self) -> List[int]:
return self._observed
def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.item(index, "structure_id")
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
_col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, self._observed]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
feat_columns = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
else:
feats = feat_func(compounds[0].smiles)
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
ds_columns = (["structure_id"] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
rows = []
for i, comp in enumerate(compounds):
# Features
feats = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feat = feat_func.get_molecule_descriptors(comp.smiles)
else:
feat = feat_func(comp.smiles)
feats.extend(feat)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feats + trig + obs)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
return ds
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]:
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
@ -113,186 +327,18 @@ class Dataset:
else:
trig.append(0)
prods.append([])
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
new_row = [struct_id] + features + trig + ([-1] * len(trig))
if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products
return Dataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set()
for r in reactions:
for e in r.educts.all():
_structures.add(e)
if not educts_only:
for e in r.products:
_structures.add(e)
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
else:
pass
ds = None
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
if ds is None:
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
row_key, col_key = key
# Normalize rows
if isinstance(row_key, int):
rows = [self.data[row_key]]
else:
rows = self.data[row_key]
# Normalize columns
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
@ -304,7 +350,7 @@ class Dataset:
arff += f"@attribute {c} {{0,1}}\n"
arff += "\n@data\n"
for d in self.data:
for d in self:
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
@ -313,10 +359,40 @@ class Dataset:
fh.write(arff)
fh.flush()
def __repr__(self):
return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class EnviFormerDataset(Dataset):
def __init__(self, columns=None, data=None):
super().__init__(columns, data)
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data
stereo = kwargs.get("stereo", False)
rows = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all()
]
)
rows.append([e, p])
ds = EnviFormerDataset(["educts", "products"], rows)
return ds
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -498,7 +574,7 @@ class EnsembleClassifierChain:
self.classifiers = []
if self.num_labels is None:
self.num_labels = len(Y[0])
self.num_labels = Y.shape[1]
for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
@ -529,7 +605,7 @@ class RelativeReasoning:
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = len(Y[0])
n_attributes = Y.shape[1]
for i in range(n_attributes):
for j in range(n_attributes):
@ -541,8 +617,8 @@ class RelativeReasoning:
countboth = 0
for k in range(n_instances):
vi = Y[k][i]
vj = Y[k][j]
vi = Y[k, i]
vj = Y[k, j]
if vi is None or vj is None:
continue
@ -598,7 +674,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: "Dataset"):
def build(self, train_dataset: "RuleBasedDataset"):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
@ -612,7 +688,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: "Dataset"):
def is_applicable(self, classify_instances: "RuleBasedDataset"):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []