[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1451,7 +1451,7 @@ def create_pathway(
from .tasks import dispatch, predict from .tasks import dispatch, predict
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=-1) dispatch(request.user, predict, new_pw.pk, setting.pk, limit=None)
return redirect(new_pw.url) return redirect(new_pw.url)
except ValueError as e: except ValueError as e:

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union, Tuple
from uuid import UUID from uuid import UUID
import nh3 import nh3
@ -16,6 +16,7 @@ from epdb.models import (
Edge, Edge,
EnzymeLink, EnzymeLink,
EPModel, EPModel,
ExpansionSchemeChoice,
Group, Group,
GroupPackagePermission, GroupPackagePermission,
Node, Node,
@ -1116,6 +1117,7 @@ class SettingManager(object):
rule_packages: List[Package] = None, rule_packages: List[Package] = None,
model: EPModel = None, model: EPModel = None,
model_threshold: float = None, model_threshold: float = None,
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
): ):
new_s = Setting() new_s = Setting()
# Clean for potential XSS # Clean for potential XSS
@ -1550,6 +1552,196 @@ class SPathway(object):
return sorted(res, key=lambda x: hash(x)) return sorted(res, key=lambda x: hash(x))
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
"""
Expands the given substrates by generating new nodes and edges based on prediction settings.
This method processes a list of substrates and expands them into new nodes and edges using defined
rules and settings. It evaluates each substrate to determine its applicability domain, persists
domain assessments, and generates candidates for further processing. Newly created nodes and edges
are returned, and any applicable information is stored or updated internally during the process.
Parameters:
substrates (List[SNode]): A list of substrate nodes to be expanded.
Returns:
Tuple[List[SNode], List[SEdge]]:
A tuple containing:
- A list of new nodes generated during the expansion.
- A list of new edges representing connections between nodes based on candidate reactions.
Raises:
ValueError: If a node does not have an ID when it should have been saved already.
"""
new_nodes: List[SNode] = []
new_edges: List[SEdge] = []
for sub in substrates:
# For App Domain we have to ensure that each Node is evaluated
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
if n.id is None:
raise ValueError(f"Node {n} has no ID... aborting!")
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
candidates = self.prediction_setting.expand(self, sub)
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
snode = SNode(c, sub.depth + 1, app_domain_assessment)
self.smiles_to_node[c] = snode
new_nodes.append(snode)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
new_edges.append(edge)
return new_nodes, new_edges
def predict(self):
"""
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
nodes are processed systematically according to the specified schema.
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
supports persisting changes by writing back data to the database when configured to do so.
Attributes
----------
done : bool
A flag indicating whether the prediction process is completed.
persist : Any
An optional object that manages persistence operations for saving modifications.
root_nodes : List[SNode]
A collection of initial nodes in the graph from which traversal begins.
prediction_setting : Any
Configuration object specifying settings for graph traversal, such as the choice of
expansion schema.
Raises
------
ValueError
If an invalid or unknown expansion schema is provided in `prediction_setting`.
"""
# populate initial queue
queue = list(self.root_nodes)
processed = set()
# initial nodes have prob 1.0
node_probs: Dict[SNode, float] = {}
node_probs.update({n: 1.0 for n in queue})
while queue:
current = queue.pop(0)
if current in processed:
continue
processed.add(current)
new_nodes, new_edges = self._expand([current])
if new_nodes or new_edges:
# Check if we need to write back data to the database
if self.persist:
self._sync_to_pathway()
# call save to update the internal modified field
self.persist.save()
if new_nodes:
for edge in new_edges:
# All edge have `current` as educt
# Use `current` and adjust probs
current_prob = node_probs[current]
for prod in edge.products:
# Either is a new product or a product and we found a path with a higher prob
if (
prod not in node_probs
or current_prob * edge.probability > node_probs[prod]
):
node_probs[prod] = current_prob * edge.probability
# Update Queue to proceed
if self.prediction_setting.expansion_scheme == "DFS":
for n in new_nodes:
if n not in processed:
# We want to follow this path -> prepend queue
queue.insert(0, n)
elif self.prediction_setting.expansion_scheme == "BFS":
for n in new_nodes:
if n not in processed:
# Add at the end, everything queued before will be processed
# before new_nodese
queue.append(n)
elif self.prediction_setting.expansion_scheme == "GREEDY":
# Simply add them, as we will re-order the queue later
for n in new_nodes:
if n not in processed:
queue.append(n)
node_and_probs = []
for queued_val in queue:
node_and_probs.append((queued_val, node_probs[queued_val]))
from pprint import pprint
pprint(node_and_probs)
# re-order the queue and only pick smiles
queue = [
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
]
else:
raise ValueError(
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
)
# Queue exhausted, we're done
self.done = True
def predict_step(self, from_depth: int = None, from_node: "Node" = None): def predict_step(self, from_depth: int = None, from_node: "Node" = None):
substrates: List[SNode] = [] substrates: List[SNode] = []
@ -1560,67 +1752,15 @@ class SPathway(object):
if from_node == v: if from_node == v:
substrates = [k] substrates = [k]
break break
else:
raise ValueError(f"Node {from_node} not found in SPathway!")
else: else:
raise ValueError("Neither from_depth nor from_node_url specified") raise ValueError("Neither from_depth nor from_node_url specified")
new_tp = False new_tp = False
if substrates: if substrates:
for sub in substrates: new_nodes, _ = self._expand(substrates)
if sub.app_domain_assessment is None: new_tp = len(new_nodes) > 0
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
assert n.id is not None, (
"Node has no id! Should have been saved already... aborting!"
)
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
candidates = self.prediction_setting.expand(self, sub)
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in candidates:
if cand_set:
new_tp = True
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment
)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
# In case no substrates are found, we're done. # In case no substrates are found, we're done.
# For "predict from node" we're always done # For "predict from node" we're always done
@ -1704,11 +1844,6 @@ class SPathway(object):
"to": to_indices, "to": to_indices,
} }
# if edge.rule:
# e['rule'] = {
# 'name': edge.rule.name,
# 'id': edge.rule.url,
# }
edges.append(e) edges.append(e)
return { return {

View File

@ -0,0 +1,25 @@
# Generated by Django 5.2.7 on 2025-12-14 11:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0012_node_stereo_removed_pathway_predicted"),
]
operations = [
migrations.AddField(
model_name="setting",
name="expansion_schema",
field=models.CharField(
choices=[
("BFS", "Breadth First Search"),
("DFS", "Depth First Search"),
("GREEDY", "Greedy"),
],
default="BFS",
max_length=20,
),
),
]

View File

@ -0,0 +1,17 @@
# Generated by Django 5.2.7 on 2025-12-14 16:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("epdb", "0013_setting_expansion_schema"),
]
operations = [
migrations.RenameField(
model_name="setting",
old_name="expansion_schema",
new_name="expansion_scheme",
),
]

