Files
enviPy-bayer/epdb/logic.py
2025-12-20 02:11:47 +13:00

1866 lines
66 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Union, Tuple
from uuid import UUID
import nh3
from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.db import transaction
from pydantic import ValidationError
from epdb.models import (
Compound,
CompoundStructure,
Edge,
EnzymeLink,
EPModel,
ExpansionSchemeChoice,
Group,
GroupPackagePermission,
Node,
Pathway,
Permission,
Reaction,
Rule,
Setting,
User,
UserPackagePermission,
UserSettingPermission,
)
from utilities.chem import FormatConverter
from utilities.misc import PackageExporter, PackageImporter
logger = logging.getLogger(__name__)
Package = s.GET_PACKAGE_MODEL()
class EPDBURLParser:
UUID_PATTERN = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
MODEL_PATTERNS = {
"epdb.User": re.compile(rf"^.*/user/{UUID_PATTERN}"),
"epdb.Group": re.compile(rf"^.*/group/{UUID_PATTERN}"),
"epdb.Package": re.compile(rf"^.*/package/{UUID_PATTERN}"),
"epdb.Compound": re.compile(rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}"),
"epdb.CompoundStructure": re.compile(
rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}/structure/{UUID_PATTERN}"
),
"epdb.Rule": re.compile(
rf"^.*/package/{UUID_PATTERN}/(?:simple-ambit-rule|simple-rdkit-rule|parallel-rule|sequential-rule|rule)/{UUID_PATTERN}"
),
"epdb.Reaction": re.compile(rf"^.*/package/{UUID_PATTERN}/reaction/{UUID_PATTERN}$"),
"epdb.Pathway": re.compile(rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}"),
"epdb.Node": re.compile(
rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/node/{UUID_PATTERN}"
),
"epdb.Edge": re.compile(
rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/edge/{UUID_PATTERN}"
),
"epdb.Scenario": re.compile(rf"^.*/package/{UUID_PATTERN}/scenario/{UUID_PATTERN}"),
"epdb.EPModel": re.compile(rf"^.*/package/{UUID_PATTERN}/model/{UUID_PATTERN}"),
"epdb.Setting": re.compile(rf"^.*/setting/{UUID_PATTERN}"),
}
def __init__(self, url: str):
self.url = url
self._matches = {}
self._analyze_url()
def _analyze_url(self):
for model_path, pattern in self.MODEL_PATTERNS.items():
match = pattern.findall(self.url)
if match:
self._matches[model_path] = match[0]
def _get_model_class(self, model_path: str):
try:
from django.apps import apps
app_label, model_name = model_path.split(".")[-2:]
return apps.get_model(app_label, model_name)
except (ImportError, LookupError, ValueError):
raise ValueError(f"Model {model_path} does not exist!")
def _get_object_by_url(self, model_path: str, url: str):
model_class = self._get_model_class(model_path)
return model_class.objects.get(url=url)
def is_package_url(self) -> bool:
return bool(re.compile(rf"^.*/package/{self.UUID_PATTERN}$").findall(self.url))
def contains_package_url(self):
return (
bool(self.MODEL_PATTERNS["epdb.Package"].findall(self.url))
and not self.is_package_url()
)
def is_user_url(self) -> bool:
return bool(self.MODEL_PATTERNS["epdb.User"].findall(self.url))
def is_group_url(self) -> bool:
return bool(self.MODEL_PATTERNS["epdb.Group"].findall(self.url))
def is_setting_url(self) -> bool:
return bool(self.MODEL_PATTERNS["epdb.Setting"].findall(self.url))
def get_object(self) -> Optional[Any]:
# Define priority order from most specific to least specific
priority_order = [
# 3rd level
"epdb.CompoundStructure",
"epdb.Node",
"epdb.Edge",
# 2nd level
"epdb.Compound",
"epdb.Rule",
"epdb.Reaction",
"epdb.Scenario",
"epdb.EPModel",
"epdb.Pathway",
# 1st level
"epdb.Package",
"epdb.Setting",
"epdb.Group",
"epdb.User",
]
for model_path in priority_order:
if model_path in self._matches:
url = self._matches[model_path]
return self._get_object_by_url(model_path, url)
raise ValueError(f"No object found for URL {self.url}")
def get_objects(self) -> List[Any]:
"""
Get all Django model objects along the URL path in hierarchical order.
Returns objects from parent to child (e.g., Package -> Compound -> Structure).
"""
objects = []
hierarchy_order = [
# 1st level
"epdb.Package",
"epdb.Setting",
"epdb.Group",
"epdb.User",
# 2nd level
"epdb.Compound",
"epdb.Rule",
"epdb.Reaction",
"epdb.Scenario",
"epdb.EPModel",
"epdb.Pathway",
# 3rd level
"epdb.CompoundStructure",
"epdb.Node",
"epdb.Edge",
]
for model_path in hierarchy_order:
if model_path in self._matches:
url = self._matches[model_path]
objects.append(self._get_object_by_url(model_path, url))
return objects
def __str__(self) -> str:
return f"EPDBURLParser(url='{self.url}')"
def __repr__(self) -> str:
return f"EPDBURLParser(url='{self.url}', matches={list(self._matches.keys())})"
class UserManager(object):
user_pattern = re.compile(
r".*/user/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
)
@staticmethod
def is_user_url(url: str):
return bool(re.findall(UserManager.user_pattern, url))
@staticmethod
@transaction.atomic
def create_user(
username, email, password, set_setting=True, add_to_group=True, *args, **kwargs
):
# Clean for potential XSS
clean_username = nh3.clean(username).strip()
clean_email = nh3.clean(email).strip()
if clean_username != username or clean_email != email:
# This will be caught by the try in view.py/register
raise ValueError("Invalid username or password")
# avoid circular import :S
from .tasks import send_registration_mail
extra_fields = {"is_active": not s.ADMIN_APPROVAL_REQUIRED}
if "is_active" in kwargs:
extra_fields["is_active"] = kwargs["is_active"]
if "uuid" in kwargs:
extra_fields["uuid"] = kwargs["uuid"]
u = get_user_model().objects.create_user(username, email, password, **extra_fields)
# Create package
package_name = f"{u.username}{'' if u.username[-1] in 'sxzß' else 's'} Package"
package_description = "This package was generated during registration."
p = PackageManager.create_package(u, package_name, package_description)
u.default_package = p
u.save()
if not u.is_active:
# send email for verification
send_registration_mail.delay(u.pk)
if set_setting:
u.default_setting = Setting.objects.get(global_default=True)
u.save()
if add_to_group:
g = Group.objects.get(public=True, name="enviPath Users")
g.user_member.add(u)
g.save()
u.default_group = g
u.save()
return u
@staticmethod
def get_user(user_url):
pass
@staticmethod
def get_user_by_id(user, user_uuid: str):
if str(user.uuid) != user_uuid and not user.is_superuser:
raise ValueError("Getting user failed!")
return get_user_model().objects.get(uuid=user_uuid)
@staticmethod
def get_user_lp(user_url: str):
uuid = user_url.strip().split("/")[-1]
return get_user_model().objects.get(uuid=uuid)
@staticmethod
def get_users_lp():
return get_user_model().objects.all()
@staticmethod
def get_users():
raise ValueError("")
@staticmethod
def writable(current_user, user):
return (current_user == user) or user.is_superuser
class GroupManager(object):
group_pattern = re.compile(
r".*/group/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
)
@staticmethod
def is_group_url(url: str):
return bool(re.findall(GroupManager.group_pattern, url))
@staticmethod
def create_group(current_user, name, description):
g = Group()
# Clean for potential XSS
g.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
g.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
g.owner = current_user
g.save()
g.user_member.add(current_user)
g.save()
return g
@staticmethod
def get_group_lp(group_url: str):
uuid = group_url.strip().split("/")[-1]
return Group.objects.get(uuid=uuid)
@staticmethod
def get_groups_lp():
return Group.objects.all()
@staticmethod
def get_group_by_url(user, group_url):
return GroupManager.get_group_by_id(user, group_url.split("/")[-1])
@staticmethod
def get_group_by_id(user, group_id):
g = Group.objects.get(uuid=group_id)
if user in g.user_member.all():
return g
return None
@staticmethod
def get_groups(user):
return Group.objects.filter(user_member=user)
@staticmethod
@transaction.atomic
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
if caller != group.owner and not caller.is_superuser:
raise ValueError("Only the group Owner is allowed to add members!")
if isinstance(member, Group):
if add_or_remove == "add":
group.group_member.add(member)
else:
group.group_member.remove(member)
else:
if add_or_remove == "add":
group.user_member.add(member)
else:
group.user_member.remove(member)
group.save()
@staticmethod
def writable(user, group):
return (user == group.owner) or user.is_superuser
class PackageManager(object):
package_pattern = re.compile(
r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
)
@staticmethod
def is_package_url(url: str):
return bool(re.findall(PackageManager.package_pattern, url))
@staticmethod
def get_reviewed_packages():
return Package.objects.filter(reviewed=True)
@staticmethod
def readable(user, package):
if (
UserPackagePermission.objects.filter(package=package, user=user).exists()
or GroupPackagePermission.objects.filter(
package=package, group__in=GroupManager.get_groups(user)
)
or package.reviewed is True
or user.is_superuser
):
return True
return False
@staticmethod
def writable(user, package):
if (
UserPackagePermission.objects.filter(
package=package, user=user, permission=Permission.WRITE[0]
).exists()
or GroupPackagePermission.objects.filter(
package=package,
group__in=GroupManager.get_groups(user),
permission=Permission.WRITE[0],
).exists()
or UserPackagePermission.objects.filter(
package=package, user=user, permission=Permission.ALL[0]
).exists()
or user.is_superuser
):
return True
return False
@staticmethod
def administrable(user, package):
if (
UserPackagePermission.objects.filter(
package=package, user=user, permission=Permission.ALL[0]
).exists()
or GroupPackagePermission.objects.filter(
package=package,
group__in=GroupManager.get_groups(user),
permission=Permission.ALL[0],
).exists()
or user.is_superuser
):
return True
return False
@staticmethod
def has_package_permission(user: "User", package: Union[str, UUID, "Package"], permission: str):
if isinstance(package, str) or isinstance(package, UUID):
package = Package.objects.get(uuid=package)
groups = GroupManager.get_groups(user)
perms = {"all": ["all"], "write": ["all", "write"], "read": ["all", "write", "read"]}
valid_perms = perms.get(permission)
if (
UserPackagePermission.objects.filter(
package=package, user=user, permission__in=valid_perms
).exists()
or GroupPackagePermission.objects.filter(
package=package, group__in=groups, permission__in=valid_perms
).exists()
or user.is_superuser
):
return True
return False
@staticmethod
def get_package_lp(package_url):
match = re.findall(PackageManager.package_pattern, package_url)
if match:
package_id = match[0].split("/")[-1]
return Package.objects.get(uuid=package_id)
return None
@staticmethod
def get_package_by_url(user, package_url):
match = re.findall(PackageManager.package_pattern, package_url)
if match:
package_id = match[0].split("/")[-1]
return PackageManager.get_package_by_id(user, package_id)
else:
raise ValueError(
"Requested URL {} does not contain a valid package identifier!".format(package_url)
)
@staticmethod
def get_package_by_id(user, package_id):
try:
p = Package.objects.get(uuid=package_id)
if PackageManager.readable(user, p):
return p
else:
# FIXME: use custom exception to be translatable to 403 in API
raise ValueError(
"Insufficient permissions to access Package with ID {}".format(package_id)
)
except Package.DoesNotExist:
raise ValueError("Package with ID {} does not exist!".format(package_id))
@staticmethod
def get_all_readable_packages(user, include_reviewed=False):
# UserPermission only exists if at least read is granted...
if user.is_superuser:
qs = Package.objects.all()
else:
user_package_qs = Package.objects.filter(
id__in=UserPackagePermission.objects.filter(user=user).values("package").distinct()
)
group_package_qs = Package.objects.filter(
id__in=GroupPackagePermission.objects.filter(
group__in=GroupManager.get_groups(user)
)
.values("package")
.distinct()
)
qs = user_package_qs | group_package_qs
if include_reviewed:
qs |= Package.objects.filter(reviewed=True)
else:
# remove package if user is owner and package is reviewed e.g. admin
qs = qs.filter(reviewed=False)
return qs.distinct()
@staticmethod
def get_all_writeable_packages(user):
# UserPermission only exists if at least read is granted...
if user.is_superuser:
qs = Package.objects.all()
else:
write_user_packs = (
UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0])
.values("package")
.distinct()
)
owner_user_packs = (
UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0])
.values("package")
.distinct()
)
user_packs = write_user_packs | owner_user_packs
user_package_qs = Package.objects.filter(id__in=user_packs)
write_group_packs = (
GroupPackagePermission.objects.filter(
group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]
)
.values("package")
.distinct()
)
owner_group_packs = (
GroupPackagePermission.objects.filter(
group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]
)
.values("package")
.distinct()
)
group_packs = write_group_packs | owner_group_packs
group_package_qs = Package.objects.filter(id__in=group_packs)
qs = user_package_qs | group_package_qs
qs = qs.filter(reviewed=False)
return qs.distinct()
@staticmethod
def get_packages():
return Package.objects.all()
@staticmethod
@transaction.atomic
def create_package(current_user, name: str, description: str = None):
p = Package()
# Clean for potential XSS
p.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
if description is not None and description.strip() != "":
p.description = nh3.clean(description.strip(), tags=s.ALLOWED_HTML_TAGS).strip()
p.save()
up = UserPackagePermission()
up.user = current_user
up.package = p
up.permission = UserPackagePermission.ALL[0]
up.save()
return p
@staticmethod
@transaction.atomic
def update_permissions(
caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]
):
caller_perm = None
if not caller.is_superuser:
caller_perm = UserPackagePermission.objects.get(user=caller, package=package).permission
if caller_perm != Permission.ALL[0] and not caller.is_superuser:
raise ValueError("Only owner are allowed to modify permissions")
data = {
"package": package,
}
if isinstance(grantee, User):
perm_cls = UserPackagePermission
data["user"] = grantee
else:
perm_cls = GroupPackagePermission
data["group"] = grantee
if new_perm is None:
qs = perm_cls.objects.filter(**data)
if qs.count() > 1:
raise ValueError("Got more Permission objects than expected!")
if qs.count() != 0:
logger.info(f"Deleting Perm {qs.first()}")
qs.delete()
else:
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
else:
_ = perm_cls.objects.update_or_create(defaults={"permission": new_perm}, **data)
@staticmethod
def grant_read(caller: User, package: Package, grantee: Union[User, Group]):
PackageManager.update_permissions(caller, package, grantee, Permission.READ[0])
@staticmethod
def grant_write(caller: User, package: Package, grantee: Union[User, Group]):
PackageManager.update_permissions(caller, package, grantee, Permission.WRITE[0])
@staticmethod
@transaction.atomic
def import_legacy_package(
data: dict, owner: User, keep_ids=False, add_import_timestamp=True, trust_reviewed=False
):
from collections import defaultdict
from datetime import datetime
from uuid import UUID, uuid4
from envipy_additional_information import AdditionalInformationConverter
from .models import (
Compound,
CompoundStructure,
Edge,
Node,
ParallelRule,
Pathway,
Reaction,
Scenario,
SequentialRule,
SequentialRuleOrdering,
SimpleAmbitRule,
SimpleRule,
)
pack = Package()
pack.uuid = UUID(data["id"].split("/")[-1]) if keep_ids else uuid4()
if add_import_timestamp:
pack.name = "{} - {}".format(data["name"], datetime.now().strftime("%Y-%m-%d %H:%M"))
else:
pack.name = data["name"]
if trust_reviewed:
pack.reviewed = True if data["reviewStatus"] == "reviewed" else False
else:
pack.reviewed = False
pack.description = data["description"]
pack.save()
up = UserPackagePermission()
up.user = owner
up.package = pack
up.permission = up.ALL[0]
up.save()
# Stores old_id to new_id
mapping = {}
# Stores new_scen_id to old_parent_scen_id
parent_mapping = {}
# Mapping old scen_id to old_obj_id
scen_mapping = defaultdict(list)
# Enzymelink Mapping rule_id to enzymelink objects
enzyme_mapping = defaultdict(list)
# Store Scenarios
for scenario in data["scenarios"]:
scen = Scenario()
scen.package = pack
scen.uuid = UUID(scenario["id"].split("/")[-1]) if keep_ids else uuid4()
scen.name = scenario["name"]
scen.description = scenario["description"]
scen.scenario_type = scenario["type"]
scen.scenario_date = scenario["date"]
scen.additional_information = dict()
scen.save()
mapping[scenario["id"]] = scen.uuid
new_add_inf = defaultdict(list)
# TODO Store AI...
for ex in scenario.get("additionalInformationCollection", {}).get(
"additionalInformation", []
):
name = ex["name"]
addinf_data = ex["data"]
# park the parent scen id for now and link it later
if name == "referringscenario":
parent_mapping[scen.uuid] = addinf_data
continue
# Broken eP Data
if name == "initialmasssediment" and addinf_data == "missing data":
continue
if name == "columnheight" and addinf_data == "(2)-(2.5);(6)-(8)":
continue
try:
res = AdditionalInformationConverter.convert(name, addinf_data)
res_cls_name = res.__class__.__name__
ai_data = json.loads(res.model_dump_json())
ai_data["uuid"] = f"{uuid4()}"
new_add_inf[res_cls_name].append(ai_data)
except ValidationError:
logger.error(f"Failed to convert {name} with {addinf_data}")
scen.additional_information = new_add_inf
scen.save()
print("Scenarios imported...")
# Store compounds and its structures
for compound in data["compounds"]:
comp = Compound()
comp.package = pack
comp.uuid = UUID(compound["id"].split("/")[-1]) if keep_ids else uuid4()
comp.name = compound["name"]
comp.description = compound["description"]
comp.aliases = compound["aliases"]
comp.save()
mapping[compound["id"]] = comp.uuid
for scen in compound["scenarios"]:
scen_mapping[scen["id"]].append(comp)
default_structure = None
for structure in compound["structures"]:
struc = CompoundStructure()
# struc.object_url = Command.get_id(structure, keep_ids)
struc.compound = comp
struc.uuid = UUID(structure["id"].split("/")[-1]) if keep_ids else uuid4()
struc.name = structure["name"]
struc.description = structure["description"]
struc.aliases = structure.get("aliases", [])
struc.smiles = structure["smiles"]
struc.save()
for scen in structure["scenarios"]:
scen_mapping[scen["id"]].append(struc)
mapping[structure["id"]] = struc.uuid
if structure["id"] == compound["defaultStructure"]["id"]:
default_structure = struc
struc.save()
if default_structure is None:
raise ValueError("No default structure set")
comp.default_structure = default_structure
comp.save()
print("Compounds imported...")
# Store simple and parallel-rules
par_rules = []
seq_rules = []
for rule in data["rules"]:
if rule["identifier"] == "parallel-rule":
par_rules.append(rule)
continue
if rule["identifier"] == "sequential-rule":
seq_rules.append(rule)
continue
r = SimpleAmbitRule()
r.uuid = UUID(rule["id"].split("/")[-1]) if keep_ids else uuid4()
r.package = pack
r.name = rule["name"]
r.description = rule["description"]
r.aliases = rule.get("aliases", [])
r.smirks = rule["smirks"]
r.reactant_filter_smarts = rule.get("reactantFilterSmarts", None)
r.product_filter_smarts = rule.get("productFilterSmarts", None)
r.save()
mapping[rule["id"]] = r.uuid
for scen in rule["scenarios"]:
scen_mapping[scen["id"]].append(r)
for enzyme_link in rule.get("enzymeLinks", []):
enzyme_mapping[r.uuid].append(enzyme_link)
print("Par: ", len(par_rules))
print("Seq: ", len(seq_rules))
for par_rule in par_rules:
r = ParallelRule()
r.package = pack
r.uuid = UUID(par_rule["id"].split("/")[-1]) if keep_ids else uuid4()
r.name = par_rule["name"]
r.description = par_rule["description"]
r.aliases = par_rule.get("aliases", [])
r.save()
mapping[par_rule["id"]] = r.uuid
for scen in par_rule["scenarios"]:
scen_mapping[scen["id"]].append(r)
for enzyme_link in par_rule.get("enzymeLinks", []):
enzyme_mapping[r.uuid].append(enzyme_link)
for simple_rule in par_rule["simpleRules"]:
if simple_rule["id"] in mapping:
r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule["id"]]))
r.save()
for seq_rule in seq_rules:
r = SequentialRule()
r.package = pack
r.uuid = UUID(seq_rule["id"].split("/")[-1]) if keep_ids else uuid4()
r.name = seq_rule["name"]
r.description = seq_rule["description"]
r.aliases = seq_rule.get("aliases", [])
r.save()
mapping[seq_rule["id"]] = r.uuid
for scen in seq_rule["scenarios"]:
scen_mapping[scen["id"]].append(r)
for enzyme_link in seq_rule.get("enzymeLinks", []):
enzyme_mapping[r.uuid].append(enzyme_link)
for i, simple_rule in enumerate(seq_rule["simpleRules"]):
sro = SequentialRuleOrdering()
sro.simple_rule = simple_rule
sro.sequential_rule = r
sro.order_index = i
sro.save()
# r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
r.save()
print("Rules imported...")
for reaction in data["reactions"]:
r = Reaction()
r.package = pack
r.uuid = UUID(reaction["id"].split("/")[-1]) if keep_ids else uuid4()
r.name = reaction["name"]
r.description = reaction["description"]
r.aliases = reaction.get("aliases", [])
r.medlinereferences = (reaction["medlinereferences"],)
r.multi_step = True if reaction["multistep"] == "true" else False
r.save()
mapping[reaction["id"]] = r.uuid
for scen in reaction["scenarios"]:
scen_mapping[scen["id"]].append(r)
for educt in reaction["educts"]:
r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt["id"]]))
for product in reaction["products"]:
r.products.add(CompoundStructure.objects.get(uuid=mapping[product["id"]]))
if "rules" in reaction:
for rule in reaction["rules"]:
try:
r.rules.add(Rule.objects.get(uuid=mapping[rule["id"]]))
except Exception as e:
print(f"Rule with id {rule['id']} not found!")
print(e)
r.save()
print("Reactions imported...")
for pathway in data["pathways"]:
pw = Pathway()
pw.package = pack
pw.uuid = UUID(pathway["id"].split("/")[-1]) if keep_ids else uuid4()
pw.name = pathway["name"]
pw.description = pathway["description"]
pw.aliases = pathway.get("aliases", [])
pw.save()
mapping[pathway["id"]] = pw.uuid
for scen in pathway["scenarios"]:
scen_mapping[scen["id"]].append(pw)
out_nodes_mapping = defaultdict(set)
for node in pathway["nodes"]:
n = Node()
n.uuid = UUID(node["id"].split("/")[-1]) if keep_ids else uuid4()
n.name = node["name"]
n.description = node.get("description")
n.aliases = node.get("aliases", [])
n.pathway = pw
n.depth = node["depth"]
n.default_node_label = CompoundStructure.objects.get(
uuid=mapping[node["defaultNodeLabel"]["id"]]
)
n.save()
mapping[node["id"]] = n.uuid
for scen in node["scenarios"]:
scen_mapping[scen["id"]].append(n)
for node_label in node["nodeLabels"]:
n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label["id"]]))
n.save()
for out_edge in node["outEdges"]:
out_nodes_mapping[n.uuid].add(out_edge)
for edge in pathway["edges"]:
e = Edge()
e.uuid = UUID(edge["id"].split("/")[-1]) if keep_ids else uuid4()
e.name = edge["name"]
e.pathway = pw
e.description = edge["description"]
e.aliases = edge.get("aliases", [])
e.edge_label = Reaction.objects.get(uuid=mapping[edge["edgeLabel"]["id"]])
e.save()
mapping[edge["id"]] = e.uuid
for scen in edge["scenarios"]:
scen_mapping[scen["id"]].append(e)
for start_node in edge["startNodes"]:
e.start_nodes.add(Node.objects.get(uuid=mapping[start_node]))
for end_node in edge["endNodes"]:
e.end_nodes.add(Node.objects.get(uuid=mapping[end_node]))
e.save()
for k, v in out_nodes_mapping.items():
n = Node.objects.get(uuid=k)
for v1 in v:
n.out_edges.add(Edge.objects.get(uuid=mapping[v1]))
n.save()
print("Pathways imported...")
# Linking Phase
for child, parent in parent_mapping.items():
child_obj = Scenario.objects.get(uuid=child)
parent_obj = Scenario.objects.get(uuid=mapping[parent])
child_obj.parent = parent_obj
child_obj.save()
for scen_id, objects in scen_mapping.items():
scen = Scenario.objects.get(uuid=mapping[scen_id])
for o in objects:
o.scenarios.add(scen)
o.save()
print("Scenarios linked...")
# Import Enzyme Links
for rule_uuid, enzyme_links in enzyme_mapping.items():
r = Rule.objects.get(uuid=rule_uuid)
for enzyme in enzyme_links:
e = EnzymeLink()
e.uuid = UUID(enzyme["id"].split("/")[-1]) if keep_ids else uuid4()
e.rule = r
e.name = enzyme["name"]
e.ec_number = enzyme["ecNumber"]
e.classification_level = enzyme["classificationLevel"]
e.linking_method = enzyme["linkingMethod"]
e.save()
for reaction in enzyme["reactionLinkEvidence"]:
reaction = Reaction.objects.get(uuid=mapping[reaction["id"]])
e.reaction_evidence.add(reaction)
for edge in enzyme["edgeLinkEvidence"]:
edge = Edge.objects.get(uuid=mapping[edge["id"]])
e.reaction_evidence.add(edge)
for evidence in enzyme["linkEvidence"]:
matches = re.findall(r">(R[0-9]+)<", evidence["evidence"])
if not matches or len(matches) != 1:
logger.warning(f"Could not find reaction id in {evidence['evidence']}")
continue
e.add_kegg_reaction_id(matches[0])
e.save()
print("Enzyme links imported...")
print("Import statistics:")
print("Package {} stored".format(pack.url))
print("Imported {} compounds".format(Compound.objects.filter(package=pack).count()))
print("Imported {} rules".format(Rule.objects.filter(package=pack).count()))
print("Imported {} reactions".format(Reaction.objects.filter(package=pack).count()))
print("Imported {} pathways".format(Pathway.objects.filter(package=pack).count()))
print("Imported {} Scenarios".format(Scenario.objects.filter(package=pack).count()))
print("Fixing Node depths...")
total_pws = Pathway.objects.filter(package=pack).count()
for p, pw in enumerate(Pathway.objects.filter(package=pack)):
print(pw.url)
in_count = defaultdict(lambda: 0)
out_count = defaultdict(lambda: 0)
for e in pw.edges:
# TODO check if this will remain
for react in e.start_nodes.all():
out_count[str(react.uuid)] += 1
for prod in e.end_nodes.all():
in_count[str(prod.uuid)] += 1
root_nodes = []
for n in pw.nodes:
num_parents = in_count[str(n.uuid)]
if num_parents == 0:
# must be a root node or unconnected node
if n.depth != 0:
n.depth = 0
n.save()
# Only root node may have children
if out_count[str(n.uuid)] > 0:
root_nodes.append(n)
levels = [root_nodes]
seen = set()
# Do a bfs to determine depths starting with level 0 a.k.a. root nodes
for i, level_nodes in enumerate(levels):
new_level = []
for n in level_nodes:
for e in n.out_edges.all():
for prod in e.end_nodes.all():
if str(prod.uuid) not in seen:
old_depth = prod.depth
if old_depth != i + 1:
print(f"updating depth from {old_depth} to {i + 1}")
prod.depth = i + 1
prod.save()
new_level.append(prod)
seen.add(str(n.uuid))
if new_level:
levels.append(new_level)
print(f"{p + 1}/{total_pws} fixed.")
return pack
@staticmethod
@transaction.atomic
def import_package(
data: Dict[str, Any],
owner: User,
preserve_uuids=False,
add_import_timestamp=True,
trust_reviewed=False,
) -> Package:
importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed)
imported_package = importer.do_import()
up = UserPackagePermission()
up.user = owner
up.package = imported_package
up.permission = up.ALL[0]
up.save()
return imported_package
@staticmethod
def export_package(
package: Package, include_models: bool = False, include_external_identifiers: bool = True
) -> Dict[str, Any]:
return PackageExporter(package).do_export()
class SettingManager(object):
setting_pattern = re.compile(
r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$"
)
@staticmethod
def get_setting_by_url(user, setting_url):
match = re.findall(SettingManager.setting_pattern, setting_url)
if match:
setting_id = match[0].split("/")[-1]
return SettingManager.get_setting_by_id(user, setting_id)
else:
raise ValueError(
"Requested URL {} does not contain a valid setting identifier!".format(setting_url)
)
@staticmethod
def get_setting_by_id(user, setting_id):
s = Setting.objects.get(uuid=setting_id)
if (
s.global_default
or s.public
or user.is_superuser
or UserSettingPermission.objects.filter(user=user, setting=s).exists()
):
return s
raise ValueError("Insufficient permissions to access Setting with ID {}".format(setting_id))
@staticmethod
def get_all_settings(user):
sp = UserSettingPermission.objects.filter(user=user).values("setting")
return (
Setting.objects.filter(id__in=sp)
| Setting.objects.filter(public=True)
| Setting.objects.filter(global_default=True)
).distinct()
@staticmethod
@transaction.atomic
def create_setting(
user: User,
name: str = None,
description: str = None,
max_nodes: int = None,
max_depth: int = None,
rule_packages: List[Package] = None,
model: EPModel = None,
model_threshold: float = None,
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
):
new_s = Setting()
# Clean for potential XSS
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
new_s.max_nodes = max_nodes
new_s.max_depth = max_depth
new_s.model = model
new_s.model_threshold = model_threshold
new_s.save()
if rule_packages is not None:
for r in rule_packages:
new_s.rule_packages.add(r)
new_s.save()
usp = UserSettingPermission()
usp.user = user
usp.setting = new_s
usp.permission = Permission.ALL[0]
usp.save()
return new_s
@staticmethod
def get_default_setting(user: User):
pass
@staticmethod
@transaction.atomic
def set_default_setting(user: User, setting: Setting):
pass
class SearchManager(object):
@staticmethod
def search(packages: Union[Package, List[Package]], searchterm: str, mode: str):
match mode:
case "text":
return SearchManager._search_text(packages, searchterm)
case "default":
return SearchManager._search_default_smiles(packages, searchterm)
case "exact":
return SearchManager._search_exact_smiles(packages, searchterm)
case "canonical":
return SearchManager._search_canonical_smiles(packages, searchterm)
case "inchikey":
return SearchManager._search_inchikey(packages, searchterm)
case _:
raise ValueError(f"Unknown search mode {mode}!")
@staticmethod
def _search_inchikey(packages: Union[Package, List[Package]], searchterm: str):
from django.db.models import Q
search_cond = Q(inchikey=searchterm)
compound_qs = Compound.objects.filter(
Q(package__in=packages) & Q(compoundstructure__inchikey=searchterm)
).distinct()
compound_structure_qs = CompoundStructure.objects.filter(
Q(compound__package__in=packages) & search_cond
)
reactions_qs = Reaction.objects.filter(
Q(package__in=packages)
& (Q(educts__inchikey=searchterm) | Q(products__inchikey=searchterm))
).distinct()
pathway_qs = Pathway.objects.filter(
Q(package__in=packages)
& (
Q(edge__edge_label__educts__inchikey=searchterm)
| Q(edge__edge_label__products__inchikey=searchterm)
)
).distinct()
return {
"Compounds": [
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
],
"Compound Structures": [
{"name": c.name, "description": c.description, "url": c.url}
for c in compound_structure_qs
],
"Reactions": [
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
],
"Pathways": [
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
],
}
@staticmethod
def _search_exact_smiles(packages: Union[Package, List[Package]], searchterm: str):
from django.db.models import Q
search_cond = Q(smiles=searchterm)
compound_qs = Compound.objects.filter(
Q(package__in=packages) & Q(compoundstructure__smiles=searchterm)
).distinct()
compound_structure_qs = CompoundStructure.objects.filter(
Q(compound__package__in=packages) & search_cond
)
reactions_qs = Reaction.objects.filter(
Q(package__in=packages)
& (Q(educts__smiles=searchterm) | Q(products__smiles=searchterm))
).distinct()
pathway_qs = Pathway.objects.filter(
Q(package__in=packages)
& (
Q(edge__edge_label__educts__smiles=searchterm)
| Q(edge__edge_label__products__smiles=searchterm)
)
).distinct()
return {
"Compounds": [
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
],
"Compound Structures": [
{"name": c.name, "description": c.description, "url": c.url}
for c in compound_structure_qs
],
"Reactions": [
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
],
"Pathways": [
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
],
}
@staticmethod
def _search_default_smiles(packages: Union[Package, List[Package]], searchterm: str):
from django.db.models import Q
inchi_front = FormatConverter.InChIKey(searchterm)[:14]
search_cond = Q(inchikey__startswith=inchi_front)
compound_qs = Compound.objects.filter(
Q(package__in=packages) & Q(compoundstructure__inchikey__startswith=inchi_front)
).distinct()
compound_structure_qs = CompoundStructure.objects.filter(
Q(compound__package__in=packages) & search_cond
)
reactions_qs = Reaction.objects.filter(
Q(package__in=packages)
& (
Q(educts__inchikey__startswith=inchi_front)
| Q(products__inchikey__startswith=inchi_front)
)
).distinct()
pathway_qs = Pathway.objects.filter(
Q(package__in=packages)
& (
Q(edge__edge_label__educts__inchikey__startswith=inchi_front)
| Q(edge__edge_label__products__inchikey__startswith=inchi_front)
)
).distinct()
return {
"Compounds": [
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
],
"Compound Structures": [
{"name": c.name, "description": c.description, "url": c.url}
for c in compound_structure_qs
],
"Reactions": [
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
],
"Pathways": [
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
],
}
@staticmethod
def _search_canonical_smiles(packages: Union[Package, List[Package]], searchterm: str):
from django.db.models import Q
search_cond = Q(canonical_smiles=searchterm)
compound_qs = Compound.objects.filter(
Q(package__in=packages) & Q(compoundstructure__canonical_smiles=searchterm)
).distinct()
compound_structure_qs = CompoundStructure.objects.filter(
Q(compound__package__in=packages) & search_cond
)
reactions_qs = Reaction.objects.filter(
Q(package__in=packages)
& (Q(educts__canonical_smiles=searchterm) | Q(products__canonical_smiles=searchterm))
).distinct()
pathway_qs = Pathway.objects.filter(
Q(package__in=packages)
& (
Q(edge__edge_label__educts__canonical_smiles=searchterm)
| Q(edge__edge_label__products__canonical_smiles=searchterm)
)
).distinct()
return {
"Compounds": [
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
],
"Compound Structures": [
{"name": c.name, "description": c.description, "url": c.url}
for c in compound_structure_qs
],
"Reactions": [
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
],
"Pathways": [
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
],
}
@staticmethod
def _search_text(packages: Union[Package, List[Package]], searchterm: str):
from django.db.models import Q
search_cond = Q(name__icontains=searchterm) | Q(description__icontains=searchterm)
cond = Q(package__in=packages) & search_cond
compound_qs = Compound.objects.filter(cond)
compound_structure_qs = CompoundStructure.objects.filter(
Q(compound__package__in=packages) & search_cond
)
rule_qs = Rule.objects.filter(cond)
reactions_qs = Reaction.objects.filter(cond)
pathway_qs = Pathway.objects.filter(cond)
res = {
"Compounds": [
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
],
"Compound Structures": [
{"name": c.name, "description": c.description, "url": c.url}
for c in compound_structure_qs
],
"Rules": [
{"name": r.name, "description": r.description, "url": r.url} for r in rule_qs
],
"Reactions": [
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
],
"Pathways": [
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
],
}
return res
class SNode(object):
def __init__(self, smiles: str, depth: int, app_domain_assessment: dict = None):
self.smiles = smiles
self.depth = depth
self.app_domain_assessment = app_domain_assessment
def __hash__(self):
return hash(self.smiles)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.smiles == other.smiles
return False
def __repr__(self):
return f"SNode('{self.smiles}', {self.depth})"
class SEdge(object):
def __init__(
self,
educts: Union[SNode, List[SNode]],
products: Union[SNode | List[SNode]],
rule: Optional["Rule"] = None,
probability: Optional[float] = None,
):
if not isinstance(educts, list):
educts = [educts]
self.educts = educts
self.products = products
self.rule = rule
self.probability = probability
def product_smiles(self):
return [p.smiles for p in self.products]
def __hash__(self):
full_hash = 0
for n in sorted(self.educts, key=lambda x: x.smiles):
full_hash += hash(n)
for n in sorted(self.products, key=lambda x: x.smiles):
full_hash += hash(n)
if self.rule is not None:
full_hash += hash(self.rule)
return full_hash
def __eq__(self, other):
if not isinstance(other, SEdge):
return False
if (
self.rule is not None
and other.rule is None
or self.rule is None
and other.rule is not None
or self.rule != other.rule
):
return False
if not (len(self.educts) == len(other.educts)):
return False
for n1, n2 in zip(
sorted(self.educts, key=lambda x: x.smiles),
sorted(other.educts, key=lambda x: x.smiles),
):
if n1.smiles != n2.smiles:
return False
if not (len(self.products) == len(other.products)):
return False
for n1, n2 in zip(
sorted(self.products, key=lambda x: x.smiles),
sorted(other.products, key=lambda x: x.smiles),
):
if n1.smiles != n2.smiles:
return False
return True
def __repr__(self):
return f"SEdge({self.educts}, {self.products}, {self.rule})"
class SPathway(object):
def __init__(
self,
root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
persist: Optional["Pathway"] = None,
prediction_setting: Optional[Setting] = None,
):
self.root_nodes = []
self.persist = persist
self.snode_persist_lookup: Dict[SNode, Node] = dict()
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
self.prediction_setting = prediction_setting
if persist:
for n in persist.root_nodes:
snode = SNode(n.default_node_label.smiles, n.depth)
self.root_nodes.append(snode)
self.snode_persist_lookup[snode] = n
else:
if not isinstance(root_nodes, list):
root_nodes = [root_nodes]
for n in root_nodes:
if isinstance(n, str):
self.root_nodes.append(SNode(n, 0))
elif isinstance(n, SNode):
self.root_nodes.append(n)
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
self.edges: Set["SEdge"] = set()
self.done = False
self.empty_due_to_threshold = False
@staticmethod
def from_pathway(pw: "Pathway", persist: bool = True):
"""Initializes a SPathway with a state given by a Pathway"""
spw = SPathway(
root_nodes=pw.root_nodes, persist=pw if persist else None, prediction_setting=pw.setting
)
# root_nodes are already added in __init__, add remaining nodes
for n in pw.nodes:
snode = SNode(n.default_node_label.smiles, n.depth)
if snode.smiles not in spw.smiles_to_node:
spw.smiles_to_node[snode.smiles] = snode
spw.snode_persist_lookup[snode] = n
for e in pw.edges:
sub = []
prod = []
for n in e.start_nodes.all():
sub.append(spw.smiles_to_node[n.default_node_label.smiles])
for n in e.end_nodes.all():
prod.append(spw.smiles_to_node[n.default_node_label.smiles])
rule = None
if e.edge_label.rules.all():
rule = e.edge_label.rules.all().first()
prob = None
if e.kv.get("probability"):
prob = float(e.kv["probability"])
sedge = SEdge(sub, prod, rule=rule, probability=prob)
spw.edges.add(sedge)
spw.sedge_persist_lookup[sedge] = e
return spw
def num_nodes(self):
return len(self.smiles_to_node.keys())
def depth(self):
return max([v.depth for v in self.smiles_to_node.values()])
def _get_nodes_for_depth(self, depth: int) -> List[SNode]:
if depth == 0:
return self.root_nodes
res = []
for n in self.smiles_to_node.values():
if n.depth == depth:
res.append(n)
return sorted(res, key=lambda x: x.smiles)
def _get_edges_for_depth(self, depth: int) -> List[SEdge]:
res = []
for e in self.edges:
for n in e.educts:
if n.depth == depth:
res.append(e)
return sorted(res, key=lambda x: hash(x))
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
"""
Expands the given substrates by generating new nodes and edges based on prediction settings.
This method processes a list of substrates and expands them into new nodes and edges using defined
rules and settings. It evaluates each substrate to determine its applicability domain, persists
domain assessments, and generates candidates for further processing. Newly created nodes and edges
are returned, and any applicable information is stored or updated internally during the process.
Parameters:
substrates (List[SNode]): A list of substrate nodes to be expanded.
Returns:
Tuple[List[SNode], List[SEdge]]:
A tuple containing:
- A list of new nodes generated during the expansion.
- A list of new edges representing connections between nodes based on candidate reactions.
Raises:
ValueError: If a node does not have an ID when it should have been saved already.
"""
new_nodes: List[SNode] = []
new_edges: List[SEdge] = []
for sub in substrates:
# For App Domain we have to ensure that each Node is evaluated
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
if n.id is None:
raise ValueError(f"Node {n} has no ID... aborting!")
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
sub.app_domain_assessment = app_domain_assessment
expansion_result = self.prediction_setting.expand(self, sub)
# We don't have any substrate, but technically we have at least one rule that triggered.
# If our substrate is a root node a.k.a. depth == 0 store that info in SPathway
if (
len(expansion_result["transformations"]) == 0
and expansion_result["rule_triggered"]
and sub.depth == 0
):
self.empty_due_to_threshold = True
# Emit directly
if self.persist is not None:
self.persist.kv["empty_due_to_threshold"] = True
self.persist.save()
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
for cand_set in expansion_result["transformations"]:
if cand_set:
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
for cand in cand_set:
cand_nodes = []
# candidate reactions can have multiple fragments
for c in cand:
if c not in self.smiles_to_node:
# For new nodes do an AppDomain Assessment if an AppDomain is attached
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)
)
snode = SNode(c, sub.depth + 1, app_domain_assessment)
self.smiles_to_node[c] = snode
new_nodes.append(snode)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(
sub,
cand_nodes,
rule=cand_set.rule,
probability=cand_set.probability,
)
self.edges.add(edge)
new_edges.append(edge)
return new_nodes, new_edges
def predict(self):
"""
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
nodes are processed systematically according to the specified schema.
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
supports persisting changes by writing back data to the database when configured to do so.
Attributes
----------
done : bool
A flag indicating whether the prediction process is completed.
persist : Any
An optional object that manages persistence operations for saving modifications.
root_nodes : List[SNode]
A collection of initial nodes in the graph from which traversal begins.
prediction_setting : Any
Configuration object specifying settings for graph traversal, such as the choice of
expansion schema.
Raises
------
ValueError
If an invalid or unknown expansion schema is provided in `prediction_setting`.
"""
# populate initial queue
queue = list(self.root_nodes)
processed = set()
# initial nodes have prob 1.0
node_probs: Dict[SNode, float] = {}
node_probs.update({n: 1.0 for n in queue})
while queue:
current = queue.pop(0)
if current in processed:
continue
processed.add(current)
new_nodes, new_edges = self._expand([current])
if new_nodes or new_edges:
# Check if we need to write back data to the database
if self.persist:
self._sync_to_pathway()
# call save to update the internal modified field
self.persist.save()
if new_nodes:
for edge in new_edges:
# All edge have `current` as educt
# Use `current` and adjust probs
current_prob = node_probs[current]
for prod in edge.products:
# Either is a new product or a product and we found a path with a higher prob
if (
prod not in node_probs
or current_prob * edge.probability > node_probs[prod]
):
node_probs[prod] = current_prob * edge.probability
# Update Queue to proceed
if self.prediction_setting.expansion_scheme == "DFS":
for n in new_nodes:
if n not in processed:
# We want to follow this path -> prepend queue
queue.insert(0, n)
elif self.prediction_setting.expansion_scheme == "BFS":
for n in new_nodes:
if n not in processed:
# Add at the end, everything queued before will be processed
# before new_nodese
queue.append(n)
elif self.prediction_setting.expansion_scheme == "GREEDY":
# Simply add them, as we will re-order the queue later
for n in new_nodes:
if n not in processed:
queue.append(n)
node_and_probs = []
for queued_val in queue:
node_and_probs.append((queued_val, node_probs[queued_val]))
# re-order the queue and only pick smiles
queue = [
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
]
else:
raise ValueError(
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
)
# Queue exhausted, we're done
self.done = True
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
substrates: List[SNode] = []
if from_depth is not None:
substrates = self._get_nodes_for_depth(from_depth)
elif from_node is not None:
for k, v in self.snode_persist_lookup.items():
if from_node == v:
substrates = [k]
break
else:
raise ValueError(f"Node {from_node} not found in SPathway!")
else:
raise ValueError("Neither from_depth nor from_node_url specified")
new_tp = False
if substrates:
new_nodes, _ = self._expand(substrates)
new_tp = len(new_nodes) > 0
# In case no substrates are found, we're done.
# For "predict from node" we're always done
if len(substrates) == 0 or from_node is not None:
self.done = True
# Check if we need to write back data to the database
if new_tp and self.persist:
self._sync_to_pathway()
# call save to update the internal modified field
self.persist.save()
def get_edge_for_educt_smiles(self, smiles: str) -> List[SEdge]:
res = []
for e in self.edges:
for n in e.educts:
if n.smiles == smiles:
res.append(e)
return res
def _sync_to_pathway(self) -> None:
logger.info("Updating Pathway with SPathway")
for snode in self.smiles_to_node.values():
if snode not in self.snode_persist_lookup:
n = Node.create(self.persist, snode.smiles, snode.depth)
if snode.app_domain_assessment is not None:
app_domain_assessment = snode.app_domain_assessment
assert n.id is not None, (
"Node has no id! Should have been saved already... aborting!"
)
node_data = n.simple_json()
node_data["image"] = f"{n.url}?image=svg"
app_domain_assessment["assessment"]["node"] = node_data
n.kv["app_domain_assessment"] = app_domain_assessment
n.save()
self.snode_persist_lookup[snode] = n
for sedge in self.edges:
if sedge not in self.sedge_persist_lookup:
educt_nodes = []
for snode in sedge.educts:
educt_nodes.append(self.snode_persist_lookup[snode])
product_nodes = []
for snode in sedge.products:
product_nodes.append(self.snode_persist_lookup[snode])
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
if sedge.probability:
e.kv.update({"probability": sedge.probability})
e.save()
self.sedge_persist_lookup[sedge] = e
logger.info("Update done!")
def to_json(self):
nodes = []
edges = []
idx_lookup = {}
for i, smiles in enumerate(self.smiles_to_node):
n = self.smiles_to_node[smiles]
idx_lookup[smiles] = i
nodes.append({"depth": n.depth, "smiles": n.smiles, "id": i})
for edge in self.edges:
from_idx = idx_lookup[edge.educts[0].smiles]
to_indices = [idx_lookup[p.smiles] for p in edge.products]
e = {
"from": from_idx,
"to": to_indices,
}
edges.append(e)
return {
"nodes": nodes,
"edges": edges,
}