import logging from typing import Optional from celery import shared_task from epdb.models import Pathway, Node, EPModel, Setting from epdb.logic import SPathway logger = logging.getLogger(__name__) @shared_task(queue="background") def mul(a, b): return a * b @shared_task(queue="predict") def predict_simple(model_pk: int, smiles: str): mod = EPModel.objects.get(id=model_pk) res = mod.predict(smiles) return res @shared_task(queue="background") def send_registration_mail(user_pk: int): pass @shared_task(queue="model") def build_model(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.build_dataset() mod.build_model() @shared_task(queue="model") def evaluate_model(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.evaluate_model() @shared_task(queue="model") def retrain(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.retrain() @shared_task(queue="predict") def predict( pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None ) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) pw.kv.update(**{"status": "running"}) pw.save() try: # regular prediction if limit is not None: spw = SPathway(prediction_setting=setting, persist=pw) level = 0 while not spw.done: spw.predict_step(from_depth=level) level += 1 # break in case we are in incremental mode if limit != -1: if level >= limit: break elif node_pk is not None: n = Node.objects.get(id=node_pk, pathway=pw) spw = SPathway.from_pathway(pw) spw.predict_step(from_node=n) else: raise ValueError("Neither limit nor node_pk given!") except Exception as e: pw.kv.update({"status": "failed"}) pw.save() raise e pw.kv.update(**{"status": "completed"}) pw.save()