import csv import io import logging from typing import Any, Callable, List, Optional from uuid import uuid4 from celery import shared_task from celery.utils.functional import LRUCache from django.conf import settings as s from django.utils import timezone from epdb.logic import SPathway from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User from utilities.chem import FormatConverter logger = logging.getLogger(__name__) ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. Package = s.GET_PACKAGE_MODEL() def get_ml_model(model_pk: int): if model_pk not in ML_CACHE: ML_CACHE[model_pk] = EPModel.objects.get(id=model_pk) return ML_CACHE[model_pk] def dispatch_eager(user: "User", job: Callable, *args, **kwargs): try: x = job(*args, **kwargs) log = JobLog() log.user = user log.task_id = uuid4() log.job_name = job.__name__ log.status = "SUCCESS" log.done_at = timezone.now() log.task_result = str(x) if x else None log.save() return log, x except Exception as e: logger.exception(e) raise e def dispatch(user: "User", job: Callable, *args, **kwargs): try: x = job.delay(*args, **kwargs) log = JobLog() log.user = user log.task_id = x.task_id log.job_name = job.__name__ log.status = "INITIAL" log.save() return log except Exception as e: logger.exception(e) raise e @shared_task(queue="background") def mul(a, b): return a * b @shared_task(queue="predict") def predict_simple(model_pk: int, smiles: str): mod = get_ml_model(model_pk) res = mod.predict(smiles) return res @shared_task(queue="background") def send_registration_mail(user_pk: int): pass @shared_task(bind=True, queue="model") def build_model(self, model_pk: int): mod = EPModel.objects.get(id=model_pk) if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url) try: mod.build_dataset() mod.build_model() except Exception as e: if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=mod.url ) raise e if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url) return mod.url @shared_task(bind=True, queue="model") def evaluate_model(self, model_pk: int, multigen: bool, package_pks: Optional[list] = None): packages = None if package_pks: packages = Package.objects.filter(pk__in=package_pks) mod = EPModel.objects.get(id=model_pk) if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=mod.url) try: mod.evaluate_model(multigen, eval_packages=packages) except Exception as e: if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=mod.url ) raise e if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=mod.url) return mod.url @shared_task(queue="model") def retrain(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.retrain() @shared_task(bind=True, queue="predict") def predict( self, pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None, setting_overrides: Optional[dict] = None, ) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) if setting_overrides: for k, v in setting_overrides.items(): setattr(setting, k, v) # If the setting has a model add/restore it from the cache if setting.model is not None: setting.model = get_ml_model(setting.model.pk) kv = {"status": "running"} if setting_overrides: kv["setting_overrides"] = setting_overrides pw.kv.update(**kv) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="RUNNING", task_result=pw.url) try: # regular prediction if limit is not None: spw = SPathway(prediction_setting=setting, persist=pw) level = 0 while not spw.done: spw.predict_step(from_depth=level) level += 1 # break in case we are in incremental mode if limit != -1: if level >= limit: break elif node_pk is not None: n = Node.objects.get(id=node_pk, pathway=pw) spw = SPathway.from_pathway(pw) spw.predict_step(from_node=n) else: spw = SPathway(prediction_setting=setting, persist=pw) spw.predict() except Exception as e: pw.kv.update({"status": "failed"}) pw.kv.update(**{"error": str(e)}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update( status="FAILED", task_result=pw.url ) raise e pw.kv.update(**{"status": "completed"}) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url) return pw.url @shared_task(bind=True, queue="background") def identify_missing_rules( self, pw_pks: List[int], rule_package_pk: int, ): from utilities.misc import PathwayUtils rules = Package.objects.get(pk=rule_package_pk).get_applicable_rules() rows: List[Any] = [] header = [ "Package Name", "Pathway Name", "Educt Name", "Educt SMILES", "Reaction Name", "Reaction SMIRKS", "Triggered Rules", "Reactant SMARTS", "Product SMARTS", "Product Names", "Product SMILES", ] rows.append(header) for pw in Pathway.objects.filter(pk__in=pw_pks): pu = PathwayUtils(pw) missing_rules = pu.find_missing_rules(rules) package_name = pw.package.name pathway_name = pw.name for edge_url, rule_chain in missing_rules.items(): row: List[Any] = [package_name, pathway_name] edge = Edge.objects.get(url=edge_url) educts = edge.start_nodes.all() for educt in educts: row.append(educt.default_node_label.name) row.append(educt.default_node_label.smiles) row.append(edge.edge_label.name) row.append(edge.edge_label.smirks()) rule_names = [] reactant_smarts = [] product_smarts = [] for r in rule_chain: r = Rule.objects.get(url=r[0]) rule_names.append(r.name) rs = r.reactants_smarts if isinstance(rs, set): rs = list(rs) ps = r.products_smarts if isinstance(ps, set): ps = list(ps) reactant_smarts.append(rs) product_smarts.append(ps) row.append(rule_names) row.append(reactant_smarts) row.append(product_smarts) products = edge.end_nodes.all() product_names = [] product_smiles = [] for product in products: product_names.append(product.default_node_label.name) product_smiles.append(product.default_node_label.smiles) row.append(product_names) row.append(product_smiles) rows.append(row) buffer = io.StringIO() writer = csv.writer(buffer) writer.writerows(rows) buffer.seek(0) return buffer.getvalue() @shared_task(bind=True, queue="background") def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_pk: int): from utilities.misc import PathwayUtils setting = Setting.objects.get(pk=setting_pk) # Temporarily set model_threshold to 0.0 to keep all tps setting.model_threshold = 0.0 target = Package.objects.get(pk=target_package_pk) intermediate_pathways = [] predicted_pathways = [] for pw in Pathway.objects.filter(pk__in=pw_pks): pu = PathwayUtils(pw) eng_pw, node_to_snode_mapping, intermediates = pu.engineer(setting) # If we've found intermediates, do the following # - Get a copy of the original pathway and add intermediates # - Store the predicted pathway for further investigation if len(intermediates): copy_mapping = {} copied_pw = pw.copy(target, copy_mapping) copied_pw.name = f"{copied_pw.name} (Engineered)" copied_pw.description = f"The original Pathway can be found here: {pw.url}" copied_pw.save() for inter in intermediates: start = copy_mapping[inter[0]] end = copy_mapping[inter[1]] start_snode = inter[2] end_snode = inter[3] for idx, intermediate_edge in enumerate(inter[4]): smiles_to_node = {} snodes_to_create = list( set(intermediate_edge.educts + intermediate_edge.products) ) for snode in snodes_to_create: if snode == start_snode or snode == end_snode: smiles_to_node[snode.smiles] = start if snode == start_snode else end continue if snode.smiles not in smiles_to_node: n = Node.create(copied_pw, smiles=snode.smiles, depth=snode.depth) # Used in viz to highlight intermediates n.kv.update({"is_engineered_intermediate": True}) n.save() smiles_to_node[snode.smiles] = n Edge.create( copied_pw, [smiles_to_node[educt.smiles] for educt in intermediate_edge.educts], [smiles_to_node[product.smiles] for product in intermediate_edge.products], rule=intermediate_edge.rule, ) # Persist the predicted pathway pred_pw = pu.spathway_to_pathway(target, eng_pw, name=f"{pw.name} (Predicted)") intermediate_pathways.append(copied_pw.url) predicted_pathways.append(pred_pw.url) return intermediate_pathways, predicted_pathways @shared_task(bind=True, queue="background") def batch_predict( self, substrates: List[str] | List[List[str]], prediction_setting_pk: int, target_package_pk: int, num_tps: int = 50, ): target_package = Package.objects.get(pk=target_package_pk) prediction_setting = Setting.objects.get(pk=prediction_setting_pk) if len(substrates) == 0: raise ValueError("No substrates given!") is_pair = isinstance(substrates[0], list) substrate_and_names = [] if not is_pair: for sub in substrates: substrate_and_names.append([sub, None]) else: substrate_and_names = substrates # Check prerequisite that we can standardize all substrates standardized_substrates_and_smiles = [] for substrate in substrate_and_names: try: stand_smiles = FormatConverter.standardize(substrate[0]) standardized_substrates_and_smiles.append([stand_smiles, substrate[1]]) except ValueError: raise ValueError( f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!' ) pathways = [] for pair in standardized_substrates_and_smiles: pw = Pathway.create( target_package, pair[0], name=pair[1], predicted=True, ) # set mode and setting pw.setting = prediction_setting pw.kv.update({"mode": "predict"}) pw.save() predict( pw.pk, prediction_setting.pk, limit=None, setting_overrides={ "max_nodes": num_tps, "max_depth": num_tps, "model_threshold": 0.001, }, ) pathways.append(pw) buffer = io.StringIO() for idx, pw in enumerate(pathways): # Carry out header only for the first pathway buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True)) buffer.seek(0) return buffer.getvalue()