Files
enviPy-bayer/tests/test_dataset.py
2025-11-03 15:24:28 +13:00

129 lines
4.9 KiB
Python

from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.ml import RuleBasedDataset
class DatasetTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"]
def setUp(self):
self.cs1 = Compound.create(
self.package,
name="2,6-Dibromohydroquinone",
description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b",
smiles="C1=C(C(=C(C=C1O)Br)O)Br",
).default_structure
self.cs2 = Compound.create(
self.package,
smiles="O=C(O)CC(=O)/C=C(/Br)C(=O)O",
).default_structure
self.rule1 = Rule.create(
rule_type="SimpleAmbitRule",
package=self.package,
smirks="[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\\[#6:3]=[#6:2](\\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]",
description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6",
)
self.reaction1 = Reaction.create(
package=self.package,
educts=[self.cs1],
products=[self.cs2],
rules=[self.rule1],
multi_step=False,
)
@classmethod
def setUpClass(cls):
super(DatasetTest, cls).setUpClass()
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_generate_dataset(self):
"""Test generating dataset does not crash"""
self.generate_dataset()
def test_indexing(self):
"""Test indexing a few different ways to check for crashes"""
ds, reactions, rules = self.generate_dataset()
print(ds[5])
print(ds[2, 5])
print(ds[3:6, 2:8])
print(ds[:2, "structure_id"])
def test_add_rows(self):
"""Test adding one row and adding multiple rows"""
ds, reactions, rules = self.generate_dataset()
ds.add_row(list(ds.df.row(1)))
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
def test_times_triggered(self):
"""Check getting times triggered for a rule id"""
ds, reactions, rules = self.generate_dataset()
print(ds.times_triggered(rules[0].uuid))
def test_block_indices(self):
"""Test the usages of _block_indices"""
ds, reactions, rules = self.generate_dataset()
print(ds.struct_features())
print(ds.triggered())
print(ds.observed())
def test_structure_id(self):
"""Check getting a structure id from row index"""
ds, reactions, rules = self.generate_dataset()
print(ds.structure_id(0))
def test_x(self):
"""Test getting X portion of the dataframe"""
ds, reactions, rules = self.generate_dataset()
print(ds.X().df.head())
def test_trig(self):
"""Test getting the triggered portion of the dataframe"""
ds, reactions, rules = self.generate_dataset()
print(ds.trig().df.head())
def test_y(self):
"""Test getting the Y portion of the dataframe"""
ds, reactions, rules = self.generate_dataset()
print(ds.y().df.head())
def test_classification_dataset(self):
"""Test making the classification dataset"""
ds, reactions, rules = self.generate_dataset()
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
class_ds, products = ds.classification_dataset(compounds, rules)
print(class_ds.df.head(5))
print(products[:5])
def test_to_arff(self):
"""Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_dataset()
ds.to_arff("dataset_arff_test.arff")
def test_save_load(self):
"""Test saving and loading dataset"""
ds, reactions, rules = self.generate_dataset()
def test_dataset_example(self):
"""Test with a concrete example checking dataset size"""
reactions = [r for r in Reaction.objects.filter(package=self.package)]
applicable_rules = [self.rule1]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
self.assertEqual(len(ds.y()), 1)
self.assertEqual(ds.y().df.item(), 1)
def generate_dataset(self):
"""Generate a RuleBasedDataset from test package data"""
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
return ds, reactions, applicable_rules