Files
enviPy-bayer/epdb/tasks.py
jebus a8554c903c [Enhancement] Swappable Packages (#216)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#216
Reviewed-by: liambrydon <lbry121@aucklanduni.ac.nz>
Reviewed-by: Tobias O <tobias.olenyi@envipath.com>
2025-11-14 21:42:39 +13:00

287 lines
7.8 KiB
Python

import csv
import io
import logging
from typing import Any, Callable, List, Optional
from uuid import uuid4
from celery import shared_task
from celery.utils.functional import LRUCache
from django.conf import settings as s
from django.utils import timezone
from epdb.logic import SPathway
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
Package = s.GET_PACKAGE_MODEL()
def get_ml_model(model_pk: int):
if model_pk not in ML_CACHE:
ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk)
return ML_CACHE[model_pk]
def dispatch_eager(user: "User", job: Callable, *args, **kwargs):
try:
x = job(*args, **kwargs)
log = JobLog()
log.user = user
log.task_id = uuid4()
log.job_name = job.__name__
log.status = "SUCCESS"
log.done_at = timezone.now()
log.task_result = str(x) if x else None
log.save()
return x
except Exception as e:
logger.exception(e)
raise e
def dispatch(user: "User", job: Callable, *args, **kwargs):
try:
x = job.delay(*args, **kwargs)
log = JobLog()
log.user = user
log.task_id = x.task_id
log.job_name = job.__name__
log.status = "INITIAL"
log.save()
return x.result
except Exception as e:
logger.exception(e)
raise e
@shared_task(queue="background")
def mul(a, b):
return a * b
@shared_task(queue="predict")
def predict_simple(model_pk: int, smiles: str):
mod = get_ml_model(model_pk)
res = mod.predict(smiles)
return res
@shared_task(queue="background")
def send_registration_mail(user_pk: int):
pass
@shared_task(bind=True, queue="model")
def build_model(self, model_pk: int):
mod = EPModel.objects.get(id=model_pk)
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
try:
mod.build_dataset()
mod.build_model()
except Exception as e:
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(
status="FAILED", task_result=mod.url
)
raise e
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
return mod.url
@shared_task(bind=True, queue="model")
def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None):
packages = None
if package_pks:
packages = Package.objects.filter(pk__in=package_pks)
mod = EPModel.objects.get(id=model_pk)
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
try:
mod.evaluate_model(multigen, eval_packages=packages)
except Exception as e:
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(
status="FAILED", task_result=mod.url
)
raise e
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
return mod.url
@shared_task(queue="model")
def retrain(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.retrain()
@shared_task(bind=True, queue="predict")
def predict(
self,
pw_pk: int,
pred_setting_pk: int,
limit: Optional[int] = None,
node_pk: Optional[int] = None,
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
# If the setting has a model add/restore it from the cache
if setting.model is not None:
setting.model = get_ml_model(setting.model.pk)
pw.kv.update(**{"status": "running"})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url)
try:
# regular prediction
if limit is not None:
spw = SPathway(prediction_setting=setting, persist=pw)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# break in case we are in incremental mode
if limit != -1:
if level >= limit:
break
elif node_pk is not None:
n = Node.objects.get(id=node_pk, pathway=pw)
spw = SPathway.from_pathway(pw)
spw.predict_step(from_node=n)
else:
raise ValueError("Neither limit nor node_pk given!")
except Exception as e:
pw.kv.update({"status": "failed"})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(
status="FAILED", task_result=pw.url
)
raise e
pw.kv.update(**{"status": "completed"})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url)
return pw.url
@shared_task(bind=True, queue="background")
def identify_missing_rules(
self,
pw_pks: List[int],
rule_package_pk: int,
):
from utilities.misc import PathwayUtils
rules = Package.objects.get(pk=rule_package_pk).get_applicable_rules()
rows: List[Any] = []
header = [
"Package Name",
"Pathway Name",
"Educt Name",
"Educt SMILES",
"Reaction Name",
"Reaction SMIRKS",
"Triggered Rules",
"Reactant SMARTS",
"Product SMARTS",
"Product Names",
"Product SMILES",
]
rows.append(header)
for pw in Pathway.objects.filter(pk__in=pw_pks):
pu = PathwayUtils(pw)
missing_rules = pu.find_missing_rules(rules)
package_name = pw.package.name
pathway_name = pw.name
for edge_url, rule_chain in missing_rules.items():
row: List[Any] = [package_name, pathway_name]
edge = Edge.objects.get(url=edge_url)
educts = edge.start_nodes.all()
for educt in educts:
row.append(educt.default_node_label.name)
row.append(educt.default_node_label.smiles)
row.append(edge.edge_label.name)
row.append(edge.edge_label.smirks())
rule_names = []
reactant_smarts = []
product_smarts = []
for r in rule_chain:
r = Rule.objects.get(url=r[0])
rule_names.append(r.name)
rs = r.reactants_smarts
if isinstance(rs, set):
rs = list(rs)
ps = r.products_smarts
if isinstance(ps, set):
ps = list(ps)
reactant_smarts.append(rs)
product_smarts.append(ps)
row.append(rule_names)
row.append(reactant_smarts)
row.append(product_smarts)
products = edge.end_nodes.all()
product_names = []
product_smiles = []
for product in products:
product_names.append(product.default_node_label.name)
product_smiles.append(product.default_node_label.smiles)
row.append(product_names)
row.append(product_smiles)
rows.append(row)
buffer = io.StringIO()
writer = csv.writer(buffer)
writer.writerows(rows)
buffer.seek(0)
return buffer.getvalue()