[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1451,7 +1451,7 @@ def create_pathway(
from .tasks import dispatch, predict
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=-1)
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=None)
return redirect(new_pw.url)
except ValueError as e:

View File

@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union, Tuple
from uuid import UUID
import nh3
@ -16,6 +16,7 @@ from epdb.models import (
Edge,
EnzymeLink,
EPModel,
ExpansionSchemeChoice,
Group,
GroupPackagePermission,
Node,
@ -1116,6 +1117,7 @@ class SettingManager(object):
rule_packages: List[Package] = None,
model: EPModel = None,
model_threshold: float = None,
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
):
new_s = Setting()
# Clean for potential XSS
@ -1550,22 +1552,32 @@ class SPathway(object):
return sorted(res, key=lambda x: hash(x))
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
substrates: List[SNode] = []
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
"""
Expands the given substrates by generating new nodes and edges based on prediction settings.
if from_depth is not None:
substrates = self._get_nodes_for_depth(from_depth)
elif from_node is not None:
for k, v in self.snode_persist_lookup.items():
if from_node == v:
substrates = [k]
break
else:
raise ValueError("Neither from_depth nor from_node_url specified")
This method processes a list of substrates and expands them into new nodes and edges using defined
rules and settings. It evaluates each substrate to determine its applicability domain, persists
domain assessments, and generates candidates for further processing. Newly created nodes and edges
are returned, and any applicable information is stored or updated internally during the process.
Parameters:
substrates (List[SNode]): A list of substrate nodes to be expanded.
Returns:
Tuple[List[SNode], List[SEdge]]:
A tuple containing:
- A list of new nodes generated during the expansion.
- A list of new edges representing connections between nodes based on candidate reactions.
Raises:
ValueError: If a node does not have an ID when it should have been saved already.
"""
new_nodes: List[SNode] = []
new_edges: List[SEdge] = []
new_tp = False
if substrates:
for sub in substrates:
# For App Domain we have to ensure that each Node is evaluated
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
@ -1576,9 +1588,9 @@ class SPathway(object):
if self.persist is not None:
n = self.snode_persist_lookup[sub]
assert n.id is not None, (
"Node has no id! Should have been saved already... aborting!"
)
if n.id is None:
raise ValueError(f"Node {n} has no ID... aborting!")
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
@ -1592,7 +1604,6 @@ class SPathway(object):
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
new_tp = True
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
@ -1606,10 +1617,9 @@ class SPathway(object):
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment
)
snode = SNode(c, sub.depth + 1, app_domain_assessment)
self.smiles_to_node[c] = snode
new_nodes.append(snode)
node = self.smiles_to_node[c]
cand_nodes.append(node)
@ -1621,6 +1631,136 @@ class SPathway(object):
probability=cand_set.probability,
)
self.edges.add(edge)
new_edges.append(edge)
return new_nodes, new_edges
def predict(self):
"""
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
nodes are processed systematically according to the specified schema.
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
supports persisting changes by writing back data to the database when configured to do so.
Attributes
----------
done : bool
A flag indicating whether the prediction process is completed.
persist : Any
An optional object that manages persistence operations for saving modifications.
root_nodes : List[SNode]
A collection of initial nodes in the graph from which traversal begins.
prediction_setting : Any
Configuration object specifying settings for graph traversal, such as the choice of
expansion schema.
Raises
------
ValueError
If an invalid or unknown expansion schema is provided in `prediction_setting`.
"""
# populate initial queue
queue = list(self.root_nodes)
processed = set()
# initial nodes have prob 1.0
node_probs: Dict[SNode, float] = {}
node_probs.update({n: 1.0 for n in queue})
while queue:
current = queue.pop(0)
if current in processed:
continue
processed.add(current)
new_nodes, new_edges = self._expand([current])
if new_nodes or new_edges:
# Check if we need to write back data to the database
if self.persist:
self._sync_to_pathway()
# call save to update the internal modified field
self.persist.save()
if new_nodes:
for edge in new_edges:
# All edge have `current` as educt
# Use `current` and adjust probs
current_prob = node_probs[current]
for prod in edge.products:
# Either is a new product or a product and we found a path with a higher prob
if (
prod not in node_probs
or current_prob * edge.probability > node_probs[prod]
):
node_probs[prod] = current_prob * edge.probability
# Update Queue to proceed
if self.prediction_setting.expansion_scheme == "DFS":
for n in new_nodes:
if n not in processed:
# We want to follow this path -> prepend queue
queue.insert(0, n)
elif self.prediction_setting.expansion_scheme == "BFS":
for n in new_nodes:
if n not in processed:
# Add at the end, everything queued before will be processed
# before new_nodese
queue.append(n)
elif self.prediction_setting.expansion_scheme == "GREEDY":
# Simply add them, as we will re-order the queue later
for n in new_nodes:
if n not in processed:
queue.append(n)
node_and_probs = []
for queued_val in queue:
node_and_probs.append((queued_val, node_probs[queued_val]))
from pprint import pprint
pprint(node_and_probs)
# re-order the queue and only pick smiles
queue = [
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
]
else:
raise ValueError(
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
)
# Queue exhausted, we're done
self.done = True
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
substrates: List[SNode] = []
if from_depth is not None:
substrates = self._get_nodes_for_depth(from_depth)
elif from_node is not None:
for k, v in self.snode_persist_lookup.items():
if from_node == v:
substrates = [k]
break
else:
raise ValueError(f"Node {from_node} not found in SPathway!")
else:
raise ValueError("Neither from_depth nor from_node_url specified")
new_tp = False
if substrates:
new_nodes, _ = self._expand(substrates)
new_tp = len(new_nodes) > 0
# In case no substrates are found, we're done.
# For "predict from node" we're always done
@ -1704,11 +1844,6 @@ class SPathway(object):
"to": to_indices,
}
# if edge.rule:
# e['rule'] = {
# 'name': edge.rule.name,
# 'id': edge.rule.url,
# }
edges.append(e)
return {

View File

@ -0,0 +1,25 @@
# Generated by Django 5.2.7 on 2025-12-14 11:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0012_node_stereo_removed_pathway_predicted"),
]
operations = [
migrations.AddField(
model_name="setting",
name="expansion_schema",
field=models.CharField(
choices=[
("BFS", "Breadth First Search"),
("DFS", "Depth First Search"),
("GREEDY", "Greedy"),
],
default="BFS",
max_length=20,
),
),
]

