[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union, Tuple
from uuid import UUID
import nh3
@ -16,6 +16,7 @@ from epdb.models import (
Edge,
EnzymeLink,
EPModel,
ExpansionSchemeChoice,
Group,
GroupPackagePermission,
Node,
@ -1116,6 +1117,7 @@ class SettingManager(object):
rule_packages: List[Package] = None,
model: EPModel = None,
model_threshold: float = None,
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
):
new_s = Setting()
# Clean for potential XSS
@ -1550,6 +1552,196 @@ class SPathway(object):
return sorted(res, key=lambda x: hash(x))
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
"""
Expands the given substrates by generating new nodes and edges based on prediction settings.
This method processes a list of substrates and expands them into new nodes and edges using defined
rules and settings. It evaluates each substrate to determine its applicability domain, persists
domain assessments, and generates candidates for further processing. Newly created nodes and edges
are returned, and any applicable information is stored or updated internally during the process.
Parameters:
substrates (List[SNode]): A list of substrate nodes to be expanded.
Returns:
Tuple[List[SNode], List[SEdge]]:
A tuple containing:
- A list of new nodes generated during the expansion.
- A list of new edges representing connections between nodes based on candidate reactions.
Raises:
ValueError: If a node does not have an ID when it should have been saved already.
"""
new_nodes: List[SNode] = []
new_edges: List[SEdge] = []
for sub in substrates:
# For App Domain we have to ensure that each Node is evaluated
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
if n.id is None:
raise ValueError(f"Node {n} has no ID... aborting!")
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
candidates = self.prediction_setting.expand(self, sub)
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
snode = SNode(c, sub.depth + 1, app_domain_assessment)
self.smiles_to_node[c] = snode
new_nodes.append(snode)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
new_edges.append(edge)
return new_nodes, new_edges
def predict(self):
"""
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
nodes are processed systematically according to the specified schema.
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
supports persisting changes by writing back data to the database when configured to do so.
Attributes
----------
done : bool
A flag indicating whether the prediction process is completed.
persist : Any
An optional object that manages persistence operations for saving modifications.
root_nodes : List[SNode]
A collection of initial nodes in the graph from which traversal begins.
prediction_setting : Any
Configuration object specifying settings for graph traversal, such as the choice of
expansion schema.
Raises
------
ValueError
If an invalid or unknown expansion schema is provided in `prediction_setting`.
"""
# populate initial queue
queue = list(self.root_nodes)
processed = set()
# initial nodes have prob 1.0
node_probs: Dict[SNode, float] = {}
node_probs.update({n: 1.0 for n in queue})
while queue:
current = queue.pop(0)
if current in processed:
continue
processed.add(current)
new_nodes, new_edges = self._expand([current])
if new_nodes or new_edges:
# Check if we need to write back data to the database
if self.persist:
self._sync_to_pathway()
# call save to update the internal modified field
self.persist.save()
if new_nodes:
for edge in new_edges:
# All edge have `current` as educt
# Use `current` and adjust probs
current_prob = node_probs[current]
for prod in edge.products:
# Either is a new product or a product and we found a path with a higher prob
if (
prod not in node_probs
or current_prob * edge.probability > node_probs[prod]
):
node_probs[prod] = current_prob * edge.probability
# Update Queue to proceed
if self.prediction_setting.expansion_scheme == "DFS":
for n in new_nodes:
if n not in processed:
# We want to follow this path -> prepend queue
queue.insert(0, n)
elif self.prediction_setting.expansion_scheme == "BFS":
for n in new_nodes:
if n not in processed:
# Add at the end, everything queued before will be processed
# before new_nodese
queue.append(n)
elif self.prediction_setting.expansion_scheme == "GREEDY":
# Simply add them, as we will re-order the queue later
for n in new_nodes:
if n not in processed:
queue.append(n)
node_and_probs = []
for queued_val in queue:
node_and_probs.append((queued_val, node_probs[queued_val]))
from pprint import pprint
pprint(node_and_probs)
# re-order the queue and only pick smiles
queue = [
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
]
else:
raise ValueError(
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
)
# Queue exhausted, we're done
self.done = True
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
substrates: List[SNode] = []
@ -1560,67 +1752,15 @@ class SPathway(object):
if from_node == v:
substrates = [k]
break
else:
raise ValueError(f"Node {from_node} not found in SPathway!")
else:
raise ValueError("Neither from_depth nor from_node_url specified")
new_tp = False
if substrates:
for sub in substrates:
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
assert n.id is not None, (
"Node has no id! Should have been saved already... aborting!"
)
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
candidates = self.prediction_setting.expand(self, sub)
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
new_tp = True
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment
)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
new_nodes, _ = self._expand(substrates)
new_tp = len(new_nodes) > 0
# In case no substrates are found, we're done.
# For "predict from node" we're always done
@ -1704,11 +1844,6 @@ class SPathway(object):
"to": to_indices,
}
# if edge.rule:
# e['rule'] = {
# 'name': edge.rule.name,
# 'id': edge.rule.url,
# }
edges.append(e)
return {