from django.test import TestCase from networkx.utils.misc import graphs_equal from epdb.logic import PackageManager, SPathway from epdb.models import Pathway, User, Package from utilities.ml import multigen_eval, pathway_edit_eval, graph_from_pathway class MultiGenTest(TestCase): fixtures = ["test_fixtures.jsonl.gz"] @classmethod def setUpClass(cls): super(MultiGenTest, cls).setUpClass() cls.user: "User" = User.objects.get(username="anonymous") cls.package: "Package" = PackageManager.create_package( cls.user, "Anon Test Package", "No Desc" ) cls.BBD_SUBSET: "Package" = Package.objects.get(name="Fixtures") def test_equal_pathways(self): """Test that two identical pathways return a precision and recall of 1.0""" pathways = self.BBD_SUBSET.pathways.all() for pathway in pathways: if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges continue score, precision, recall = multigen_eval(pathway, pathway) self.assertEqual( precision, 1.0, f"Precision should be one for identical pathways. " f"Failed on pathway: {pathway.name}", ) self.assertEqual( recall, 1.0, f"Recall should be one for identical pathways. Failed on pathway: {pathway.name}", ) def test_intermediates(self): """Test that an intermediate can be correctly identified and the metrics are correctly adjusted""" score, precision, recall, intermediates = multigen_eval( *self.intermediate_case(), return_intermediates=True ) self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate") self.assertEqual(precision, 1, "Precision should be 1") self.assertEqual(recall, 1, "Recall should be 1") def test_fp(self): """Test that a false-positive (extra compound) is correctly penalised""" score, precision, recall = multigen_eval(*self.fp_case()) self.assertAlmostEqual(precision, 0.75, 3, "Precision should be 0.75") self.assertEqual(recall, 1, "Recall should be 1") def test_fn(self): """Test that a false-negative (missed compound) is correctly penalised""" score, precision, recall = multigen_eval(*self.fn_case()) self.assertEqual(precision, 1, "Precision should be 1.0") self.assertAlmostEqual(recall, 0.667, 3, "Recall should be 0.667") def test_all(self): """Test an intermediate, false-positive and false-negative together""" score, precision, recall, intermediates = multigen_eval( *self.all_case(), return_intermediates=True ) self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate") self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6") self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75") def test_shallow_pathway(self): pathways = self.BBD_SUBSET.pathways.all() for pathway in pathways: if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges continue shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway)) pathway = graph_from_pathway(pathway) if not graphs_equal(shallow_pathway, pathway): print("\n\nS", shallow_pathway.adj) print("\n\nPW", pathway.adj) # print(shallow_pathway.nodes, pathway.nodes) # print(shallow_pathway.graph, pathway.graph) self.assertTrue( graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not " f"equal to pathway for pathway {pathway.name}", ) def test_graph_edit_eval(self): """Performs all the previous tests but with graph_edit_eval Unlike multigen_eval, these test cases have not been hand verified""" pathways = self.BBD_SUBSET.pathways.all() for pathway in pathways: if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges continue score = pathway_edit_eval(pathway, pathway) self.assertEqual( score, 0.0, "Pathway edit distance should be zero for identical pathways. " f"Failed on pathway: {pathway.name}", ) inter_score = pathway_edit_eval(*self.intermediate_case()) self.assertAlmostEqual( inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case" ) fp_score = pathway_edit_eval(*self.fp_case()) self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case") fn_score = pathway_edit_eval(*self.fn_case()) self.assertAlmostEqual(fn_score, 1.25, 3, "Pathway edit distance failed on fn case") all_score = pathway_edit_eval(*self.all_case()) self.assertAlmostEqual(all_score, 1.0, 3, "Pathway edit distance failed on all case") def intermediate_case(self): """Create an example with an intermediate in the predicted pathway""" true_pathway = Pathway.create(self.package, "CCO") true_pathway.add_edge( [true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)] ) pred_pathway = Pathway.create(self.package, "CCO") pred_pathway.add_edge( [pred_pathway.root_nodes.all()[0]], [acetaldehyde := pred_pathway.add_node("CC=O", depth=1)], ) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)]) return true_pathway, pred_pathway def fp_case(self): """Create an example with an extra compound in the predicted pathway""" true_pathway = Pathway.create(self.package, "CCO") true_pathway.add_edge( [true_pathway.root_nodes.all()[0]], [acetaldehyde := true_pathway.add_node("CC=O", depth=1)], ) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway = Pathway.create(self.package, "CCO") pred_pathway.add_edge( [pred_pathway.root_nodes.all()[0]], [acetaldehyde := pred_pathway.add_node("CC=O", depth=1)], ) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)]) return true_pathway, pred_pathway def fn_case(self): """Create an example with a missing compound in the predicted pathway""" true_pathway = Pathway.create(self.package, "CCO") true_pathway.add_edge( [true_pathway.root_nodes.all()[0]], [acetaldehyde := true_pathway.add_node("CC=O", depth=1)], ) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway = Pathway.create(self.package, "CCO") pred_pathway.add_edge( [pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)] ) return true_pathway, pred_pathway def all_case(self): """Create an example with an intermediate, extra compound and missing compound""" true_pathway = Pathway.create(self.package, "CCO") true_pathway.add_edge( [true_pathway.root_nodes.all()[0]], [acetaldehyde := true_pathway.add_node("CC=O", depth=1)], ) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)]) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway = Pathway.create(self.package, "CCO") pred_pathway.add_edge( [pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)] ) pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)]) pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)]) return true_pathway, pred_pathway