View File

@ -0,0 +1,17 @@
# Generated by Django 5.2.7 on 2025-12-14 16:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("epdb", "0013_setting_expansion_schema"),
]
operations = [
migrations.RenameField(
model_name="setting",
old_name="expansion_schema",
new_name="expansion_scheme",
),
]

View File

@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# potentially prefetched edge_set
return self.edge_set.all()
@property
def setting_with_overrides(self):
mem_copy = Setting.objects.get(pk=self.setting.pk)
if "setting_overrides" in self.kv:
for k, v in self.kv["setting_overrides"].items():
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
return mem_copy
def _url(self):
return "{}/pathway/{}".format(self.package.url, self.uuid)
@ -1879,13 +1889,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
return json.dumps(res)
def to_csv(self) -> str:
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
import csv
import io
rows = []
rows.append(
[
header = []
if include_pathway_url:
header += ["Pathway URL"]
header += [
"SMILES",
"name",
"depth",
@ -1894,10 +1907,20 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
"rule_ids",
"parent_smiles",
]
)
rows = []
if include_header:
rows.append(header)
for n in self.nodes.order_by("depth"):
cs = n.default_node_label
row = [cs.smiles, cs.name, n.depth]
row = []
if include_pathway_url:
row.append(n.pathway.url)
row += [cs.smiles, cs.name, n.depth]
edges = self.edges.filter(end_nodes__in=[n])
if len(edges):
@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel):
return res
@property
def mg_pr_curve(self):
if self.model_status != self.FINISHED:
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
if not self.multigen_eval:
raise ValueError("MG PR Curve is only available for multigen models")
res = []
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
for t in thresholds:
res.append(
{
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
"threshold": float(t),
}
)
return res
@cached_property
def applicable_rules(self) -> List["Rule"]:
"""
@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel):
for i, root in enumerate(root_compounds):
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
spw = SPathway(root_nodes=root, prediction_setting=s)
spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
level = 0
while not spw.done:
@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission):
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
class ExpansionSchemeChoice(models.TextChoices):
BFS = "BFS", "Breadth First Search"
DFS = "DFS", "Depth First Search"
GREEDY = "GREEDY", "Greedy"
class Setting(EnviPathModel):
public = models.BooleanField(null=False, blank=False, default=False)
global_default = models.BooleanField(null=False, blank=False, default=False)
@ -3795,6 +3847,12 @@ class Setting(EnviPathModel):
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
)
expansion_scheme = models.CharField(
max_length=20,
choices=ExpansionSchemeChoice.choices,
default=ExpansionSchemeChoice.BFS,
)
def _url(self):
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
@ -3833,7 +3891,7 @@ class Setting(EnviPathModel):
"""Decision Method whether to expand on a certain Node or not"""
if pathway.num_nodes() >= self.max_nodes:
logger.info(
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}"
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
)
return []
@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel):
if self.job_name == "engineer_pathways":
return ast.literal_eval(self.task_result)
return self.task_result
def is_result_downloadable(self):
downloadable = ["batch_predict"]
return self.job_name in downloadable

View File

@ -11,6 +11,7 @@ from django.utils import timezone
from epdb.logic import SPathway
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
@ -139,14 +140,25 @@ def predict(
pred_setting_pk: int,
limit: Optional[int] = None,
node_pk: Optional[int] = None,
setting_overrides: Optional[dict] = None,
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
if setting_overrides:
for k, v in setting_overrides.items():
setattr(setting, k, v)
# If the setting has a model add/restore it from the cache
if setting.model is not None:
setting.model = get_ml_model(setting.model.pk)
pw.kv.update(**{"status": "running"})
kv = {"status": "running"}
if setting_overrides:
kv["setting_overrides"] = setting_overrides
pw.kv.update(**kv)
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
@ -171,7 +183,8 @@ def predict(
spw = SPathway.from_pathway(pw)
spw.predict_step(from_node=n)
else:
raise ValueError("Neither limit nor node_pk given!")
spw = SPathway(prediction_setting=setting, persist=pw)
spw.predict()
except Exception as e:
pw.kv.update({"status": "failed"})
@ -353,3 +366,76 @@ def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_p
predicted_pathways.append(pred_pw.url)
return intermediate_pathways, predicted_pathways
@shared_task(bind=True, queue="background")
def batch_predict(
self,
substrates: List[str] | List[List[str]],
prediction_setting_pk: int,
target_package_pk: int,
num_tps: int = 50,
):
target_package = Package.objects.get(pk=target_package_pk)
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
if len(substrates) == 0:
raise ValueError("No substrates given!")
is_pair = isinstance(substrates[0], list)
substrate_and_names = []
if not is_pair:
for sub in substrates:
substrate_and_names.append([sub, None])
else:
substrate_and_names = substrates
# Check prerequisite that we can standardize all substrates
standardized_substrates_and_smiles = []
for substrate in substrate_and_names:
try:
stand_smiles = FormatConverter.standardize(substrate[0])
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
except ValueError:
raise ValueError(
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
)
pathways = []
for pair in standardized_substrates_and_smiles:
pw = Pathway.create(
target_package,
pair[0],
name=pair[1],
predicted=True,
)
# set mode and setting
pw.setting = prediction_setting
pw.kv.update({"mode": "predict"})
pw.save()
predict(
pw.pk,
prediction_setting.pk,
limit=None,
setting_overrides={
"max_nodes": num_tps,
"max_depth": num_tps,
"model_threshold": 0.001,
},
)
pathways.append(pw)
buffer = io.StringIO()
for idx, pw in enumerate(pathways):
# Carry out header only for the first pathway
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
buffer.seek(0)
return buffer.getvalue()

View File

@ -49,6 +49,7 @@ urlpatterns = [
re_path(r"^group$", v.groups, name="groups"),
re_path(r"^search$", v.search, name="search"),
re_path(r"^predict$", v.predict_pathway, name="predict_pathway"),
re_path(r"^batch-predict$", v.batch_predict_pathway, name="batch_predict_pathway"),
# User Detail
re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
# Group Detail

View File

@ -1,11 +1,12 @@
import json
import logging
from typing import Any, Dict, List
from datetime import datetime
import nh3
from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest
from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render
from django.urls import reverse
@ -50,6 +51,7 @@ from .models import (
SimpleAmbitRule,
User,
UserPackagePermission,
ExpansionSchemeChoice,
)
logger = logging.getLogger(__name__)
@ -438,6 +440,18 @@ def predict_pathway(request):
return render(request, "predict_pathway.html", context)
def batch_predict_pathway(request):
"""Top-level predict pathway view using user's default package."""
if request.method != "GET":
return HttpResponseNotAllowed(["GET"])
context = get_base_context(request)
context["title"] = "enviPath - Batch Predict Pathway"
context["meta"]["current_package"] = context["meta"]["user"].default_package
return render(request, "batch_predict_pathway.html", context)
@package_permission_required()
def package_predict_pathway(request, package_uuid):
"""Package-specific predict pathway view."""
@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid):
if pw_mode == "predict" or pw_mode == "incremental":
# unlimited pred (will be handled by setting)
limit = -1
limit = None
# For incremental predict first level and return
if pw_mode == "incremental":
@ -2877,15 +2891,25 @@ def settings(request):
)
if not PackageManager.readable(current_user, params["model"].package):
raise ValueError("")
raise PermissionDenied("You're not allowed to access this model!")
expansion_scheme = request.POST.get(
"model-based-prediction-setting-expansion-scheme", "BFS"
)
if expansion_scheme not in ExpansionSchemeChoice.values:
raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}")
params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme)
elif tp_gen_method == "rule-based-prediction-setting":
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
else:
raise ValueError("")
raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!")
created_setting = SettingManager.create_setting(
current_user,
@ -2963,6 +2987,59 @@ def jobs(request):
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
elif job_name == "batch-predict":
substrates = request.POST.get("substrates")
prediction_setting_url = request.POST.get("prediction-setting")
num_tps = request.POST.get("num-tps")
if substrates is None or substrates.strip() == "":
raise BadRequest("No substrates provided.")
pred_data = []
for pair in substrates.split("\n"):
parts = pair.split(",")
try:
smiles = FormatConverter.standardize(parts[0])
except ValueError:
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
# name is optional
name = parts[1] if len(parts) > 1 else None
pred_data.append([smiles, name])
max_tps = 50
if num_tps is not None and num_tps.strip() != "":
try:
num_tps = int(num_tps)
max_tps = max(min(num_tps, 50), 1)
except ValueError:
raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.")
batch_predict_setting = SettingManager.get_setting_by_url(
current_user, prediction_setting_url
)
target_package = PackageManager.create_package(
current_user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
from .tasks import dispatch, batch_predict
res = dispatch(
current_user,
batch_predict,
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=max_tps,
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
else:
raise BadRequest(f"Job {job_name} is not supported!")
else:
@ -2982,6 +3059,21 @@ def job(request, job_uuid):
# No op if status is already in a terminal state
job.check_for_update()
if request.GET.get("download", False) == "true":
if not job.is_result_downloadable():
raise BadRequest("Result is not downloadable!")
if job.job_name == "batch_predict":
filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv"
else:
raise BadRequest("Result is not downloadable!")
res_str = job.task_result
response = HttpResponse(res_str, content_type="text/csv")
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
context["object_type"] = "joblog"
context["breadcrumbs"] = [
{"Home": s.SERVER_URL},

View File

@ -0,0 +1,10 @@
{% if job.is_result_downloadable %}
<li>
<a
class="button"
onclick="document.getElementById('download_job_result_modal').showModal(); return false;"
>
<i class="glyphicon glyphicon-floppy-save"></i> Download Result</a
>
</li>
{% endif %}

View File

@ -0,0 +1,168 @@
{% extends "framework_modern.html" %}
{% load static %}
{% block content %}
<div class="mx-auto w-full p-8">
<h1 class="h1 mb-4 text-3xl font-bold">Batch Predict Pathways</h1>
<form id="smiles-form" method="POST" action="{% url "jobs" %}">
{% csrf_token %}
<input type="hidden" name="substrates" id="substrates" />
<input type="hidden" name="job-name" value="batch-predict" />
<fieldset class="flex flex-col gap-4 md:flex-3/4">
<table class="table table-zebra w-full">
<thead>
<tr>
<th>SMILES</th>
<th>Name</th>
</tr>
</thead>
<tbody id="smiles-table-body">
<tr>
<td>
<label>
<input
type="text"
class="input input-bordered w-full smiles-input"
placeholder="CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
{% if meta.debug %}
value="CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
{% endif %}
/>
</label>
</td>
<td>
<label>
<input
type="text"
class="input input-bordered w-full name-input"
placeholder="Caffeine"
{% if meta.debug %}
value="Caffeine"
{% endif %}
/>
</label>
</td>
</tr>
<tr>
<td>
<label>
<input
type="text"
class="input input-bordered w-full smiles-input"
placeholder="CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
{% if meta.debug %}
value="CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
{% endif %}
/>
</label>
</td>
<td>
<label>
<input
type="text"
class="input input-bordered w-full name-input"
placeholder="Ibuprofen"
{% if meta.debug %}
value="Ibuprofen"
{% endif %}
/>
</label>
</td>
</tr>
</tbody>
</table>
<label class="select mb-2 w-full">
<span class="label">Predictor</span>
<select id="prediction-setting" name="prediction-setting">
<option disabled>Select a Setting</option>
{% for s in meta.available_settings %}
<option
value="{{ s.url }}"
{% if s.id == meta.user.default_setting.id %}selected{% endif %}
>
{{ s.name }}{% if s.id == meta.user.default_setting.id %}
(User default)
{% endif %}
</option>
{% endfor %}
</select>
</label>
<label class="floating-label" for="num-tps">
<input
type="number"
name="num-tps"
value="50"
step="1"
min="1"
max="100"
id="num-tps"
class="input input-md w-full"
/>
<span>Max Transformation Products</span>
</label>
<div class="flex justify-end gap-2">
<button type="button" id="add-row-btn" class="btn btn-outline">
Add row
</button>
<button type="submit" class="btn btn-primary">Submit</button>
</div>
</fieldset>
</form>
</div>
<script>
const tableBody = document.getElementById("smiles-table-body");
const addRowBtn = document.getElementById("add-row-btn");
const form = document.getElementById("smiles-form");
const hiddenField = document.getElementById("substrates");
addRowBtn.addEventListener("click", () => {
const row = document.createElement("tr");
const tdSmiles = document.createElement("td");
const tdName = document.createElement("td");
const smilesInput = document.createElement("input");
smilesInput.type = "text";
smilesInput.className = "input input-bordered w-full smiles-input";
smilesInput.placeholder = "SMILES";
const nameInput = document.createElement("input");
nameInput.type = "text";
nameInput.className = "input input-bordered w-full name-input";
nameInput.placeholder = "Name";
tdSmiles.appendChild(smilesInput);
tdName.appendChild(nameInput);
row.appendChild(tdSmiles);
row.appendChild(tdName);
tableBody.appendChild(row);
});
// Before submit, gather table data into the hidden field
form.addEventListener("submit", (e) => {
const smilesInputs = Array.from(
document.querySelectorAll(".smiles-input"),
);
const nameInputs = Array.from(document.querySelectorAll(".name-input"));
const lines = [];
for (let i = 0; i < smilesInputs.length; i++) {
const smiles = smilesInputs[i].value.trim();
const name = nameInputs[i]?.value.trim() ?? "";
// Skip emtpy rows
if (!smiles && !name) {
continue;
}
lines.push(`${smiles},${name}`);
}
// Value looks like:
// "CN1C=NC2=C1C(=O)N(C(=O)N2C)C,Caffeine\nCC(C)CC1=CC=C(C=C1)C(C)C(=O)O,Ibuprofen"
hiddenField.value = lines.join("\n");
});
</script>
{% endblock content %}

View File

@ -210,6 +210,27 @@
step="0.05"
/>
</div>
<div class="form-control mb-3">
<label
class="label"
for="model-based-prediction-setting-expansion-scheme"
>
<span class="label-text">Select Expansion Scheme</span>
</label>
<select
id="model-based-prediction-setting-expansion-scheme"
name="model-based-prediction-setting-expansion-scheme"
class="select select-bordered w-full"
>
<option value="" disabled selected>
Select the Expansion Scheme
</option>
<option value="BFS">Breadth First Search</option>
<option value="DFS">Depth First Search</option>
<option value="GREEDY">Greedy</option>
</select>
</div>
</div>
<div class="form-control">

View File

@ -0,0 +1,66 @@
{% load static %}
<dialog
id="download_job_result_modal"
class="modal"
x-data="modalForm()"
@close="reset()"
>
<div class="modal-box">
<!-- Header -->
<h3 class="font-bold text-lg">Download Job Result</h3>
<!-- Close button (X) -->
<form method="dialog">
<button
class="btn btn-sm btn-circle btn-ghost absolute right-2 top-2"
:disabled="isSubmitting"
>
</button>
</form>
<!-- Body -->
<div class="py-4">
<p>By clicking on Download the Result of this Job will be saved.</p>
<form
id="download-job-result-modal-form"
accept-charset="UTF-8"
action="{{ job.url }}"
method="GET"
>
<input type="hidden" name="download" value="true" />
</form>
</div>
<!-- Footer -->
<div class="modal-action">
<button
type="button"
class="btn"
onclick="this.closest('dialog').close()"
:disabled="isSubmitting"
>
Close
</button>
<button
type="button"
class="btn btn-primary"
@click="submit('download-job-result-modal-form'); $el.closest('dialog').close();"
:disabled="isSubmitting"
>
<span x-show="!isSubmitting">Download</span>
<span
x-show="isSubmitting"
class="loading loading-spinner loading-sm"
></span>
</button>
</div>
</div>
<!-- Backdrop -->
<form method="dialog" class="modal-backdrop">
<button :disabled="isSubmitting">close</button>
</form>
</dialog>

View File

@ -3,7 +3,9 @@
{% block content %}
{% block action_modals %}
{# {% include "modals/objects/refresh_job_log.html" %}#}
{% if job.is_result_downloadable %}
{% include "modals/objects/download_job_result_modal.html" %}
{% endif %}
{% endblock action_modals %}
<div class="space-y-2 p-4">
@ -49,22 +51,20 @@
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Description</div>
<div class="collapse-content">
Status page for Task {{ job.job_name }}
</div>
<div class="collapse-content">Status page for Job {{ job.job_name }}</div>
</div>
<!-- Job Status -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Task Status</div>
<div class="collapse-title text-xl font-medium">Job Status</div>
<div class="collapse-content">{{ job.status }}</div>
</div>
<!-- Job ID -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Task ID</div>
<div class="collapse-title text-xl font-medium">Job ID</div>
<div class="collapse-content">{{ job.task_id }}</div>
</div>
@ -72,7 +72,7 @@
{% if job.is_in_terminal_state %}
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Task Result</div>
<div class="collapse-title text-xl font-medium">Job Result</div>
<div class="collapse-content">
{% if job.job_name == 'engineer_pathways' %}
<div class="card bg-base-100">
@ -103,6 +103,68 @@
</ul>
</div>
</div>
{% elif job.job_name == 'batch_predict' %}
<div
id="table-container"
class="overflow-x-auto overflow-y-auto max-h-96 border rounded-lg"
></div>
<script>
const input = `{{ job.task_result }}`;
function renderCsvTable(str) {
const lines = str
.split("\n")
.map((l) => l.trim())
.filter(Boolean);
const [headerLine, ...rows] = lines;
const headers = headerLine.split(",").map((h) => h.trim());
const table = document.createElement("table");
table.className = "table table-zebra w-full";
const thead = document.createElement("thead");
const headerRow = document.createElement("tr");
headers.forEach((h) => {
const th = document.createElement("th");
th.textContent = h;
headerRow.appendChild(th);
});
thead.appendChild(headerRow);
const tbody = document.createElement("tbody");
rows.forEach((rowStr) => {
console.log(rowStr.split(","));
console.log(headers);
const row = document.createElement("tr");
const cells = rowStr.split(",").map((c) => c.trim());
headers.forEach((_, i) => {
const td = document.createElement("td");
const value = cells[i] || "";
td.textContent = value;
row.appendChild(td);
});
console.log(row);
tbody.appendChild(row);
});
table.appendChild(thead);
table.appendChild(tbody);
return table;
}
document
.getElementById("table-container")
.appendChild(renderCsvTable(input));
</script>
{% else %}
{{ job.parsed_result }}
{% endif %}

View File

@ -73,6 +73,7 @@
</ul>
</div>
</div>
{% endif %}
<!-- Reaction Packages -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
@ -96,9 +97,7 @@
<ul class="menu bg-base-100 rounded-box w-full">
{% for p in model.eval_packages.all %}
<li>
<a href="{{ p.url }}" class="hover:bg-base-200"
>{{ p.name }}</a
>
<a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a>
</li>
{% endfor %}
</ul>
@ -111,7 +110,6 @@
<div class="collapse-title text-xl font-medium">Model Status</div>
<div class="collapse-content">{{ model.status }}</div>
</div>
{% endif %}
{% if model.ready_for_prediction %}
<!-- Predict Panel -->
@ -174,7 +172,6 @@
</div>
</div>
{% endif %}
{% if model.model_status == 'FINISHED' %}
<!-- Single Gen Curve Panel -->
<div class="collapse-arrow bg-base-200 collapse">
@ -188,6 +185,19 @@
</div>
</div>
</div>
{% if model.multigen_eval %}
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">
Multi Gen Precision Recall Curve
</div>
<div class="collapse-content">
<div class="flex justify-center">
<div id="mg-chart"></div>
</div>
</div>
</div>
{% endif %}
{% endif %}
</div>
@ -244,25 +254,7 @@
}
}
function makeLoadingGif(selector, gifPath) {
const element = document.querySelector(selector);
if (element) {
element.innerHTML = '<img src="' + gifPath + '" alt="Loading...">';
}
}
document.addEventListener('DOMContentLoaded', function() {
// Show actions button if there are actions
const actionsButton = document.getElementById('actionsButton');
const actionsList = actionsButton?.querySelector('ul');
if (actionsList && actionsList.children.length > 0) {
actionsButton?.classList.remove('hidden');
}
{% if model.model_status == 'FINISHED' %}
// Precision Recall Curve
const sgChart = document.getElementById('sg-chart');
if (sgChart) {
function makeChart(selector, data) {
const x = ['Recall'];
const y = ['Precision'];
const thres = ['threshold'];
@ -285,7 +277,6 @@
return -1;
}
var data = {{ model.pr_curve|safe }};
if (!data || data.length === 0) {
console.warn('PR curve data is empty');
return;
@ -300,7 +291,7 @@
thres.push(d.threshold);
}
const chart = c3.generate({
bindto: '#sg-chart',
bindto: selector,
data: {
onclick: function (d, e) {
const idx = d.index;
@ -361,6 +352,29 @@
}
});
}
function makeLoadingGif(selector, gifPath) {
const element = document.querySelector(selector);
if (element) {
element.innerHTML = '<img src="' + gifPath + '" alt="Loading...">';
}
}
document.addEventListener('DOMContentLoaded', function() {
// Show actions button if there are actions
const actionsButton = document.getElementById('actionsButton');
const actionsList = actionsButton?.querySelector('ul');
if (actionsList && actionsList.children.length > 0) {
actionsButton?.classList.remove('hidden');
}
{% if model.model_status == 'FINISHED' %}
// Precision Recall Curve
makeChart('#sg-chart', {{ model.pr_curve|safe }});
{% if model.multigen_eval %}
// Multi Gen Precision Recall Curve
makeChart('#mg-chart', {{ model.mg_pr_curve|safe }});
{% endif %}
{% endif %}
// Predict button handler

View File

@ -393,7 +393,9 @@
<tbody>
<tr>
<td>Threshold</td>
<td>{{ pathway.setting.model_threshold }}</td>
<td>
{{ pathway.setting_with_overrides.model_threshold }}
</td>
</tr>
</tbody>
</table>
@ -420,11 +422,15 @@
{% endif %}
<tr>
<td>Max Nodes</td>
<td>{{ pathway.setting.max_nodes }}</td>
<td>{{ pathway.setting_with_overrides.max_nodes }}</td>
</tr>
<tr>
<td>Max Depth</td>
<td>{{ pathway.setting.max_depth }}</td>
<td>{{ pathway.setting_with_overrides.max_depth }}</td>
</tr>
<tr>
<td>Expansion Scheme</td>
<td>{{ user.default_setting.expansion_scheme }}</td>
</tr>
</tbody>
</table>

View File

@ -150,6 +150,10 @@
<td>Max Depth</td>
<td>{{ user.default_setting.max_depth }}</td>
</tr>
<tr>
<td>Expansion Scheme</td>
<td>{{ user.default_setting.expansion_scheme }}</td>
</tr>
</tbody>
</table>
</div>

View File

@ -1,3 +1,5 @@
from datetime import datetime
from django.conf import settings as s
from django.test import TestCase, override_settings
@ -23,6 +25,46 @@ class MultiGenTest(TestCase):
cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)"
cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine"
def test_batch_predict(self):
from epdb.tasks import batch_predict
pred_data = [
["CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"],
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "Ibuprofen"],
]
batch_predict_setting = self.user.prediction_settings()
target_package = PackageManager.create_package(
self.user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
num_tps = 50
res = batch_predict(
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=num_tps,
)
self.assertTrue(res.startswith("Pathway URL,"))
# Min 3 lines (1 header, 2 root nodes)
self.assertGreaterEqual(len(res.split("\n")), 3)
self.assertEqual(target_package.pathways.count(), 2)
pw = target_package.pathways.first()
self.assertEqual(
pw.setting_with_overrides.max_depth,
f"{num_tps} (this is an override for this particular pathway)",
)
self.assertEqual(
pw.setting_with_overrides.max_nodes,
f"{num_tps} (this is an override for this particular pathway)",
)
def test_engineer_pathway(self):
from epdb.tasks import engineer_pathways

View File

@ -1,20 +1,20 @@
from unittest.mock import Mock, patch
from django.test import TestCase
from epdb.logic import SNode, SEdge
from epdb.logic import SEdge, SNode, SPathway
from epdb.models import Pathway, Setting
from utilities.chem import PredictionResult, ProductSet
class SObjectTest(TestCase):
def setUp(self):
pass
class SNodeTest(TestCase):
def test_snode_eq(self):
snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
assert snode1 == snode2
self.assertEqual(snode1, snode2)
def test_snode_hash(self):
pass
class SEdgeTest(TestCase):
def test_sedge_eq(self):
sedge1 = SEdge(
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
@ -26,4 +26,62 @@ class SObjectTest(TestCase):
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
rule=None,
)
assert sedge1 == sedge2
self.assertEqual(sedge1, sedge2)
class SPathwayTest(TestCase):
def setUp(self):
"""Set up test data for SPathway tests."""
self.test_smiles = "CCN(CC)C(=O)C1=CC(=CC=C1)CO"
self.mock_setting = Mock(spec=Setting)
self.mock_pathway = Mock(spec=Pathway)
def test_predict_step_basic(self):
"""Test basic predict_step functionality."""
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
# e.g. bt0002
pr = PredictionResult(
[
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
],
0.17,
None,
)
with patch.object(self.mock_setting, "expand", return_value=[pr]):
spw.predict_step(from_depth=0)
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
self.assertEqual(len(spw.edges), 3)
def test_to_json(self):
"""Test basic predict_step functionality."""
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
# e.g. bt0002
pr = PredictionResult(
[
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
],
0.17,
None,
)
with patch.object(self.mock_setting, "expand", return_value=[pr]):
spw.predict_step(from_depth=0)
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
self.assertEqual(len(spw.edges), 3)
json_result = spw.to_json()
self.assertIsInstance(json_result, dict)
self.assertIn("nodes", json_result)
self.assertIn("edges", json_result)
self.assertEqual(len(json_result["nodes"]), 4)
self.assertEqual(len(json_result["edges"]), 3)