[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1451,7 +1451,7 @@ def create_pathway(
from .tasks import dispatch, predict
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=-1)
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=None)
return redirect(new_pw.url)
except ValueError as e:

View File

@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union, Tuple
from uuid import UUID
import nh3
@ -16,6 +16,7 @@ from epdb.models import (
Edge,
EnzymeLink,
EPModel,
ExpansionSchemeChoice,
Group,
GroupPackagePermission,
Node,
@ -1116,6 +1117,7 @@ class SettingManager(object):
rule_packages: List[Package] = None,
model: EPModel = None,
model_threshold: float = None,
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
):
new_s = Setting()
# Clean for potential XSS
@ -1550,6 +1552,196 @@ class SPathway(object):
return sorted(res, key=lambda x: hash(x))
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
"""
Expands the given substrates by generating new nodes and edges based on prediction settings.
This method processes a list of substrates and expands them into new nodes and edges using defined
rules and settings. It evaluates each substrate to determine its applicability domain, persists
domain assessments, and generates candidates for further processing. Newly created nodes and edges
are returned, and any applicable information is stored or updated internally during the process.
Parameters:
substrates (List[SNode]): A list of substrate nodes to be expanded.
Returns:
Tuple[List[SNode], List[SEdge]]:
A tuple containing:
- A list of new nodes generated during the expansion.
- A list of new edges representing connections between nodes based on candidate reactions.
Raises:
ValueError: If a node does not have an ID when it should have been saved already.
"""
new_nodes: List[SNode] = []
new_edges: List[SEdge] = []
for sub in substrates:
# For App Domain we have to ensure that each Node is evaluated
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
if n.id is None:
raise ValueError(f"Node {n} has no ID... aborting!")
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
candidates = self.prediction_setting.expand(self, sub)
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
snode = SNode(c, sub.depth + 1, app_domain_assessment)
self.smiles_to_node[c] = snode
new_nodes.append(snode)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
new_edges.append(edge)
return new_nodes, new_edges
def predict(self):
"""
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
nodes are processed systematically according to the specified schema.
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
supports persisting changes by writing back data to the database when configured to do so.
Attributes
----------
done : bool
A flag indicating whether the prediction process is completed.
persist : Any
An optional object that manages persistence operations for saving modifications.
root_nodes : List[SNode]
A collection of initial nodes in the graph from which traversal begins.
prediction_setting : Any
Configuration object specifying settings for graph traversal, such as the choice of
expansion schema.
Raises
------
ValueError
If an invalid or unknown expansion schema is provided in `prediction_setting`.
"""
# populate initial queue
queue = list(self.root_nodes)
processed = set()
# initial nodes have prob 1.0
node_probs: Dict[SNode, float] = {}
node_probs.update({n: 1.0 for n in queue})
while queue:
current = queue.pop(0)
if current in processed:
continue
processed.add(current)
new_nodes, new_edges = self._expand([current])
if new_nodes or new_edges:
# Check if we need to write back data to the database
if self.persist:
self._sync_to_pathway()
# call save to update the internal modified field
self.persist.save()
if new_nodes:
for edge in new_edges:
# All edge have `current` as educt
# Use `current` and adjust probs
current_prob = node_probs[current]
for prod in edge.products:
# Either is a new product or a product and we found a path with a higher prob
if (
prod not in node_probs
or current_prob * edge.probability > node_probs[prod]
):
node_probs[prod] = current_prob * edge.probability
# Update Queue to proceed
if self.prediction_setting.expansion_scheme == "DFS":
for n in new_nodes:
if n not in processed:
# We want to follow this path -> prepend queue
queue.insert(0, n)
elif self.prediction_setting.expansion_scheme == "BFS":
for n in new_nodes:
if n not in processed:
# Add at the end, everything queued before will be processed
# before new_nodese
queue.append(n)
elif self.prediction_setting.expansion_scheme == "GREEDY":
# Simply add them, as we will re-order the queue later
for n in new_nodes:
if n not in processed:
queue.append(n)
node_and_probs = []
for queued_val in queue:
node_and_probs.append((queued_val, node_probs[queued_val]))
from pprint import pprint
pprint(node_and_probs)
# re-order the queue and only pick smiles
queue = [
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
]
else:
raise ValueError(
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
)
# Queue exhausted, we're done
self.done = True
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
substrates: List[SNode] = []
@ -1560,67 +1752,15 @@ class SPathway(object):
if from_node == v:
substrates = [k]
break
else:
raise ValueError(f"Node {from_node} not found in SPathway!")
else:
raise ValueError("Neither from_depth nor from_node_url specified")
new_tp = False
if substrates:
for sub in substrates:
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
assert n.id is not None, (
"Node has no id! Should have been saved already... aborting!"
)
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
candidates = self.prediction_setting.expand(self, sub)
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
new_tp = True
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment
)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
new_nodes, _ = self._expand(substrates)
new_tp = len(new_nodes) > 0
# In case no substrates are found, we're done.
# For "predict from node" we're always done
@ -1704,11 +1844,6 @@ class SPathway(object):
"to": to_indices,
}
# if edge.rule:
# e['rule'] = {
# 'name': edge.rule.name,
# 'id': edge.rule.url,
# }
edges.append(e)
return {

View File

@ -0,0 +1,25 @@
# Generated by Django 5.2.7 on 2025-12-14 11:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0012_node_stereo_removed_pathway_predicted"),
]
operations = [
migrations.AddField(
model_name="setting",
name="expansion_schema",
field=models.CharField(
choices=[
("BFS", "Breadth First Search"),
("DFS", "Depth First Search"),
("GREEDY", "Greedy"),
],
default="BFS",
max_length=20,
),
),
]

