forked from enviPath/enviPy
442 lines
13 KiB
Python
442 lines
13 KiB
Python
import csv
|
|
import io
|
|
import logging
|
|
from typing import Any, Callable, List, Optional
|
|
from uuid import uuid4
|
|
|
|
from celery import shared_task
|
|
from celery.utils.functional import LRUCache
|
|
from django.conf import settings as s
|
|
from django.utils import timezone
|
|
|
|
from epdb.logic import SPathway
|
|
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
|
|
from utilities.chem import FormatConverter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
|
|
|
|
Package = s.GET_PACKAGE_MODEL()
|
|
|
|
|
|
def get_ml_model(model_pk: int):
|
|
if model_pk not in ML_CACHE:
|
|
ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk)
|
|
return ML_CACHE[model_pk]
|
|
|
|
|
|
def dispatch_eager(user: "User", job: Callable, *args, **kwargs):
|
|
try:
|
|
x = job(*args, **kwargs)
|
|
log = JobLog()
|
|
log.user = user
|
|
log.task_id = uuid4()
|
|
log.job_name = job.__name__
|
|
log.status = "SUCCESS"
|
|
log.done_at = timezone.now()
|
|
log.task_result = str(x) if x else None
|
|
log.save()
|
|
|
|
return log, x
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise e
|
|
|
|
|
|
def dispatch(user: "User", job: Callable, *args, **kwargs):
|
|
try:
|
|
x = job.delay(*args, **kwargs)
|
|
log = JobLog()
|
|
log.user = user
|
|
log.task_id = x.task_id
|
|
log.job_name = job.__name__
|
|
log.status = "INITIAL"
|
|
log.save()
|
|
|
|
return log
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise e
|
|
|
|
|
|
@shared_task(queue="background")
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
|
|
@shared_task(queue="predict")
|
|
def predict_simple(model_pk: int, smiles: str):
|
|
mod = get_ml_model(model_pk)
|
|
res = mod.predict(smiles)
|
|
return res
|
|
|
|
|
|
@shared_task(queue="background")
|
|
def send_registration_mail(user_pk: int):
|
|
pass
|
|
|
|
|
|
@shared_task(bind=True, queue="model")
|
|
def build_model(self, model_pk: int):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
|
|
|
|
try:
|
|
mod.build_dataset()
|
|
mod.build_model()
|
|
except Exception as e:
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(
|
|
status="FAILED", task_result=mod.url
|
|
)
|
|
|
|
raise e
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
|
|
|
|
return mod.url
|
|
|
|
|
|
@shared_task(bind=True, queue="model")
|
|
def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None):
|
|
packages = None
|
|
|
|
if package_pks:
|
|
packages = Package.objects.filter(pk__in=package_pks)
|
|
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url)
|
|
|
|
try:
|
|
mod.evaluate_model(multigen, eval_packages=packages)
|
|
except Exception as e:
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(
|
|
status="FAILED", task_result=mod.url
|
|
)
|
|
|
|
raise e
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url)
|
|
|
|
return mod.url
|
|
|
|
|
|
@shared_task(queue="model")
|
|
def retrain(model_pk: int):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
mod.retrain()
|
|
|
|
|
|
@shared_task(bind=True, queue="predict")
|
|
def predict(
|
|
self,
|
|
pw_pk: int,
|
|
pred_setting_pk: int,
|
|
limit: Optional[int] = None,
|
|
node_pk: Optional[int] = None,
|
|
setting_overrides: Optional[dict] = None,
|
|
) -> Pathway:
|
|
pw = Pathway.objects.get(id=pw_pk)
|
|
setting = Setting.objects.get(id=pred_setting_pk)
|
|
|
|
if setting_overrides:
|
|
for k, v in setting_overrides.items():
|
|
setattr(setting, k, v)
|
|
|
|
# If the setting has a model add/restore it from the cache
|
|
if setting.model is not None:
|
|
setting.model = get_ml_model(setting.model.pk)
|
|
|
|
kv = {"status": "running"}
|
|
|
|
if setting_overrides:
|
|
kv["setting_overrides"] = setting_overrides
|
|
|
|
pw.kv.update(**kv)
|
|
pw.save()
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url)
|
|
|
|
try:
|
|
# regular prediction
|
|
if limit is not None:
|
|
spw = SPathway(prediction_setting=setting, persist=pw)
|
|
level = 0
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
# break in case we are in incremental mode
|
|
if limit != -1:
|
|
if level >= limit:
|
|
break
|
|
|
|
elif node_pk is not None:
|
|
n = Node.objects.get(id=node_pk, pathway=pw)
|
|
spw = SPathway.from_pathway(pw)
|
|
spw.predict_step(from_node=n)
|
|
else:
|
|
spw = SPathway(prediction_setting=setting, persist=pw)
|
|
spw.predict()
|
|
|
|
except Exception as e:
|
|
pw.kv.update({"status": "failed"})
|
|
pw.kv.update(**{"error": str(e)})
|
|
pw.save()
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(
|
|
status="FAILED", task_result=pw.url
|
|
)
|
|
|
|
raise e
|
|
|
|
pw.kv.update(**{"status": "completed"})
|
|
pw.save()
|
|
|
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
|
JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url)
|
|
|
|
return pw.url
|
|
|
|
|
|
@shared_task(bind=True, queue="background")
|
|
def identify_missing_rules(
|
|
self,
|
|
pw_pks: List[int],
|
|
rule_package_pk: int,
|
|
):
|
|
from utilities.misc import PathwayUtils
|
|
|
|
rules = Package.objects.get(pk=rule_package_pk).get_applicable_rules()
|
|
|
|
rows: List[Any] = []
|
|
header = [
|
|
"Package Name",
|
|
"Pathway Name",
|
|
"Educt Name",
|
|
"Educt SMILES",
|
|
"Reaction Name",
|
|
"Reaction SMIRKS",
|
|
"Triggered Rules",
|
|
"Reactant SMARTS",
|
|
"Product SMARTS",
|
|
"Product Names",
|
|
"Product SMILES",
|
|
]
|
|
|
|
rows.append(header)
|
|
|
|
for pw in Pathway.objects.filter(pk__in=pw_pks):
|
|
pu = PathwayUtils(pw)
|
|
|
|
missing_rules = pu.find_missing_rules(rules)
|
|
|
|
package_name = pw.package.name
|
|
pathway_name = pw.name
|
|
|
|
for edge_url, rule_chain in missing_rules.items():
|
|
row: List[Any] = [package_name, pathway_name]
|
|
edge = Edge.objects.get(url=edge_url)
|
|
educts = edge.start_nodes.all()
|
|
|
|
for educt in educts:
|
|
row.append(educt.default_node_label.name)
|
|
row.append(educt.default_node_label.smiles)
|
|
|
|
row.append(edge.edge_label.name)
|
|
row.append(edge.edge_label.smirks())
|
|
|
|
rule_names = []
|
|
reactant_smarts = []
|
|
product_smarts = []
|
|
|
|
for r in rule_chain:
|
|
r = Rule.objects.get(url=r[0])
|
|
rule_names.append(r.name)
|
|
|
|
rs = r.reactants_smarts
|
|
if isinstance(rs, set):
|
|
rs = list(rs)
|
|
|
|
ps = r.products_smarts
|
|
if isinstance(ps, set):
|
|
ps = list(ps)
|
|
|
|
reactant_smarts.append(rs)
|
|
product_smarts.append(ps)
|
|
|
|
row.append(rule_names)
|
|
row.append(reactant_smarts)
|
|
row.append(product_smarts)
|
|
|
|
products = edge.end_nodes.all()
|
|
product_names = []
|
|
product_smiles = []
|
|
|
|
for product in products:
|
|
product_names.append(product.default_node_label.name)
|
|
product_smiles.append(product.default_node_label.smiles)
|
|
|
|
row.append(product_names)
|
|
row.append(product_smiles)
|
|
|
|
rows.append(row)
|
|
|
|
buffer = io.StringIO()
|
|
|
|
writer = csv.writer(buffer)
|
|
writer.writerows(rows)
|
|
|
|
buffer.seek(0)
|
|
|
|
return buffer.getvalue()
|
|
|
|
|
|
@shared_task(bind=True, queue="background")
|
|
def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_pk: int):
|
|
from utilities.misc import PathwayUtils
|
|
|
|
setting = Setting.objects.get(pk=setting_pk)
|
|
# Temporarily set model_threshold to 0.0 to keep all tps
|
|
setting.model_threshold = 0.0
|
|
|
|
target = Package.objects.get(pk=target_package_pk)
|
|
|
|
intermediate_pathways = []
|
|
predicted_pathways = []
|
|
|
|
for pw in Pathway.objects.filter(pk__in=pw_pks):
|
|
pu = PathwayUtils(pw)
|
|
|
|
eng_pw, node_to_snode_mapping, intermediates = pu.engineer(setting)
|
|
|
|
# If we've found intermediates, do the following
|
|
# - Get a copy of the original pathway and add intermediates
|
|
# - Store the predicted pathway for further investigation
|
|
if len(intermediates):
|
|
copy_mapping = {}
|
|
copied_pw = pw.copy(target, copy_mapping)
|
|
copied_pw.name = f"{copied_pw.name} (Engineered)"
|
|
copied_pw.description = f"The original Pathway can be found here: {pw.url}"
|
|
copied_pw.save()
|
|
|
|
for inter in intermediates:
|
|
start = copy_mapping[inter[0]]
|
|
end = copy_mapping[inter[1]]
|
|
start_snode = inter[2]
|
|
end_snode = inter[3]
|
|
for idx, intermediate_edge in enumerate(inter[4]):
|
|
smiles_to_node = {}
|
|
|
|
snodes_to_create = list(
|
|
set(intermediate_edge.educts + intermediate_edge.products)
|
|
)
|
|
|
|
for snode in snodes_to_create:
|
|
if snode == start_snode or snode == end_snode:
|
|
smiles_to_node[snode.smiles] = start if snode == start_snode else end
|
|
continue
|
|
|
|
if snode.smiles not in smiles_to_node:
|
|
n = Node.create(copied_pw, smiles=snode.smiles, depth=snode.depth)
|
|
# Used in viz to highlight intermediates
|
|
n.kv.update({"is_engineered_intermediate": True})
|
|
n.save()
|
|
smiles_to_node[snode.smiles] = n
|
|
|
|
Edge.create(
|
|
copied_pw,
|
|
[smiles_to_node[educt.smiles] for educt in intermediate_edge.educts],
|
|
[smiles_to_node[product.smiles] for product in intermediate_edge.products],
|
|
rule=intermediate_edge.rule,
|
|
)
|
|
|
|
# Persist the predicted pathway
|
|
pred_pw = pu.spathway_to_pathway(target, eng_pw, name=f"{pw.name} (Predicted)")
|
|
|
|
intermediate_pathways.append(copied_pw.url)
|
|
predicted_pathways.append(pred_pw.url)
|
|
|
|
return intermediate_pathways, predicted_pathways
|
|
|
|
|
|
@shared_task(bind=True, queue="background")
|
|
def batch_predict(
|
|
self,
|
|
substrates: List[str] | List[List[str]],
|
|
prediction_setting_pk: int,
|
|
target_package_pk: int,
|
|
num_tps: int = 50,
|
|
):
|
|
target_package = Package.objects.get(pk=target_package_pk)
|
|
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
|
|
|
|
if len(substrates) == 0:
|
|
raise ValueError("No substrates given!")
|
|
|
|
is_pair = isinstance(substrates[0], list)
|
|
|
|
substrate_and_names = []
|
|
if not is_pair:
|
|
for sub in substrates:
|
|
substrate_and_names.append([sub, None])
|
|
else:
|
|
substrate_and_names = substrates
|
|
|
|
# Check prerequisite that we can standardize all substrates
|
|
standardized_substrates_and_smiles = []
|
|
for substrate in substrate_and_names:
|
|
try:
|
|
stand_smiles = FormatConverter.standardize(substrate[0])
|
|
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
|
|
except ValueError:
|
|
raise ValueError(
|
|
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
|
|
)
|
|
|
|
pathways = []
|
|
|
|
for pair in standardized_substrates_and_smiles:
|
|
pw = Pathway.create(
|
|
target_package,
|
|
pair[0],
|
|
name=pair[1],
|
|
predicted=True,
|
|
)
|
|
|
|
# set mode and setting
|
|
pw.setting = prediction_setting
|
|
pw.kv.update({"mode": "predict"})
|
|
pw.save()
|
|
|
|
predict(
|
|
pw.pk,
|
|
prediction_setting.pk,
|
|
limit=None,
|
|
setting_overrides={
|
|
"max_nodes": num_tps,
|
|
"max_depth": num_tps,
|
|
"model_threshold": 0.001,
|
|
},
|
|
)
|
|
|
|
pathways.append(pw)
|
|
|
|
buffer = io.StringIO()
|
|
|
|
for idx, pw in enumerate(pathways):
|
|
# Carry out header only for the first pathway
|
|
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
|
|
|
|
buffer.seek(0)
|
|
|
|
return buffer.getvalue()
|