[Feature] MultiGen Eval (Backend) (#117)

Fixes #16

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#117
This commit is contained in:
2025-09-18 18:40:45 +12:00
parent 762a6b7baf
commit 50db2fb372
24 changed files with 816 additions and 2137274 deletions

View File

@ -5,7 +5,7 @@ from epdb.models import Compound, User, CompoundStructure
class CompoundTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
def setUp(self):
pass

View File

@ -7,7 +7,7 @@ from epdb.models import Compound, User, Reaction
class CopyTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):

View File

@ -6,7 +6,7 @@ from utilities.ml import Dataset
class DatasetTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
def setUp(self):
self.cs1 = Compound.create(

View File

@ -1,13 +1,14 @@
from django.test import TestCase
from tempfile import TemporaryDirectory
import numpy as np
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, Package
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
class ModelTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
@ -17,28 +18,55 @@ class ModelTest(TestCase):
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
def test_smoke(self):
threshold = float(0.5)
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
# get Package objects from urls
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = []
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = []
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold,
'ECC - BBD - 0.5',
'Created MLRelativeReasoning in Testcase',
)
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold=threshold,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
)
mod.build_dataset()
mod.build_model()
print("Model built!")
mod.evaluate_model()
print("Model Evaluated")
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
print(results)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
# mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
products = dict()
for r in results:
for ps in r.product_sets:
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
expected = {
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
}
self.assertEqual(products, expected)
# from pprint import pprint
# pprint(mod.eval_results)

137
tests/test_multigen_eval.py Normal file
View File

@ -0,0 +1,137 @@
import json
from django.test import TestCase
from networkx.utils.misc import graphs_equal
from epdb.logic import PackageManager, SPathway
from epdb.models import Pathway, User, Package
from utilities.ml import multigen_eval, pathway_edit_eval, graph_from_pathway
class MultiGenTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
super(MultiGenTest, cls).setUpClass()
cls.user: 'User' = User.objects.get(username='anonymous')
cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures')
def test_equal_pathways(self):
"""Test that two identical pathways return a precision and recall of 1.0"""
pathways = self.BBD_SUBSET.pathways.all()
for pathway in pathways:
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue
score, precision, recall = multigen_eval(pathway, pathway)
self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. "
f"Failed on pathway: {pathway.name}")
self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. "
f"Failed on pathway: {pathway.name}")
def test_intermediates(self):
"""Test that an intermediate can be correctly identified and the metrics are correctly adjusted"""
score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True)
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
self.assertEqual(precision, 1, "Precision should be 1")
self.assertEqual(recall, 1, "Recall should be 1")
def test_fp(self):
"""Test that a false-positive (extra compound) is correctly penalised"""
score, precision, recall = multigen_eval(*self.fp_case())
self.assertAlmostEqual(precision, 0.75, 3, "Precision should be 0.75")
self.assertEqual(recall, 1, "Recall should be 1")
def test_fn(self):
"""Test that a false-negative (missed compound) is correctly penalised"""
score, precision, recall = multigen_eval(*self.fn_case())
self.assertEqual(precision, 1, "Precision should be 1.0")
self.assertAlmostEqual(recall, 0.667, 3, "Recall should be 0.667")
def test_all(self):
"""Test an intermediate, false-positive and false-negative together"""
score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True)
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6")
self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75")
def test_shallow_pathway(self):
pathways = self.BBD_SUBSET.pathways.all()
for pathway in pathways:
pathway_name = pathway.name
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue
shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway))
pathway = graph_from_pathway(pathway)
if not graphs_equal(shallow_pathway, pathway):
print('\n\nS', shallow_pathway.adj)
print('\n\nPW', pathway.adj)
# print(shallow_pathway.nodes, pathway.nodes)
# print(shallow_pathway.graph, pathway.graph)
self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not "
f"equal to pathway for pathway {pathway.name}")
def test_graph_edit_eval(self):
"""Performs all the previous tests but with graph_edit_eval
Unlike multigen_eval, these test cases have not been hand verified"""
pathways = self.BBD_SUBSET.pathways.all()
for pathway in pathways:
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue
score = pathway_edit_eval(pathway, pathway)
self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. "
f"Failed on pathway: {pathway.name}")
inter_score = pathway_edit_eval(*self.intermediate_case())
self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case")
fp_score = pathway_edit_eval(*self.fp_case())
self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case")
fn_score = pathway_edit_eval(*self.fn_case())
self.assertAlmostEqual(fn_score, 1.25, 3, "Pathway edit distance failed on fn case")
all_score = pathway_edit_eval(*self.all_case())
self.assertAlmostEqual(all_score, 1.0, 3, "Pathway edit distance failed on all case")
def intermediate_case(self):
"""Create an example with an intermediate in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
return true_pathway, pred_pathway
def fp_case(self):
"""Create an example with an extra compound in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)])
return true_pathway, pred_pathway
def fn_case(self):
"""Create an example with a missing compound in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)])
return true_pathway, pred_pathway
def all_case(self):
"""Create an example with an intermediate, extra compound and missing compound"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)])
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)])
pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)])
pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)])
return true_pathway, pred_pathway

View File

@ -5,7 +5,7 @@ from epdb.models import Compound, User, Reaction, Rule
class ReactionTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):

View File

@ -5,10 +5,7 @@ from epdb.models import Rule, User
class RuleTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
def setUp(self):
pass
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):

View File

@ -7,7 +7,7 @@ from epdb.models import User, SimpleAmbitRule
class SimpleAmbitRuleTest(TestCase):
fixtures = ["test_fixtures.json.gz"]
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):