import logging from datetime import datetime from typing import Callable, Optional from uuid import uuid4 from celery import shared_task from celery.utils.functional import LRUCache from epdb.logic import SPathway from epdb.models import EPModel, JobLog, Node, Package, Pathway, Setting, User logger = logging.getLogger(__name__) ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. def get_ml_model(model_pk: int): if model_pk not in ML_CACHE: ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk) return ML_CACHE[model_pk] def dispatch_eager(user: "User", job: Callable, *args, **kwargs): try: x = job(*args, **kwargs) log = JobLog() log.user = user log.task_id = uuid4() log.job_name = job.__name__ log.status = "SUCCESS" log.done_at = datetime.now() log.task_result = str(x) if x else None log.save() return x except Exception as e: logger.exception(e) raise e def dispatch(user: "User", job: Callable, *args, **kwargs): try: x = job.delay(*args, **kwargs) log = JobLog() log.user = user log.task_id = x.task_id log.job_name = job.__name__ log.status = "INITIAL" log.save() return x.result except Exception as e: logger.exception(e) raise e @shared_task(queue="background") def mul(a, b): return a * b @shared_task(queue="predict") def predict_simple(model_pk: int, smiles: str): mod = get_ml_model(model_pk) res = mod.predict(smiles) return res @shared_task(queue="background") def send_registration_mail(user_pk: int): pass @shared_task(bind=True, queue="model") def build_model(self, model_pk: int): mod = EPModel.objects.get(id=model_pk) if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url) try: mod.build_dataset() mod.build_model() except Exception as e: if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=mod.url ) raise e if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url) return mod.url @shared_task(bind=True, queue="model") def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None): packages = None if package_pks: packages = Package.objects.filter(pk__in=package_pks) mod = EPModel.objects.get(id=model_pk) if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url) try: mod.evaluate_model(multigen, eval_packages=packages) except Exception as e: if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=mod.url ) raise e if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url) return mod.url @shared_task(queue="model") def retrain(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.retrain() @shared_task(bind=True, queue="predict") def predict( self, pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None, ) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) # If the setting has a model add/restore it from the cache if setting.model is not None: setting.model = get_ml_model(setting.model.pk) pw.kv.update(**{"status": "running"}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url) try: # regular prediction if limit is not None: spw = SPathway(prediction_setting=setting, persist=pw) level = 0 while not spw.done: spw.predict_step(from_depth=level) level += 1 # break in case we are in incremental mode if limit != -1: if level >= limit: break elif node_pk is not None: n = Node.objects.get(id=node_pk, pathway=pw) spw = SPathway.from_pathway(pw) spw.predict_step(from_node=n) else: raise ValueError("Neither limit nor node_pk given!") except Exception as e: pw.kv.update({"status": "failed"}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=pw.url ) raise e pw.kv.update(**{"status": "completed"}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url) return pw.url