Files
enviPy-bayer/tests/test_model.py
2025-10-27 22:34:05 +13:00

73 lines
2.5 KiB
Python

from tempfile import TemporaryDirectory
import numpy as np
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, Package
class ModelTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
super(ModelTest, cls).setUpClass()
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
threshold=threshold,
name="ECC - BBD - 0.5",
description="Created MLRelativeReasoning in Testcase",
)
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
products = dict()
for r in results:
for ps in r.product_sets:
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
expected = {
("CC=O", "CCNC(=O)C1=CC(C)=CC=C1"): (
"bt0243-4301",
np.float64(0.33333333333333337),
),
("CC1=CC=CC(C(=O)O)=C1", "CCNCC"): ("bt0430-4011", np.float64(0.25)),
}
self.assertEqual(products, expected)
# from pprint import pprint
# pprint(mod.eval_results)