[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1,3 +1,5 @@
from datetime import datetime
from django.conf import settings as s
from django.test import TestCase, override_settings
@ -23,6 +25,46 @@ class MultiGenTest(TestCase):
cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)"
cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine"
def test_batch_predict(self):
from epdb.tasks import batch_predict
pred_data = [
["CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"],
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "Ibuprofen"],
]
batch_predict_setting = self.user.prediction_settings()
target_package = PackageManager.create_package(
self.user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
num_tps = 50
res = batch_predict(
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=num_tps,
)
self.assertTrue(res.startswith("Pathway URL,"))
# Min 3 lines (1 header, 2 root nodes)
self.assertGreaterEqual(len(res.split("\n")), 3)
self.assertEqual(target_package.pathways.count(), 2)
pw = target_package.pathways.first()
self.assertEqual(
pw.setting_with_overrides.max_depth,
f"{num_tps} (this is an override for this particular pathway)",
)
self.assertEqual(
pw.setting_with_overrides.max_nodes,
f"{num_tps} (this is an override for this particular pathway)",
)
def test_engineer_pathway(self):
from epdb.tasks import engineer_pathways

View File

@ -1,20 +1,20 @@
from unittest.mock import Mock, patch
from django.test import TestCase
from epdb.logic import SNode, SEdge
from epdb.logic import SEdge, SNode, SPathway
from epdb.models import Pathway, Setting
from utilities.chem import PredictionResult, ProductSet
class SObjectTest(TestCase):
def setUp(self):
pass
class SNodeTest(TestCase):
def test_snode_eq(self):
snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
assert snode1 == snode2
self.assertEqual(snode1, snode2)
def test_snode_hash(self):
pass
class SEdgeTest(TestCase):
def test_sedge_eq(self):
sedge1 = SEdge(
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
@ -26,4 +26,62 @@ class SObjectTest(TestCase):
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
rule=None,
)
assert sedge1 == sedge2
self.assertEqual(sedge1, sedge2)
class SPathwayTest(TestCase):
def setUp(self):
"""Set up test data for SPathway tests."""
self.test_smiles = "CCN(CC)C(=O)C1=CC(=CC=C1)CO"
self.mock_setting = Mock(spec=Setting)
self.mock_pathway = Mock(spec=Pathway)
def test_predict_step_basic(self):
"""Test basic predict_step functionality."""
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
# e.g. bt0002
pr = PredictionResult(
[
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
],
0.17,
None,
)
with patch.object(self.mock_setting, "expand", return_value=[pr]):
spw.predict_step(from_depth=0)
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
self.assertEqual(len(spw.edges), 3)
def test_to_json(self):
"""Test basic predict_step functionality."""
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
# e.g. bt0002
pr = PredictionResult(
[
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
],
0.17,
None,
)
with patch.object(self.mock_setting, "expand", return_value=[pr]):
spw.predict_step(from_depth=0)
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
self.assertEqual(len(spw.edges), 3)
json_result = spw.to_json()
self.assertIsInstance(json_result, dict)
self.assertIn("nodes", json_result)
self.assertIn("edges", json_result)
self.assertEqual(len(json_result["nodes"]), 4)
self.assertEqual(len(json_result["edges"]), 3)