Files
enviPy-bayer/epdb/tasks.py
jebus 5477b5b3d4 [Feature] Rule Based Model (#92)
Fixes #89

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#92
2025-09-09 19:32:12 +12:00

87 lines
2.1 KiB
Python

import logging
from typing import Optional
from celery.signals import worker_process_init
from celery import shared_task
from epdb.models import Pathway, Node, Edge, EPModel, Setting
from epdb.logic import SPathway
from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
@shared_task(queue='background')
def mul(a, b):
return a * b
@shared_task(queue='predict')
def predict_simple(model_pk: int, smiles: str):
mod = EPModel.objects.get(id=model_pk)
res = mod.predict(smiles)
return res
@shared_task(queue='background')
def send_registration_mail(user_pk: int):
pass
@shared_task(queue='model')
def build_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.build_dataset()
mod.build_model()
@shared_task(queue='model')
def evaluate_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.evaluate_model()
@shared_task(queue='model')
def retrain(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.retrain()
@shared_task(queue='predict')
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
pw.kv.update(**{'status': 'running'})
pw.save()
try:
# regular prediction
if limit is not None:
spw = SPathway(prediction_setting=setting, persist=pw)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# break in case we are in incremental mode
if limit != -1:
if level >= limit:
break
elif node_pk is not None:
n = Node.objects.get(id=node_pk, pathway=pw)
spw = SPathway.from_pathway(pw)
spw.predict_step(from_node=n)
else:
raise ValueError("Neither limit nor node_pk given!")
except Exception as e:
pw.kv.update({'status': 'failed'})
pw.save()
raise e
pw.kv.update(**{'status': 'completed'})
pw.save()