import os.path from tempfile import TemporaryDirectory from django.test import TestCase from epdb.logic import PackageManager from epdb.models import Reaction, Compound, User, Rule, Package from utilities.ml import RuleBasedDataset, EnviFormerDataset class DatasetTest(TestCase): fixtures = ["test_fixtures.jsonl.gz"] def setUp(self): self.cs1 = Compound.create( self.package, name="2,6-Dibromohydroquinone", description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b", smiles="C1=C(C(=C(C=C1O)Br)O)Br", ).default_structure self.cs2 = Compound.create( self.package, smiles="O=C(O)CC(=O)/C=C(/Br)C(=O)O", ).default_structure self.rule1 = Rule.create( rule_type="SimpleAmbitRule", package=self.package, smirks="[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\\[#6:3]=[#6:2](\\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]", description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6", ) self.reaction1 = Reaction.create( package=self.package, educts=[self.cs1], products=[self.cs2], rules=[self.rule1], multi_step=False, ) @classmethod def setUpClass(cls): super(DatasetTest, cls).setUpClass() cls.user = User.objects.get(username="anonymous") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.BBD_SUBSET = Package.objects.get(name="Fixtures") def test_generate_dataset(self): """Test generating dataset does not crash""" self.generate_rule_dataset() def test_indexing(self): """Test indexing a few different ways to check for crashes""" ds, reactions, rules = self.generate_rule_dataset() print(ds[5]) print(ds[2, 5]) print(ds[3:6, 2:8]) print(ds[:2, "structure_id"]) def test_add_rows(self): """Test adding one row and adding multiple rows""" ds, reactions, rules = self.generate_rule_dataset() ds.add_row(list(ds.df.row(1))) ds.add_rows([list(ds.df.row(i)) for i in range(5)]) def test_times_triggered(self): """Check getting times triggered for a rule id""" ds, reactions, rules = self.generate_rule_dataset() print(ds.times_triggered(rules[0].uuid)) def test_block_indices(self): """Test the usages of _block_indices""" ds, reactions, rules = self.generate_rule_dataset() print(ds.struct_features()) print(ds.triggered()) print(ds.observed()) def test_structure_id(self): """Check getting a structure id from row index""" ds, reactions, rules = self.generate_rule_dataset() print(ds.structure_id(0)) def test_x(self): """Test getting X portion of the dataframe""" ds, reactions, rules = self.generate_rule_dataset() print(ds.X().df.head()) def test_trig(self): """Test getting the triggered portion of the dataframe""" ds, reactions, rules = self.generate_rule_dataset() print(ds.trig().df.head()) def test_y(self): """Test getting the Y portion of the dataframe""" ds, reactions, rules = self.generate_rule_dataset() print(ds.y().df.head()) def test_classification_dataset(self): """Test making the classification dataset""" ds, reactions, rules = self.generate_rule_dataset() compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)] class_ds, products = ds.classification_dataset(compounds, rules) print(class_ds.df.head(5)) print(products[:5]) def test_to_arff(self): """Test exporting the arff version of the dataset""" ds, reactions, rules = self.generate_rule_dataset() ds.to_arff("dataset_arff_test.arff") def test_save_load(self): """Test saving and loading dataset""" with TemporaryDirectory() as tmpdir: ds, reactions, rules = self.generate_rule_dataset() ds.save(os.path.join(tmpdir, "save_dataset.pkl")) ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl")) self.assertTrue(ds.df.equals(ds_loaded.df)) def test_dataset_example(self): """Test with a concrete example checking dataset size""" reactions = [r for r in Reaction.objects.filter(package=self.package)] applicable_rules = [self.rule1] ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) self.assertEqual(len(ds.y()), 1) self.assertEqual(ds.y().df.item(), 1) def test_enviformer_dataset(self): ds, reactions = self.generate_enviformer_dataset() print(ds.X().head()) print(ds.y().head()) def generate_rule_dataset(self): """Generate a RuleBasedDataset from test package data""" reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) return ds, reactions, applicable_rules def generate_enviformer_dataset(self): reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] ds = EnviFormerDataset.generate_dataset(reactions) return ds, reactions