forked from enviPath/enviPy
[Feature] MultiGen Eval (Backend) (#117)
Fixes #16 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#117
This commit is contained in:
@ -5,7 +5,7 @@ from epdb.models import Compound, User, CompoundStructure
|
||||
|
||||
|
||||
class CompoundTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
@ -7,7 +7,7 @@ from epdb.models import Compound, User, Reaction
|
||||
|
||||
|
||||
class CopyTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -6,7 +6,7 @@ from utilities.ml import Dataset
|
||||
|
||||
|
||||
class DatasetTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
def setUp(self):
|
||||
self.cs1 = Compound.create(
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
from django.test import TestCase
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import User, MLRelativeReasoning, Package
|
||||
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
|
||||
|
||||
|
||||
class ModelTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -17,28 +18,55 @@ class ModelTest(TestCase):
|
||||
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
|
||||
|
||||
def test_smoke(self):
|
||||
threshold = float(0.5)
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold,
|
||||
'ECC - BBD - 0.5',
|
||||
'Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold=threshold,
|
||||
name='ECC - BBD - 0.5',
|
||||
description='Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
print("Model built!")
|
||||
mod.evaluate_model()
|
||||
print("Model Evaluated")
|
||||
# mod = RuleBasedRelativeReasoning.create(
|
||||
# self.package,
|
||||
# rule_package_objs,
|
||||
# data_package_objs,
|
||||
# eval_packages_objs,
|
||||
# threshold=threshold,
|
||||
# min_count=5,
|
||||
# max_count=0,
|
||||
# name='ECC - BBD - 0.5',
|
||||
# description='Created MLRelativeReasoning in Testcase',
|
||||
# )
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
print(results)
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.multigen_eval = True
|
||||
mod.save()
|
||||
# mod.evaluate_model()
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
|
||||
products = dict()
|
||||
for r in results:
|
||||
for ps in r.product_sets:
|
||||
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
|
||||
|
||||
expected = {
|
||||
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
|
||||
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
|
||||
}
|
||||
|
||||
self.assertEqual(products, expected)
|
||||
|
||||
# from pprint import pprint
|
||||
# pprint(mod.eval_results)
|
||||
|
||||
137
tests/test_multigen_eval.py
Normal file
137
tests/test_multigen_eval.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
from django.test import TestCase
|
||||
from networkx.utils.misc import graphs_equal
|
||||
from epdb.logic import PackageManager, SPathway
|
||||
from epdb.models import Pathway, User, Package
|
||||
from utilities.ml import multigen_eval, pathway_edit_eval, graph_from_pathway
|
||||
|
||||
|
||||
class MultiGenTest(TestCase):
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(MultiGenTest, cls).setUpClass()
|
||||
cls.user: 'User' = User.objects.get(username='anonymous')
|
||||
cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
|
||||
cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures')
|
||||
|
||||
def test_equal_pathways(self):
|
||||
"""Test that two identical pathways return a precision and recall of 1.0"""
|
||||
pathways = self.BBD_SUBSET.pathways.all()
|
||||
for pathway in pathways:
|
||||
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||
continue
|
||||
score, precision, recall = multigen_eval(pathway, pathway)
|
||||
self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. "
|
||||
f"Failed on pathway: {pathway.name}")
|
||||
self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. "
|
||||
f"Failed on pathway: {pathway.name}")
|
||||
|
||||
def test_intermediates(self):
|
||||
"""Test that an intermediate can be correctly identified and the metrics are correctly adjusted"""
|
||||
score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True)
|
||||
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
|
||||
self.assertEqual(precision, 1, "Precision should be 1")
|
||||
self.assertEqual(recall, 1, "Recall should be 1")
|
||||
|
||||
def test_fp(self):
|
||||
"""Test that a false-positive (extra compound) is correctly penalised"""
|
||||
score, precision, recall = multigen_eval(*self.fp_case())
|
||||
self.assertAlmostEqual(precision, 0.75, 3, "Precision should be 0.75")
|
||||
self.assertEqual(recall, 1, "Recall should be 1")
|
||||
|
||||
def test_fn(self):
|
||||
"""Test that a false-negative (missed compound) is correctly penalised"""
|
||||
score, precision, recall = multigen_eval(*self.fn_case())
|
||||
self.assertEqual(precision, 1, "Precision should be 1.0")
|
||||
self.assertAlmostEqual(recall, 0.667, 3, "Recall should be 0.667")
|
||||
|
||||
def test_all(self):
|
||||
"""Test an intermediate, false-positive and false-negative together"""
|
||||
score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True)
|
||||
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
|
||||
self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6")
|
||||
self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75")
|
||||
|
||||
def test_shallow_pathway(self):
|
||||
pathways = self.BBD_SUBSET.pathways.all()
|
||||
for pathway in pathways:
|
||||
pathway_name = pathway.name
|
||||
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||
continue
|
||||
shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway))
|
||||
pathway = graph_from_pathway(pathway)
|
||||
if not graphs_equal(shallow_pathway, pathway):
|
||||
print('\n\nS', shallow_pathway.adj)
|
||||
print('\n\nPW', pathway.adj)
|
||||
# print(shallow_pathway.nodes, pathway.nodes)
|
||||
# print(shallow_pathway.graph, pathway.graph)
|
||||
|
||||
self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not "
|
||||
f"equal to pathway for pathway {pathway.name}")
|
||||
|
||||
def test_graph_edit_eval(self):
|
||||
"""Performs all the previous tests but with graph_edit_eval
|
||||
Unlike multigen_eval, these test cases have not been hand verified"""
|
||||
pathways = self.BBD_SUBSET.pathways.all()
|
||||
for pathway in pathways:
|
||||
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||
continue
|
||||
score = pathway_edit_eval(pathway, pathway)
|
||||
self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. "
|
||||
f"Failed on pathway: {pathway.name}")
|
||||
inter_score = pathway_edit_eval(*self.intermediate_case())
|
||||
self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case")
|
||||
fp_score = pathway_edit_eval(*self.fp_case())
|
||||
self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case")
|
||||
fn_score = pathway_edit_eval(*self.fn_case())
|
||||
self.assertAlmostEqual(fn_score, 1.25, 3, "Pathway edit distance failed on fn case")
|
||||
all_score = pathway_edit_eval(*self.all_case())
|
||||
self.assertAlmostEqual(all_score, 1.0, 3, "Pathway edit distance failed on all case")
|
||||
|
||||
def intermediate_case(self):
|
||||
"""Create an example with an intermediate in the predicted pathway"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
|
||||
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
|
||||
return true_pathway, pred_pathway
|
||||
|
||||
def fp_case(self):
|
||||
"""Create an example with an extra compound in the predicted pathway"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
|
||||
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)])
|
||||
return true_pathway, pred_pathway
|
||||
|
||||
def fn_case(self):
|
||||
"""Create an example with a missing compound in the predicted pathway"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)])
|
||||
return true_pathway, pred_pathway
|
||||
|
||||
def all_case(self):
|
||||
"""Create an example with an intermediate, extra compound and missing compound"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)])
|
||||
pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)])
|
||||
pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)])
|
||||
return true_pathway, pred_pathway
|
||||
@ -5,7 +5,7 @@ from epdb.models import Compound, User, Reaction, Rule
|
||||
|
||||
|
||||
class ReactionTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -5,10 +5,7 @@ from epdb.models import Rule, User
|
||||
|
||||
|
||||
class RuleTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -7,7 +7,7 @@ from epdb.models import User, SimpleAmbitRule
|
||||
|
||||
|
||||
class SimpleAmbitRuleTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user