Files
enviPy-bayer/epdb/tasks.py
2026-01-29 11:13:34 +13:00

467 lines
14 KiB
Python

import csv
import io
import logging
from typing import Any, Callable, List, Optional
from uuid import uuid4
from celery import shared_task
from celery.utils.functional import LRUCache
from django.conf import settings as s
from django.core.mail import EmailMultiAlternatives
from django.utils import timezone
from epdb.logic import SPathway
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
Package = s.GET_PACKAGE_MODEL()
def get_ml_model(model_pk: int):
if model_pk not in ML_CACHE:
ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk)
return ML_CACHE[model_pk]
def dispatch_eager(user: "User", job: Callable, *args, **kwargs):
try:
x = job(*args, **kwargs)
log = JobLog()
log.user = user
log.task_id = uuid4()
log.job_name = job.__name__
log.status = "SUCCESS"
log.done_at = timezone.now()
log.task_result = str(x) if x else None
log.save()
return log, x
except Exception as e:
logger.exception(e)
raise e
def dispatch(user: "User", job: Callable, *args, **kwargs):
try:
x = job.delay(*args, **kwargs)
log = JobLog()
log.user = user
log.task_id = x.task_id
log.job_name = job.__name__
log.status = "INITIAL"
log.save()
return log
except Exception as e:
logger.exception(e)
raise e
@shared_task(queue="background")
def mul(a, b):
return a * b
@shared_task(queue="predict")
def predict_simple(model_pk: int, smiles: str):
mod = get_ml_model(model_pk)
res = mod.predict(smiles)
return res
@shared_task(queue="background")
def send_registration_mail(user_pk: int):
u = User.objects.get(id=user_pk)
tpl = """Welcome {username}!,
Thank you for your interest in enviPath.
The public system is intended for non-commercial use only.
We will review your account details and usually activate your account within 24 hours.
Once activated, you will be notified by email.
If we have any questions, we will contact you at this email address.
Best regards,
enviPath team"""
msg = EmailMultiAlternatives(
"Your enviPath account",
tpl.format(username=u.username),
"admin@envipath.org",
[u.email],
bcc=["admin@envipath.org"],
)
msg.send(fail_silently=False)
@shared_task(bind=True, queue="model")
def build_model(self, model_pk: int):
mod = EPModel.objects.get(id=model_pk)
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
try:
mod.build_dataset()
mod.build_model()
except Exception as e:
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(
status="FAILED", task_result=mod.url
)
raise e
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
return mod.url
@shared_task(bind=True, queue="model")
def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None):
packages = None
if package_pks:
packages = Package.objects.filter(pk__in=package_pks)
mod = EPModel.objects.get(id=model_pk)
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
try:
mod.evaluate_model(multigen, eval_packages=packages)
except Exception as e:
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(
status="FAILED", task_result=mod.url
)
raise e
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
return mod.url
@shared_task(queue="model")
def retrain(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.retrain()
@shared_task(bind=True, queue="predict")
def predict(
self,
pw_pk: int,
pred_setting_pk: int,
limit: Optional[int] = None,
node_pk: Optional[int] = None,
setting_overrides: Optional[dict] = None,
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
if setting_overrides:
for k, v in setting_overrides.items():
setattr(setting, k, v)
# If the setting has a model add/restore it from the cache
if setting.model is not None:
setting.model = get_ml_model(setting.model.pk)
kv = {"status": "running"}
if setting_overrides:
kv["setting_overrides"] = setting_overrides
pw.kv.update(**kv)
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url)
try:
# regular prediction
if limit is not None:
spw = SPathway(prediction_setting=setting, persist=pw)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# break in case we are in incremental mode
if limit != -1:
if level >= limit:
break
elif node_pk is not None:
n = Node.objects.get(id=node_pk, pathway=pw)
spw = SPathway.from_pathway(pw)
spw.predict_step(from_node=n)
else:
spw = SPathway(prediction_setting=setting, persist=pw)
spw.predict()
except Exception as e:
pw.kv.update({"status": "failed"})
pw.kv.update(**{"error": str(e)})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(
status="FAILED", task_result=pw.url
)
raise e
pw.kv.update(**{"status": "completed"})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url)
return pw.url
@shared_task(bind=True, queue="background")
def identify_missing_rules(
self,
pw_pks: List[int],
rule_package_pk: int,
):
from utilities.misc import PathwayUtils
rules = Package.objects.get(pk=rule_package_pk).get_applicable_rules()
rows: List[Any] = []
header = [
"Package Name",
"Pathway Name",
"Educt Name",
"Educt SMILES",
"Reaction Name",
"Reaction SMIRKS",
"Triggered Rules",
"Reactant SMARTS",
"Product SMARTS",
"Product Names",
"Product SMILES",
]
rows.append(header)
for pw in Pathway.objects.filter(pk__in=pw_pks):
pu = PathwayUtils(pw)
missing_rules = pu.find_missing_rules(rules)
package_name = pw.package.name
pathway_name = pw.name
for edge_url, rule_chain in missing_rules.items():
row: List[Any] = [package_name, pathway_name]
edge = Edge.objects.get(url=edge_url)
educts = edge.start_nodes.all()
for educt in educts:
row.append(educt.default_node_label.name)
row.append(educt.default_node_label.smiles)
row.append(edge.edge_label.name)
row.append(edge.edge_label.smirks())
rule_names = []
reactant_smarts = []
product_smarts = []
for r in rule_chain:
r = Rule.objects.get(url=r[0])
rule_names.append(r.name)
rs = r.reactants_smarts
if isinstance(rs, set):
rs = list(rs)
ps = r.products_smarts
if isinstance(ps, set):
ps = list(ps)
reactant_smarts.append(rs)
product_smarts.append(ps)
row.append(rule_names)
row.append(reactant_smarts)
row.append(product_smarts)
products = edge.end_nodes.all()
product_names = []
product_smiles = []
for product in products:
product_names.append(product.default_node_label.name)
product_smiles.append(product.default_node_label.smiles)
row.append(product_names)
row.append(product_smiles)
rows.append(row)
buffer = io.StringIO()
writer = csv.writer(buffer)
writer.writerows(rows)
buffer.seek(0)
return buffer.getvalue()
@shared_task(bind=True, queue="background")
def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_pk: int):
from utilities.misc import PathwayUtils
setting = Setting.objects.get(pk=setting_pk)
# Temporarily set model_threshold to 0.0 to keep all tps
setting.model_threshold = 0.0
target = Package.objects.get(pk=target_package_pk)
intermediate_pathways = []
predicted_pathways = []
for pw in Pathway.objects.filter(pk__in=pw_pks):
pu = PathwayUtils(pw)
eng_pw, node_to_snode_mapping, intermediates = pu.engineer(setting)
# If we've found intermediates, do the following
# - Get a copy of the original pathway and add intermediates
# - Store the predicted pathway for further investigation
if len(intermediates):
copy_mapping = {}
copied_pw = pw.copy(target, copy_mapping)
copied_pw.name = f"{copied_pw.name} (Engineered)"
copied_pw.description = f"The original Pathway can be found here: {pw.url}"
copied_pw.save()
for inter in intermediates:
start = copy_mapping[inter[0]]
end = copy_mapping[inter[1]]
start_snode = inter[2]
end_snode = inter[3]
for idx, intermediate_edge in enumerate(inter[4]):
smiles_to_node = {}
snodes_to_create = list(
set(intermediate_edge.educts + intermediate_edge.products)
)
for snode in snodes_to_create:
if snode == start_snode or snode == end_snode:
smiles_to_node[snode.smiles] = start if snode == start_snode else end
continue
if snode.smiles not in smiles_to_node:
n = Node.create(copied_pw, smiles=snode.smiles, depth=snode.depth)
# Used in viz to highlight intermediates
n.kv.update({"is_engineered_intermediate": True})
n.save()
smiles_to_node[snode.smiles] = n
Edge.create(
copied_pw,
[smiles_to_node[educt.smiles] for educt in intermediate_edge.educts],
[smiles_to_node[product.smiles] for product in intermediate_edge.products],
rule=intermediate_edge.rule,
)
# Persist the predicted pathway
pred_pw = pu.spathway_to_pathway(target, eng_pw, name=f"{pw.name} (Predicted)")
intermediate_pathways.append(copied_pw.url)
predicted_pathways.append(pred_pw.url)
return intermediate_pathways, predicted_pathways
@shared_task(bind=True, queue="background")
def batch_predict(
self,
substrates: List[str] | List[List[str]],
prediction_setting_pk: int,
target_package_pk: int,
num_tps: int = 50,
):
target_package = Package.objects.get(pk=target_package_pk)
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
if len(substrates) == 0:
raise ValueError("No substrates given!")
is_pair = isinstance(substrates[0], list)
substrate_and_names = []
if not is_pair:
for sub in substrates:
substrate_and_names.append([sub, None])
else:
substrate_and_names = substrates
# Check prerequisite that we can standardize all substrates
standardized_substrates_and_smiles = []
for substrate in substrate_and_names:
try:
stand_smiles = FormatConverter.standardize(substrate[0])
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
except ValueError:
raise ValueError(
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
)
pathways = []
for pair in standardized_substrates_and_smiles:
pw = Pathway.create(
target_package,
pair[0],
name=pair[1],
predicted=True,
)
# set mode and setting
pw.setting = prediction_setting
pw.kv.update({"mode": "predict"})
pw.save()
predict(
pw.pk,
prediction_setting.pk,
limit=None,
setting_overrides={
"max_nodes": num_tps,
"max_depth": num_tps,
"model_threshold": 0.001,
},
)
pathways.append(pw)
buffer = io.StringIO()
for idx, pw in enumerate(pathways):
# Carry out header only for the first pathway
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
buffer.seek(0)
return buffer.getvalue()