import csv import io import logging from typing import Any, Callable, List, Optional from uuid import uuid4 from celery import shared_task from celery.utils.functional import LRUCache from django.conf import settings as s from django.core.mail import EmailMultiAlternatives from django.utils import timezone from epdb.logic import SPathway from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User from utilities.chem import FormatConverter logger = logging.getLogger(__name__) ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. Package = s.GET_PACKAGE_MODEL() def get_ml_model(model_pk: int): if model_pk not in ML_CACHE: ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk) return ML_CACHE[model_pk] def dispatch_eager(user: "User", job: Callable, *args, **kwargs): try: x = job(*args, **kwargs) log = JobLog() log.user = user log.task_id = uuid4() log.job_name = job.__name__ log.status = "SUCCESS" log.done_at = timezone.now() log.task_result = str(x) if x else None log.save() return log, x except Exception as e: logger.exception(e) raise e def dispatch(user: "User", job: Callable, *args, **kwargs): try: x = job.delay(*args, **kwargs) log = JobLog() log.user = user log.task_id = x.task_id log.job_name = job.__name__ log.status = "INITIAL" log.save() return log except Exception as e: logger.exception(e) raise e @shared_task(queue="background") def mul(a, b): return a * b @shared_task(queue="predict") def predict_simple(model_pk: int, smiles: str): mod = get_ml_model(model_pk) res = mod.predict(smiles) return res @shared_task(queue="background") def send_registration_mail(user_pk: int): u = User.objects.get(id=user_pk) tpl = """Welcome {username}!, Thank you for your interest in enviPath. The public system is intended for non-commercial use only. We will review your account details and usually activate your account within 24 hours. Once activated, you will be notified by email. If we have any questions, we will contact you at this email address. Best regards, enviPath team""" msg = EmailMultiAlternatives( "Your enviPath account", tpl.format(username=u.username), "admin@envipath.org", [u.email], bcc=["admin@envipath.org"], ) msg.send(fail_silently=False) @shared_task(bind=True, queue="model") def build_model(self, model_pk: int): mod = EPModel.objects.get(id=model_pk) if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url) try: mod.build_dataset() mod.build_model() except Exception as e: if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=mod.url ) raise e if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url) return mod.url @shared_task(bind=True, queue="model") def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None): packages = None if package_pks: packages = Package.objects.filter(pk__in=package_pks) mod = EPModel.objects.get(id=model_pk) if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url) try: mod.evaluate_model(multigen, eval_packages=packages) except Exception as e: if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=mod.url ) raise e if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url) return mod.url @shared_task(queue="model") def retrain(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.retrain() @shared_task(bind=True, queue="predict") def predict( self, pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None, setting_overrides: Optional[dict] = None, ) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) if setting_overrides: for k, v in setting_overrides.items(): setattr(setting, k, v) # If the setting has a model add/restore it from the cache if setting.model is not None: setting.model = get_ml_model(setting.model.pk) kv = {"status": "running"} if setting_overrides: kv["setting_overrides"] = setting_overrides pw.kv.update(**kv) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url) try: # regular prediction if limit is not None: spw = SPathway(prediction_setting=setting, persist=pw) level = 0 while not spw.done: spw.predict_step(from_depth=level) level += 1 # break in case we are in incremental mode if limit != -1: if level >= limit: break elif node_pk is not None: n = Node.objects.get(id=node_pk, pathway=pw) spw = SPathway.from_pathway(pw) spw.predict_step(from_node=n) else: spw = SPathway(prediction_setting=setting, persist=pw) spw.predict() except Exception as e: pw.kv.update({"status": "failed"}) pw.kv.update(**{"error": str(e)}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=pw.url ) raise e pw.kv.update(**{"status": "completed"}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url) return pw.url @shared_task(bind=True, queue="background") def identify_missing_rules( self, pw_pks: List[int], rule_package_pk: int, ): from utilities.misc import PathwayUtils rules = Package.objects.get(pk=rule_package_pk).get_applicable_rules() rows: List[Any] = [] header = [ "Package Name", "Pathway Name", "Educt Name", "Educt SMILES", "Reaction Name", "Reaction SMIRKS", "Triggered Rules", "Reactant SMARTS", "Product SMARTS", "Product Names", "Product SMILES", ] rows.append(header) for pw in Pathway.objects.filter(pk__in=pw_pks): pu = PathwayUtils(pw) missing_rules = pu.find_missing_rules(rules) package_name = pw.package.name pathway_name = pw.name for edge_url, rule_chain in missing_rules.items(): row: List[Any] = [package_name, pathway_name] edge = Edge.objects.get(url=edge_url) educts = edge.start_nodes.all() for educt in educts: row.append(educt.default_node_label.name) row.append(educt.default_node_label.smiles) row.append(edge.edge_label.name) row.append(edge.edge_label.smirks()) rule_names = [] reactant_smarts = [] product_smarts = [] for r in rule_chain: r = Rule.objects.get(url=r[0]) rule_names.append(r.name) rs = r.reactants_smarts if isinstance(rs, set): rs = list(rs) ps = r.products_smarts if isinstance(ps, set): ps = list(ps) reactant_smarts.append(rs) product_smarts.append(ps) row.append(rule_names) row.append(reactant_smarts) row.append(product_smarts) products = edge.end_nodes.all() product_names = [] product_smiles = [] for product in products: product_names.append(product.default_node_label.name) product_smiles.append(product.default_node_label.smiles) row.append(product_names) row.append(product_smiles) rows.append(row) buffer = io.StringIO() writer = csv.writer(buffer) writer.writerows(rows) buffer.seek(0) return buffer.getvalue() @shared_task(bind=True, queue="background") def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_pk: int): from utilities.misc import PathwayUtils setting = Setting.objects.get(pk=setting_pk) # Temporarily set model_threshold to 0.0 to keep all tps setting.model_threshold = 0.0 target = Package.objects.get(pk=target_package_pk) intermediate_pathways = [] predicted_pathways = [] for pw in Pathway.objects.filter(pk__in=pw_pks): pu = PathwayUtils(pw) eng_pw, node_to_snode_mapping, intermediates = pu.engineer(setting) # If we've found intermediates, do the following # - Get a copy of the original pathway and add intermediates # - Store the predicted pathway for further investigation if len(intermediates): copy_mapping = {} copied_pw = pw.copy(target, copy_mapping) copied_pw.name = f"{copied_pw.name} (Engineered)" copied_pw.description = f"The original Pathway can be found here: {pw.url}" copied_pw.save() for inter in intermediates: start = copy_mapping[inter[0]] end = copy_mapping[inter[1]] start_snode = inter[2] end_snode = inter[3] for idx, intermediate_edge in enumerate(inter[4]): smiles_to_node = {} snodes_to_create = list( set(intermediate_edge.educts + intermediate_edge.products) ) for snode in snodes_to_create: if snode == start_snode or snode == end_snode: smiles_to_node[snode.smiles] = start if snode == start_snode else end continue if snode.smiles not in smiles_to_node: n = Node.create(copied_pw, smiles=snode.smiles, depth=snode.depth) # Used in viz to highlight intermediates n.kv.update({"is_engineered_intermediate": True}) n.save() smiles_to_node[snode.smiles] = n Edge.create( copied_pw, [smiles_to_node[educt.smiles] for educt in intermediate_edge.educts], [smiles_to_node[product.smiles] for product in intermediate_edge.products], rule=intermediate_edge.rule, ) # Persist the predicted pathway pred_pw = pu.spathway_to_pathway(target, eng_pw, name=f"{pw.name} (Predicted)") intermediate_pathways.append(copied_pw.url) predicted_pathways.append(pred_pw.url) return intermediate_pathways, predicted_pathways @shared_task(bind=True, queue="background") def batch_predict( self, substrates: List[str] | List[List[str]], prediction_setting_pk: int, target_package_pk: int, num_tps: int = 50, ): target_package = Package.objects.get(pk=target_package_pk) prediction_setting = Setting.objects.get(pk=prediction_setting_pk) if len(substrates) == 0: raise ValueError("No substrates given!") is_pair = isinstance(substrates[0], list) substrate_and_names = [] if not is_pair: for sub in substrates: substrate_and_names.append([sub, None]) else: substrate_and_names = substrates # Check prerequisite that we can standardize all substrates standardized_substrates_and_smiles = [] for substrate in substrate_and_names: try: stand_smiles = FormatConverter.standardize(substrate[0]) standardized_substrates_and_smiles.append([stand_smiles, substrate[1]]) except ValueError: raise ValueError( f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!' ) pathways = [] for pair in standardized_substrates_and_smiles: pw = Pathway.create( target_package, pair[0], name=pair[1], predicted=True, ) # set mode and setting pw.setting = prediction_setting pw.kv.update({"mode": "predict"}) pw.save() predict( pw.pk, prediction_setting.pk, limit=None, setting_overrides={ "max_nodes": num_tps, "max_depth": num_tps, "model_threshold": 0.001, }, ) pathways.append(pw) buffer = io.StringIO() for idx, pw in enumerate(pathways): # Carry out header only for the first pathway buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True)) buffer.seek(0) return buffer.getvalue()