forked from enviPath/enviPy
[Feature] Show Multi Gen Eval + Batch Prediction (#267)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#267
This commit is contained in:
@ -1451,7 +1451,7 @@ def create_pathway(
|
|||||||
|
|
||||||
from .tasks import dispatch, predict
|
from .tasks import dispatch, predict
|
||||||
|
|
||||||
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=-1)
|
dispatch(request.user, predict, new_pw.pk, setting.pk, limit=None)
|
||||||
|
|
||||||
return redirect(new_pw.url)
|
return redirect(new_pw.url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
259
epdb/logic.py
259
epdb/logic.py
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union, Tuple
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import nh3
|
import nh3
|
||||||
@ -16,6 +16,7 @@ from epdb.models import (
|
|||||||
Edge,
|
Edge,
|
||||||
EnzymeLink,
|
EnzymeLink,
|
||||||
EPModel,
|
EPModel,
|
||||||
|
ExpansionSchemeChoice,
|
||||||
Group,
|
Group,
|
||||||
GroupPackagePermission,
|
GroupPackagePermission,
|
||||||
Node,
|
Node,
|
||||||
@ -1116,6 +1117,7 @@ class SettingManager(object):
|
|||||||
rule_packages: List[Package] = None,
|
rule_packages: List[Package] = None,
|
||||||
model: EPModel = None,
|
model: EPModel = None,
|
||||||
model_threshold: float = None,
|
model_threshold: float = None,
|
||||||
|
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
|
||||||
):
|
):
|
||||||
new_s = Setting()
|
new_s = Setting()
|
||||||
# Clean for potential XSS
|
# Clean for potential XSS
|
||||||
@ -1550,6 +1552,196 @@ class SPathway(object):
|
|||||||
|
|
||||||
return sorted(res, key=lambda x: hash(x))
|
return sorted(res, key=lambda x: hash(x))
|
||||||
|
|
||||||
|
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
|
||||||
|
"""
|
||||||
|
Expands the given substrates by generating new nodes and edges based on prediction settings.
|
||||||
|
|
||||||
|
This method processes a list of substrates and expands them into new nodes and edges using defined
|
||||||
|
rules and settings. It evaluates each substrate to determine its applicability domain, persists
|
||||||
|
domain assessments, and generates candidates for further processing. Newly created nodes and edges
|
||||||
|
are returned, and any applicable information is stored or updated internally during the process.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
substrates (List[SNode]): A list of substrate nodes to be expanded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[SNode], List[SEdge]]:
|
||||||
|
A tuple containing:
|
||||||
|
- A list of new nodes generated during the expansion.
|
||||||
|
- A list of new edges representing connections between nodes based on candidate reactions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a node does not have an ID when it should have been saved already.
|
||||||
|
"""
|
||||||
|
new_nodes: List[SNode] = []
|
||||||
|
new_edges: List[SEdge] = []
|
||||||
|
|
||||||
|
for sub in substrates:
|
||||||
|
# For App Domain we have to ensure that each Node is evaluated
|
||||||
|
if sub.app_domain_assessment is None:
|
||||||
|
if self.prediction_setting.model:
|
||||||
|
if self.prediction_setting.model.app_domain:
|
||||||
|
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
||||||
|
sub.smiles
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.persist is not None:
|
||||||
|
n = self.snode_persist_lookup[sub]
|
||||||
|
|
||||||
|
if n.id is None:
|
||||||
|
raise ValueError(f"Node {n} has no ID... aborting!")
|
||||||
|
|
||||||
|
node_data = n.simple_json()
|
||||||
|
node_data["image"] = f"{n.url}?image=svg"
|
||||||
|
app_domain_assessment["assessment"]["node"] = node_data
|
||||||
|
|
||||||
|
n.kv["app_domain_assessment"] = app_domain_assessment
|
||||||
|
n.save()
|
||||||
|
|
||||||
|
sub.app_domain_assessment = app_domain_assessment
|
||||||
|
|
||||||
|
candidates = self.prediction_setting.expand(self, sub)
|
||||||
|
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
|
||||||
|
for cand_set in candidates:
|
||||||
|
if cand_set:
|
||||||
|
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
|
||||||
|
for cand in cand_set:
|
||||||
|
cand_nodes = []
|
||||||
|
# candidate reactions can have multiple fragments
|
||||||
|
for c in cand:
|
||||||
|
if c not in self.smiles_to_node:
|
||||||
|
# For new nodes do an AppDomain Assessment if an AppDomain is attached
|
||||||
|
app_domain_assessment = None
|
||||||
|
if self.prediction_setting.model:
|
||||||
|
if self.prediction_setting.model.app_domain:
|
||||||
|
app_domain_assessment = (
|
||||||
|
self.prediction_setting.model.app_domain.assess(c)
|
||||||
|
)
|
||||||
|
snode = SNode(c, sub.depth + 1, app_domain_assessment)
|
||||||
|
self.smiles_to_node[c] = snode
|
||||||
|
new_nodes.append(snode)
|
||||||
|
|
||||||
|
node = self.smiles_to_node[c]
|
||||||
|
cand_nodes.append(node)
|
||||||
|
|
||||||
|
edge = SEdge(
|
||||||
|
sub,
|
||||||
|
cand_nodes,
|
||||||
|
rule=cand_set.rule,
|
||||||
|
probability=cand_set.probability,
|
||||||
|
)
|
||||||
|
self.edges.add(edge)
|
||||||
|
new_edges.append(edge)
|
||||||
|
|
||||||
|
return new_nodes, new_edges
|
||||||
|
|
||||||
|
def predict(self):
|
||||||
|
"""
|
||||||
|
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
|
||||||
|
|
||||||
|
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
|
||||||
|
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
|
||||||
|
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
|
||||||
|
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
|
||||||
|
nodes are processed systematically according to the specified schema.
|
||||||
|
|
||||||
|
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
|
||||||
|
supports persisting changes by writing back data to the database when configured to do so.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
done : bool
|
||||||
|
A flag indicating whether the prediction process is completed.
|
||||||
|
persist : Any
|
||||||
|
An optional object that manages persistence operations for saving modifications.
|
||||||
|
root_nodes : List[SNode]
|
||||||
|
A collection of initial nodes in the graph from which traversal begins.
|
||||||
|
prediction_setting : Any
|
||||||
|
Configuration object specifying settings for graph traversal, such as the choice of
|
||||||
|
expansion schema.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If an invalid or unknown expansion schema is provided in `prediction_setting`.
|
||||||
|
"""
|
||||||
|
# populate initial queue
|
||||||
|
queue = list(self.root_nodes)
|
||||||
|
processed = set()
|
||||||
|
|
||||||
|
# initial nodes have prob 1.0
|
||||||
|
node_probs: Dict[SNode, float] = {}
|
||||||
|
node_probs.update({n: 1.0 for n in queue})
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
current = queue.pop(0)
|
||||||
|
|
||||||
|
if current in processed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed.add(current)
|
||||||
|
|
||||||
|
new_nodes, new_edges = self._expand([current])
|
||||||
|
|
||||||
|
if new_nodes or new_edges:
|
||||||
|
# Check if we need to write back data to the database
|
||||||
|
if self.persist:
|
||||||
|
self._sync_to_pathway()
|
||||||
|
# call save to update the internal modified field
|
||||||
|
self.persist.save()
|
||||||
|
|
||||||
|
if new_nodes:
|
||||||
|
for edge in new_edges:
|
||||||
|
# All edge have `current` as educt
|
||||||
|
# Use `current` and adjust probs
|
||||||
|
current_prob = node_probs[current]
|
||||||
|
|
||||||
|
for prod in edge.products:
|
||||||
|
# Either is a new product or a product and we found a path with a higher prob
|
||||||
|
if (
|
||||||
|
prod not in node_probs
|
||||||
|
or current_prob * edge.probability > node_probs[prod]
|
||||||
|
):
|
||||||
|
node_probs[prod] = current_prob * edge.probability
|
||||||
|
|
||||||
|
# Update Queue to proceed
|
||||||
|
if self.prediction_setting.expansion_scheme == "DFS":
|
||||||
|
for n in new_nodes:
|
||||||
|
if n not in processed:
|
||||||
|
# We want to follow this path -> prepend queue
|
||||||
|
queue.insert(0, n)
|
||||||
|
elif self.prediction_setting.expansion_scheme == "BFS":
|
||||||
|
for n in new_nodes:
|
||||||
|
if n not in processed:
|
||||||
|
# Add at the end, everything queued before will be processed
|
||||||
|
# before new_nodese
|
||||||
|
queue.append(n)
|
||||||
|
elif self.prediction_setting.expansion_scheme == "GREEDY":
|
||||||
|
# Simply add them, as we will re-order the queue later
|
||||||
|
for n in new_nodes:
|
||||||
|
if n not in processed:
|
||||||
|
queue.append(n)
|
||||||
|
|
||||||
|
node_and_probs = []
|
||||||
|
for queued_val in queue:
|
||||||
|
node_and_probs.append((queued_val, node_probs[queued_val]))
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
pprint(node_and_probs)
|
||||||
|
|
||||||
|
# re-order the queue and only pick smiles
|
||||||
|
queue = [
|
||||||
|
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Queue exhausted, we're done
|
||||||
|
self.done = True
|
||||||
|
|
||||||
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
|
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
|
||||||
substrates: List[SNode] = []
|
substrates: List[SNode] = []
|
||||||
|
|
||||||
@ -1560,67 +1752,15 @@ class SPathway(object):
|
|||||||
if from_node == v:
|
if from_node == v:
|
||||||
substrates = [k]
|
substrates = [k]
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Node {from_node} not found in SPathway!")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Neither from_depth nor from_node_url specified")
|
raise ValueError("Neither from_depth nor from_node_url specified")
|
||||||
|
|
||||||
new_tp = False
|
new_tp = False
|
||||||
if substrates:
|
if substrates:
|
||||||
for sub in substrates:
|
new_nodes, _ = self._expand(substrates)
|
||||||
if sub.app_domain_assessment is None:
|
new_tp = len(new_nodes) > 0
|
||||||
if self.prediction_setting.model:
|
|
||||||
if self.prediction_setting.model.app_domain:
|
|
||||||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
|
||||||
sub.smiles
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.persist is not None:
|
|
||||||
n = self.snode_persist_lookup[sub]
|
|
||||||
|
|
||||||
assert n.id is not None, (
|
|
||||||
"Node has no id! Should have been saved already... aborting!"
|
|
||||||
)
|
|
||||||
node_data = n.simple_json()
|
|
||||||
node_data["image"] = f"{n.url}?image=svg"
|
|
||||||
app_domain_assessment["assessment"]["node"] = node_data
|
|
||||||
|
|
||||||
n.kv["app_domain_assessment"] = app_domain_assessment
|
|
||||||
n.save()
|
|
||||||
|
|
||||||
sub.app_domain_assessment = app_domain_assessment
|
|
||||||
|
|
||||||
candidates = self.prediction_setting.expand(self, sub)
|
|
||||||
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
|
|
||||||
for cand_set in candidates:
|
|
||||||
if cand_set:
|
|
||||||
new_tp = True
|
|
||||||
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
|
|
||||||
for cand in cand_set:
|
|
||||||
cand_nodes = []
|
|
||||||
# candidate reactions can have multiple fragments
|
|
||||||
for c in cand:
|
|
||||||
if c not in self.smiles_to_node:
|
|
||||||
# For new nodes do an AppDomain Assessment if an AppDomain is attached
|
|
||||||
app_domain_assessment = None
|
|
||||||
if self.prediction_setting.model:
|
|
||||||
if self.prediction_setting.model.app_domain:
|
|
||||||
app_domain_assessment = (
|
|
||||||
self.prediction_setting.model.app_domain.assess(c)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.smiles_to_node[c] = SNode(
|
|
||||||
c, sub.depth + 1, app_domain_assessment
|
|
||||||
)
|
|
||||||
|
|
||||||
node = self.smiles_to_node[c]
|
|
||||||
cand_nodes.append(node)
|
|
||||||
|
|
||||||
edge = SEdge(
|
|
||||||
sub,
|
|
||||||
cand_nodes,
|
|
||||||
rule=cand_set.rule,
|
|
||||||
probability=cand_set.probability,
|
|
||||||
)
|
|
||||||
self.edges.add(edge)
|
|
||||||
|
|
||||||
# In case no substrates are found, we're done.
|
# In case no substrates are found, we're done.
|
||||||
# For "predict from node" we're always done
|
# For "predict from node" we're always done
|
||||||
@ -1704,11 +1844,6 @@ class SPathway(object):
|
|||||||
"to": to_indices,
|
"to": to_indices,
|
||||||
}
|
}
|
||||||
|
|
||||||
# if edge.rule:
|
|
||||||
# e['rule'] = {
|
|
||||||
# 'name': edge.rule.name,
|
|
||||||
# 'id': edge.rule.url,
|
|
||||||
# }
|
|
||||||
edges.append(e)
|
edges.append(e)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
25
epdb/migrations/0013_setting_expansion_schema.py
Normal file
25
epdb/migrations/0013_setting_expansion_schema.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Generated by Django 5.2.7 on 2025-12-14 11:30
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("epdb", "0012_node_stereo_removed_pathway_predicted"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="setting",
|
||||||
|
name="expansion_schema",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[
|
||||||
|
("BFS", "Breadth First Search"),
|
||||||
|
("DFS", "Depth First Search"),
|
||||||
|
("GREEDY", "Greedy"),
|
||||||
|
],
|
||||||
|
default="BFS",
|
||||||
|
max_length=20,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
# Generated by Django 5.2.7 on 2025-12-14 16:02
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("epdb", "0013_setting_expansion_schema"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RenameField(
|
||||||
|
model_name="setting",
|
||||||
|
old_name="expansion_schema",
|
||||||
|
new_name="expansion_scheme",
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|||||||
# potentially prefetched edge_set
|
# potentially prefetched edge_set
|
||||||
return self.edge_set.all()
|
return self.edge_set.all()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def setting_with_overrides(self):
|
||||||
|
mem_copy = Setting.objects.get(pk=self.setting.pk)
|
||||||
|
|
||||||
|
if "setting_overrides" in self.kv:
|
||||||
|
for k, v in self.kv["setting_overrides"].items():
|
||||||
|
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
|
||||||
|
|
||||||
|
return mem_copy
|
||||||
|
|
||||||
def _url(self):
|
def _url(self):
|
||||||
return "{}/pathway/{}".format(self.package.url, self.uuid)
|
return "{}/pathway/{}".format(self.package.url, self.uuid)
|
||||||
|
|
||||||
@ -1879,25 +1889,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|||||||
|
|
||||||
return json.dumps(res)
|
return json.dumps(res)
|
||||||
|
|
||||||
def to_csv(self) -> str:
|
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
|
||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
header = []
|
||||||
|
|
||||||
|
if include_pathway_url:
|
||||||
|
header += ["Pathway URL"]
|
||||||
|
|
||||||
|
header += [
|
||||||
|
"SMILES",
|
||||||
|
"name",
|
||||||
|
"depth",
|
||||||
|
"probability",
|
||||||
|
"rule_names",
|
||||||
|
"rule_ids",
|
||||||
|
"parent_smiles",
|
||||||
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
rows.append(
|
|
||||||
[
|
if include_header:
|
||||||
"SMILES",
|
rows.append(header)
|
||||||
"name",
|
|
||||||
"depth",
|
|
||||||
"probability",
|
|
||||||
"rule_names",
|
|
||||||
"rule_ids",
|
|
||||||
"parent_smiles",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for n in self.nodes.order_by("depth"):
|
for n in self.nodes.order_by("depth"):
|
||||||
cs = n.default_node_label
|
cs = n.default_node_label
|
||||||
row = [cs.smiles, cs.name, n.depth]
|
row = []
|
||||||
|
|
||||||
|
if include_pathway_url:
|
||||||
|
row.append(n.pathway.url)
|
||||||
|
|
||||||
|
row += [cs.smiles, cs.name, n.depth]
|
||||||
|
|
||||||
edges = self.edges.filter(end_nodes__in=[n])
|
edges = self.edges.filter(end_nodes__in=[n])
|
||||||
if len(edges):
|
if len(edges):
|
||||||
@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mg_pr_curve(self):
|
||||||
|
if self.model_status != self.FINISHED:
|
||||||
|
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
||||||
|
|
||||||
|
if not self.multigen_eval:
|
||||||
|
raise ValueError("MG PR Curve is only available for multigen models")
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
|
||||||
|
|
||||||
|
for t in thresholds:
|
||||||
|
res.append(
|
||||||
|
{
|
||||||
|
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
|
||||||
|
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
|
||||||
|
"threshold": float(t),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def applicable_rules(self) -> List["Rule"]:
|
def applicable_rules(self) -> List["Rule"]:
|
||||||
"""
|
"""
|
||||||
@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel):
|
|||||||
for i, root in enumerate(root_compounds):
|
for i, root in enumerate(root_compounds):
|
||||||
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
||||||
|
|
||||||
spw = SPathway(root_nodes=root, prediction_setting=s)
|
spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
|
||||||
level = 0
|
level = 0
|
||||||
|
|
||||||
while not spw.done:
|
while not spw.done:
|
||||||
@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission):
|
|||||||
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
|
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
|
||||||
|
|
||||||
|
|
||||||
|
class ExpansionSchemeChoice(models.TextChoices):
|
||||||
|
BFS = "BFS", "Breadth First Search"
|
||||||
|
DFS = "DFS", "Depth First Search"
|
||||||
|
GREEDY = "GREEDY", "Greedy"
|
||||||
|
|
||||||
|
|
||||||
class Setting(EnviPathModel):
|
class Setting(EnviPathModel):
|
||||||
public = models.BooleanField(null=False, blank=False, default=False)
|
public = models.BooleanField(null=False, blank=False, default=False)
|
||||||
global_default = models.BooleanField(null=False, blank=False, default=False)
|
global_default = models.BooleanField(null=False, blank=False, default=False)
|
||||||
@ -3795,6 +3847,12 @@ class Setting(EnviPathModel):
|
|||||||
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
|
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expansion_scheme = models.CharField(
|
||||||
|
max_length=20,
|
||||||
|
choices=ExpansionSchemeChoice.choices,
|
||||||
|
default=ExpansionSchemeChoice.BFS,
|
||||||
|
)
|
||||||
|
|
||||||
def _url(self):
|
def _url(self):
|
||||||
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
|
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
|
||||||
|
|
||||||
@ -3833,7 +3891,7 @@ class Setting(EnviPathModel):
|
|||||||
"""Decision Method whether to expand on a certain Node or not"""
|
"""Decision Method whether to expand on a certain Node or not"""
|
||||||
if pathway.num_nodes() >= self.max_nodes:
|
if pathway.num_nodes() >= self.max_nodes:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}"
|
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel):
|
|||||||
if self.job_name == "engineer_pathways":
|
if self.job_name == "engineer_pathways":
|
||||||
return ast.literal_eval(self.task_result)
|
return ast.literal_eval(self.task_result)
|
||||||
return self.task_result
|
return self.task_result
|
||||||
|
|
||||||
|
def is_result_downloadable(self):
|
||||||
|
downloadable = ["batch_predict"]
|
||||||
|
|
||||||
|
return self.job_name in downloadable
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from django.utils import timezone
|
|||||||
|
|
||||||
from epdb.logic import SPathway
|
from epdb.logic import SPathway
|
||||||
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
|
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
|
||||||
|
from utilities.chem import FormatConverter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
|
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
|
||||||
@ -139,14 +140,25 @@ def predict(
|
|||||||
pred_setting_pk: int,
|
pred_setting_pk: int,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
node_pk: Optional[int] = None,
|
node_pk: Optional[int] = None,
|
||||||
|
setting_overrides: Optional[dict] = None,
|
||||||
) -> Pathway:
|
) -> Pathway:
|
||||||
pw = Pathway.objects.get(id=pw_pk)
|
pw = Pathway.objects.get(id=pw_pk)
|
||||||
setting = Setting.objects.get(id=pred_setting_pk)
|
setting = Setting.objects.get(id=pred_setting_pk)
|
||||||
|
|
||||||
|
if setting_overrides:
|
||||||
|
for k, v in setting_overrides.items():
|
||||||
|
setattr(setting, k, v)
|
||||||
|
|
||||||
# If the setting has a model add/restore it from the cache
|
# If the setting has a model add/restore it from the cache
|
||||||
if setting.model is not None:
|
if setting.model is not None:
|
||||||
setting.model = get_ml_model(setting.model.pk)
|
setting.model = get_ml_model(setting.model.pk)
|
||||||
|
|
||||||
pw.kv.update(**{"status": "running"})
|
kv = {"status": "running"}
|
||||||
|
|
||||||
|
if setting_overrides:
|
||||||
|
kv["setting_overrides"] = setting_overrides
|
||||||
|
|
||||||
|
pw.kv.update(**kv)
|
||||||
pw.save()
|
pw.save()
|
||||||
|
|
||||||
if JobLog.objects.filter(task_id=self.request.id).exists():
|
if JobLog.objects.filter(task_id=self.request.id).exists():
|
||||||
@ -171,7 +183,8 @@ def predict(
|
|||||||
spw = SPathway.from_pathway(pw)
|
spw = SPathway.from_pathway(pw)
|
||||||
spw.predict_step(from_node=n)
|
spw.predict_step(from_node=n)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Neither limit nor node_pk given!")
|
spw = SPathway(prediction_setting=setting, persist=pw)
|
||||||
|
spw.predict()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pw.kv.update({"status": "failed"})
|
pw.kv.update({"status": "failed"})
|
||||||
@ -353,3 +366,76 @@ def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_p
|
|||||||
predicted_pathways.append(pred_pw.url)
|
predicted_pathways.append(pred_pw.url)
|
||||||
|
|
||||||
return intermediate_pathways, predicted_pathways
|
return intermediate_pathways, predicted_pathways
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(bind=True, queue="background")
|
||||||
|
def batch_predict(
|
||||||
|
self,
|
||||||
|
substrates: List[str] | List[List[str]],
|
||||||
|
prediction_setting_pk: int,
|
||||||
|
target_package_pk: int,
|
||||||
|
num_tps: int = 50,
|
||||||
|
):
|
||||||
|
target_package = Package.objects.get(pk=target_package_pk)
|
||||||
|
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
|
||||||
|
|
||||||
|
if len(substrates) == 0:
|
||||||
|
raise ValueError("No substrates given!")
|
||||||
|
|
||||||
|
is_pair = isinstance(substrates[0], list)
|
||||||
|
|
||||||
|
substrate_and_names = []
|
||||||
|
if not is_pair:
|
||||||
|
for sub in substrates:
|
||||||
|
substrate_and_names.append([sub, None])
|
||||||
|
else:
|
||||||
|
substrate_and_names = substrates
|
||||||
|
|
||||||
|
# Check prerequisite that we can standardize all substrates
|
||||||
|
standardized_substrates_and_smiles = []
|
||||||
|
for substrate in substrate_and_names:
|
||||||
|
try:
|
||||||
|
stand_smiles = FormatConverter.standardize(substrate[0])
|
||||||
|
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
|
||||||
|
)
|
||||||
|
|
||||||
|
pathways = []
|
||||||
|
|
||||||
|
for pair in standardized_substrates_and_smiles:
|
||||||
|
pw = Pathway.create(
|
||||||
|
target_package,
|
||||||
|
pair[0],
|
||||||
|
name=pair[1],
|
||||||
|
predicted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# set mode and setting
|
||||||
|
pw.setting = prediction_setting
|
||||||
|
pw.kv.update({"mode": "predict"})
|
||||||
|
pw.save()
|
||||||
|
|
||||||
|
predict(
|
||||||
|
pw.pk,
|
||||||
|
prediction_setting.pk,
|
||||||
|
limit=None,
|
||||||
|
setting_overrides={
|
||||||
|
"max_nodes": num_tps,
|
||||||
|
"max_depth": num_tps,
|
||||||
|
"model_threshold": 0.001,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
pathways.append(pw)
|
||||||
|
|
||||||
|
buffer = io.StringIO()
|
||||||
|
|
||||||
|
for idx, pw in enumerate(pathways):
|
||||||
|
# Carry out header only for the first pathway
|
||||||
|
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
|
||||||
|
|
||||||
|
buffer.seek(0)
|
||||||
|
|
||||||
|
return buffer.getvalue()
|
||||||
|
|||||||
@ -49,6 +49,7 @@ urlpatterns = [
|
|||||||
re_path(r"^group$", v.groups, name="groups"),
|
re_path(r"^group$", v.groups, name="groups"),
|
||||||
re_path(r"^search$", v.search, name="search"),
|
re_path(r"^search$", v.search, name="search"),
|
||||||
re_path(r"^predict$", v.predict_pathway, name="predict_pathway"),
|
re_path(r"^predict$", v.predict_pathway, name="predict_pathway"),
|
||||||
|
re_path(r"^batch-predict$", v.batch_predict_pathway, name="batch_predict_pathway"),
|
||||||
# User Detail
|
# User Detail
|
||||||
re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
|
re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
|
||||||
# Group Detail
|
# Group Detail
|
||||||
|
|||||||
100
epdb/views.py
100
epdb/views.py
@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import nh3
|
import nh3
|
||||||
from django.conf import settings as s
|
from django.conf import settings as s
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.core.exceptions import BadRequest
|
from django.core.exceptions import BadRequest, PermissionDenied
|
||||||
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
|
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
|
||||||
from django.shortcuts import redirect, render
|
from django.shortcuts import redirect, render
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
@ -50,6 +51,7 @@ from .models import (
|
|||||||
SimpleAmbitRule,
|
SimpleAmbitRule,
|
||||||
User,
|
User,
|
||||||
UserPackagePermission,
|
UserPackagePermission,
|
||||||
|
ExpansionSchemeChoice,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -438,6 +440,18 @@ def predict_pathway(request):
|
|||||||
return render(request, "predict_pathway.html", context)
|
return render(request, "predict_pathway.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_predict_pathway(request):
|
||||||
|
"""Top-level predict pathway view using user's default package."""
|
||||||
|
if request.method != "GET":
|
||||||
|
return HttpResponseNotAllowed(["GET"])
|
||||||
|
|
||||||
|
context = get_base_context(request)
|
||||||
|
context["title"] = "enviPath - Batch Predict Pathway"
|
||||||
|
context["meta"]["current_package"] = context["meta"]["user"].default_package
|
||||||
|
|
||||||
|
return render(request, "batch_predict_pathway.html", context)
|
||||||
|
|
||||||
|
|
||||||
@package_permission_required()
|
@package_permission_required()
|
||||||
def package_predict_pathway(request, package_uuid):
|
def package_predict_pathway(request, package_uuid):
|
||||||
"""Package-specific predict pathway view."""
|
"""Package-specific predict pathway view."""
|
||||||
@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid):
|
|||||||
|
|
||||||
if pw_mode == "predict" or pw_mode == "incremental":
|
if pw_mode == "predict" or pw_mode == "incremental":
|
||||||
# unlimited pred (will be handled by setting)
|
# unlimited pred (will be handled by setting)
|
||||||
limit = -1
|
limit = None
|
||||||
|
|
||||||
# For incremental predict first level and return
|
# For incremental predict first level and return
|
||||||
if pw_mode == "incremental":
|
if pw_mode == "incremental":
|
||||||
@ -2877,15 +2891,25 @@ def settings(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not PackageManager.readable(current_user, params["model"].package):
|
if not PackageManager.readable(current_user, params["model"].package):
|
||||||
raise ValueError("")
|
raise PermissionDenied("You're not allowed to access this model!")
|
||||||
|
|
||||||
|
expansion_scheme = request.POST.get(
|
||||||
|
"model-based-prediction-setting-expansion-scheme", "BFS"
|
||||||
|
)
|
||||||
|
|
||||||
|
if expansion_scheme not in ExpansionSchemeChoice.values:
|
||||||
|
raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}")
|
||||||
|
|
||||||
|
params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme)
|
||||||
|
|
||||||
elif tp_gen_method == "rule-based-prediction-setting":
|
elif tp_gen_method == "rule-based-prediction-setting":
|
||||||
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
|
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
|
||||||
params["rule_packages"] = [
|
params["rule_packages"] = [
|
||||||
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("")
|
raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!")
|
||||||
|
|
||||||
created_setting = SettingManager.create_setting(
|
created_setting = SettingManager.create_setting(
|
||||||
current_user,
|
current_user,
|
||||||
@ -2963,6 +2987,59 @@ def jobs(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
|
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
|
||||||
|
|
||||||
|
elif job_name == "batch-predict":
|
||||||
|
substrates = request.POST.get("substrates")
|
||||||
|
prediction_setting_url = request.POST.get("prediction-setting")
|
||||||
|
num_tps = request.POST.get("num-tps")
|
||||||
|
|
||||||
|
if substrates is None or substrates.strip() == "":
|
||||||
|
raise BadRequest("No substrates provided.")
|
||||||
|
|
||||||
|
pred_data = []
|
||||||
|
for pair in substrates.split("\n"):
|
||||||
|
parts = pair.split(",")
|
||||||
|
|
||||||
|
try:
|
||||||
|
smiles = FormatConverter.standardize(parts[0])
|
||||||
|
except ValueError:
|
||||||
|
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
|
||||||
|
|
||||||
|
# name is optional
|
||||||
|
name = parts[1] if len(parts) > 1 else None
|
||||||
|
pred_data.append([smiles, name])
|
||||||
|
|
||||||
|
max_tps = 50
|
||||||
|
if num_tps is not None and num_tps.strip() != "":
|
||||||
|
try:
|
||||||
|
num_tps = int(num_tps)
|
||||||
|
max_tps = max(min(num_tps, 50), 1)
|
||||||
|
except ValueError:
|
||||||
|
raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.")
|
||||||
|
|
||||||
|
batch_predict_setting = SettingManager.get_setting_by_url(
|
||||||
|
current_user, prediction_setting_url
|
||||||
|
)
|
||||||
|
|
||||||
|
target_package = PackageManager.create_package(
|
||||||
|
current_user,
|
||||||
|
f"Autogenerated Package for Batch Prediction {datetime.now()}",
|
||||||
|
"This Package was generated automatically for the batch prediction task.",
|
||||||
|
)
|
||||||
|
|
||||||
|
from .tasks import dispatch, batch_predict
|
||||||
|
|
||||||
|
res = dispatch(
|
||||||
|
current_user,
|
||||||
|
batch_predict,
|
||||||
|
pred_data,
|
||||||
|
batch_predict_setting.pk,
|
||||||
|
target_package.pk,
|
||||||
|
num_tps=max_tps,
|
||||||
|
)
|
||||||
|
|
||||||
|
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise BadRequest(f"Job {job_name} is not supported!")
|
raise BadRequest(f"Job {job_name} is not supported!")
|
||||||
else:
|
else:
|
||||||
@ -2982,6 +3059,21 @@ def job(request, job_uuid):
|
|||||||
# No op if status is already in a terminal state
|
# No op if status is already in a terminal state
|
||||||
job.check_for_update()
|
job.check_for_update()
|
||||||
|
|
||||||
|
if request.GET.get("download", False) == "true":
|
||||||
|
if not job.is_result_downloadable():
|
||||||
|
raise BadRequest("Result is not downloadable!")
|
||||||
|
|
||||||
|
if job.job_name == "batch_predict":
|
||||||
|
filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv"
|
||||||
|
else:
|
||||||
|
raise BadRequest("Result is not downloadable!")
|
||||||
|
|
||||||
|
res_str = job.task_result
|
||||||
|
response = HttpResponse(res_str, content_type="text/csv")
|
||||||
|
response["Content-Disposition"] = f'attachment; filename="{filename}"'
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
context["object_type"] = "joblog"
|
context["object_type"] = "joblog"
|
||||||
context["breadcrumbs"] = [
|
context["breadcrumbs"] = [
|
||||||
{"Home": s.SERVER_URL},
|
{"Home": s.SERVER_URL},
|
||||||
|
|||||||
@ -0,0 +1,10 @@
|
|||||||
|
{% if job.is_result_downloadable %}
|
||||||
|
<li>
|
||||||
|
<a
|
||||||
|
class="button"
|
||||||
|
onclick="document.getElementById('download_job_result_modal').showModal(); return false;"
|
||||||
|
>
|
||||||
|
<i class="glyphicon glyphicon-floppy-save"></i> Download Result</a
|
||||||
|
>
|
||||||
|
</li>
|
||||||
|
{% endif %}
|
||||||
|
|||||||
168
templates/batch_predict_pathway.html
Normal file
168
templates/batch_predict_pathway.html
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
{% extends "framework_modern.html" %}
|
||||||
|
{% load static %}
|
||||||
|
{% block content %}
|
||||||
|
<div class="mx-auto w-full p-8">
|
||||||
|
<h1 class="h1 mb-4 text-3xl font-bold">Batch Predict Pathways</h1>
|
||||||
|
<form id="smiles-form" method="POST" action="{% url "jobs" %}">
|
||||||
|
{% csrf_token %}
|
||||||
|
<input type="hidden" name="substrates" id="substrates" />
|
||||||
|
<input type="hidden" name="job-name" value="batch-predict" />
|
||||||
|
|
||||||
|
<fieldset class="flex flex-col gap-4 md:flex-3/4">
|
||||||
|
<table class="table table-zebra w-full">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>SMILES</th>
|
||||||
|
<th>Name</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody id="smiles-table-body">
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
class="input input-bordered w-full smiles-input"
|
||||||
|
placeholder="CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
|
||||||
|
{% if meta.debug %}
|
||||||
|
value="CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
|
||||||
|
{% endif %}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
class="input input-bordered w-full name-input"
|
||||||
|
placeholder="Caffeine"
|
||||||
|
{% if meta.debug %}
|
||||||
|
value="Caffeine"
|
||||||
|
{% endif %}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
class="input input-bordered w-full smiles-input"
|
||||||
|
placeholder="CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
|
||||||
|
{% if meta.debug %}
|
||||||
|
value="CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
|
||||||
|
{% endif %}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
class="input input-bordered w-full name-input"
|
||||||
|
placeholder="Ibuprofen"
|
||||||
|
{% if meta.debug %}
|
||||||
|
value="Ibuprofen"
|
||||||
|
{% endif %}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
<label class="select mb-2 w-full">
|
||||||
|
<span class="label">Predictor</span>
|
||||||
|
<select id="prediction-setting" name="prediction-setting">
|
||||||
|
<option disabled>Select a Setting</option>
|
||||||
|
{% for s in meta.available_settings %}
|
||||||
|
<option
|
||||||
|
value="{{ s.url }}"
|
||||||
|
{% if s.id == meta.user.default_setting.id %}selected{% endif %}
|
||||||
|
>
|
||||||
|
{{ s.name }}{% if s.id == meta.user.default_setting.id %}
|
||||||
|
(User default)
|
||||||
|
{% endif %}
|
||||||
|
</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
<label class="floating-label" for="num-tps">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
name="num-tps"
|
||||||
|
value="50"
|
||||||
|
step="1"
|
||||||
|
min="1"
|
||||||
|
max="100"
|
||||||
|
id="num-tps"
|
||||||
|
class="input input-md w-full"
|
||||||
|
/>
|
||||||
|
<span>Max Transformation Products</span>
|
||||||
|
</label>
|
||||||
|
<div class="flex justify-end gap-2">
|
||||||
|
<button type="button" id="add-row-btn" class="btn btn-outline">
|
||||||
|
Add row
|
||||||
|
</button>
|
||||||
|
<button type="submit" class="btn btn-primary">Submit</button>
|
||||||
|
</div>
|
||||||
|
</fieldset>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const tableBody = document.getElementById("smiles-table-body");
|
||||||
|
const addRowBtn = document.getElementById("add-row-btn");
|
||||||
|
const form = document.getElementById("smiles-form");
|
||||||
|
const hiddenField = document.getElementById("substrates");
|
||||||
|
|
||||||
|
addRowBtn.addEventListener("click", () => {
|
||||||
|
const row = document.createElement("tr");
|
||||||
|
|
||||||
|
const tdSmiles = document.createElement("td");
|
||||||
|
const tdName = document.createElement("td");
|
||||||
|
|
||||||
|
const smilesInput = document.createElement("input");
|
||||||
|
smilesInput.type = "text";
|
||||||
|
smilesInput.className = "input input-bordered w-full smiles-input";
|
||||||
|
smilesInput.placeholder = "SMILES";
|
||||||
|
|
||||||
|
const nameInput = document.createElement("input");
|
||||||
|
nameInput.type = "text";
|
||||||
|
nameInput.className = "input input-bordered w-full name-input";
|
||||||
|
nameInput.placeholder = "Name";
|
||||||
|
|
||||||
|
tdSmiles.appendChild(smilesInput);
|
||||||
|
tdName.appendChild(nameInput);
|
||||||
|
|
||||||
|
row.appendChild(tdSmiles);
|
||||||
|
row.appendChild(tdName);
|
||||||
|
|
||||||
|
tableBody.appendChild(row);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Before submit, gather table data into the hidden field
|
||||||
|
form.addEventListener("submit", (e) => {
|
||||||
|
const smilesInputs = Array.from(
|
||||||
|
document.querySelectorAll(".smiles-input"),
|
||||||
|
);
|
||||||
|
const nameInputs = Array.from(document.querySelectorAll(".name-input"));
|
||||||
|
|
||||||
|
const lines = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < smilesInputs.length; i++) {
|
||||||
|
const smiles = smilesInputs[i].value.trim();
|
||||||
|
const name = nameInputs[i]?.value.trim() ?? "";
|
||||||
|
|
||||||
|
// Skip emtpy rows
|
||||||
|
if (!smiles && !name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
lines.push(`${smiles},${name}`);
|
||||||
|
}
|
||||||
|
// Value looks like:
|
||||||
|
// "CN1C=NC2=C1C(=O)N(C(=O)N2C)C,Caffeine\nCC(C)CC1=CC=C(C=C1)C(C)C(=O)O,Ibuprofen"
|
||||||
|
hiddenField.value = lines.join("\n");
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
{% endblock content %}
|
||||||
@ -210,6 +210,27 @@
|
|||||||
step="0.05"
|
step="0.05"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="form-control mb-3">
|
||||||
|
<label
|
||||||
|
class="label"
|
||||||
|
for="model-based-prediction-setting-expansion-scheme"
|
||||||
|
>
|
||||||
|
<span class="label-text">Select Expansion Scheme</span>
|
||||||
|
</label>
|
||||||
|
<select
|
||||||
|
id="model-based-prediction-setting-expansion-scheme"
|
||||||
|
name="model-based-prediction-setting-expansion-scheme"
|
||||||
|
class="select select-bordered w-full"
|
||||||
|
>
|
||||||
|
<option value="" disabled selected>
|
||||||
|
Select the Expansion Scheme
|
||||||
|
</option>
|
||||||
|
<option value="BFS">Breadth First Search</option>
|
||||||
|
<option value="DFS">Depth First Search</option>
|
||||||
|
<option value="GREEDY">Greedy</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="form-control">
|
<div class="form-control">
|
||||||
|
|||||||
66
templates/modals/objects/download_job_result_modal.html
Normal file
66
templates/modals/objects/download_job_result_modal.html
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
{% load static %}
|
||||||
|
|
||||||
|
<dialog
|
||||||
|
id="download_job_result_modal"
|
||||||
|
class="modal"
|
||||||
|
x-data="modalForm()"
|
||||||
|
@close="reset()"
|
||||||
|
>
|
||||||
|
<div class="modal-box">
|
||||||
|
<!-- Header -->
|
||||||
|
<h3 class="font-bold text-lg">Download Job Result</h3>
|
||||||
|
|
||||||
|
<!-- Close button (X) -->
|
||||||
|
<form method="dialog">
|
||||||
|
<button
|
||||||
|
class="btn btn-sm btn-circle btn-ghost absolute right-2 top-2"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<!-- Body -->
|
||||||
|
<div class="py-4">
|
||||||
|
<p>By clicking on Download the Result of this Job will be saved.</p>
|
||||||
|
|
||||||
|
<form
|
||||||
|
id="download-job-result-modal-form"
|
||||||
|
accept-charset="UTF-8"
|
||||||
|
action="{{ job.url }}"
|
||||||
|
method="GET"
|
||||||
|
>
|
||||||
|
<input type="hidden" name="download" value="true" />
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Footer -->
|
||||||
|
<div class="modal-action">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn"
|
||||||
|
onclick="this.closest('dialog').close()"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
>
|
||||||
|
Close
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-primary"
|
||||||
|
@click="submit('download-job-result-modal-form'); $el.closest('dialog').close();"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
>
|
||||||
|
<span x-show="!isSubmitting">Download</span>
|
||||||
|
<span
|
||||||
|
x-show="isSubmitting"
|
||||||
|
class="loading loading-spinner loading-sm"
|
||||||
|
></span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Backdrop -->
|
||||||
|
<form method="dialog" class="modal-backdrop">
|
||||||
|
<button :disabled="isSubmitting">close</button>
|
||||||
|
</form>
|
||||||
|
</dialog>
|
||||||
@ -3,7 +3,9 @@
|
|||||||
{% block content %}
|
{% block content %}
|
||||||
|
|
||||||
{% block action_modals %}
|
{% block action_modals %}
|
||||||
{# {% include "modals/objects/refresh_job_log.html" %}#}
|
{% if job.is_result_downloadable %}
|
||||||
|
{% include "modals/objects/download_job_result_modal.html" %}
|
||||||
|
{% endif %}
|
||||||
{% endblock action_modals %}
|
{% endblock action_modals %}
|
||||||
|
|
||||||
<div class="space-y-2 p-4">
|
<div class="space-y-2 p-4">
|
||||||
@ -49,22 +51,20 @@
|
|||||||
<div class="collapse-arrow bg-base-200 collapse">
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
<input type="checkbox" checked />
|
<input type="checkbox" checked />
|
||||||
<div class="collapse-title text-xl font-medium">Description</div>
|
<div class="collapse-title text-xl font-medium">Description</div>
|
||||||
<div class="collapse-content">
|
<div class="collapse-content">Status page for Job {{ job.job_name }}</div>
|
||||||
Status page for Task {{ job.job_name }}
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Job Status -->
|
<!-- Job Status -->
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
<input type="checkbox" checked />
|
<input type="checkbox" checked />
|
||||||
<div class="collapse-title text-xl font-medium">Task Status</div>
|
<div class="collapse-title text-xl font-medium">Job Status</div>
|
||||||
<div class="collapse-content">{{ job.status }}</div>
|
<div class="collapse-content">{{ job.status }}</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Job ID -->
|
<!-- Job ID -->
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
<input type="checkbox" checked />
|
<input type="checkbox" checked />
|
||||||
<div class="collapse-title text-xl font-medium">Task ID</div>
|
<div class="collapse-title text-xl font-medium">Job ID</div>
|
||||||
<div class="collapse-content">{{ job.task_id }}</div>
|
<div class="collapse-content">{{ job.task_id }}</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -72,7 +72,7 @@
|
|||||||
{% if job.is_in_terminal_state %}
|
{% if job.is_in_terminal_state %}
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
<input type="checkbox" checked />
|
<input type="checkbox" checked />
|
||||||
<div class="collapse-title text-xl font-medium">Task Result</div>
|
<div class="collapse-title text-xl font-medium">Job Result</div>
|
||||||
<div class="collapse-content">
|
<div class="collapse-content">
|
||||||
{% if job.job_name == 'engineer_pathways' %}
|
{% if job.job_name == 'engineer_pathways' %}
|
||||||
<div class="card bg-base-100">
|
<div class="card bg-base-100">
|
||||||
@ -103,6 +103,68 @@
|
|||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
{% elif job.job_name == 'batch_predict' %}
|
||||||
|
<div
|
||||||
|
id="table-container"
|
||||||
|
class="overflow-x-auto overflow-y-auto max-h-96 border rounded-lg"
|
||||||
|
></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const input = `{{ job.task_result }}`;
|
||||||
|
|
||||||
|
function renderCsvTable(str) {
|
||||||
|
const lines = str
|
||||||
|
.split("\n")
|
||||||
|
.map((l) => l.trim())
|
||||||
|
.filter(Boolean);
|
||||||
|
const [headerLine, ...rows] = lines;
|
||||||
|
|
||||||
|
const headers = headerLine.split(",").map((h) => h.trim());
|
||||||
|
|
||||||
|
const table = document.createElement("table");
|
||||||
|
table.className = "table table-zebra w-full";
|
||||||
|
|
||||||
|
const thead = document.createElement("thead");
|
||||||
|
const headerRow = document.createElement("tr");
|
||||||
|
|
||||||
|
headers.forEach((h) => {
|
||||||
|
const th = document.createElement("th");
|
||||||
|
th.textContent = h;
|
||||||
|
headerRow.appendChild(th);
|
||||||
|
});
|
||||||
|
|
||||||
|
thead.appendChild(headerRow);
|
||||||
|
|
||||||
|
const tbody = document.createElement("tbody");
|
||||||
|
|
||||||
|
rows.forEach((rowStr) => {
|
||||||
|
console.log(rowStr.split(","));
|
||||||
|
console.log(headers);
|
||||||
|
const row = document.createElement("tr");
|
||||||
|
const cells = rowStr.split(",").map((c) => c.trim());
|
||||||
|
|
||||||
|
headers.forEach((_, i) => {
|
||||||
|
const td = document.createElement("td");
|
||||||
|
|
||||||
|
const value = cells[i] || "";
|
||||||
|
|
||||||
|
td.textContent = value;
|
||||||
|
|
||||||
|
row.appendChild(td);
|
||||||
|
});
|
||||||
|
console.log(row);
|
||||||
|
tbody.appendChild(row);
|
||||||
|
});
|
||||||
|
|
||||||
|
table.appendChild(thead);
|
||||||
|
table.appendChild(tbody);
|
||||||
|
return table;
|
||||||
|
}
|
||||||
|
|
||||||
|
document
|
||||||
|
.getElementById("table-container")
|
||||||
|
.appendChild(renderCsvTable(input));
|
||||||
|
</script>
|
||||||
{% else %}
|
{% else %}
|
||||||
{{ job.parsed_result }}
|
{{ job.parsed_result }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|||||||
@ -73,13 +73,29 @@
|
|||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<!-- Reaction Packages -->
|
{% endif %}
|
||||||
|
<!-- Reaction Packages -->
|
||||||
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
|
<input type="checkbox" checked />
|
||||||
|
<div class="collapse-title text-xl font-medium">Reaction Packages</div>
|
||||||
|
<div class="collapse-content">
|
||||||
|
<ul class="menu bg-base-100 rounded-box w-full">
|
||||||
|
{% for p in model.data_packages.all %}
|
||||||
|
<li>
|
||||||
|
<a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a>
|
||||||
|
</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% if model.eval_packages.all|length > 0 %}
|
||||||
|
<!-- Eval Packages -->
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
<input type="checkbox" checked />
|
<input type="checkbox" checked />
|
||||||
<div class="collapse-title text-xl font-medium">Reaction Packages</div>
|
<div class="collapse-title text-xl font-medium">Eval Packages</div>
|
||||||
<div class="collapse-content">
|
<div class="collapse-content">
|
||||||
<ul class="menu bg-base-100 rounded-box w-full">
|
<ul class="menu bg-base-100 rounded-box w-full">
|
||||||
{% for p in model.data_packages.all %}
|
{% for p in model.eval_packages.all %}
|
||||||
<li>
|
<li>
|
||||||
<a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a>
|
<a href="{{ p.url }}" class="hover:bg-base-200">{{ p.name }}</a>
|
||||||
</li>
|
</li>
|
||||||
@ -87,31 +103,13 @@
|
|||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{% if model.eval_packages.all|length > 0 %}
|
|
||||||
<!-- Eval Packages -->
|
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
|
||||||
<input type="checkbox" checked />
|
|
||||||
<div class="collapse-title text-xl font-medium">Eval Packages</div>
|
|
||||||
<div class="collapse-content">
|
|
||||||
<ul class="menu bg-base-100 rounded-box w-full">
|
|
||||||
{% for p in model.eval_packages.all %}
|
|
||||||
<li>
|
|
||||||
<a href="{{ p.url }}" class="hover:bg-base-200"
|
|
||||||
>{{ p.name }}</a
|
|
||||||
>
|
|
||||||
</li>
|
|
||||||
{% endfor %}
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{% endif %}
|
|
||||||
<!-- Model Status -->
|
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
|
||||||
<input type="checkbox" checked />
|
|
||||||
<div class="collapse-title text-xl font-medium">Model Status</div>
|
|
||||||
<div class="collapse-content">{{ model.status }}</div>
|
|
||||||
</div>
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
<!-- Model Status -->
|
||||||
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
|
<input type="checkbox" checked />
|
||||||
|
<div class="collapse-title text-xl font-medium">Model Status</div>
|
||||||
|
<div class="collapse-content">{{ model.status }}</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
{% if model.ready_for_prediction %}
|
{% if model.ready_for_prediction %}
|
||||||
<!-- Predict Panel -->
|
<!-- Predict Panel -->
|
||||||
@ -174,7 +172,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if model.model_status == 'FINISHED' %}
|
{% if model.model_status == 'FINISHED' %}
|
||||||
<!-- Single Gen Curve Panel -->
|
<!-- Single Gen Curve Panel -->
|
||||||
<div class="collapse-arrow bg-base-200 collapse">
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
@ -188,6 +185,19 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
{% if model.multigen_eval %}
|
||||||
|
<div class="collapse-arrow bg-base-200 collapse">
|
||||||
|
<input type="checkbox" checked />
|
||||||
|
<div class="collapse-title text-xl font-medium">
|
||||||
|
Multi Gen Precision Recall Curve
|
||||||
|
</div>
|
||||||
|
<div class="collapse-content">
|
||||||
|
<div class="flex justify-center">
|
||||||
|
<div id="mg-chart"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -244,6 +254,105 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function makeChart(selector, data) {
|
||||||
|
const x = ['Recall'];
|
||||||
|
const y = ['Precision'];
|
||||||
|
const thres = ['threshold'];
|
||||||
|
|
||||||
|
function compare(a, b) {
|
||||||
|
if (a.threshold < b.threshold)
|
||||||
|
return -1;
|
||||||
|
else if (a.threshold > b.threshold)
|
||||||
|
return 1;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getIndexForValue(data, val, val_name) {
|
||||||
|
for (const idx in data) {
|
||||||
|
if (data[idx][val_name] == val) {
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data || data.length === 0) {
|
||||||
|
console.warn('PR curve data is empty');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const dataLength = data.length;
|
||||||
|
data.sort(compare);
|
||||||
|
|
||||||
|
for (const idx in data) {
|
||||||
|
const d = data[idx];
|
||||||
|
x.push(d.recall);
|
||||||
|
y.push(d.precision);
|
||||||
|
thres.push(d.threshold);
|
||||||
|
}
|
||||||
|
const chart = c3.generate({
|
||||||
|
bindto: selector,
|
||||||
|
data: {
|
||||||
|
onclick: function (d, e) {
|
||||||
|
const idx = d.index;
|
||||||
|
const thresh = data[dataLength - idx - 1].threshold;
|
||||||
|
},
|
||||||
|
x: 'Recall',
|
||||||
|
y: 'Precision',
|
||||||
|
columns: [
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
]
|
||||||
|
},
|
||||||
|
size: {
|
||||||
|
height: 400,
|
||||||
|
width: 480
|
||||||
|
},
|
||||||
|
axis: {
|
||||||
|
x: {
|
||||||
|
max: 1,
|
||||||
|
min: 0,
|
||||||
|
label: 'Recall',
|
||||||
|
padding: 0,
|
||||||
|
tick: {
|
||||||
|
fit: true,
|
||||||
|
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
y: {
|
||||||
|
max: 1,
|
||||||
|
min: 0,
|
||||||
|
label: 'Precision',
|
||||||
|
padding: 0,
|
||||||
|
tick: {
|
||||||
|
fit: true,
|
||||||
|
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
point: {
|
||||||
|
r: 4
|
||||||
|
},
|
||||||
|
tooltip: {
|
||||||
|
format: {
|
||||||
|
title: function (recall) {
|
||||||
|
const idx = getIndexForValue(data, recall, "recall");
|
||||||
|
if (idx != -1) {
|
||||||
|
return "Threshold: " + data[idx].threshold;
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
},
|
||||||
|
value: function (precision, ratio, id) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
zoom: {
|
||||||
|
enabled: true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
function makeLoadingGif(selector, gifPath) {
|
function makeLoadingGif(selector, gifPath) {
|
||||||
const element = document.querySelector(selector);
|
const element = document.querySelector(selector);
|
||||||
if (element) {
|
if (element) {
|
||||||
@ -260,107 +369,12 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
{% if model.model_status == 'FINISHED' %}
|
{% if model.model_status == 'FINISHED' %}
|
||||||
// Precision Recall Curve
|
// Precision Recall Curve
|
||||||
const sgChart = document.getElementById('sg-chart');
|
makeChart('#sg-chart', {{ model.pr_curve|safe }});
|
||||||
if (sgChart) {
|
{% if model.multigen_eval %}
|
||||||
const x = ['Recall'];
|
// Multi Gen Precision Recall Curve
|
||||||
const y = ['Precision'];
|
makeChart('#mg-chart', {{ model.mg_pr_curve|safe }});
|
||||||
const thres = ['threshold'];
|
{% endif %}
|
||||||
|
|
||||||
function compare(a, b) {
|
|
||||||
if (a.threshold < b.threshold)
|
|
||||||
return -1;
|
|
||||||
else if (a.threshold > b.threshold)
|
|
||||||
return 1;
|
|
||||||
else
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
function getIndexForValue(data, val, val_name) {
|
|
||||||
for (const idx in data) {
|
|
||||||
if (data[idx][val_name] == val) {
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
var data = {{ model.pr_curve|safe }};
|
|
||||||
if (!data || data.length === 0) {
|
|
||||||
console.warn('PR curve data is empty');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const dataLength = data.length;
|
|
||||||
data.sort(compare);
|
|
||||||
|
|
||||||
for (const idx in data) {
|
|
||||||
const d = data[idx];
|
|
||||||
x.push(d.recall);
|
|
||||||
y.push(d.precision);
|
|
||||||
thres.push(d.threshold);
|
|
||||||
}
|
|
||||||
const chart = c3.generate({
|
|
||||||
bindto: '#sg-chart',
|
|
||||||
data: {
|
|
||||||
onclick: function (d, e) {
|
|
||||||
const idx = d.index;
|
|
||||||
const thresh = data[dataLength - idx - 1].threshold;
|
|
||||||
},
|
|
||||||
x: 'Recall',
|
|
||||||
y: 'Precision',
|
|
||||||
columns: [
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
]
|
|
||||||
},
|
|
||||||
size: {
|
|
||||||
height: 400,
|
|
||||||
width: 480
|
|
||||||
},
|
|
||||||
axis: {
|
|
||||||
x: {
|
|
||||||
max: 1,
|
|
||||||
min: 0,
|
|
||||||
label: 'Recall',
|
|
||||||
padding: 0,
|
|
||||||
tick: {
|
|
||||||
fit: true,
|
|
||||||
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
y: {
|
|
||||||
max: 1,
|
|
||||||
min: 0,
|
|
||||||
label: 'Precision',
|
|
||||||
padding: 0,
|
|
||||||
tick: {
|
|
||||||
fit: true,
|
|
||||||
values: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
point: {
|
|
||||||
r: 4
|
|
||||||
},
|
|
||||||
tooltip: {
|
|
||||||
format: {
|
|
||||||
title: function (recall) {
|
|
||||||
const idx = getIndexForValue(data, recall, "recall");
|
|
||||||
if (idx != -1) {
|
|
||||||
return "Threshold: " + data[idx].threshold;
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
},
|
|
||||||
value: function (precision, ratio, id) {
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
zoom: {
|
|
||||||
enabled: true
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
// Predict button handler
|
// Predict button handler
|
||||||
|
|||||||
@ -393,7 +393,9 @@
|
|||||||
<tbody>
|
<tbody>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Threshold</td>
|
<td>Threshold</td>
|
||||||
<td>{{ pathway.setting.model_threshold }}</td>
|
<td>
|
||||||
|
{{ pathway.setting_with_overrides.model_threshold }}
|
||||||
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
@ -420,11 +422,15 @@
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
<tr>
|
<tr>
|
||||||
<td>Max Nodes</td>
|
<td>Max Nodes</td>
|
||||||
<td>{{ pathway.setting.max_nodes }}</td>
|
<td>{{ pathway.setting_with_overrides.max_nodes }}</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Max Depth</td>
|
<td>Max Depth</td>
|
||||||
<td>{{ pathway.setting.max_depth }}</td>
|
<td>{{ pathway.setting_with_overrides.max_depth }}</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Expansion Scheme</td>
|
||||||
|
<td>{{ user.default_setting.expansion_scheme }}</td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
|
|||||||
@ -150,6 +150,10 @@
|
|||||||
<td>Max Depth</td>
|
<td>Max Depth</td>
|
||||||
<td>{{ user.default_setting.max_depth }}</td>
|
<td>{{ user.default_setting.max_depth }}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Expansion Scheme</td>
|
||||||
|
<td>{{ user.default_setting.expansion_scheme }}</td>
|
||||||
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from django.conf import settings as s
|
from django.conf import settings as s
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
|
|
||||||
@ -23,6 +25,46 @@ class MultiGenTest(TestCase):
|
|||||||
cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)"
|
cls.PW_WITH_INTERMEDIATE_NAME = "1,1,1-Trichloroethane (an/aerobic)"
|
||||||
cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine"
|
cls.PW_WITHOUT_INTERMEDIATE_NAME = "Caffeine"
|
||||||
|
|
||||||
|
def test_batch_predict(self):
|
||||||
|
from epdb.tasks import batch_predict
|
||||||
|
|
||||||
|
pred_data = [
|
||||||
|
["CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Caffeine"],
|
||||||
|
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", "Ibuprofen"],
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_predict_setting = self.user.prediction_settings()
|
||||||
|
|
||||||
|
target_package = PackageManager.create_package(
|
||||||
|
self.user,
|
||||||
|
f"Autogenerated Package for Batch Prediction {datetime.now()}",
|
||||||
|
"This Package was generated automatically for the batch prediction task.",
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tps = 50
|
||||||
|
res = batch_predict(
|
||||||
|
pred_data,
|
||||||
|
batch_predict_setting.pk,
|
||||||
|
target_package.pk,
|
||||||
|
num_tps=num_tps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(res.startswith("Pathway URL,"))
|
||||||
|
# Min 3 lines (1 header, 2 root nodes)
|
||||||
|
self.assertGreaterEqual(len(res.split("\n")), 3)
|
||||||
|
self.assertEqual(target_package.pathways.count(), 2)
|
||||||
|
|
||||||
|
pw = target_package.pathways.first()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
pw.setting_with_overrides.max_depth,
|
||||||
|
f"{num_tps} (this is an override for this particular pathway)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
pw.setting_with_overrides.max_nodes,
|
||||||
|
f"{num_tps} (this is an override for this particular pathway)",
|
||||||
|
)
|
||||||
|
|
||||||
def test_engineer_pathway(self):
|
def test_engineer_pathway(self):
|
||||||
from epdb.tasks import engineer_pathways
|
from epdb.tasks import engineer_pathways
|
||||||
|
|
||||||
|
|||||||
@ -1,20 +1,20 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from epdb.logic import SNode, SEdge
|
from epdb.logic import SEdge, SNode, SPathway
|
||||||
|
from epdb.models import Pathway, Setting
|
||||||
|
from utilities.chem import PredictionResult, ProductSet
|
||||||
|
|
||||||
|
|
||||||
class SObjectTest(TestCase):
|
class SNodeTest(TestCase):
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_snode_eq(self):
|
def test_snode_eq(self):
|
||||||
snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
|
snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
|
||||||
snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
|
snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
|
||||||
assert snode1 == snode2
|
self.assertEqual(snode1, snode2)
|
||||||
|
|
||||||
def test_snode_hash(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
class SEdgeTest(TestCase):
|
||||||
def test_sedge_eq(self):
|
def test_sedge_eq(self):
|
||||||
sedge1 = SEdge(
|
sedge1 = SEdge(
|
||||||
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
|
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
|
||||||
@ -26,4 +26,62 @@ class SObjectTest(TestCase):
|
|||||||
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
|
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
|
||||||
rule=None,
|
rule=None,
|
||||||
)
|
)
|
||||||
assert sedge1 == sedge2
|
self.assertEqual(sedge1, sedge2)
|
||||||
|
|
||||||
|
|
||||||
|
class SPathwayTest(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test data for SPathway tests."""
|
||||||
|
self.test_smiles = "CCN(CC)C(=O)C1=CC(=CC=C1)CO"
|
||||||
|
self.mock_setting = Mock(spec=Setting)
|
||||||
|
self.mock_pathway = Mock(spec=Pathway)
|
||||||
|
|
||||||
|
def test_predict_step_basic(self):
|
||||||
|
"""Test basic predict_step functionality."""
|
||||||
|
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
|
||||||
|
|
||||||
|
# e.g. bt0002
|
||||||
|
pr = PredictionResult(
|
||||||
|
[
|
||||||
|
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
|
||||||
|
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
|
||||||
|
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
|
||||||
|
],
|
||||||
|
0.17,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.mock_setting, "expand", return_value=[pr]):
|
||||||
|
spw.predict_step(from_depth=0)
|
||||||
|
|
||||||
|
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
|
||||||
|
self.assertEqual(len(spw.edges), 3)
|
||||||
|
|
||||||
|
def test_to_json(self):
|
||||||
|
"""Test basic predict_step functionality."""
|
||||||
|
spw = SPathway(root_nodes=self.test_smiles, prediction_setting=self.mock_setting)
|
||||||
|
|
||||||
|
# e.g. bt0002
|
||||||
|
pr = PredictionResult(
|
||||||
|
[
|
||||||
|
ProductSet(["CC1=CC=C(C2OC(CO)C(=O)C(O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
|
||||||
|
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(O)C2=O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
|
||||||
|
ProductSet(["CC1=CC=C(C2OC(CO)C(O)C(=O)C2O)C=C1CC1=CC=C(C2=CC=C(F)C=C2)S1"]),
|
||||||
|
],
|
||||||
|
0.17,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.mock_setting, "expand", return_value=[pr]):
|
||||||
|
spw.predict_step(from_depth=0)
|
||||||
|
|
||||||
|
self.assertEqual(len(spw.smiles_to_node.keys()), 4)
|
||||||
|
self.assertEqual(len(spw.edges), 3)
|
||||||
|
|
||||||
|
json_result = spw.to_json()
|
||||||
|
|
||||||
|
self.assertIsInstance(json_result, dict)
|
||||||
|
self.assertIn("nodes", json_result)
|
||||||
|
self.assertIn("edges", json_result)
|
||||||
|
self.assertEqual(len(json_result["nodes"]), 4)
|
||||||
|
self.assertEqual(len(json_result["edges"]), 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user