diff --git a/epdb/models.py b/epdb/models.py index 4b8f4198..a03fcb6d 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -3043,9 +3043,9 @@ class EnviFormer(PackageBasedModel): @cached_property def model(self): from enviformer import load - ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt") - return load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt) + mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt) + return mod def predict(self, smiles) -> List["PredictionResult"]: return self.predict_batch([smiles])[0] @@ -3059,8 +3059,10 @@ class EnviFormer(PackageBasedModel): for smiles in smiles_list ] logger.info(f"Submitting {canon_smiles} to {self.name}") + start = datetime.now() products_list = self.model.predict_batch(canon_smiles) - logger.info(f"Got results {products_list}") + end = datetime.now() + logger.info(f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}") results = [] for products in products_list: diff --git a/epdb/tasks.py b/epdb/tasks.py index aabaf8d1..b9845c86 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -1,12 +1,19 @@ import logging from typing import Optional - +from celery.utils.functional import LRUCache from celery import shared_task from epdb.models import Pathway, Node, EPModel, Setting from epdb.logic import SPathway logger = logging.getLogger(__name__) +ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. + + +def get_ml_model(model_pk: int): + if model_pk not in ML_CACHE: + ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk) + return ML_CACHE[model_pk] @shared_task(queue="background") @@ -16,7 +23,7 @@ def mul(a, b): @shared_task(queue="predict") def predict_simple(model_pk: int, smiles: str): - mod = EPModel.objects.get(id=model_pk) + mod = get_ml_model(model_pk) res = mod.predict(smiles) return res @@ -51,6 +58,9 @@ def predict( ) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) + # If the setting has a model add/restore it from the cache + if setting.model is not None: + setting.model = get_ml_model(setting.model.pk) pw.kv.update(**{"status": "running"}) pw.save() diff --git a/tests/test_enviformer.py b/tests/test_enviformer.py index 1a688cb1..b81ca2ca 100644 --- a/tests/test_enviformer.py +++ b/tests/test_enviformer.py @@ -1,7 +1,27 @@ +from collections import defaultdict +from datetime import datetime from tempfile import TemporaryDirectory from django.test import TestCase, tag from epdb.logic import PackageManager -from epdb.models import User, EnviFormer, Package +from epdb.models import User, EnviFormer, Package, Setting, Pathway +from epdb.tasks import predict_simple, predict + + +def measure_predict(mod, pathway_pk=None): + # Measure and return the prediction time + start = datetime.now() + if pathway_pk: + s = Setting() + s.model = mod + s.model_threshold = 0.2 + s.max_depth = 4 + s.max_nodes = 20 + s.save() + pred_result = predict.delay(pathway_pk, s.pk, limit=s.max_depth) + else: + pred_result = predict_simple.delay(mod.pk, "C1=CC=C(CSCC2=CC=CC=C2)C=C1") + _ = pred_result.get() + return round((datetime.now() - start).total_seconds(), 2) @tag("slow") @@ -33,3 +53,34 @@ class EnviFormerTest(TestCase): mod.evaluate_model() mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") + + def test_predict_runtime(self): + with TemporaryDirectory() as tmpdir: + with self.settings(MODEL_DIR=tmpdir): + threshold = float(0.5) + data_package_objs = [self.BBD_SUBSET] + eval_packages_objs = [self.BBD_SUBSET] + mods = [] + for _ in range(4): + mod = EnviFormer.create( + self.package, data_package_objs, eval_packages_objs, threshold=threshold + ) + mod.build_dataset() + mod.build_model() + mods.append(mod) + + # Test prediction time drops after first prediction + times = [measure_predict(mods[0]) for _ in range(5)] + print(f"First prediction took {times[0]} seconds, subsequent ones took {times[1:]}") + + # Test pathway prediction + times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)] + print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}") + + # Test eviction by performing three prediction with every model, twice. + times = defaultdict(list) + for _ in range(2): # Eviction should cause the second iteration here to have to reload the models + for mod in mods: + for _ in range(3): + times[mod.pk].append(measure_predict(mod)) + print(times)