Files
enviPy-bayer/tests/test_jobs.py
2025-12-15 08:48:28 +13:00

138 lines
4.8 KiB
Python

from datetime import datetime
from django.conf import settings as s
from django.test import TestCase, override_settings
from epdb.logic import PackageManager
from epdb.models import Pathway, User
Package = s.GET_PACKAGE_MODEL()
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models", CELERY_TASK_ALWAYS_EAGER=True)
class MultiGenTest(TestCase):
fixtures = ["test_fixtures_incl_model.jsonl.gz"]
@classmethod
def setUpClass(cls):
super(MultiGenTest, cls).setUpClass()
cls.user: "User" = User.objects.get(username="anonymous")
cls.package: "Package" = PackageManager.create_package(
cls.user, "Anon Test Package", "No Desc"
)
cls.BBD_SUBSET: "Package" = Package.objects.get(name="Fixtures")
# 1,1,1-Trichloroethane (an/aerobic)
cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)"
cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine"
def test_batch_predict(self):
from epdb.tasks import batch_predict
pred_data = [
["CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"],
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "Ibuprofen"],
]
batch_predict_setting = self.user.prediction_settings()
target_package = PackageManager.create_package(
self.user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
num_tps = 50
res = batch_predict(
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=num_tps,
)
self.assertTrue(res.startswith("Pathway URL,"))
# Min 3 lines (1 header, 2 root nodes)
self.assertGreaterEqual(len(res.split("\n")), 3)
self.assertEqual(target_package.pathways.count(), 2)
pw = target_package.pathways.first()
self.assertEqual(
pw.setting_with_overrides.max_depth,
f"{num_tps} (this is an override for this particular pathway)",
)
self.assertEqual(
pw.setting_with_overrides.max_nodes,
f"{num_tps} (this is an override for this particular pathway)",
)
def test_engineer_pathway(self):
from epdb.tasks import engineer_pathways
pw_to_engineer = Pathway.objects.get(name=self.PW_WITH_INTERMEDIATE_NAME)
engineered, predicted = engineer_pathways(
[pw_to_engineer.pk], self.user.prediction_settings().pk, self.package.pk
)
self.assertEqual(len(engineered), 1)
self.assertEqual(len(predicted), 1)
eng_pw = Pathway.objects.get(url=engineered[0])
for n in eng_pw.nodes:
if n.kv.get("is_engineered_intermediate"):
self.assertEqual(n.default_node_label.smiles, "CCO")
pw_to_engineer = Pathway.objects.get(name=self.PW_WITHOUT_INTERMEDIATE_NAME)
engineered, predicted = engineer_pathways(
[pw_to_engineer.pk], self.user.prediction_settings().pk, self.package.pk
)
self.assertEqual(len(engineered), 0)
self.assertEqual(len(predicted), 0)
# Test pathway deduplication in eng pathway process
pw1 = Pathway.objects.get(name=self.PW_WITH_INTERMEDIATE_NAME)
# Add pw1 twice
engineered, predicted = engineer_pathways(
[pw1.pk, pw1.pk], self.user.prediction_settings().pk, self.package.pk
)
self.assertEqual(len(engineered), 1)
self.assertEqual(len(predicted), 1)
# Check that both pathways contain the intermediate
num_intermediates_found = 0
for eng in engineered:
eng_pw = Pathway.objects.get(url=eng)
for n in eng_pw.nodes:
if n.kv.get("is_engineered_intermediate"):
self.assertEqual(n.default_node_label.smiles, "CCO")
num_intermediates_found += 1
self.assertEqual(num_intermediates_found, 1)
# Get a copy to have two pathways with potential intermediates as the fixture
# only contains one
mapping = {}
pw2 = pw1.copy(self.package, mapping=mapping)
engineered, predicted = engineer_pathways(
[pw1.pk, pw2.pk], self.user.prediction_settings().pk, self.package.pk
)
self.assertEqual(len(engineered), 2)
self.assertEqual(len(predicted), 2)
# Check that both pathways contain the intermediate
num_intermediates_found = 0
for eng in engineered:
eng_pw = Pathway.objects.get(url=eng)
for n in eng_pw.nodes:
if n.kv.get("is_engineered_intermediate"):
self.assertEqual(n.default_node_label.smiles, "CCO")
num_intermediates_found += 1
self.assertEqual(num_intermediates_found, 2)