[Feature] ML model caching for reducing prediction overhead (#156)

The caching is now finished. The cache is created in `settings.py` giving us the most flexibility for using it in the future.

The cache is currently updated/accessed by `tasks.py/get_ml_model` which can be called from whatever task needs to access ml models in this way (currently, `predict` and `predict_simple`).

This implementation currently caches all ml models including the relative reasoning. If we don't want this and only want to cache enviFormer, i can change it to that. However, I don't think there is a harm in having the other models be cached as well.

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#156
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-10-16 08:58:36 +13:00
committed by jebus
parent d5ebb23622
commit 376fd65785
3 changed files with 69 additions and 6 deletions

View File

@ -3043,9 +3043,9 @@ class EnviFormer(PackageBasedModel):
@cached_property
def model(self):
from enviformer import load
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
return load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
return mod
def predict(self, smiles) -> List["PredictionResult"]:
return self.predict_batch([smiles])[0]
@ -3059,8 +3059,10 @@ class EnviFormer(PackageBasedModel):
for smiles in smiles_list
]
logger.info(f"Submitting {canon_smiles} to {self.name}")
start = datetime.now()
products_list = self.model.predict_batch(canon_smiles)
logger.info(f"Got results {products_list}")
end = datetime.now()
logger.info(f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}")
results = []
for products in products_list:

View File

@ -1,12 +1,19 @@
import logging
from typing import Optional
from celery.utils.functional import LRUCache
from celery import shared_task
from epdb.models import Pathway, Node, EPModel, Setting
from epdb.logic import SPathway
logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
def get_ml_model(model_pk: int):
if model_pk not in ML_CACHE:
ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk)
return ML_CACHE[model_pk]
@shared_task(queue="background")
@ -16,7 +23,7 @@ def mul(a, b):
@shared_task(queue="predict")
def predict_simple(model_pk: int, smiles: str):
mod = EPModel.objects.get(id=model_pk)
mod = get_ml_model(model_pk)
res = mod.predict(smiles)
return res
@ -51,6 +58,9 @@ def predict(
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
# If the setting has a model add/restore it from the cache
if setting.model is not None:
setting.model = get_ml_model(setting.model.pk)
pw.kv.update(**{"status": "running"})
pw.save()