[Enhancement] Refactor Dataset (#184)

# Summary
I have introduced a new base `class Dataset` in `ml.py` which all datasets should subclass. It stores the dataset as a polars DataFrame with the column names and number of columns determined by the subclass. It implements generic methods such as `add_row`, `at`, `limit` and dataset saving. It also details abstract methods required by the subclasses. These include `X`, `y` and `generate_dataset`.

There are two subclasses that currently exist. `RuleBasedDataset` for the MLRR models and `EnviFormerDataset` for the enviFormer models.

# Old Dataset to New RuleBasedDataset Functionality Translation

- [x] \_\_init\_\_
    - self.columns and self.num_labels moved to base Dataset class
    - self.data moved to base class with name self.df along with initialising from list or from another DataFrame
    - struct_features, triggered and observed remain the same
- [x] \_block\_indices
    - function moved to base Dataset class
- [x] structure_id
    - stays in RuleBasedDataset, now requires an index for the row of interest
- [x] add_row
    - moved to base Dataset class, now calls add_rows so one or more rows can be added at a time
- [x] times_triggered
    - stays in RuleBasedDataset, now does a look up using polars df.filter
- [x] struct_features (see init)
- [x] triggered (see init)
- [x] observed (see init)
- [x] at
    - removed in favour of indexing with getitem
- [x] limit
    - removed in favour of indexing with getitem
- [x] classification_dataset
    - stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] generate_dataset
    - stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] X
    - moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] trig
    - stays in RuleBasedDataset, functionally the same but uses polars
- [x] y
    - moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] \_\_get_item\_\_
    - moved to base dataset, now passes item to the dataframe for polars to handle
- [x] to_arff
    - stays in RuleBasedDataset, functionally the same but uses polars
- [x] \_\_repr\_\_
    - moved to base dataset
- [x] \_\_iter\_\_
    - moved to base Dataset, now uses polars iter_rows

# Base Dataset class Features
The following functions are available in the base Dataset class

- init - Create the dataset from a list of columns and data in format list of list. Or can create a dataset from a polars Dataframe, this is essential for recreating itself during indexing. Can create an empty dataset by just passing column names.
- add_rows - Add rows to the Dataset, we check that the new data length is the same but it is presumed that the column order matches the existing dataframe
- add_row - Add one row, see add_rows
- block_indices - Returns the column indices that start with the given prefix
- columns - Property, returns dataframe.columns
- shape - Property, returns dataframe.shape
- X - Abstract method to be implemented by the subclasses, it should represent the input to a ML model
- y - Abstract method to be implemented by the subclasses, it should represent the target for a ML model
- generate_dataset - Abstract and static method to be implemented by the subclasses, should return an initialised subclass of Dataset
- iter - returns the iterable from dataframe.iter_rows()
- getitem - passes the item argument to the dataframe. If the result of indexing the dataframe is another dataframe, the new dataframe is  packaged into a new Dataset of the same subclass. If the result of indexing is something else (int, float, polar Series) return the result.
- save - Pickle and save the dataframe to the given path
- load - Static method to load the dataset from the given path
- to_numpy - returns the dataframe as a numpy array. Required for compatibility with training of the ECC model
- repr - return a representation of the dataset
- len - return the length of the dataframe
- iter_rows - Return dataframe.iterrows with arguments passed through. Mainly used to get the named iterable which returns rows of the dataframe as dict of column names: column values instead of tuple of column values.
- filter - pass to dataframe.filter and recreates self with the result
- select - pass to dataframe.select and recreates self with the result
- with_columns - pass to dataframe.with_columns and recreates self with the result
- sort - pass to dataframe.sort and recreates self with the result
- item - pass to dataframe.item
- fill_nan - fill the dataframe nan's with value
- height - Property, returns the height (number of rows) of the dataframe

- [x] App domain
- [x] MACCS alternatives

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#184
Reviewed-by: jebus <lorsbach@envipath.com>
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-11-07 08:46:17 +13:00
committed by jebus
parent 98d62e1d1f
commit e26d5a21e3
10 changed files with 754 additions and 513 deletions

View File

@ -1,8 +1,10 @@
import os.path
from tempfile import TemporaryDirectory
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule
from utilities.ml import Dataset
from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.chem import FormatConverter
from utilities.ml import RuleBasedDataset, EnviFormerDataset
class DatasetTest(TestCase):
@ -41,12 +43,108 @@ class DatasetTest(TestCase):
super(DatasetTest, cls).setUpClass()
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self):
def test_generate_dataset(self):
"""Test generating dataset does not crash"""
self.generate_rule_dataset()
def test_indexing(self):
"""Test indexing a few different ways to check for crashes"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds[5])
print(ds[2, 5])
print(ds[3:6, 2:8])
print(ds[:2, "structure_id"])
def test_add_rows(self):
"""Test adding one row and adding multiple rows"""
ds, reactions, rules = self.generate_rule_dataset()
ds.add_row(list(ds.df.row(1)))
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
def test_times_triggered(self):
"""Check getting times triggered for a rule id"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.times_triggered(rules[0].uuid))
def test_block_indices(self):
"""Test the usages of _block_indices"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.struct_features())
print(ds.triggered())
print(ds.observed())
def test_structure_id(self):
"""Check getting a structure id from row index"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.structure_id(0))
def test_x(self):
"""Test getting X portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.X().df.head())
def test_trig(self):
"""Test getting the triggered portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.trig().df.head())
def test_y(self):
"""Test getting the Y portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.y().df.head())
def test_classification_dataset(self):
"""Test making the classification dataset"""
ds, reactions, rules = self.generate_rule_dataset()
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
class_ds, products = ds.classification_dataset(compounds, rules)
print(class_ds.df.head(5))
print(products[:5])
def test_extra_features(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
print(ds.shape)
def test_to_arff(self):
"""Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_rule_dataset()
ds.to_arff("dataset_arff_test.arff")
def test_save_load(self):
"""Test saving and loading dataset"""
with TemporaryDirectory() as tmpdir:
ds, reactions, rules = self.generate_rule_dataset()
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
self.assertTrue(ds.df.equals(ds_loaded.df))
def test_dataset_example(self):
"""Test with a concrete example checking dataset size"""
reactions = [r for r in Reaction.objects.filter(package=self.package)]
applicable_rules = [self.rule1]
ds = Dataset.generate_dataset(reactions, applicable_rules)
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
self.assertEqual(len(ds.y()), 1)
self.assertEqual(sum(ds.y()[0]), 1)
self.assertEqual(ds.y().df.item(), 1)
def test_enviformer_dataset(self):
ds, reactions = self.generate_enviformer_dataset()
print(ds.X().head())
print(ds.y().head())
def generate_rule_dataset(self):
"""Generate a RuleBasedDataset from test package data"""
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
return ds, reactions, applicable_rules
def generate_enviformer_dataset(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
ds = EnviFormerDataset.generate_dataset(reactions)
return ds, reactions