forked from enviPath/enviPy
[Feature] Threshold Warning + Cosmetics (#277)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#277
This commit is contained in:
@ -1489,6 +1489,7 @@ class SPathway(object):
|
||||
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
|
||||
self.edges: Set["SEdge"] = set()
|
||||
self.done = False
|
||||
self.empty_due_to_threshold = False
|
||||
|
||||
@staticmethod
|
||||
def from_pathway(pw: "Pathway", persist: bool = True):
|
||||
@ -1601,9 +1602,24 @@ class SPathway(object):
|
||||
|
||||
sub.app_domain_assessment = app_domain_assessment
|
||||
|
||||
candidates = self.prediction_setting.expand(self, sub)
|
||||
expansion_result = self.prediction_setting.expand(self, sub)
|
||||
|
||||
# We don't have any substrate, but technically we have at least one rule that triggered.
|
||||
# If our substrate is a root node a.k.a. depth == 0 store that info in SPathway
|
||||
if (
|
||||
len(expansion_result["transformations"]) == 0
|
||||
and expansion_result["rule_triggered"]
|
||||
and sub.depth == 0
|
||||
):
|
||||
self.empty_due_to_threshold = True
|
||||
|
||||
# Emit directly
|
||||
if self.persist is not None:
|
||||
self.persist.kv["empty_due_to_threshold"] = True
|
||||
self.persist.save()
|
||||
|
||||
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
|
||||
for cand_set in candidates:
|
||||
for cand_set in expansion_result["transformations"]:
|
||||
if cand_set:
|
||||
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
|
||||
for cand in cand_set:
|
||||
@ -1727,10 +1743,6 @@ class SPathway(object):
|
||||
for queued_val in queue:
|
||||
node_and_probs.append((queued_val, node_probs[queued_val]))
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
pprint(node_and_probs)
|
||||
|
||||
# re-order the queue and only pick smiles
|
||||
queue = [
|
||||
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
|
||||
|
||||
@ -23,7 +23,7 @@ from django.db import models, transaction
|
||||
from django.db.models import Count, JSONField, Q, QuerySet
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
from envipy_additional_information import EnviPyModel
|
||||
from envipy_additional_information import EnviPyModel, HalfLife
|
||||
from model_utils.models import TimeStampedModel
|
||||
from polymorphic.models import PolymorphicModel
|
||||
from sklearn.metrics import jaccard_score, precision_score, recall_score
|
||||
@ -795,9 +795,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
|
||||
@property
|
||||
def related_pathways(self):
|
||||
pathways = Node.objects.filter(node_labels__in=[self.default_structure]).values_list(
|
||||
"pathway", flat=True
|
||||
)
|
||||
pathways = self.related_nodes.values_list("pathway", flat=True)
|
||||
return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by("name")
|
||||
|
||||
@property
|
||||
@ -807,6 +805,12 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
| Reaction.objects.filter(package=self.package, products__in=[self.default_structure])
|
||||
).order_by("name")
|
||||
|
||||
@property
|
||||
def related_nodes(self):
|
||||
return Node.objects.filter(
|
||||
node_labels__in=[self.default_structure], pathway__package=self.package
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(
|
||||
@ -1042,6 +1046,17 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
|
||||
return new_compound
|
||||
|
||||
def half_lifes(self):
|
||||
hls: Dict[Scenario, List[HalfLife]] = defaultdict(list)
|
||||
|
||||
for n in self.related_nodes:
|
||||
for scen in n.scenarios.all().order_by("name"):
|
||||
for ai in scen.get_additional_information():
|
||||
if isinstance(ai, HalfLife):
|
||||
hls[scen].append(ai)
|
||||
|
||||
return dict(hls)
|
||||
|
||||
class Meta:
|
||||
unique_together = [("uuid", "package")]
|
||||
|
||||
@ -1780,6 +1795,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
def failed(self):
|
||||
return self.status() == "failed"
|
||||
|
||||
def empty_due_to_threshold(self):
|
||||
return self.kv.get("empty_due_to_threshold", False)
|
||||
|
||||
def d3_json(self):
|
||||
# Ideally it would be something like this but
|
||||
# to reduce crossing in edges do a DFS
|
||||
@ -1887,7 +1905,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
"status": self.status(),
|
||||
}
|
||||
|
||||
return json.dumps(res)
|
||||
return res
|
||||
|
||||
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
|
||||
import csv
|
||||
@ -3887,33 +3905,48 @@ class Setting(EnviPathModel):
|
||||
rules = sorted(rules, key=lambda x: x.url)
|
||||
return rules
|
||||
|
||||
def expand(self, pathway, current_node):
|
||||
def expand(self, pathway, current_node) -> Dict[str, Any]:
|
||||
res: Dict[str, Any] = defaultdict(list)
|
||||
|
||||
"""Decision Method whether to expand on a certain Node or not"""
|
||||
if pathway.num_nodes() >= self.max_nodes:
|
||||
logger.info(
|
||||
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
|
||||
)
|
||||
return []
|
||||
res["expansion_skipped"] = True
|
||||
return res
|
||||
|
||||
if pathway.depth() >= self.max_depth:
|
||||
logger.info(
|
||||
f"Pathway has reached depth {pathway.depth()} which exceeds the limit of {self.max_depth}"
|
||||
)
|
||||
return []
|
||||
res["expansion_skipped"] = True
|
||||
return res
|
||||
|
||||
transformations = []
|
||||
if self.model is not None:
|
||||
pred_results = self.model.predict(current_node.smiles)
|
||||
|
||||
# Store whether there are results that may be removed as they are below
|
||||
# the given threshold
|
||||
if len(pred_results):
|
||||
res["rule_triggered"] = True
|
||||
|
||||
for pred_result in pred_results:
|
||||
if pred_result.probability >= self.model_threshold:
|
||||
transformations.append(pred_result)
|
||||
if (
|
||||
len(pred_result.product_sets)
|
||||
and pred_result.probability >= self.model_threshold
|
||||
):
|
||||
res["transformations"].append(pred_result)
|
||||
else:
|
||||
for rule in self.applicable_rules:
|
||||
tmp_products = rule.apply(current_node.smiles)
|
||||
if tmp_products:
|
||||
transformations.append(PredictionResult(tmp_products, 1.0, rule))
|
||||
res["transformations"].append(PredictionResult(tmp_products, 1.0, rule))
|
||||
|
||||
return transformations
|
||||
if len(res["transformations"]):
|
||||
res["rule_triggered"] = True
|
||||
|
||||
return res
|
||||
|
||||
@transaction.atomic
|
||||
def make_global_default(self):
|
||||
|
||||
@ -1937,6 +1937,7 @@ def package_pathway(request, package_uuid, pathway_uuid):
|
||||
{
|
||||
"status": current_pathway.status(),
|
||||
"modified": current_pathway.modified.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"emptyDueToThreshold": current_pathway.empty_due_to_threshold(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user