View File

@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# potentially prefetched edge_set # potentially prefetched edge_set
return self.edge_set.all() return self.edge_set.all()
@property
def setting_with_overrides(self):
mem_copy = Setting.objects.get(pk=self.setting.pk)
if "setting_overrides" in self.kv:
for k, v in self.kv["setting_overrides"].items():
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
return mem_copy
def _url(self): def _url(self):
return "{}/pathway/{}".format(self.package.url, self.uuid) return "{}/pathway/{}".format(self.package.url, self.uuid)
@ -1879,25 +1889,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
return json.dumps(res) return json.dumps(res)
def to_csv(self) -> str: def to_csv(self, include_header=True, include_pathway_url=False) -> str:
import csv import csv
import io import io
header = []
if include_pathway_url:
header += ["Pathway URL"]
header += [
"SMILES",
"name",
"depth",
"probability",
"rule_names",
"rule_ids",
"parent_smiles",
]
rows = [] rows = []
rows.append(
[ if include_header:
"SMILES", rows.append(header)
"name",
"depth",
"probability",
"rule_names",
"rule_ids",
"parent_smiles",
]
)
for n in self.nodes.order_by("depth"): for n in self.nodes.order_by("depth"):
cs = n.default_node_label cs = n.default_node_label
row = [cs.smiles, cs.name, n.depth] row = []
if include_pathway_url:
row.append(n.pathway.url)
row += [cs.smiles, cs.name, n.depth]
edges = self.edges.filter(end_nodes__in=[n]) edges = self.edges.filter(end_nodes__in=[n])
if len(edges): if len(edges):
@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel):
return res return res
@property
def mg_pr_curve(self):
if self.model_status != self.FINISHED:
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
if not self.multigen_eval:
raise ValueError("MG PR Curve is only available for multigen models")
res = []
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
for t in thresholds:
res.append(
{
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
"threshold": float(t),
}
)
return res
@cached_property @cached_property
def applicable_rules(self) -> List["Rule"]: def applicable_rules(self) -> List["Rule"]:
""" """
@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel):
for i, root in enumerate(root_compounds): for i, root in enumerate(root_compounds):
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...") logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
spw = SPathway(root_nodes=root, prediction_setting=s) spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
level = 0 level = 0
while not spw.done: while not spw.done:
@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission):
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}" return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
class ExpansionSchemeChoice(models.TextChoices):
BFS = "BFS", "Breadth First Search"
DFS = "DFS", "Depth First Search"
GREEDY = "GREEDY", "Greedy"
class Setting(EnviPathModel): class Setting(EnviPathModel):
public = models.BooleanField(null=False, blank=False, default=False) public = models.BooleanField(null=False, blank=False, default=False)
global_default = models.BooleanField(null=False, blank=False, default=False) global_default = models.BooleanField(null=False, blank=False, default=False)
@ -3795,6 +3847,12 @@ class Setting(EnviPathModel):
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25 null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
) )
expansion_scheme = models.CharField(
max_length=20,
choices=ExpansionSchemeChoice.choices,
default=ExpansionSchemeChoice.BFS,
)
def _url(self): def _url(self):
return "{}/setting/{}".format(s.SERVER_URL, self.uuid) return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
@ -3833,7 +3891,7 @@ class Setting(EnviPathModel):
"""Decision Method whether to expand on a certain Node or not""" """Decision Method whether to expand on a certain Node or not"""
if pathway.num_nodes() >= self.max_nodes: if pathway.num_nodes() >= self.max_nodes:
logger.info( logger.info(
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}" f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
) )
return [] return []
@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel):
if self.job_name == "engineer_pathways": if self.job_name == "engineer_pathways":
return ast.literal_eval(self.task_result) return ast.literal_eval(self.task_result)
return self.task_result return self.task_result
def is_result_downloadable(self):
downloadable = ["batch_predict"]
return self.job_name in downloadable

