from collections import defaultdict from datetime import datetime from tempfile import TemporaryDirectory from django.conf import settings as s from django.test import TestCase, tag from epdb.logic import PackageManager from epdb.models import EnviFormer, Setting, User from epdb.tasks import predict, predict_simple Package = s.GET_PACKAGE_MODEL() def measure_predict(mod, pathway_pk=None): # Measure and return the prediction time start = datetime.now() if pathway_pk: s = Setting() s.model = mod s.model_threshold = 0.2 s.max_depth = 4 s.max_nodes = 20 s.save() pred_result = predict.delay(pathway_pk, s.pk, limit=s.max_depth) else: pred_result = predict_simple.delay(mod.pk, "C1=CC=C(CSCC2=CC=CC=C2)C=C1") _ = pred_result.get() return round((datetime.now() - start).total_seconds(), 2) @tag("slow") class EnviFormerTest(TestCase): fixtures = ["test_fixtures.jsonl.gz"] @classmethod def setUpClass(cls): super(EnviFormerTest, cls).setUpClass() cls.user = User.objects.get(username="anonymous") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.BBD_SUBSET = Package.objects.get(name="Fixtures") def test_model_flow(self): """Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference""" with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] mod = EnviFormer.create( self.package, data_package_objs, eval_packages_objs, threshold=threshold ) mod.build_dataset() mod.build_model() mod.evaluate_model(True, eval_packages_objs) mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") def test_predict_runtime(self): with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] mods = [] for _ in range(4): mod = EnviFormer.create( self.package, data_package_objs, eval_packages_objs, threshold=threshold ) mod.build_dataset() mod.build_model() mods.append(mod) # Test prediction time drops after first prediction times = [measure_predict(mods[0]) for _ in range(5)] print(f"First prediction took {times[0]} seconds, subsequent ones took {times[1:]}") # Test pathway prediction times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)] print( f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}" ) # Test eviction by performing three prediction with every model, twice. times = defaultdict(list) for _ in range( 2 ): # Eviction should cause the second iteration here to have to reload the models for mod in mods: for _ in range(3): times[mod.pk].append(measure_predict(mod)) print(times)