forked from enviPath/enviPy
[Feature] MultiGen Eval (Backend) (#117)
Fixes #16 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#117
This commit is contained in:
@ -1,13 +1,14 @@
|
||||
from django.test import TestCase
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import User, MLRelativeReasoning, Package
|
||||
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
|
||||
|
||||
|
||||
class ModelTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -17,28 +18,55 @@ class ModelTest(TestCase):
|
||||
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
|
||||
|
||||
def test_smoke(self):
|
||||
threshold = float(0.5)
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold,
|
||||
'ECC - BBD - 0.5',
|
||||
'Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold=threshold,
|
||||
name='ECC - BBD - 0.5',
|
||||
description='Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
print("Model built!")
|
||||
mod.evaluate_model()
|
||||
print("Model Evaluated")
|
||||
# mod = RuleBasedRelativeReasoning.create(
|
||||
# self.package,
|
||||
# rule_package_objs,
|
||||
# data_package_objs,
|
||||
# eval_packages_objs,
|
||||
# threshold=threshold,
|
||||
# min_count=5,
|
||||
# max_count=0,
|
||||
# name='ECC - BBD - 0.5',
|
||||
# description='Created MLRelativeReasoning in Testcase',
|
||||
# )
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
print(results)
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.multigen_eval = True
|
||||
mod.save()
|
||||
# mod.evaluate_model()
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
|
||||
products = dict()
|
||||
for r in results:
|
||||
for ps in r.product_sets:
|
||||
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
|
||||
|
||||
expected = {
|
||||
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
|
||||
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
|
||||
}
|
||||
|
||||
self.assertEqual(products, expected)
|
||||
|
||||
# from pprint import pprint
|
||||
# pprint(mod.eval_results)
|
||||
|
||||
Reference in New Issue
Block a user