View File

@ -11,6 +11,7 @@ from django.utils import timezone
from epdb.logic import SPathway from epdb.logic import SPathway
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
from utilities.chem import FormatConverter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
@ -139,14 +140,25 @@ def predict(
pred_setting_pk: int, pred_setting_pk: int,
limit: Optional[int] = None, limit: Optional[int] = None,
node_pk: Optional[int] = None, node_pk: Optional[int] = None,
setting_overrides: Optional[dict] = None,
) -> Pathway: ) -> Pathway:
pw = Pathway.objects.get(id=pw_pk) pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk) setting = Setting.objects.get(id=pred_setting_pk)
if setting_overrides:
for k, v in setting_overrides.items():
setattr(setting, k, v)
# If the setting has a model add/restore it from the cache # If the setting has a model add/restore it from the cache
if setting.model is not None: if setting.model is not None:
setting.model = get_ml_model(setting.model.pk) setting.model = get_ml_model(setting.model.pk)
pw.kv.update(**{"status": "running"}) kv = {"status": "running"}
if setting_overrides:
kv["setting_overrides"] = setting_overrides
pw.kv.update(**kv)
pw.save() pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists(): if JobLog.objects.filter(task_id=self.request.id).exists():
@ -171,7 +183,8 @@ def predict(
spw = SPathway.from_pathway(pw) spw = SPathway.from_pathway(pw)
spw.predict_step(from_node=n) spw.predict_step(from_node=n)
else: else:
raise ValueError("Neither limit nor node_pk given!") spw = SPathway(prediction_setting=setting, persist=pw)
spw.predict()
except Exception as e: except Exception as e:
pw.kv.update({"status": "failed"}) pw.kv.update({"status": "failed"})
@ -353,3 +366,76 @@ def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_p
predicted_pathways.append(pred_pw.url) predicted_pathways.append(pred_pw.url)
return intermediate_pathways, predicted_pathways return intermediate_pathways, predicted_pathways
@shared_task(bind=True, queue="background")
def batch_predict(
self,
substrates: List[str] | List[List[str]],
prediction_setting_pk: int,
target_package_pk: int,
num_tps: int = 50,
):
target_package = Package.objects.get(pk=target_package_pk)
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
if len(substrates) == 0:
raise ValueError("No substrates given!")
is_pair = isinstance(substrates[0], list)
substrate_and_names = []
if not is_pair:
for sub in substrates:
substrate_and_names.append([sub, None])
else:
substrate_and_names = substrates
# Check prerequisite that we can standardize all substrates
standardized_substrates_and_smiles = []
for substrate in substrate_and_names:
try:
stand_smiles = FormatConverter.standardize(substrate[0])
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
except ValueError:
raise ValueError(
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
)
pathways = []
for pair in standardized_substrates_and_smiles:
pw = Pathway.create(
target_package,
pair[0],
name=pair[1],
predicted=True,
)
# set mode and setting
pw.setting = prediction_setting
pw.kv.update({"mode": "predict"})
pw.save()
predict(
pw.pk,
prediction_setting.pk,
limit=None,
setting_overrides={
"max_nodes": num_tps,
"max_depth": num_tps,
"model_threshold": 0.001,
},
)
pathways.append(pw)
buffer = io.StringIO()
for idx, pw in enumerate(pathways):
# Carry out header only for the first pathway
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
buffer.seek(0)
return buffer.getvalue()

View File

@ -49,6 +49,7 @@ urlpatterns = [
re_path(r"^group$", v.groups, name="groups"), re_path(r"^group$", v.groups, name="groups"),
re_path(r"^search$", v.search, name="search"), re_path(r"^search$", v.search, name="search"),
re_path(r"^predict$", v.predict_pathway, name="predict_pathway"), re_path(r"^predict$", v.predict_pathway, name="predict_pathway"),
re_path(r"^batch-predict$", v.batch_predict_pathway, name="batch_predict_pathway"),
# User Detail # User Detail
re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"), re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
# Group Detail # Group Detail

View File

@ -1,11 +1,12 @@
import json import json
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from datetime import datetime
import nh3 import nh3
from django.conf import settings as s from django.conf import settings as s
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render from django.shortcuts import redirect, render
from django.urls import reverse from django.urls import reverse
@ -50,6 +51,7 @@ from .models import (
SimpleAmbitRule, SimpleAmbitRule,
User, User,
UserPackagePermission, UserPackagePermission,
ExpansionSchemeChoice,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -438,6 +440,18 @@ def predict_pathway(request):
return render(request, "predict_pathway.html", context) return render(request, "predict_pathway.html", context)
def batch_predict_pathway(request):
"""Top-level predict pathway view using user's default package."""
if request.method != "GET":
return HttpResponseNotAllowed(["GET"])
context = get_base_context(request)
context["title"] = "enviPath - Batch Predict Pathway"
context["meta"]["current_package"] = context["meta"]["user"].default_package
return render(request, "batch_predict_pathway.html", context)
@package_permission_required() @package_permission_required()
def package_predict_pathway(request, package_uuid): def package_predict_pathway(request, package_uuid):
"""Package-specific predict pathway view.""" """Package-specific predict pathway view."""
@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid):
if pw_mode == "predict" or pw_mode == "incremental": if pw_mode == "predict" or pw_mode == "incremental":
# unlimited pred (will be handled by setting) # unlimited pred (will be handled by setting)
limit = -1 limit = None
# For incremental predict first level and return # For incremental predict first level and return
if pw_mode == "incremental": if pw_mode == "incremental":
@ -2877,15 +2891,25 @@ def settings(request):
) )
if not PackageManager.readable(current_user, params["model"].package): if not PackageManager.readable(current_user, params["model"].package):
raise ValueError("") raise PermissionDenied("You're not allowed to access this model!")
expansion_scheme = request.POST.get(
"model-based-prediction-setting-expansion-scheme", "BFS"
)
if expansion_scheme not in ExpansionSchemeChoice.values:
raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}")
params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme)
elif tp_gen_method == "rule-based-prediction-setting": elif tp_gen_method == "rule-based-prediction-setting":
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages") rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
params["rule_packages"] = [ params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages PackageManager.get_package_by_url(current_user, p) for p in rule_packages
] ]
else: else:
raise ValueError("") raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!")
created_setting = SettingManager.create_setting( created_setting = SettingManager.create_setting(
current_user, current_user,
@ -2963,6 +2987,59 @@ def jobs(request):
) )
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}") return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
elif job_name == "batch-predict":
substrates = request.POST.get("substrates")
prediction_setting_url = request.POST.get("prediction-setting")
num_tps = request.POST.get("num-tps")
if substrates is None or substrates.strip() == "":
raise BadRequest("No substrates provided.")
pred_data = []
for pair in substrates.split("\n"):
parts = pair.split(",")
try:
smiles = FormatConverter.standardize(parts[0])
except ValueError:
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
# name is optional
name = parts[1] if len(parts) > 1 else None
pred_data.append([smiles, name])
max_tps = 50
if num_tps is not None and num_tps.strip() != "":
try:
num_tps = int(num_tps)
max_tps = max(min(num_tps, 50), 1)
except ValueError:
raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.")
batch_predict_setting = SettingManager.get_setting_by_url(
current_user, prediction_setting_url
)
target_package = PackageManager.create_package(
current_user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
from .tasks import dispatch, batch_predict
res = dispatch(
current_user,
batch_predict,
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=max_tps,
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
else: else:
raise BadRequest(f"Job {job_name} is not supported!") raise BadRequest(f"Job {job_name} is not supported!")
else: else:
@ -2982,6 +3059,21 @@ def job(request, job_uuid):
# No op if status is already in a terminal state # No op if status is already in a terminal state
job.check_for_update() job.check_for_update()
if request.GET.get("download", False) == "true":
if not job.is_result_downloadable():
raise BadRequest("Result is not downloadable!")
if job.job_name == "batch_predict":
filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv"
else:
raise BadRequest("Result is not downloadable!")
res_str = job.task_result
response = HttpResponse(res_str, content_type="text/csv")
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
context["object_type"] = "joblog" context["object_type"] = "joblog"
context["breadcrumbs"] = [ context["breadcrumbs"] = [
{"Home": s.SERVER_URL}, {"Home": s.SERVER_URL},

View File

@ -0,0 +1,10 @@
{% if job.is_result_downloadable %}
<li>
<a
class="button"
onclick="document.getElementById('download_job_result_modal').showModal(); return false;"
>
<i class="glyphicon glyphicon-floppy-save"></i> Download Result</a
>
</li>
{% endif %}

View File

@ -0,0 +1,168 @@
{% extends "framework_modern.html" %}
{% load static %}
{% block content %}
<div class="mx-auto w-full p-8">
<h1 class="h1 mb-4 text-3xl font-bold">Batch Predict Pathways</h1>
<form id="smiles-form" method="POST" action="{% url "jobs" %}">
{% csrf_token %}
<input type="hidden" name="substrates" id="substrates" />
<input type="hidden" name="job-name" value="batch-predict" />
<fieldset class="flex flex-col gap-4 md:flex-3/4">
<table class="table table-zebra w-full">
<thead>
<tr>
<th>SMILES</th>
<th>Name</th>
</tr>
</thead>
<tbody id="smiles-table-body">
<tr>
<td>
<label>
<input
type="text"
class="input input-bordered w-full smiles-input"
placeholder="CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
{% if meta.debug %}
value="CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
{% endif %}
/>
</label>
</td>
<td>
<label>
<input
type="text"
class="input input-bordered w-full name-input"
placeholder="Caffeine"
{% if meta.debug %}
value="Caffeine"
{% endif %}
/>
</label>
</td>
</tr>
<tr>
<td>
<label>
<input
type="text"
class="input input-bordered w-full smiles-input"
placeholder="CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
{% if meta.debug %}
value="CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
{% endif %}
/>
</label>
</td>
<td>
<label>
<input
type="text"
class="input input-bordered w-full name-input"
placeholder="Ibuprofen"
{% if meta.debug %}
value="Ibuprofen"
{% endif %}
/>
</label>
</td>
</tr>
</tbody>
</table>
<label class="select mb-2 w-full">
<span class="label">Predictor</span>
<select id="prediction-setting" name="prediction-setting">
<option disabled>Select a Setting</option>
{% for s in meta.available_settings %}
<option
value="{{ s.url }}"
{% if s.id == meta.user.default_setting.id %}selected{% endif %}
>
{{ s.name }}{% if s.id == meta.user.default_setting.id %}
(User default)
{% endif %}
</option>
{% endfor %}
</select>
</label>
<label class="floating-label" for="num-tps">
<input
type="number"
name="num-tps"
value="50"
step="1"
min="1"
max="100"
id="num-tps"
class="input input-md w-full"
/>
<span>Max Transformation Products</span>
</label>
<div class="flex justify-end gap-2">
<button type="button" id="add-row-btn" class="btn btn-outline">
Add row
</button>
<button type="submit" class="btn btn-primary">Submit</button>
</div>
</fieldset>
</form>
</div>
<script>
const tableBody = document.getElementById("smiles-table-body");
const addRowBtn = document.getElementById("add-row-btn");
const form = document.getElementById("smiles-form");
const hiddenField = document.getElementById("substrates");
addRowBtn.addEventListener("click", () => {
const row = document.createElement("tr");
const tdSmiles = document.createElement("td");
const tdName = document.createElement("td");
const smilesInput = document.createElement("input");
smilesInput.type = "text";
smilesInput.className = "input input-bordered w-full smiles-input";
smilesInput.placeholder = "SMILES";
const nameInput = document.createElement("input");
nameInput.type = "text";
nameInput.className = "input input-bordered w-full name-input";
nameInput.placeholder = "Name";
tdSmiles.appendChild(smilesInput);
tdName.appendChild(nameInput);
row.appendChild(tdSmiles);
row.appendChild(tdName);
tableBody.appendChild(row);
});
// Before submit, gather table data into the hidden field
form.addEventListener("submit", (e) => {
const smilesInputs = Array.from(
document.querySelectorAll(".smiles-input"),
);
const nameInputs = Array.from(document.querySelectorAll(".name-input"));
const lines = [];
for (let i = 0; i < smilesInputs.length; i++) {
const smiles = smilesInputs[i].value.trim();
const name = nameInputs[i]?.value.trim() ?? "";
// Skip emtpy rows
if (!smiles && !name) {
continue;
}
lines.push(`${smiles},${name}`);
}
// Value looks like:
// "CN1C=NC2=C1C(=O)N(C(=O)N2C)C,Caffeine\nCC(C)CC1=CC=C(C=C1)C(C)C(=O)O,Ibuprofen"
hiddenField.value = lines.join("\n");
});
</script>
{% endblock content %}

View File

@ -210,6 +210,27 @@
step="0.05" step="0.05"
/> />
</div> </div>
<div class="form-control mb-3">
<label
class="label"
for="model-based-prediction-setting-expansion-scheme"
>
<span class="label-text">Select Expansion Scheme</span>
</label>
<select
id="model-based-prediction-setting-expansion-scheme"
name="model-based-prediction-setting-expansion-scheme"
class="select select-bordered w-full"
>
<option value="" disabled selected>
Select the Expansion Scheme
</option>
<option value="BFS">Breadth First Search</option>
<option value="DFS">Depth First Search</option>
<option value="GREEDY">Greedy</option>
</select>
</div>
</div> </div>
<div class="form-control"> <div class="form-control">

View File

@ -0,0 +1,66 @@
{% load static %}
<dialog
id="download_job_result_modal"
class="modal"
x-data="modalForm()"
@close="reset()"
>
<div class="modal-box">
<!-- Header -->
<h3 class="font-bold text-lg">Download Job Result</h3>
<!-- Close button (X) -->
<form method="dialog">
<button
class="btn btn-sm btn-circle btn-ghost absolute right-2 top-2"
:disabled="isSubmitting"
>
</button>
</form>
<!-- Body -->
<div class="py-4">
<p>By clicking on Download the Result of this Job will be saved.</p>
<form
id="download-job-result-modal-form"
accept-charset="UTF-8"
action="{{ job.url }}"
method="GET"
>
<input type="hidden" name="download" value="true" />
</form>
</div>
<!-- Footer -->
<div class="modal-action">
<button
type="button"
class="btn"
onclick="this.closest('dialog').close()"
:disabled="isSubmitting"
>
Close
</button>
<button
type="button"
class="btn btn-primary"
@click="submit('download-job-result-modal-form'); $el.closest('dialog').close();"
:disabled="isSubmitting"
>
<span x-show="!isSubmitting">Download</span>
<span
x-show="isSubmitting"
class="loading loading-spinner loading-sm"
></span>
</button>
</div>
</div>
<!-- Backdrop -->
<form method="dialog" class="modal-backdrop">
<button :disabled="isSubmitting">close</button>
</form>
</dialog>

View File

@ -3,7 +3,9 @@
{% block content %} {% block content %}
{% block action_modals %} {% block action_modals %}
{# {% include "modals/objects/refresh_job_log.html" %}#} {% if job.is_result_downloadable %}
{% include "modals/objects/download_job_result_modal.html" %}
{% endif %}
{% endblock action_modals %} {% endblock action_modals %}
<div class="space-y-2 p-4"> <div class="space-y-2 p-4">
@ -49,22 +51,20 @@
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Description</div> <div class="collapse-title text-xl font-medium">Description</div>
<div class="collapse-content"> <div class="collapse-content">Status page for Job {{ job.job_name }}</div>
Status page for Task {{ job.job_name }}
</div>
</div> </div>
<!-- Job Status --> <!-- Job Status -->
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Task Status</div> <div class="collapse-title text-xl font-medium">Job Status</div>
<div class="collapse-content">{{ job.status }}</div> <div class="collapse-content">{{ job.status }}</div>
</div> </div>
<!-- Job ID --> <!-- Job ID -->
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Task ID</div> <div class="collapse-title text-xl font-medium">Job ID</div>
<div class="collapse-content">{{ job.task_id }}</div> <div class="collapse-content">{{ job.task_id }}</div>
</div> </div>
@ -72,7 +72,7 @@
{% if job.is_in_terminal_state %} {% if job.is_in_terminal_state %}
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Task Result</div> <div class="collapse-title text-xl font-medium">Job Result</div>
<div class="collapse-content"> <div class="collapse-content">
{% if job.job_name == 'engineer_pathways' %} {% if job.job_name == 'engineer_pathways' %}
<div class="card bg-base-100"> <div class="card bg-base-100">
@ -103,6 +103,68 @@
</ul> </ul>
</div> </div>
</div> </div>
{% elif job.job_name == 'batch_predict' %}
<div
id="table-container"
class="overflow-x-auto overflow-y-auto max-h-96 border rounded-lg"
></div>
<script>
const input = `{{ job.task_result }}`;
function renderCsvTable(str) {
const lines = str
.split("\n")
.map((l) => l.trim())
.filter(Boolean);
const [headerLine, ...rows] = lines;
const headers = headerLine.split(",").map((h) => h.trim());
const table = document.createElement("table");
table.className = "table table-zebra w-full";
const thead = document.createElement("thead");
const headerRow = document.createElement("tr");
headers.forEach((h) => {
const th = document.createElement("th");
th.textContent = h;
headerRow.appendChild(th);
});
thead.appendChild(headerRow);
const tbody = document.createElement("tbody");
rows.forEach((rowStr) => {
console.log(rowStr.split(","));
console.log(headers);
const row = document.createElement("tr");
const cells = rowStr.split(",").map((c) => c.trim());
headers.forEach((_, i) => {
const td = document.createElement("td");
const value = cells[i] || "";
td.textContent = value;
row.appendChild(td);
});
console.log(row);
tbody.appendChild(row);
});
table.appendChild(thead);
table.appendChild(tbody);
return table;
}
document
.getElementById("table-container")
.appendChild(renderCsvTable(input));
</script>
{% else %} {% else %}
{{ job.parsed_result }} {{ job.parsed_result }}
{% endif %} {% endif %}

View File

@ -73,13 +73,29 @@
</ul> </ul>
</div> </div>
</div> </div>
<!-- Reaction Packages --> {% endif %}
<!-- Reaction Packages -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Reaction Packages</div>
<div class="collapse-content">
<ul class="menu bg-base-100 rounded-box w-full">
{% for p in model.data_packages.all %}
<li>
<a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a>
</li>
{% endfor %}
</ul>
</div>
</div>
{% if model.eval_packages.all|length > 0 %}
<!-- Eval Packages -->
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Reaction Packages</div> <div class="collapse-title text-xl font-medium">Eval Packages</div>
<div class="collapse-content"> <div class="collapse-content">
<ul class="menu bg-base-100 rounded-box w-full"> <ul class="menu bg-base-100 rounded-box w-full">
{% for p in model.data_packages.all %} {% for p in model.eval_packages.all %}
<li> <li>
<a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a> <a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a>
</li> </li>
@ -87,31 +103,13 @@
</ul> </ul>
</div> </div>
</div> </div>
{% if model.eval_packages.all|length > 0 %}
<!-- Eval Packages -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Eval Packages</div>
<div class="collapse-content">
<ul class="menu bg-base-100 rounded-box w-full">
{% for p in model.eval_packages.all %}
<li>
<a href="{{ p.url }}" class="hover:bg-base-200"
>{{ p.name }}</a
>
</li>
{% endfor %}
</ul>
</div>
</div>
{% endif %}
<!-- Model Status -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Model Status</div>
<div class="collapse-content">{{ model.status }}</div>
</div>
{% endif %} {% endif %}
<!-- Model Status -->
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">Model Status</div>
<div class="collapse-content">{{ model.status }}</div>
</div>
{% if model.ready_for_prediction %} {% if model.ready_for_prediction %}
<!-- Predict Panel --> <!-- Predict Panel -->
@ -174,7 +172,6 @@
</div> </div>
</div> </div>
{% endif %} {% endif %}
{% if model.model_status == 'FINISHED' %} {% if model.model_status == 'FINISHED' %}
<!-- Single Gen Curve Panel --> <!-- Single Gen Curve Panel -->
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
@ -188,6 +185,19 @@
</div> </div>
</div> </div>
</div> </div>
{% if model.multigen_eval %}
<div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked />
<div class="collapse-title text-xl font-medium">
Multi Gen Precision Recall Curve
</div>
<div class="collapse-content">
<div class="flex justify-center">
<div id="mg-chart"></div>
</div>
</div>
</div>
{% endif %}
{% endif %} {% endif %}
</div> </div>
@ -244,6 +254,105 @@
} }
} }
function makeChart(selector, data) {
const x = ['Recall'];
const y = ['Precision'];
const thres = ['threshold'];
function compare(a, b) {
if (a.threshold < b.threshold)
return -1;
else if (a.threshold > b.threshold)
return 1;
else
return 0;
}
function getIndexForValue(data, val, val_name) {
for (const idx in data) {
if (data[idx][val_name] == val) {
return idx;
}
}
return -1;
}
if (!data || data.length === 0) {
console.warn('PR curve data is empty');
return;
}
const dataLength = data.length;
data.sort(compare);
for (const idx in data) {
const d = data[idx];
x.push(d.recall);
y.push(d.precision);
thres.push(d.threshold);
}
const chart = c3.generate({
bindto: selector,
data: {
onclick: function (d, e) {
const idx = d.index;
const thresh = data[dataLength - idx - 1].threshold;
},
x: 'Recall',
y: 'Precision',
columns: [
x,
y,
]
},
size: {
height: 400,
width: 480
},
axis: {
x: {
max: 1,
min: 0,
label: 'Recall',
padding: 0,
tick: {
fit: true,
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
}
},
y: {
max: 1,
min: 0,
label: 'Precision',
padding: 0,
tick: {
fit: true,
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
}
}
},
point: {
r: 4
},
tooltip: {
format: {
title: function (recall) {
const idx = getIndexForValue(data, recall, "recall");
if (idx != -1) {
return "Threshold: " + data[idx].threshold;
}
return "";
},
value: function (precision, ratio, id) {
return undefined;
}
}
},
zoom: {
enabled: true
}
});
}
function makeLoadingGif(selector, gifPath) { function makeLoadingGif(selector, gifPath) {
const element = document.querySelector(selector); const element = document.querySelector(selector);
if (element) { if (element) {
@ -260,107 +369,12 @@
} }
{% if model.model_status == 'FINISHED' %} {% if model.model_status == 'FINISHED' %}
// Precision Recall Curve // Precision Recall Curve
const sgChart = document.getElementById('sg-chart'); makeChart('#sg-chart', {{ model.pr_curve|safe }});
if (sgChart) { {% if model.multigen_eval %}
const x = ['Recall']; // Multi Gen Precision Recall Curve
const y = ['Precision']; makeChart('#mg-chart', {{ model.mg_pr_curve|safe }});
const thres = ['threshold']; {% endif %}
function compare(a, b) {
if (a.threshold < b.threshold)
return -1;
else if (a.threshold > b.threshold)
return 1;
else
return 0;
}
function getIndexForValue(data, val, val_name) {
for (const idx in data) {
if (data[idx][val_name] == val) {
return idx;
}
}
return -1;
}
var data = {{ model.pr_curve|safe }};
if (!data || data.length === 0) {
console.warn('PR curve data is empty');
return;
}
const dataLength = data.length;
data.sort(compare);
for (const idx in data) {
const d = data[idx];
x.push(d.recall);
y.push(d.precision);
thres.push(d.threshold);
}
const chart = c3.generate({
bindto: '#sg-chart',
data: {
onclick: function (d, e) {
const idx = d.index;
const thresh = data[dataLength - idx - 1].threshold;
},
x: 'Recall',
y: 'Precision',
columns: [
x,
y,
]
},
size: {
height: 400,
width: 480
},
axis: {
x: {
max: 1,
min: 0,
label: 'Recall',
padding: 0,
tick: {
fit: true,
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
}
},
y: {
max: 1,
min: 0,
label: 'Precision',
padding: 0,
tick: {
fit: true,
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
}
}
},
point: {
r: 4
},
tooltip: {
format: {
title: function (recall) {
const idx = getIndexForValue(data, recall, "recall");
if (idx != -1) {
return "Threshold: " + data[idx].threshold;
}
return "";
},
value: function (precision, ratio, id) {
return undefined;
}
}
},
zoom: {
enabled: true
}
});
}
{% endif %} {% endif %}
// Predict button handler // Predict button handler

View File

@ -393,7 +393,9 @@
<tbody> <tbody>
<tr> <tr>
<td>Threshold</td> <td>Threshold</td>
<td>{{ pathway.setting.model_threshold }}</td> <td>
{{ pathway.setting_with_overrides.model_threshold }}
</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
@ -420,11 +422,15 @@
{% endif %} {% endif %}
<tr> <tr>
<td>Max Nodes</td> <td>Max Nodes</td>
<td>{{ pathway.setting.max_nodes }}</td> <td>{{ pathway.setting_with_overrides.max_nodes }}</td>
</tr> </tr>
<tr> <tr>
<td>Max Depth</td> <td>Max Depth</td>
<td>{{ pathway.setting.max_depth }}</td> <td>{{ pathway.setting_with_overrides.max_depth }}</td>
</tr>
<tr>
<td>Expansion Scheme</td>
<td>{{ user.default_setting.expansion_scheme }}</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>

View File

@ -150,6 +150,10 @@
<td>Max Depth</td> <td>Max Depth</td>
<td>{{ user.default_setting.max_depth }}</td> <td>{{ user.default_setting.max_depth }}</td>
</tr> </tr>
<tr>
<td>Expansion Scheme</td>
<td>{{ user.default_setting.expansion_scheme }}</td>
</tr>
</tbody> </tbody>
</table> </table>
</div> </div>

View File

@ -1,3 +1,5 @@
from datetime import datetime
from django.conf import settings as s from django.conf import settings as s
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
@ -23,6 +25,46 @@ class MultiGenTest(TestCase):
cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)" cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)"
cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine" cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine"
def test_batch_predict(self):
from epdb.tasks import batch_predict
pred_data = [
["CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"],
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "Ibuprofen"],
]
batch_predict_setting = self.user.prediction_settings()
target_package = PackageManager.create_package(
self.user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
num_tps = 50
res = batch_predict(
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=num_tps,
)
self.assertTrue(res.startswith("Pathway URL,"))
# Min 3 lines (1 header, 2 root nodes)
self.assertGreaterEqual(len(res.split("\n")), 3)
self.assertEqual(target_package.pathways.count(), 2)
pw = target_package.pathways.first()
self.assertEqual(
pw.setting_with_overrides.max_depth,
f"{num_tps} (this is an override for this particular pathway)",
)
self.assertEqual(
pw.setting_with_overrides.max_nodes,
f"{num_tps} (this is an override for this particular pathway)",
)
def test_engineer_pathway(self): def test_engineer_pathway(self):
from epdb.tasks import engineer_pathways from epdb.tasks import engineer_pathways

View File

@ -1,20 +1,20 @@
from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
from epdb.logic import SNode, SEdge from epdb.logic import SEdge, SNode, SPathway
from epdb.models import Pathway, Setting
from utilities.chem import PredictionResult, ProductSet
class SObjectTest(TestCase): class SNodeTest(TestCase):
def setUp(self):
pass
def test_snode_eq(self): def test_snode_eq(self):
snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0) snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0) snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
assert snode1 == snode2 self.assertEqual(snode1, snode2)
def test_snode_hash(self):
pass
class SEdgeTest(TestCase):
def test_sedge_eq(self): def test_sedge_eq(self):
sedge1 = SEdge( sedge1 = SEdge(
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)], [SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
@ -26,4 +26,62 @@ class SObjectTest(TestCase):
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)], [SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
rule=None, rule=None,
) )
assert sedge1 == sedge2 self.assertEqual(sedge1, sedge2)
class SPathwayTest(TestCase):
def setUp(self):
"""Set up test data for SPathway tests."""
self.test_smiles = "CCN(CC)C(=O)C1=CC(=CC=C1)CO"
self.mock_setting = Mock(spec=Setting)
self.mock_pathway = Mock(spec=Pathway)
def test_predict_step_basic(self):
"""Test basic predict_step functionality."""
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
# e.g. bt0002
pr = PredictionResult(
[
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
],
0.17,
None,
)
with patch.object(self.mock_setting, "expand", return_value=[pr]):
spw.predict_step(from_depth=0)
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
self.assertEqual(len(spw.edges), 3)
def test_to_json(self):
"""Test basic predict_step functionality."""
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
# e.g. bt0002
pr = PredictionResult(
[
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
],
0.17,
None,
)
with patch.object(self.mock_setting, "expand", return_value=[pr]):
spw.predict_step(from_depth=0)
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
self.assertEqual(len(spw.edges), 3)
json_result = spw.to_json()
self.assertIsInstance(json_result, dict)
self.assertIn("nodes", json_result)
self.assertIn("edges", json_result)
self.assertEqual(len(json_result["nodes"]), 4)
self.assertEqual(len(json_result["edges"]), 3)