forked from enviPath/enviPy
new RuleBasedDataset and EnviFormer dataset working for respective models #120
This commit is contained in:
@ -1,8 +1,9 @@
|
||||
import os.path
|
||||
from tempfile import TemporaryDirectory
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import Reaction, Compound, User, Rule, Package
|
||||
from utilities.ml import RuleBasedDataset
|
||||
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
||||
|
||||
|
||||
class DatasetTest(TestCase):
|
||||
@ -45,11 +46,11 @@ class DatasetTest(TestCase):
|
||||
|
||||
def test_generate_dataset(self):
|
||||
"""Test generating dataset does not crash"""
|
||||
self.generate_dataset()
|
||||
self.generate_rule_dataset()
|
||||
|
||||
def test_indexing(self):
|
||||
"""Test indexing a few different ways to check for crashes"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds[5])
|
||||
print(ds[2, 5])
|
||||
print(ds[3:6, 2:8])
|
||||
@ -57,45 +58,45 @@ class DatasetTest(TestCase):
|
||||
|
||||
def test_add_rows(self):
|
||||
"""Test adding one row and adding multiple rows"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
ds.add_row(list(ds.df.row(1)))
|
||||
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
|
||||
|
||||
def test_times_triggered(self):
|
||||
"""Check getting times triggered for a rule id"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.times_triggered(rules[0].uuid))
|
||||
|
||||
def test_block_indices(self):
|
||||
"""Test the usages of _block_indices"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.struct_features())
|
||||
print(ds.triggered())
|
||||
print(ds.observed())
|
||||
|
||||
def test_structure_id(self):
|
||||
"""Check getting a structure id from row index"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.structure_id(0))
|
||||
|
||||
def test_x(self):
|
||||
"""Test getting X portion of the dataframe"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.X().df.head())
|
||||
|
||||
def test_trig(self):
|
||||
"""Test getting the triggered portion of the dataframe"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.trig().df.head())
|
||||
|
||||
def test_y(self):
|
||||
"""Test getting the Y portion of the dataframe"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.y().df.head())
|
||||
|
||||
def test_classification_dataset(self):
|
||||
"""Test making the classification dataset"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
|
||||
class_ds, products = ds.classification_dataset(compounds, rules)
|
||||
print(class_ds.df.head(5))
|
||||
@ -103,12 +104,16 @@ class DatasetTest(TestCase):
|
||||
|
||||
def test_to_arff(self):
|
||||
"""Test exporting the arff version of the dataset"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
ds.to_arff("dataset_arff_test.arff")
|
||||
|
||||
def test_save_load(self):
|
||||
"""Test saving and loading dataset"""
|
||||
ds, reactions, rules = self.generate_dataset()
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||
self.assertTrue(ds.df.equals(ds_loaded.df))
|
||||
|
||||
def test_dataset_example(self):
|
||||
"""Test with a concrete example checking dataset size"""
|
||||
@ -120,9 +125,19 @@ class DatasetTest(TestCase):
|
||||
self.assertEqual(len(ds.y()), 1)
|
||||
self.assertEqual(ds.y().df.item(), 1)
|
||||
|
||||
def generate_dataset(self):
|
||||
def test_enviformer_dataset(self):
|
||||
ds, reactions = self.generate_enviformer_dataset()
|
||||
print(ds.X().head())
|
||||
print(ds.y().head())
|
||||
|
||||
def generate_rule_dataset(self):
|
||||
"""Generate a RuleBasedDataset from test package data"""
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||
return ds, reactions, applicable_rules
|
||||
|
||||
def generate_enviformer_dataset(self):
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||
ds = EnviFormerDataset.generate_dataset(reactions)
|
||||
return ds, reactions
|
||||
|
||||
Reference in New Issue
Block a user