forked from enviPath/enviPy
189 lines
5.3 KiB
Python
189 lines
5.3 KiB
Python
import logging
|
|
from datetime import datetime
|
|
from typing import Callable, Optional
|
|
from uuid import uuid4
|
|
|
|
from celery import shared_task
|
|
from celery.utils.functional import LRUCache
|
|
|
|
from epdb.logic import SPathway
|
|
from epdb.models import EPModel, JobLog, Node, Package, Pathway, Setting, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
|
|
|
|
|
|
def get_ml_model(model_pk: int):
|
|
if model_pk not in ML_CACHE:
|
|
ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk)
|
|
return ML_CACHE[model_pk]
|
|
|
|
|
|
def dispatch_eager(user: "User", job: Callable, *args, **kwargs):
|
|
try:
|
|
x = job(*args, **kwargs)
|
|
log = JobLog()
|
|
log.user = user
|
|
log.task_id = uuid4()
|
|
log.job_name = job.__name__
|
|
log.status = "SUCCESS"
|
|
log.done_at = datetime.now()
|
|
log.task_result = str(x) if x else None
|
|
log.save()
|
|
|
|
return x
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise e
|
|
|
|
|
|
def dispatch(user: "User", job: Callable, *args, **kwargs):
|
|
try:
|
|
x = job.delay(*args, **kwargs)
|
|
log = JobLog()
|
|
log.user = user
|
|
log.task_id = x.task_id
|
|
log.job_name = job.__name__
|
|
log.status = "INITIAL"
|
|
log.save()
|
|
|
|
return x.result
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise e
|
|
|
|
|
|
@shared_task(queue="background")
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
|
|
@shared_task(queue="predict")
|
|
def predict_simple(model_pk: int, smiles: str):
|
|
mod = get_ml_model(model_pk)
|
|
res = mod.predict(smiles)
|
|
return res
|
|
|
|
|
|
@shared_task(queue="background")
|
|
def send_registration_mail(user_pk: int):
|
|
pass
|
|
|
|
|
|
@shared_task(bind=True, queue="model")
|
|
def build_model(self, model_pk: int):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
|
|
|
|
try:
|
|
mod.build_dataset()
|
|
mod.build_model()
|
|
except Exception as e:
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(
|
|
status="FAILED", task_result=mod.url
|
|
)
|
|
|
|
raise e
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
|
|
|
|
return mod.url
|
|
|
|
|
|
@shared_task(bind=True, queue="model")
|
|
def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None):
|
|
packages = None
|
|
|
|
if package_pks:
|
|
packages = Package.objects.filter(pk__in=package_pks)
|
|
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
|
|
|
|
try:
|
|
mod.evaluate_model(multigen, eval_packages=packages)
|
|
except Exception as e:
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(
|
|
status="FAILED", task_result=mod.url
|
|
)
|
|
|
|
raise e
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
|
|
|
|
return mod.url
|
|
|
|
|
|
@shared_task(queue="model")
|
|
def retrain(model_pk: int):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
mod.retrain()
|
|
|
|
|
|
@shared_task(bind=True, queue="predict")
|
|
def predict(
|
|
self,
|
|
pw_pk: int,
|
|
pred_setting_pk: int,
|
|
limit: Optional[int] = None,
|
|
node_pk: Optional[int] = None,
|
|
) -> Pathway:
|
|
pw = Pathway.objects.get(id=pw_pk)
|
|
setting = Setting.objects.get(id=pred_setting_pk)
|
|
# If the setting has a model add/restore it from the cache
|
|
if setting.model is not None:
|
|
setting.model = get_ml_model(setting.model.pk)
|
|
|
|
pw.kv.update(**{"status": "running"})
|
|
pw.save()
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url)
|
|
|
|
try:
|
|
# regular prediction
|
|
if limit is not None:
|
|
spw = SPathway(prediction_setting=setting, persist=pw)
|
|
level = 0
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
# break in case we are in incremental mode
|
|
if limit != -1:
|
|
if level >= limit:
|
|
break
|
|
|
|
elif node_pk is not None:
|
|
n = Node.objects.get(id=node_pk, pathway=pw)
|
|
spw = SPathway.from_pathway(pw)
|
|
spw.predict_step(from_node=n)
|
|
else:
|
|
raise ValueError("Neither limit nor node_pk given!")
|
|
|
|
except Exception as e:
|
|
pw.kv.update({"status": "failed"})
|
|
pw.save()
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(
|
|
status="FAILED", task_result=pw.url
|
|
)
|
|
|
|
raise e
|
|
|
|
pw.kv.update(**{"status": "completed"})
|
|
pw.save()
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url)
|
|
|
|
return pw.url
|