forked from enviPath/enviPy
[Feature] ML model caching for reducing prediction overhead (#156)
The caching is now finished. The cache is created in `settings.py` giving us the most flexibility for using it in the future. The cache is currently updated/accessed by `tasks.py/get_ml_model` which can be called from whatever task needs to access ml models in this way (currently, `predict` and `predict_simple`). This implementation currently caches all ml models including the relative reasoning. If we don't want this and only want to cache enviFormer, i can change it to that. However, I don't think there is a harm in having the other models be cached as well. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#156 Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
@ -3043,9 +3043,9 @@ class EnviFormer(PackageBasedModel):
|
|||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
from enviformer import load
|
from enviformer import load
|
||||||
|
|
||||||
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
||||||
return load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
||||||
|
return mod
|
||||||
|
|
||||||
def predict(self, smiles) -> List["PredictionResult"]:
|
def predict(self, smiles) -> List["PredictionResult"]:
|
||||||
return self.predict_batch([smiles])[0]
|
return self.predict_batch([smiles])[0]
|
||||||
@ -3059,8 +3059,10 @@ class EnviFormer(PackageBasedModel):
|
|||||||
for smiles in smiles_list
|
for smiles in smiles_list
|
||||||
]
|
]
|
||||||
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
||||||
|
start = datetime.now()
|
||||||
products_list = self.model.predict_batch(canon_smiles)
|
products_list = self.model.predict_batch(canon_smiles)
|
||||||
logger.info(f"Got results {products_list}")
|
end = datetime.now()
|
||||||
|
logger.info(f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for products in products_list:
|
for products in products_list:
|
||||||
|
|||||||
@ -1,12 +1,19 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from celery.utils.functional import LRUCache
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from epdb.models import Pathway, Node, EPModel, Setting
|
from epdb.models import Pathway, Node, EPModel, Setting
|
||||||
from epdb.logic import SPathway
|
from epdb.logic import SPathway
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
|
||||||
|
|
||||||
|
|
||||||
|
def get_ml_model(model_pk: int):
|
||||||
|
if model_pk not in ML_CACHE:
|
||||||
|
ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk)
|
||||||
|
return ML_CACHE[model_pk]
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="background")
|
@shared_task(queue="background")
|
||||||
@ -16,7 +23,7 @@ def mul(a, b):
|
|||||||
|
|
||||||
@shared_task(queue="predict")
|
@shared_task(queue="predict")
|
||||||
def predict_simple(model_pk: int, smiles: str):
|
def predict_simple(model_pk: int, smiles: str):
|
||||||
mod = EPModel.objects.get(id=model_pk)
|
mod = get_ml_model(model_pk)
|
||||||
res = mod.predict(smiles)
|
res = mod.predict(smiles)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -51,6 +58,9 @@ def predict(
|
|||||||
) -> Pathway:
|
) -> Pathway:
|
||||||
pw = Pathway.objects.get(id=pw_pk)
|
pw = Pathway.objects.get(id=pw_pk)
|
||||||
setting = Setting.objects.get(id=pred_setting_pk)
|
setting = Setting.objects.get(id=pred_setting_pk)
|
||||||
|
# If the setting has a model add/restore it from the cache
|
||||||
|
if setting.model is not None:
|
||||||
|
setting.model = get_ml_model(setting.model.pk)
|
||||||
|
|
||||||
pw.kv.update(**{"status": "running"})
|
pw.kv.update(**{"status": "running"})
|
||||||
pw.save()
|
pw.save()
|
||||||
|
|||||||
@ -1,7 +1,27 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from django.test import TestCase, tag
|
from django.test import TestCase, tag
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import User, EnviFormer, Package
|
from epdb.models import User, EnviFormer, Package, Setting, Pathway
|
||||||
|
from epdb.tasks import predict_simple, predict
|
||||||
|
|
||||||
|
|
||||||
|
def measure_predict(mod, pathway_pk=None):
|
||||||
|
# Measure and return the prediction time
|
||||||
|
start = datetime.now()
|
||||||
|
if pathway_pk:
|
||||||
|
s = Setting()
|
||||||
|
s.model = mod
|
||||||
|
s.model_threshold = 0.2
|
||||||
|
s.max_depth = 4
|
||||||
|
s.max_nodes = 20
|
||||||
|
s.save()
|
||||||
|
pred_result = predict.delay(pathway_pk, s.pk, limit=s.max_depth)
|
||||||
|
else:
|
||||||
|
pred_result = predict_simple.delay(mod.pk, "C1=CC=C(CSCC2=CC=CC=C2)C=C1")
|
||||||
|
_ = pred_result.get()
|
||||||
|
return round((datetime.now() - start).total_seconds(), 2)
|
||||||
|
|
||||||
|
|
||||||
@tag("slow")
|
@tag("slow")
|
||||||
@ -33,3 +53,34 @@ class EnviFormerTest(TestCase):
|
|||||||
mod.evaluate_model()
|
mod.evaluate_model()
|
||||||
|
|
||||||
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
|
def test_predict_runtime(self):
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
|
threshold = float(0.5)
|
||||||
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
mods = []
|
||||||
|
for _ in range(4):
|
||||||
|
mod = EnviFormer.create(
|
||||||
|
self.package, data_package_objs, eval_packages_objs, threshold=threshold
|
||||||
|
)
|
||||||
|
mod.build_dataset()
|
||||||
|
mod.build_model()
|
||||||
|
mods.append(mod)
|
||||||
|
|
||||||
|
# Test prediction time drops after first prediction
|
||||||
|
times = [measure_predict(mods[0]) for _ in range(5)]
|
||||||
|
print(f"First prediction took {times[0]} seconds, subsequent ones took {times[1:]}")
|
||||||
|
|
||||||
|
# Test pathway prediction
|
||||||
|
times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)]
|
||||||
|
print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}")
|
||||||
|
|
||||||
|
# Test eviction by performing three prediction with every model, twice.
|
||||||
|
times = defaultdict(list)
|
||||||
|
for _ in range(2): # Eviction should cause the second iteration here to have to reload the models
|
||||||
|
for mod in mods:
|
||||||
|
for _ in range(3):
|
||||||
|
times[mod.pk].append(measure_predict(mod))
|
||||||
|
print(times)
|
||||||
|
|||||||
Reference in New Issue
Block a user