forked from enviPath/enviPy
# Summary
I have introduced a new base `class Dataset` in `ml.py` which all datasets should subclass. It stores the dataset as a polars DataFrame with the column names and number of columns determined by the subclass. It implements generic methods such as `add_row`, `at`, `limit` and dataset saving. It also details abstract methods required by the subclasses. These include `X`, `y` and `generate_dataset`.
There are two subclasses that currently exist. `RuleBasedDataset` for the MLRR models and `EnviFormerDataset` for the enviFormer models.
# Old Dataset to New RuleBasedDataset Functionality Translation
- [x] \_\_init\_\_
- self.columns and self.num_labels moved to base Dataset class
- self.data moved to base class with name self.df along with initialising from list or from another DataFrame
- struct_features, triggered and observed remain the same
- [x] \_block\_indices
- function moved to base Dataset class
- [x] structure_id
- stays in RuleBasedDataset, now requires an index for the row of interest
- [x] add_row
- moved to base Dataset class, now calls add_rows so one or more rows can be added at a time
- [x] times_triggered
- stays in RuleBasedDataset, now does a look up using polars df.filter
- [x] struct_features (see init)
- [x] triggered (see init)
- [x] observed (see init)
- [x] at
- removed in favour of indexing with getitem
- [x] limit
- removed in favour of indexing with getitem
- [x] classification_dataset
- stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] generate_dataset
- stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] X
- moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] trig
- stays in RuleBasedDataset, functionally the same but uses polars
- [x] y
- moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] \_\_get_item\_\_
- moved to base dataset, now passes item to the dataframe for polars to handle
- [x] to_arff
- stays in RuleBasedDataset, functionally the same but uses polars
- [x] \_\_repr\_\_
- moved to base dataset
- [x] \_\_iter\_\_
- moved to base Dataset, now uses polars iter_rows
# Base Dataset class Features
The following functions are available in the base Dataset class
- init - Create the dataset from a list of columns and data in format list of list. Or can create a dataset from a polars Dataframe, this is essential for recreating itself during indexing. Can create an empty dataset by just passing column names.
- add_rows - Add rows to the Dataset, we check that the new data length is the same but it is presumed that the column order matches the existing dataframe
- add_row - Add one row, see add_rows
- block_indices - Returns the column indices that start with the given prefix
- columns - Property, returns dataframe.columns
- shape - Property, returns dataframe.shape
- X - Abstract method to be implemented by the subclasses, it should represent the input to a ML model
- y - Abstract method to be implemented by the subclasses, it should represent the target for a ML model
- generate_dataset - Abstract and static method to be implemented by the subclasses, should return an initialised subclass of Dataset
- iter - returns the iterable from dataframe.iter_rows()
- getitem - passes the item argument to the dataframe. If the result of indexing the dataframe is another dataframe, the new dataframe is packaged into a new Dataset of the same subclass. If the result of indexing is something else (int, float, polar Series) return the result.
- save - Pickle and save the dataframe to the given path
- load - Static method to load the dataset from the given path
- to_numpy - returns the dataframe as a numpy array. Required for compatibility with training of the ECC model
- repr - return a representation of the dataset
- len - return the length of the dataframe
- iter_rows - Return dataframe.iterrows with arguments passed through. Mainly used to get the named iterable which returns rows of the dataframe as dict of column names: column values instead of tuple of column values.
- filter - pass to dataframe.filter and recreates self with the result
- select - pass to dataframe.select and recreates self with the result
- with_columns - pass to dataframe.with_columns and recreates self with the result
- sort - pass to dataframe.sort and recreates self with the result
- item - pass to dataframe.item
- fill_nan - fill the dataframe nan's with value
- height - Property, returns the height (number of rows) of the dataframe
- [x] App domain
- [x] MACCS alternatives
Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#184
Reviewed-by: jebus <lorsbach@envipath.com>
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
151 lines
6.0 KiB
Python
151 lines
6.0 KiB
Python
import os.path
|
|
from tempfile import TemporaryDirectory
|
|
from django.test import TestCase
|
|
from epdb.logic import PackageManager
|
|
from epdb.models import Reaction, Compound, User, Rule, Package
|
|
from utilities.chem import FormatConverter
|
|
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
|
|
|
|
|
class DatasetTest(TestCase):
|
|
fixtures = ["test_fixtures.jsonl.gz"]
|
|
|
|
def setUp(self):
|
|
self.cs1 = Compound.create(
|
|
self.package,
|
|
name="2,6-Dibromohydroquinone",
|
|
description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b",
|
|
smiles="C1=C(C(=C(C=C1O)Br)O)Br",
|
|
).default_structure
|
|
|
|
self.cs2 = Compound.create(
|
|
self.package,
|
|
smiles="O=C(O)CC(=O)/C=C(/Br)C(=O)O",
|
|
).default_structure
|
|
|
|
self.rule1 = Rule.create(
|
|
rule_type="SimpleAmbitRule",
|
|
package=self.package,
|
|
smirks="[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\\[#6:3]=[#6:2](\\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]",
|
|
description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6",
|
|
)
|
|
|
|
self.reaction1 = Reaction.create(
|
|
package=self.package,
|
|
educts=[self.cs1],
|
|
products=[self.cs2],
|
|
rules=[self.rule1],
|
|
multi_step=False,
|
|
)
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(DatasetTest, cls).setUpClass()
|
|
cls.user = User.objects.get(username="anonymous")
|
|
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
|
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
|
|
|
def test_generate_dataset(self):
|
|
"""Test generating dataset does not crash"""
|
|
self.generate_rule_dataset()
|
|
|
|
def test_indexing(self):
|
|
"""Test indexing a few different ways to check for crashes"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds[5])
|
|
print(ds[2, 5])
|
|
print(ds[3:6, 2:8])
|
|
print(ds[:2, "structure_id"])
|
|
|
|
def test_add_rows(self):
|
|
"""Test adding one row and adding multiple rows"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
ds.add_row(list(ds.df.row(1)))
|
|
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
|
|
|
|
def test_times_triggered(self):
|
|
"""Check getting times triggered for a rule id"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds.times_triggered(rules[0].uuid))
|
|
|
|
def test_block_indices(self):
|
|
"""Test the usages of _block_indices"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds.struct_features())
|
|
print(ds.triggered())
|
|
print(ds.observed())
|
|
|
|
def test_structure_id(self):
|
|
"""Check getting a structure id from row index"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds.structure_id(0))
|
|
|
|
def test_x(self):
|
|
"""Test getting X portion of the dataframe"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds.X().df.head())
|
|
|
|
def test_trig(self):
|
|
"""Test getting the triggered portion of the dataframe"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds.trig().df.head())
|
|
|
|
def test_y(self):
|
|
"""Test getting the Y portion of the dataframe"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
print(ds.y().df.head())
|
|
|
|
def test_classification_dataset(self):
|
|
"""Test making the classification dataset"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
|
|
class_ds, products = ds.classification_dataset(compounds, rules)
|
|
print(class_ds.df.head(5))
|
|
print(products[:5])
|
|
|
|
def test_extra_features(self):
|
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
|
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
|
|
print(ds.shape)
|
|
|
|
def test_to_arff(self):
|
|
"""Test exporting the arff version of the dataset"""
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
ds.to_arff("dataset_arff_test.arff")
|
|
|
|
def test_save_load(self):
|
|
"""Test saving and loading dataset"""
|
|
with TemporaryDirectory() as tmpdir:
|
|
ds, reactions, rules = self.generate_rule_dataset()
|
|
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
|
|
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
|
|
self.assertTrue(ds.df.equals(ds_loaded.df))
|
|
|
|
def test_dataset_example(self):
|
|
"""Test with a concrete example checking dataset size"""
|
|
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
|
applicable_rules = [self.rule1]
|
|
|
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
|
|
|
self.assertEqual(len(ds.y()), 1)
|
|
self.assertEqual(ds.y().df.item(), 1)
|
|
|
|
def test_enviformer_dataset(self):
|
|
ds, reactions = self.generate_enviformer_dataset()
|
|
print(ds.X().head())
|
|
print(ds.y().head())
|
|
|
|
def generate_rule_dataset(self):
|
|
"""Generate a RuleBasedDataset from test package data"""
|
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
|
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
|
return ds, reactions, applicable_rules
|
|
|
|
def generate_enviformer_dataset(self):
|
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
|
ds = EnviFormerDataset.generate_dataset(reactions)
|
|
return ds, reactions
|