[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# potentially prefetched edge_set
return self.edge_set.all()
@property
def setting_with_overrides(self):
mem_copy = Setting.objects.get(pk=self.setting.pk)
if "setting_overrides" in self.kv:
for k, v in self.kv["setting_overrides"].items():
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
return mem_copy
def _url(self):
return "{}/pathway/{}".format(self.package.url, self.uuid)
@ -1879,25 +1889,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
return json.dumps(res)
def to_csv(self) -> str:
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
import csv
import io
header = []
if include_pathway_url:
header += ["Pathway URL"]
header += [
"SMILES",
"name",
"depth",
"probability",
"rule_names",
"rule_ids",
"parent_smiles",
]
rows = []
rows.append(
[
"SMILES",
"name",
"depth",
"probability",
"rule_names",
"rule_ids",
"parent_smiles",
]
)
if include_header:
rows.append(header)
for n in self.nodes.order_by("depth"):
cs = n.default_node_label
row = [cs.smiles, cs.name, n.depth]
row = []
if include_pathway_url:
row.append(n.pathway.url)
row += [cs.smiles, cs.name, n.depth]
edges = self.edges.filter(end_nodes__in=[n])
if len(edges):
@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel):
return res
@property
def mg_pr_curve(self):
if self.model_status != self.FINISHED:
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
if not self.multigen_eval:
raise ValueError("MG PR Curve is only available for multigen models")
res = []
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
for t in thresholds:
res.append(
{
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
"threshold": float(t),
}
)
return res
@cached_property
def applicable_rules(self) -> List["Rule"]:
"""
@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel):
for i, root in enumerate(root_compounds):
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
spw = SPathway(root_nodes=root, prediction_setting=s)
spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
level = 0
while not spw.done:
@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission):
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
class ExpansionSchemeChoice(models.TextChoices):
BFS = "BFS", "Breadth First Search"
DFS = "DFS", "Depth First Search"
GREEDY = "GREEDY", "Greedy"
class Setting(EnviPathModel):
public = models.BooleanField(null=False, blank=False, default=False)
global_default = models.BooleanField(null=False, blank=False, default=False)
@ -3795,6 +3847,12 @@ class Setting(EnviPathModel):
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
)
expansion_scheme = models.CharField(
max_length=20,
choices=ExpansionSchemeChoice.choices,
default=ExpansionSchemeChoice.BFS,
)
def _url(self):
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
@ -3833,7 +3891,7 @@ class Setting(EnviPathModel):
"""Decision Method whether to expand on a certain Node or not"""
if pathway.num_nodes() >= self.max_nodes:
logger.info(
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}"
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
)
return []
@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel):
if self.job_name == "engineer_pathways":
return ast.literal_eval(self.task_result)
return self.task_result
def is_result_downloadable(self):
downloadable = ["batch_predict"]
return self.job_name in downloadable