forked from enviPath/enviPy
[Feature] Show Multi Gen Eval + Batch Prediction (#267)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#267
This commit is contained in:
@ -1744,6 +1744,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
# potentially prefetched edge_set
|
||||
return self.edge_set.all()
|
||||
|
||||
@property
|
||||
def setting_with_overrides(self):
|
||||
mem_copy = Setting.objects.get(pk=self.setting.pk)
|
||||
|
||||
if "setting_overrides" in self.kv:
|
||||
for k, v in self.kv["setting_overrides"].items():
|
||||
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
|
||||
|
||||
return mem_copy
|
||||
|
||||
def _url(self):
|
||||
return "{}/pathway/{}".format(self.package.url, self.uuid)
|
||||
|
||||
@ -1879,25 +1889,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
return json.dumps(res)
|
||||
|
||||
def to_csv(self) -> str:
|
||||
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
|
||||
import csv
|
||||
import io
|
||||
|
||||
header = []
|
||||
|
||||
if include_pathway_url:
|
||||
header += ["Pathway URL"]
|
||||
|
||||
header += [
|
||||
"SMILES",
|
||||
"name",
|
||||
"depth",
|
||||
"probability",
|
||||
"rule_names",
|
||||
"rule_ids",
|
||||
"parent_smiles",
|
||||
]
|
||||
|
||||
rows = []
|
||||
rows.append(
|
||||
[
|
||||
"SMILES",
|
||||
"name",
|
||||
"depth",
|
||||
"probability",
|
||||
"rule_names",
|
||||
"rule_ids",
|
||||
"parent_smiles",
|
||||
]
|
||||
)
|
||||
|
||||
if include_header:
|
||||
rows.append(header)
|
||||
|
||||
for n in self.nodes.order_by("depth"):
|
||||
cs = n.default_node_label
|
||||
row = [cs.smiles, cs.name, n.depth]
|
||||
row = []
|
||||
|
||||
if include_pathway_url:
|
||||
row.append(n.pathway.url)
|
||||
|
||||
row += [cs.smiles, cs.name, n.depth]
|
||||
|
||||
edges = self.edges.filter(end_nodes__in=[n])
|
||||
if len(edges):
|
||||
@ -2362,6 +2385,29 @@ class PackageBasedModel(EPModel):
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def mg_pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
||||
|
||||
if not self.multigen_eval:
|
||||
raise ValueError("MG PR Curve is only available for multigen models")
|
||||
|
||||
res = []
|
||||
|
||||
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
|
||||
|
||||
for t in thresholds:
|
||||
res.append(
|
||||
{
|
||||
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
|
||||
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
|
||||
"threshold": float(t),
|
||||
}
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self) -> List["Rule"]:
|
||||
"""
|
||||
@ -2565,7 +2611,7 @@ class PackageBasedModel(EPModel):
|
||||
for i, root in enumerate(root_compounds):
|
||||
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
||||
|
||||
spw = SPathway(root_nodes=root, prediction_setting=s)
|
||||
spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
|
||||
level = 0
|
||||
|
||||
while not spw.done:
|
||||
@ -3771,6 +3817,12 @@ class UserSettingPermission(Permission):
|
||||
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
|
||||
|
||||
|
||||
class ExpansionSchemeChoice(models.TextChoices):
|
||||
BFS = "BFS", "Breadth First Search"
|
||||
DFS = "DFS", "Depth First Search"
|
||||
GREEDY = "GREEDY", "Greedy"
|
||||
|
||||
|
||||
class Setting(EnviPathModel):
|
||||
public = models.BooleanField(null=False, blank=False, default=False)
|
||||
global_default = models.BooleanField(null=False, blank=False, default=False)
|
||||
@ -3795,6 +3847,12 @@ class Setting(EnviPathModel):
|
||||
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
|
||||
)
|
||||
|
||||
expansion_scheme = models.CharField(
|
||||
max_length=20,
|
||||
choices=ExpansionSchemeChoice.choices,
|
||||
default=ExpansionSchemeChoice.BFS,
|
||||
)
|
||||
|
||||
def _url(self):
|
||||
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
|
||||
|
||||
@ -3833,7 +3891,7 @@ class Setting(EnviPathModel):
|
||||
"""Decision Method whether to expand on a certain Node or not"""
|
||||
if pathway.num_nodes() >= self.max_nodes:
|
||||
logger.info(
|
||||
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}"
|
||||
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
|
||||
)
|
||||
return []
|
||||
|
||||
@ -3931,3 +3989,8 @@ class JobLog(TimeStampedModel):
|
||||
if self.job_name == "engineer_pathways":
|
||||
return ast.literal_eval(self.task_result)
|
||||
return self.task_result
|
||||
|
||||
def is_result_downloadable(self):
|
||||
downloadable = ["batch_predict"]
|
||||
|
||||
return self.job_name in downloadable
|
||||
|
||||
Reference in New Issue
Block a user