diff --git a/epdb/legacy_api.py b/epdb/legacy_api.py index 228ca2fe..b955ff78 100644 --- a/epdb/legacy_api.py +++ b/epdb/legacy_api.py @@ -1451,7 +1451,7 @@ def create_pathway( from .tasks import dispatch, predict - dispatch(request.user, predict, new_pw.pk, setting.pk, limit=-1) + dispatch(request.user, predict, new_pw.pk, setting.pk, limit=None) return redirect(new_pw.url) except ValueError as e: diff --git a/epdb/logic.py b/epdb/logic.py index de409f8a..729b9eea 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union, Tuple from uuid import UUID import nh3 @@ -16,6 +16,7 @@ from epdb.models import ( Edge, EnzymeLink, EPModel, + ExpansionSchemeChoice, Group, GroupPackagePermission, Node, @@ -1116,6 +1117,7 @@ class SettingManager(object): rule_packages: List[Package] = None, model: EPModel = None, model_threshold: float = None, + expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS, ): new_s = Setting() # Clean for potential XSS @@ -1550,6 +1552,196 @@ class SPathway(object): return sorted(res, key=lambda x: hash(x)) + def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]: + """ + Expands the given substrates by generating new nodes and edges based on prediction settings. + + This method processes a list of substrates and expands them into new nodes and edges using defined + rules and settings. It evaluates each substrate to determine its applicability domain, persists + domain assessments, and generates candidates for further processing. Newly created nodes and edges + are returned, and any applicable information is stored or updated internally during the process. + + Parameters: + substrates (List[SNode]): A list of substrate nodes to be expanded. + + Returns: + Tuple[List[SNode], List[SEdge]]: + A tuple containing: + - A list of new nodes generated during the expansion. + - A list of new edges representing connections between nodes based on candidate reactions. + + Raises: + ValueError: If a node does not have an ID when it should have been saved already. + """ + new_nodes: List[SNode] = [] + new_edges: List[SEdge] = [] + + for sub in substrates: + # For App Domain we have to ensure that each Node is evaluated + if sub.app_domain_assessment is None: + if self.prediction_setting.model: + if self.prediction_setting.model.app_domain: + app_domain_assessment = self.prediction_setting.model.app_domain.assess( + sub.smiles + ) + + if self.persist is not None: + n = self.snode_persist_lookup[sub] + + if n.id is None: + raise ValueError(f"Node {n} has no ID... aborting!") + + node_data = n.simple_json() + node_data["image"] = f"{n.url}?image=svg" + app_domain_assessment["assessment"]["node"] = node_data + + n.kv["app_domain_assessment"] = app_domain_assessment + n.save() + + sub.app_domain_assessment = app_domain_assessment + + candidates = self.prediction_setting.expand(self, sub) + # candidates is a List of PredictionResult. The length of the List is equal to the number of rules + for cand_set in candidates: + if cand_set: + # cand_set is a PredictionResult object that can consist of multiple candidate reactions + for cand in cand_set: + cand_nodes = [] + # candidate reactions can have multiple fragments + for c in cand: + if c not in self.smiles_to_node: + # For new nodes do an AppDomain Assessment if an AppDomain is attached + app_domain_assessment = None + if self.prediction_setting.model: + if self.prediction_setting.model.app_domain: + app_domain_assessment = ( + self.prediction_setting.model.app_domain.assess(c) + ) + snode = SNode(c, sub.depth + 1, app_domain_assessment) + self.smiles_to_node[c] = snode + new_nodes.append(snode) + + node = self.smiles_to_node[c] + cand_nodes.append(node) + + edge = SEdge( + sub, + cand_nodes, + rule=cand_set.rule, + probability=cand_set.probability, + ) + self.edges.add(edge) + new_edges.append(edge) + + return new_nodes, new_edges + + def predict(self): + """ + Predicts outcomes based on a graph traversal algorithm using the specified expansion schema. + + This method iteratively explores the nodes of a graph starting from the root nodes, propagating + probabilities through edges, and updating the probabilities of the connected nodes. The traversal + can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search + (BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable + nodes are processed systematically according to the specified schema. + + Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method + supports persisting changes by writing back data to the database when configured to do so. + + Attributes + ---------- + done : bool + A flag indicating whether the prediction process is completed. + persist : Any + An optional object that manages persistence operations for saving modifications. + root_nodes : List[SNode] + A collection of initial nodes in the graph from which traversal begins. + prediction_setting : Any + Configuration object specifying settings for graph traversal, such as the choice of + expansion schema. + + Raises + ------ + ValueError + If an invalid or unknown expansion schema is provided in `prediction_setting`. + """ + # populate initial queue + queue = list(self.root_nodes) + processed = set() + + # initial nodes have prob 1.0 + node_probs: Dict[SNode, float] = {} + node_probs.update({n: 1.0 for n in queue}) + + while queue: + current = queue.pop(0) + + if current in processed: + continue + + processed.add(current) + + new_nodes, new_edges = self._expand([current]) + + if new_nodes or new_edges: + # Check if we need to write back data to the database + if self.persist: + self._sync_to_pathway() + # call save to update the internal modified field + self.persist.save() + + if new_nodes: + for edge in new_edges: + # All edge have `current` as educt + # Use `current` and adjust probs + current_prob = node_probs[current] + + for prod in edge.products: + # Either is a new product or a product and we found a path with a higher prob + if ( + prod not in node_probs + or current_prob * edge.probability > node_probs[prod] + ): + node_probs[prod] = current_prob * edge.probability + + # Update Queue to proceed + if self.prediction_setting.expansion_scheme == "DFS": + for n in new_nodes: + if n not in processed: + # We want to follow this path -> prepend queue + queue.insert(0, n) + elif self.prediction_setting.expansion_scheme == "BFS": + for n in new_nodes: + if n not in processed: + # Add at the end, everything queued before will be processed + # before new_nodese + queue.append(n) + elif self.prediction_setting.expansion_scheme == "GREEDY": + # Simply add them, as we will re-order the queue later + for n in new_nodes: + if n not in processed: + queue.append(n) + + node_and_probs = [] + for queued_val in queue: + node_and_probs.append((queued_val, node_probs[queued_val])) + + from pprint import pprint + + pprint(node_and_probs) + + # re-order the queue and only pick smiles + queue = [ + n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True) + ] + else: + raise ValueError( + f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}" + ) + + # Queue exhausted, we're done + self.done = True + def predict_step(self, from_depth: int = None, from_node: "Node" = None): substrates: List[SNode] = [] @@ -1560,67 +1752,15 @@ class SPathway(object): if from_node == v: substrates = [k] break + else: + raise ValueError(f"Node {from_node} not found in SPathway!") else: raise ValueError("Neither from_depth nor from_node_url specified") new_tp = False if substrates: - for sub in substrates: - if sub.app_domain_assessment is None: - if self.prediction_setting.model: - if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess( - sub.smiles - ) - - if self.persist is not None: - n = self.snode_persist_lookup[sub] - - assert n.id is not None, ( - "Node has no id! Should have been saved already... aborting!" - ) - node_data = n.simple_json() - node_data["image"] = f"{n.url}?image=svg" - app_domain_assessment["assessment"]["node"] = node_data - - n.kv["app_domain_assessment"] = app_domain_assessment - n.save() - - sub.app_domain_assessment = app_domain_assessment - - candidates = self.prediction_setting.expand(self, sub) - # candidates is a List of PredictionResult. The length of the List is equal to the number of rules - for cand_set in candidates: - if cand_set: - new_tp = True - # cand_set is a PredictionResult object that can consist of multiple candidate reactions - for cand in cand_set: - cand_nodes = [] - # candidate reactions can have multiple fragments - for c in cand: - if c not in self.smiles_to_node: - # For new nodes do an AppDomain Assessment if an AppDomain is attached - app_domain_assessment = None - if self.prediction_setting.model: - if self.prediction_setting.model.app_domain: - app_domain_assessment = ( - self.prediction_setting.model.app_domain.assess(c) - ) - - self.smiles_to_node[c] = SNode( - c, sub.depth + 1, app_domain_assessment - ) - - node = self.smiles_to_node[c] - cand_nodes.append(node) - - edge = SEdge( - sub, - cand_nodes, - rule=cand_set.rule, - probability=cand_set.probability, - ) - self.edges.add(edge) + new_nodes, _ = self._expand(substrates) + new_tp = len(new_nodes) > 0 # In case no substrates are found, we're done. # For "predict from node" we're always done @@ -1704,11 +1844,6 @@ class SPathway(object): "to": to_indices, } - # if edge.rule: - # e['rule'] = { - # 'name': edge.rule.name, - # 'id': edge.rule.url, - # } edges.append(e) return { diff --git a/epdb/migrations/0013_setting_expansion_schema.py b/epdb/migrations/0013_setting_expansion_schema.py new file mode 100644 index 00000000..9a981795 --- /dev/null +++ b/epdb/migrations/0013_setting_expansion_schema.py @@ -0,0 +1,25 @@ +# Generated by Django 5.2.7 on 2025-12-14 11:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0012_node_stereo_removed_pathway_predicted"), + ] + + operations = [ + migrations.AddField( + model_name="setting", + name="expansion_schema", + field=models.CharField( + choices=[ + ("BFS", "Breadth First Search"), + ("DFS", "Depth First Search"), + ("GREEDY", "Greedy"), + ], + default="BFS", + max_length=20, + ), + ), + ] diff --git a/epdb/migrations/0014_rename_expansion_schema_setting_expansion_scheme.py b/epdb/migrations/0014_rename_expansion_schema_setting_expansion_scheme.py new file mode 100644 index 00000000..b7332fee --- /dev/null +++ b/epdb/migrations/0014_rename_expansion_schema_setting_expansion_scheme.py @@ -0,0 +1,17 @@ +# Generated by Django 5.2.7 on 2025-12-14 16:02 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0013_setting_expansion_schema"), + ] + + operations = [ + migrations.RenameField( + model_name="setting", + old_name="expansion_schema", + new_name="expansion_scheme", + ), + ] diff --git a/epdb/models.py b/epdb/models.py index ef8e9e8b..a2429c22 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): # potentially prefetched edge_set return self.edge_set.all() + @property + def setting_with_overrides(self): + mem_copy = Setting.objects.get(pk=self.setting.pk) + + if "setting_overrides" in self.kv: + for k, v in self.kv["setting_overrides"].items(): + setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)") + + return mem_copy + def _url(self): return "{}/pathway/{}".format(self.package.url, self.uuid) @@ -1879,25 +1889,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): return json.dumps(res) - def to_csv(self) -> str: + def to_csv(self, include_header=True, include_pathway_url=False) -> str: import csv import io + header = [] + + if include_pathway_url: + header += ["Pathway URL"] + + header += [ + "SMILES", + "name", + "depth", + "probability", + "rule_names", + "rule_ids", + "parent_smiles", + ] + rows = [] - rows.append( - [ - "SMILES", - "name", - "depth", - "probability", - "rule_names", - "rule_ids", - "parent_smiles", - ] - ) + + if include_header: + rows.append(header) + for n in self.nodes.order_by("depth"): cs = n.default_node_label - row = [cs.smiles, cs.name, n.depth] + row = [] + + if include_pathway_url: + row.append(n.pathway.url) + + row += [cs.smiles, cs.name, n.depth] edges = self.edges.filter(end_nodes__in=[n]) if len(edges): @@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel): return res + @property + def mg_pr_curve(self): + if self.model_status != self.FINISHED: + raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}") + + if not self.multigen_eval: + raise ValueError("MG PR Curve is only available for multigen models") + + res = [] + + thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys() + + for t in thresholds: + res.append( + { + "precision": self.eval_results["multigen_average_precision_per_threshold"][t], + "recall": self.eval_results["multigen_average_recall_per_threshold"][t], + "threshold": float(t), + } + ) + + return res + @cached_property def applicable_rules(self) -> List["Rule"]: """ @@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel): for i, root in enumerate(root_compounds): logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...") - spw = SPathway(root_nodes=root, prediction_setting=s) + spw = SPathway(root_nodes=root.smiles, prediction_setting=s) level = 0 while not spw.done: @@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission): return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}" +class ExpansionSchemeChoice(models.TextChoices): + BFS = "BFS", "Breadth First Search" + DFS = "DFS", "Depth First Search" + GREEDY = "GREEDY", "Greedy" + + class Setting(EnviPathModel): public = models.BooleanField(null=False, blank=False, default=False) global_default = models.BooleanField(null=False, blank=False, default=False) @@ -3795,6 +3847,12 @@ class Setting(EnviPathModel): null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25 ) + expansion_scheme = models.CharField( + max_length=20, + choices=ExpansionSchemeChoice.choices, + default=ExpansionSchemeChoice.BFS, + ) + def _url(self): return "{}/setting/{}".format(s.SERVER_URL, self.uuid) @@ -3833,7 +3891,7 @@ class Setting(EnviPathModel): """Decision Method whether to expand on a certain Node or not""" if pathway.num_nodes() >= self.max_nodes: logger.info( - f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}" + f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}" ) return [] @@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel): if self.job_name == "engineer_pathways": return ast.literal_eval(self.task_result) return self.task_result + + def is_result_downloadable(self): + downloadable = ["batch_predict"] + + return self.job_name in downloadable diff --git a/epdb/tasks.py b/epdb/tasks.py index 4284d02c..ebffb5b7 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -11,6 +11,7 @@ from django.utils import timezone from epdb.logic import SPathway from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User +from utilities.chem import FormatConverter logger = logging.getLogger(__name__) ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. @@ -139,14 +140,25 @@ def predict( pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None, + setting_overrides: Optional[dict] = None, ) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) + + if setting_overrides: + for k, v in setting_overrides.items(): + setattr(setting, k, v) + # If the setting has a model add/restore it from the cache if setting.model is not None: setting.model = get_ml_model(setting.model.pk) - pw.kv.update(**{"status": "running"}) + kv = {"status": "running"} + + if setting_overrides: + kv["setting_overrides"] = setting_overrides + + pw.kv.update(**kv) pw.save() if JobLog.objects.filter(task_id=self.request.id).exists(): @@ -171,7 +183,8 @@ def predict( spw = SPathway.from_pathway(pw) spw.predict_step(from_node=n) else: - raise ValueError("Neither limit nor node_pk given!") + spw = SPathway(prediction_setting=setting, persist=pw) + spw.predict() except Exception as e: pw.kv.update({"status": "failed"}) @@ -353,3 +366,76 @@ def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_p predicted_pathways.append(pred_pw.url) return intermediate_pathways, predicted_pathways + + +@shared_task(bind=True, queue="background") +def batch_predict( + self, + substrates: List[str] | List[List[str]], + prediction_setting_pk: int, + target_package_pk: int, + num_tps: int = 50, +): + target_package = Package.objects.get(pk=target_package_pk) + prediction_setting = Setting.objects.get(pk=prediction_setting_pk) + + if len(substrates) == 0: + raise ValueError("No substrates given!") + + is_pair = isinstance(substrates[0], list) + + substrate_and_names = [] + if not is_pair: + for sub in substrates: + substrate_and_names.append([sub, None]) + else: + substrate_and_names = substrates + + # Check prerequisite that we can standardize all substrates + standardized_substrates_and_smiles = [] + for substrate in substrate_and_names: + try: + stand_smiles = FormatConverter.standardize(substrate[0]) + standardized_substrates_and_smiles.append([stand_smiles, substrate[1]]) + except ValueError: + raise ValueError( + f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!' + ) + + pathways = [] + + for pair in standardized_substrates_and_smiles: + pw = Pathway.create( + target_package, + pair[0], + name=pair[1], + predicted=True, + ) + + # set mode and setting + pw.setting = prediction_setting + pw.kv.update({"mode": "predict"}) + pw.save() + + predict( + pw.pk, + prediction_setting.pk, + limit=None, + setting_overrides={ + "max_nodes": num_tps, + "max_depth": num_tps, + "model_threshold": 0.001, + }, + ) + + pathways.append(pw) + + buffer = io.StringIO() + + for idx, pw in enumerate(pathways): + # Carry out header only for the first pathway + buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True)) + + buffer.seek(0) + + return buffer.getvalue() diff --git a/epdb/urls.py b/epdb/urls.py index f5144847..5ec9deb8 100644 --- a/epdb/urls.py +++ b/epdb/urls.py @@ -49,6 +49,7 @@ urlpatterns = [ re_path(r"^group$", v.groups, name="groups"), re_path(r"^search$", v.search, name="search"), re_path(r"^predict$", v.predict_pathway, name="predict_pathway"), + re_path(r"^batch-predict$", v.batch_predict_pathway, name="batch_predict_pathway"), # User Detail re_path(rf"^user/(?P{UUID})", v.user, name="user"), # Group Detail diff --git a/epdb/views.py b/epdb/views.py index fe7364b3..917ae6da 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -1,11 +1,12 @@ import json import logging from typing import Any, Dict, List +from datetime import datetime import nh3 from django.conf import settings as s from django.contrib.auth import get_user_model -from django.core.exceptions import BadRequest +from django.core.exceptions import BadRequest, PermissionDenied from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse from django.shortcuts import redirect, render from django.urls import reverse @@ -50,6 +51,7 @@ from .models import ( SimpleAmbitRule, User, UserPackagePermission, + ExpansionSchemeChoice, ) logger = logging.getLogger(__name__) @@ -438,6 +440,18 @@ def predict_pathway(request): return render(request, "predict_pathway.html", context) +def batch_predict_pathway(request): + """Top-level predict pathway view using user's default package.""" + if request.method != "GET": + return HttpResponseNotAllowed(["GET"]) + + context = get_base_context(request) + context["title"] = "enviPath - Batch Predict Pathway" + context["meta"]["current_package"] = context["meta"]["user"].default_package + + return render(request, "batch_predict_pathway.html", context) + + @package_permission_required() def package_predict_pathway(request, package_uuid): """Package-specific predict pathway view.""" @@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid): if pw_mode == "predict" or pw_mode == "incremental": # unlimited pred (will be handled by setting) - limit = -1 + limit = None # For incremental predict first level and return if pw_mode == "incremental": @@ -2877,15 +2891,25 @@ def settings(request): ) if not PackageManager.readable(current_user, params["model"].package): - raise ValueError("") + raise PermissionDenied("You're not allowed to access this model!") + + expansion_scheme = request.POST.get( + "model-based-prediction-setting-expansion-scheme", "BFS" + ) + + if expansion_scheme not in ExpansionSchemeChoice.values: + raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}") + + params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme) elif tp_gen_method == "rule-based-prediction-setting": rule_packages = request.POST.getlist("rule-based-prediction-setting-packages") params["rule_packages"] = [ PackageManager.get_package_by_url(current_user, p) for p in rule_packages ] + else: - raise ValueError("") + raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!") created_setting = SettingManager.create_setting( current_user, @@ -2963,6 +2987,59 @@ def jobs(request): ) return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}") + + elif job_name == "batch-predict": + substrates = request.POST.get("substrates") + prediction_setting_url = request.POST.get("prediction-setting") + num_tps = request.POST.get("num-tps") + + if substrates is None or substrates.strip() == "": + raise BadRequest("No substrates provided.") + + pred_data = [] + for pair in substrates.split("\n"): + parts = pair.split(",") + + try: + smiles = FormatConverter.standardize(parts[0]) + except ValueError: + raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!") + + # name is optional + name = parts[1] if len(parts) > 1 else None + pred_data.append([smiles, name]) + + max_tps = 50 + if num_tps is not None and num_tps.strip() != "": + try: + num_tps = int(num_tps) + max_tps = max(min(num_tps, 50), 1) + except ValueError: + raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.") + + batch_predict_setting = SettingManager.get_setting_by_url( + current_user, prediction_setting_url + ) + + target_package = PackageManager.create_package( + current_user, + f"Autogenerated Package for Batch Prediction {datetime.now()}", + "This Package was generated automatically for the batch prediction task.", + ) + + from .tasks import dispatch, batch_predict + + res = dispatch( + current_user, + batch_predict, + pred_data, + batch_predict_setting.pk, + target_package.pk, + num_tps=max_tps, + ) + + return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}") + else: raise BadRequest(f"Job {job_name} is not supported!") else: @@ -2982,6 +3059,21 @@ def job(request, job_uuid): # No op if status is already in a terminal state job.check_for_update() + if request.GET.get("download", False) == "true": + if not job.is_result_downloadable(): + raise BadRequest("Result is not downloadable!") + + if job.job_name == "batch_predict": + filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv" + else: + raise BadRequest("Result is not downloadable!") + + res_str = job.task_result + response = HttpResponse(res_str, content_type="text/csv") + response["Content-Disposition"] = f'attachment; filename="{filename}"' + + return response + context["object_type"] = "joblog" context["breadcrumbs"] = [ {"Home": s.SERVER_URL}, diff --git a/templates/actions/objects/joblog.html b/templates/actions/objects/joblog.html index e69de29b..e3787ed7 100644 --- a/templates/actions/objects/joblog.html +++ b/templates/actions/objects/joblog.html @@ -0,0 +1,10 @@ +{% if job.is_result_downloadable %} +
  • + + Download Result +
  • +{% endif %} diff --git a/templates/batch_predict_pathway.html b/templates/batch_predict_pathway.html new file mode 100644 index 00000000..b9e5412c --- /dev/null +++ b/templates/batch_predict_pathway.html @@ -0,0 +1,168 @@ +{% extends "framework_modern.html" %} +{% load static %} +{% block content %} +
    +

    Batch Predict Pathways

    +
    + {% csrf_token %} + + + +
    + + + + + + + + + + + + + + + + + +
    SMILESName
    + + + +
    + + + +
    + + +
    + + +
    +
    +
    +
    + + +{% endblock content %} diff --git a/templates/modals/collections/new_prediction_setting_modal.html b/templates/modals/collections/new_prediction_setting_modal.html index 156bac2a..41a1dcaa 100644 --- a/templates/modals/collections/new_prediction_setting_modal.html +++ b/templates/modals/collections/new_prediction_setting_modal.html @@ -210,6 +210,27 @@ step="0.05" /> + +
    + + +
    diff --git a/templates/modals/objects/download_job_result_modal.html b/templates/modals/objects/download_job_result_modal.html new file mode 100644 index 00000000..e2a222ff --- /dev/null +++ b/templates/modals/objects/download_job_result_modal.html @@ -0,0 +1,66 @@ +{% load static %} + + + + + + + diff --git a/templates/objects/joblog.html b/templates/objects/joblog.html index 5c36c33d..b4dcf7f5 100644 --- a/templates/objects/joblog.html +++ b/templates/objects/joblog.html @@ -3,7 +3,9 @@ {% block content %} {% block action_modals %} - {# {% include "modals/objects/refresh_job_log.html" %}#} + {% if job.is_result_downloadable %} + {% include "modals/objects/download_job_result_modal.html" %} + {% endif %} {% endblock action_modals %}
    @@ -49,22 +51,20 @@
    Description
    -
    - Status page for Task {{ job.job_name }} -
    +
    Status page for Job {{ job.job_name }}
    -
    Task Status
    +
    Job Status
    {{ job.status }}
    -
    Task ID
    +
    Job ID
    {{ job.task_id }}
    @@ -72,7 +72,7 @@ {% if job.is_in_terminal_state %}
    -
    Task Result
    +
    Job Result
    {% if job.job_name == 'engineer_pathways' %}
    @@ -103,6 +103,68 @@
    + {% elif job.job_name == 'batch_predict' %} +
    + + {% else %} {{ job.parsed_result }} {% endif %} diff --git a/templates/objects/model.html b/templates/objects/model.html index 84524ac8..9cba8eb9 100644 --- a/templates/objects/model.html +++ b/templates/objects/model.html @@ -73,13 +73,29 @@
    - + {% endif %} + +
    + +
    Reaction Packages
    +
    + +
    +
    + {% if model.eval_packages.all|length > 0 %} +
    -
    Reaction Packages
    +
    Eval Packages
    - {% if model.eval_packages.all|length > 0 %} - -
    - -
    Eval Packages
    -
    - -
    -
    - {% endif %} - -
    - -
    Model Status
    -
    {{ model.status }}
    -
    {% endif %} + +
    + +
    Model Status
    +
    {{ model.status }}
    +
    {% if model.ready_for_prediction %} @@ -174,7 +172,6 @@
    {% endif %} - {% if model.model_status == 'FINISHED' %}
    @@ -188,6 +185,19 @@
    + {% if model.multigen_eval %} +
    + +
    + Multi Gen Precision Recall Curve +
    +
    +
    +
    +
    +
    +
    + {% endif %} {% endif %} @@ -244,6 +254,105 @@ } } + function makeChart(selector, data) { + const x = ['Recall']; + const y = ['Precision']; + const thres = ['threshold']; + + function compare(a, b) { + if (a.threshold < b.threshold) + return -1; + else if (a.threshold > b.threshold) + return 1; + else + return 0; + } + + function getIndexForValue(data, val, val_name) { + for (const idx in data) { + if (data[idx][val_name] == val) { + return idx; + } + } + return -1; + } + + if (!data || data.length === 0) { + console.warn('PR curve data is empty'); + return; + } + const dataLength = data.length; + data.sort(compare); + + for (const idx in data) { + const d = data[idx]; + x.push(d.recall); + y.push(d.precision); + thres.push(d.threshold); + } + const chart = c3.generate({ + bindto: selector, + data: { + onclick: function (d, e) { + const idx = d.index; + const thresh = data[dataLength - idx - 1].threshold; + }, + x: 'Recall', + y: 'Precision', + columns: [ + x, + y, + ] + }, + size: { + height: 400, + width: 480 + }, + axis: { + x: { + max: 1, + min: 0, + label: 'Recall', + padding: 0, + tick: { + fit: true, + values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + } + }, + y: { + max: 1, + min: 0, + label: 'Precision', + padding: 0, + tick: { + fit: true, + values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + } + } + }, + point: { + r: 4 + }, + tooltip: { + format: { + title: function (recall) { + const idx = getIndexForValue(data, recall, "recall"); + if (idx != -1) { + return "Threshold: " + data[idx].threshold; + } + return ""; + }, + value: function (precision, ratio, id) { + return undefined; + } + } + }, + zoom: { + enabled: true + } + }); + } + function makeLoadingGif(selector, gifPath) { const element = document.querySelector(selector); if (element) { @@ -260,107 +369,12 @@ } {% if model.model_status == 'FINISHED' %} - // Precision Recall Curve - const sgChart = document.getElementById('sg-chart'); - if (sgChart) { - const x = ['Recall']; - const y = ['Precision']; - const thres = ['threshold']; - - function compare(a, b) { - if (a.threshold < b.threshold) - return -1; - else if (a.threshold > b.threshold) - return 1; - else - return 0; - } - - function getIndexForValue(data, val, val_name) { - for (const idx in data) { - if (data[idx][val_name] == val) { - return idx; - } - } - return -1; - } - - var data = {{ model.pr_curve|safe }}; - if (!data || data.length === 0) { - console.warn('PR curve data is empty'); - return; - } - const dataLength = data.length; - data.sort(compare); - - for (const idx in data) { - const d = data[idx]; - x.push(d.recall); - y.push(d.precision); - thres.push(d.threshold); - } - const chart = c3.generate({ - bindto: '#sg-chart', - data: { - onclick: function (d, e) { - const idx = d.index; - const thresh = data[dataLength - idx - 1].threshold; - }, - x: 'Recall', - y: 'Precision', - columns: [ - x, - y, - ] - }, - size: { - height: 400, - width: 480 - }, - axis: { - x: { - max: 1, - min: 0, - label: 'Recall', - padding: 0, - tick: { - fit: true, - values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - } - }, - y: { - max: 1, - min: 0, - label: 'Precision', - padding: 0, - tick: { - fit: true, - values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - } - } - }, - point: { - r: 4 - }, - tooltip: { - format: { - title: function (recall) { - const idx = getIndexForValue(data, recall, "recall"); - if (idx != -1) { - return "Threshold: " + data[idx].threshold; - } - return ""; - }, - value: function (precision, ratio, id) { - return undefined; - } - } - }, - zoom: { - enabled: true - } - }); - } + // Precision Recall Curve + makeChart('#sg-chart', {{ model.pr_curve|safe }}); + {% if model.multigen_eval %} + // Multi Gen Precision Recall Curve + makeChart('#mg-chart', {{ model.mg_pr_curve|safe }}); + {% endif %} {% endif %} // Predict button handler diff --git a/templates/objects/pathway.html b/templates/objects/pathway.html index b0ef914b..f48e2ed7 100644 --- a/templates/objects/pathway.html +++ b/templates/objects/pathway.html @@ -393,7 +393,9 @@ Threshold - {{ pathway.setting.model_threshold }} + + {{ pathway.setting_with_overrides.model_threshold }} + @@ -420,11 +422,15 @@ {% endif %} Max Nodes - {{ pathway.setting.max_nodes }} + {{ pathway.setting_with_overrides.max_nodes }} Max Depth - {{ pathway.setting.max_depth }} + {{ pathway.setting_with_overrides.max_depth }} + + + Expansion Scheme + {{ user.default_setting.expansion_scheme }} diff --git a/templates/objects/user.html b/templates/objects/user.html index cc37f118..5c0cb59c 100644 --- a/templates/objects/user.html +++ b/templates/objects/user.html @@ -150,6 +150,10 @@ Max Depth {{ user.default_setting.max_depth }} + + Expansion Scheme + {{ user.default_setting.expansion_scheme }} + diff --git a/tests/test_jobs.py b/tests/test_jobs.py index 1ad23c83..08f0a4eb 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -1,3 +1,5 @@ +from datetime import datetime + from django.conf import settings as s from django.test import TestCase, override_settings @@ -23,6 +25,46 @@ class MultiGenTest(TestCase): cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)" cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine" + def test_batch_predict(self): + from epdb.tasks import batch_predict + + pred_data = [ + ["CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"], + ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "Ibuprofen"], + ] + + batch_predict_setting = self.user.prediction_settings() + + target_package = PackageManager.create_package( + self.user, + f"Autogenerated Package for Batch Prediction {datetime.now()}", + "This Package was generated automatically for the batch prediction task.", + ) + + num_tps = 50 + res = batch_predict( + pred_data, + batch_predict_setting.pk, + target_package.pk, + num_tps=num_tps, + ) + + self.assertTrue(res.startswith("Pathway URL,")) + # Min 3 lines (1 header, 2 root nodes) + self.assertGreaterEqual(len(res.split("\n")), 3) + self.assertEqual(target_package.pathways.count(), 2) + + pw = target_package.pathways.first() + + self.assertEqual( + pw.setting_with_overrides.max_depth, + f"{num_tps} (this is an override for this particular pathway)", + ) + self.assertEqual( + pw.setting_with_overrides.max_nodes, + f"{num_tps} (this is an override for this particular pathway)", + ) + def test_engineer_pathway(self): from epdb.tasks import engineer_pathways diff --git a/tests/test_sobjects.py b/tests/test_sobjects.py index 0212466e..95441e39 100644 --- a/tests/test_sobjects.py +++ b/tests/test_sobjects.py @@ -1,20 +1,20 @@ +from unittest.mock import Mock, patch + from django.test import TestCase -from epdb.logic import SNode, SEdge +from epdb.logic import SEdge, SNode, SPathway +from epdb.models import Pathway, Setting +from utilities.chem import PredictionResult, ProductSet -class SObjectTest(TestCase): - def setUp(self): - pass - +class SNodeTest(TestCase): def test_snode_eq(self): snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0) snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0) - assert snode1 == snode2 + self.assertEqual(snode1, snode2) - def test_snode_hash(self): - pass +class SEdgeTest(TestCase): def test_sedge_eq(self): sedge1 = SEdge( [SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)], @@ -26,4 +26,62 @@ class SObjectTest(TestCase): [SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)], rule=None, ) - assert sedge1 == sedge2 + self.assertEqual(sedge1, sedge2) + + +class SPathwayTest(TestCase): + def setUp(self): + """Set up test data for SPathway tests.""" + self.test_smiles = "CCN(CC)C(=O)C1=CC(=CC=C1)CO" + self.mock_setting = Mock(spec=Setting) + self.mock_pathway = Mock(spec=Pathway) + + def test_predict_step_basic(self): + """Test basic predict_step functionality.""" + spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting) + + # e.g. bt0002 + pr = PredictionResult( + [ + ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]), + ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]), + ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]), + ], + 0.17, + None, + ) + + with patch.object(self.mock_setting, "expand", return_value=[pr]): + spw.predict_step(from_depth=0) + + self.assertEqual(len(spw.smiles_to_node.keys()), 4) + self.assertEqual(len(spw.edges), 3) + + def test_to_json(self): + """Test basic predict_step functionality.""" + spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting) + + # e.g. bt0002 + pr = PredictionResult( + [ + ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]), + ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]), + ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]), + ], + 0.17, + None, + ) + + with patch.object(self.mock_setting, "expand", return_value=[pr]): + spw.predict_step(from_depth=0) + + self.assertEqual(len(spw.smiles_to_node.keys()), 4) + self.assertEqual(len(spw.edges), 3) + + json_result = spw.to_json() + + self.assertIsInstance(json_result, dict) + self.assertIn("nodes", json_result) + self.assertIn("edges", json_result) + self.assertEqual(len(json_result["nodes"]), 4) + self.assertEqual(len(json_result["edges"]), 3)