View File

@ -0,0 +1,17 @@
# Generated by Django 5.2.7 on 2025-12-14 16:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("epdb", "0013_setting_expansion_schema"),
]
operations = [
migrations.RenameField(
model_name="setting",
old_name="expansion_schema",
new_name="expansion_scheme",
),
]

View File

@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# potentially prefetched edge_set
return self.edge_set.all()
@property
def setting_with_overrides(self):
mem_copy = Setting.objects.get(pk=self.setting.pk)
if "setting_overrides" in self.kv:
for k, v in self.kv["setting_overrides"].items():
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
return mem_copy
def _url(self):
return "{}/pathway/{}".format(self.package.url, self.uuid)
@ -1879,25 +1889,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
return json.dumps(res)
def to_csv(self) -> str:
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
import csv
import io
header = []
if include_pathway_url:
header += ["Pathway URL"]
header += [
"SMILES",
"name",
"depth",
"probability",
"rule_names",
"rule_ids",
"parent_smiles",
]
rows = []
rows.append(
[
"SMILES",
"name",
"depth",
"probability",
"rule_names",
"rule_ids",
"parent_smiles",
]
)
if include_header:
rows.append(header)
for n in self.nodes.order_by("depth"):
cs = n.default_node_label
row = [cs.smiles, cs.name, n.depth]
row = []
if include_pathway_url:
row.append(n.pathway.url)
row += [cs.smiles, cs.name, n.depth]
edges = self.edges.filter(end_nodes__in=[n])
if len(edges):
@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel):
return res
@property
def mg_pr_curve(self):
if self.model_status != self.FINISHED:
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
if not self.multigen_eval:
raise ValueError("MG PR Curve is only available for multigen models")
res = []
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
for t in thresholds:
res.append(
{
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
"threshold": float(t),
}
)
return res
@cached_property
def applicable_rules(self) -> List["Rule"]:
"""
@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel):
for i, root in enumerate(root_compounds):
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
spw = SPathway(root_nodes=root, prediction_setting=s)
spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
level = 0
while not spw.done:
@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission):
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
class ExpansionSchemeChoice(models.TextChoices):
BFS = "BFS", "Breadth First Search"
DFS = "DFS", "Depth First Search"
GREEDY = "GREEDY", "Greedy"
class Setting(EnviPathModel):
public = models.BooleanField(null=False, blank=False, default=False)
global_default = models.BooleanField(null=False, blank=False, default=False)
@ -3795,6 +3847,12 @@ class Setting(EnviPathModel):
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
)
expansion_scheme = models.CharField(
max_length=20,
choices=ExpansionSchemeChoice.choices,
default=ExpansionSchemeChoice.BFS,
)
def _url(self):
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
@ -3833,7 +3891,7 @@ class Setting(EnviPathModel):
"""Decision Method whether to expand on a certain Node or not"""
if pathway.num_nodes() >= self.max_nodes:
logger.info(
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}"
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
)
return []
@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel):
if self.job_name == "engineer_pathways":
return ast.literal_eval(self.task_result)
return self.task_result
def is_result_downloadable(self):
downloadable = ["batch_predict"]
return self.job_name in downloadable

View File

@ -11,6 +11,7 @@ from django.utils import timezone
from epdb.logic import SPathway
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
@ -139,14 +140,25 @@ def predict(
pred_setting_pk: int,
limit: Optional[int] = None,
node_pk: Optional[int] = None,
setting_overrides: Optional[dict] = None,
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
if setting_overrides:
for k, v in setting_overrides.items():
setattr(setting, k, v)
# If the setting has a model add/restore it from the cache
if setting.model is not None:
setting.model = get_ml_model(setting.model.pk)
pw.kv.update(**{"status": "running"})
kv = {"status": "running"}
if setting_overrides:
kv["setting_overrides"] = setting_overrides
pw.kv.update(**kv)
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
@ -171,7 +183,8 @@ def predict(
spw = SPathway.from_pathway(pw)
spw.predict_step(from_node=n)
else:
raise ValueError("Neither limit nor node_pk given!")
spw = SPathway(prediction_setting=setting, persist=pw)
spw.predict()
except Exception as e:
pw.kv.update({"status": "failed"})
@ -353,3 +366,76 @@ def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_p
predicted_pathways.append(pred_pw.url)
return intermediate_pathways, predicted_pathways
@shared_task(bind=True, queue="background")
def batch_predict(
self,
substrates: List[str] | List[List[str]],
prediction_setting_pk: int,
target_package_pk: int,
num_tps: int = 50,
):
target_package = Package.objects.get(pk=target_package_pk)
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
if len(substrates) == 0:
raise ValueError("No substrates given!")
is_pair = isinstance(substrates[0], list)
substrate_and_names = []
if not is_pair:
for sub in substrates:
substrate_and_names.append([sub, None])
else:
substrate_and_names = substrates
# Check prerequisite that we can standardize all substrates
standardized_substrates_and_smiles = []
for substrate in substrate_and_names:
try:
stand_smiles = FormatConverter.standardize(substrate[0])
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
except ValueError:
raise ValueError(
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
)
pathways = []
for pair in standardized_substrates_and_smiles:
pw = Pathway.create(
target_package,
pair[0],
name=pair[1],
predicted=True,
)
# set mode and setting
pw.setting = prediction_setting
pw.kv.update({"mode": "predict"})
pw.save()
predict(
pw.pk,
prediction_setting.pk,
limit=None,
setting_overrides={
"max_nodes": num_tps,
"max_depth": num_tps,
"model_threshold": 0.001,
},
)
pathways.append(pw)
buffer = io.StringIO()
for idx, pw in enumerate(pathways):
# Carry out header only for the first pathway
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
buffer.seek(0)
return buffer.getvalue()

View File

@ -49,6 +49,7 @@ urlpatterns = [
re_path(r"^group$", v.groups, name="groups"),
re_path(r"^search$", v.search, name="search"),
re_path(r"^predict$", v.predict_pathway, name="predict_pathway"),
re_path(r"^batch-predict$", v.batch_predict_pathway, name="batch_predict_pathway"),
# User Detail
re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
# Group Detail

View File

@ -1,11 +1,12 @@
import json
import logging
from typing import Any, Dict, List
from datetime import datetime
import nh3
from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest
from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render
from django.urls import reverse
@ -50,6 +51,7 @@ from .models import (
SimpleAmbitRule,
User,
UserPackagePermission,
ExpansionSchemeChoice,
)
logger = logging.getLogger(__name__)
@ -438,6 +440,18 @@ def predict_pathway(request):
return render(request, "predict_pathway.html", context)
def batch_predict_pathway(request):
"""Top-level predict pathway view using user's default package."""
if request.method != "GET":
return HttpResponseNotAllowed(["GET"])
context = get_base_context(request)
context["title"] = "enviPath - Batch Predict Pathway"
context["meta"]["current_package"] = context["meta"]["user"].default_package
return render(request, "batch_predict_pathway.html", context)
@package_permission_required()
def package_predict_pathway(request, package_uuid):
"""Package-specific predict pathway view."""
@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid):
if pw_mode == "predict" or pw_mode == "incremental":
# unlimited pred (will be handled by setting)
limit = -1
limit = None
# For incremental predict first level and return
if pw_mode == "incremental":
@ -2877,15 +2891,25 @@ def settings(request):
)
if not PackageManager.readable(current_user, params["model"].package):
raise ValueError("")
raise PermissionDenied("You're not allowed to access this model!")
expansion_scheme = request.POST.get(
"model-based-prediction-setting-expansion-scheme", "BFS"
)
if expansion_scheme not in ExpansionSchemeChoice.values:
raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}")
params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme)
elif tp_gen_method == "rule-based-prediction-setting":
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
else:
raise ValueError("")
raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!")
created_setting = SettingManager.create_setting(
current_user,
@ -2963,6 +2987,59 @@ def jobs(request):
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
elif job_name == "batch-predict":
substrates = request.POST.get("substrates")
prediction_setting_url = request.POST.get("prediction-setting")
num_tps = request.POST.get("num-tps")
if substrates is None or substrates.strip() == "":
raise BadRequest("No substrates provided.")
pred_data = []
for pair in substrates.split("\n"):
parts = pair.split(",")
try:
smiles = FormatConverter.standardize(parts[0])
except ValueError:
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
# name is optional
name = parts[1] if len(parts) > 1 else None
pred_data.append([smiles, name])
max_tps = 50
if num_tps is not None and num_tps.strip() != "":
try:
num_tps = int(num_tps)
max_tps = max(min(num_tps, 50), 1)
except ValueError:
raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.")
batch_predict_setting = SettingManager.get_setting_by_url(
current_user, prediction_setting_url
)
target_package = PackageManager.create_package(
current_user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
from .tasks import dispatch, batch_predict
res = dispatch(
current_user,
batch_predict,
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=max_tps,
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
else:
raise BadRequest(f"Job {job_name} is not supported!")
else:
@ -2982,6 +3059,21 @@ def job(request, job_uuid):
# No op if status is already in a terminal state
job.check_for_update()
if request.GET.get("download", False) == "true":
if not job.is_result_downloadable():
raise BadRequest("Result is not downloadable!")
if job.job_name == "batch_predict":
filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv"
else:
raise BadRequest("Result is not downloadable!")
res_str = job.task_result
response = HttpResponse(res_str, content_type="text/csv")
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
context["object_type"] = "joblog"
context["breadcrumbs"] = [
{"Home": s.SERVER_URL},