from tempfile import TemporaryDirectory import numpy as np from django.test import TestCase from epdb.logic import PackageManager from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning class ModelTest(TestCase): fixtures = ["test_fixtures.jsonl.gz"] @classmethod def setUpClass(cls): super(ModelTest, cls).setUpClass() cls.user = User.objects.get(username="anonymous") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.BBD_SUBSET = Package.objects.get(name="Fixtures") def test_mlrr(self): with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) rule_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] mod = MLRelativeReasoning.create( self.package, rule_package_objs, data_package_objs, eval_packages_objs, threshold=threshold, name="ECC - BBD - 0.5", description="Created MLRelativeReasoning in Testcase", ) mod.build_dataset() mod.build_model() mod.multigen_eval = True mod.save() mod.evaluate_model(n_splits=2) results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") products = dict() for r in results: for ps in r.product_sets: products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability) expected = { ("CC=O", "CCNC(=O)C1=CC(C)=CC=C1"): ( "bt0243-4301", np.float64(0.33333333333333337), ), ("CC1=CC=CC(C(=O)O)=C1", "CCNCC"): ("bt0430-4011", np.float64(0.25)), } self.assertEqual(products, expected) # from pprint import pprint # pprint(mod.eval_results) def test_rbrr(self): with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) rule_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] mod = RuleBasedRelativeReasoning.create( self.package, rule_package_objs, data_package_objs, eval_packages_objs, threshold=threshold, min_count=5, max_count=0, name='ECC - BBD - 0.5', description='Created MLRelativeReasoning in Testcase', ) mod.build_dataset() mod.build_model() mod.multigen_eval = True mod.save() mod.evaluate_model(n_splits=2) results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")