[Feature] MultiGen Eval (Backend) (#117)

Fixes #16

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#117
This commit is contained in:
2025-09-18 18:40:45 +12:00
parent 762a6b7baf
commit 50db2fb372
24 changed files with 816 additions and 2137274 deletions

View File

@ -1,13 +1,14 @@
from django.test import TestCase
from tempfile import TemporaryDirectory
import numpy as np
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, Package
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
class ModelTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
@ -17,28 +18,55 @@ class ModelTest(TestCase):
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
def test_smoke(self):
threshold = float(0.5)
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
# get Package objects from urls
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = []
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = []
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold,
'ECC - BBD - 0.5',
'Created MLRelativeReasoning in Testcase',
)
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold=threshold,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
)
mod.build_dataset()
mod.build_model()
print("Model built!")
mod.evaluate_model()
print("Model Evaluated")
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
print(results)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
# mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
products = dict()
for r in results:
for ps in r.product_sets:
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
expected = {
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
}
self.assertEqual(products, expected)
# from pprint import pprint
# pprint(mod.eval_results)