forked from enviPath/enviPy
1633 lines
55 KiB
Python
1633 lines
55 KiB
Python
import re
|
||
import logging
|
||
import json
|
||
from typing import Union, List, Optional, Set, Dict, Any
|
||
from uuid import UUID
|
||
|
||
from django.contrib.auth import get_user_model
|
||
from django.db import transaction
|
||
from django.conf import settings as s
|
||
from pydantic import ValidationError
|
||
|
||
from epdb.models import (
|
||
User,
|
||
Package,
|
||
UserPackagePermission,
|
||
GroupPackagePermission,
|
||
Permission,
|
||
Group,
|
||
Setting,
|
||
EPModel,
|
||
UserSettingPermission,
|
||
Rule,
|
||
Pathway,
|
||
Node,
|
||
Edge,
|
||
Compound,
|
||
Reaction,
|
||
CompoundStructure,
|
||
)
|
||
from utilities.chem import FormatConverter
|
||
from utilities.misc import PackageImporter, PackageExporter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class EPDBURLParser:
|
||
UUID_PATTERN = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
|
||
MODEL_PATTERNS = {
|
||
"epdb.User": re.compile(rf"^.*/user/{UUID_PATTERN}"),
|
||
"epdb.Group": re.compile(rf"^.*/group/{UUID_PATTERN}"),
|
||
"epdb.Package": re.compile(rf"^.*/package/{UUID_PATTERN}"),
|
||
"epdb.Compound": re.compile(rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}"),
|
||
"epdb.CompoundStructure": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}/structure/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Rule": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/(?:simple-ambit-rule|simple-rdkit-rule|parallel-rule|sequential-rule|rule)/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Reaction": re.compile(rf"^.*/package/{UUID_PATTERN}/reaction/{UUID_PATTERN}$"),
|
||
"epdb.Pathway": re.compile(rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}"),
|
||
"epdb.Node": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/node/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Edge": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/edge/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Scenario": re.compile(rf"^.*/package/{UUID_PATTERN}/scenario/{UUID_PATTERN}"),
|
||
"epdb.EPModel": re.compile(rf"^.*/package/{UUID_PATTERN}/model/{UUID_PATTERN}"),
|
||
"epdb.Setting": re.compile(rf"^.*/setting/{UUID_PATTERN}"),
|
||
}
|
||
|
||
def __init__(self, url: str):
|
||
self.url = url
|
||
self._matches = {}
|
||
self._analyze_url()
|
||
|
||
def _analyze_url(self):
|
||
for model_path, pattern in self.MODEL_PATTERNS.items():
|
||
match = pattern.findall(self.url)
|
||
if match:
|
||
self._matches[model_path] = match[0]
|
||
|
||
def _get_model_class(self, model_path: str):
|
||
try:
|
||
from django.apps import apps
|
||
|
||
app_label, model_name = model_path.split(".")[-2:]
|
||
return apps.get_model(app_label, model_name)
|
||
except (ImportError, LookupError, ValueError):
|
||
raise ValueError(f"Model {model_path} does not exist!")
|
||
|
||
def _get_object_by_url(self, model_path: str, url: str):
|
||
model_class = self._get_model_class(model_path)
|
||
return model_class.objects.get(url=url)
|
||
|
||
def is_package_url(self) -> bool:
|
||
return bool(re.compile(rf"^.*/package/{self.UUID_PATTERN}$").findall(self.url))
|
||
|
||
def contains_package_url(self):
|
||
return (
|
||
bool(self.MODEL_PATTERNS["epdb.Package"].findall(self.url))
|
||
and not self.is_package_url()
|
||
)
|
||
|
||
def is_user_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS["epdb.User"].findall(self.url))
|
||
|
||
def is_group_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS["epdb.Group"].findall(self.url))
|
||
|
||
def is_setting_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS["epdb.Setting"].findall(self.url))
|
||
|
||
def get_object(self) -> Optional[Any]:
|
||
# Define priority order from most specific to least specific
|
||
priority_order = [
|
||
# 3rd level
|
||
"epdb.CompoundStructure",
|
||
"epdb.Node",
|
||
"epdb.Edge",
|
||
# 2nd level
|
||
"epdb.Compound",
|
||
"epdb.Rule",
|
||
"epdb.Reaction",
|
||
"epdb.Scenario",
|
||
"epdb.EPModel",
|
||
"epdb.Pathway",
|
||
# 1st level
|
||
"epdb.Package",
|
||
"epdb.Setting",
|
||
"epdb.Group",
|
||
"epdb.User",
|
||
]
|
||
|
||
for model_path in priority_order:
|
||
if model_path in self._matches:
|
||
url = self._matches[model_path]
|
||
return self._get_object_by_url(model_path, url)
|
||
|
||
raise ValueError(f"No object found for URL {self.url}")
|
||
|
||
def get_objects(self) -> List[Any]:
|
||
"""
|
||
Get all Django model objects along the URL path in hierarchical order.
|
||
Returns objects from parent to child (e.g., Package -> Compound -> Structure).
|
||
"""
|
||
objects = []
|
||
|
||
hierarchy_order = [
|
||
# 1st level
|
||
"epdb.Package",
|
||
"epdb.Setting",
|
||
"epdb.Group",
|
||
"epdb.User",
|
||
# 2nd level
|
||
"epdb.Compound",
|
||
"epdb.Rule",
|
||
"epdb.Reaction",
|
||
"epdb.Scenario",
|
||
"epdb.EPModel",
|
||
"epdb.Pathway",
|
||
# 3rd level
|
||
"epdb.CompoundStructure",
|
||
"epdb.Node",
|
||
"epdb.Edge",
|
||
]
|
||
|
||
for model_path in hierarchy_order:
|
||
if model_path in self._matches:
|
||
url = self._matches[model_path]
|
||
objects.append(self._get_object_by_url(model_path, url))
|
||
|
||
return objects
|
||
|
||
def __str__(self) -> str:
|
||
return f"EPDBURLParser(url='{self.url}')"
|
||
|
||
def __repr__(self) -> str:
|
||
return f"EPDBURLParser(url='{self.url}', matches={list(self._matches.keys())})"
|
||
|
||
|
||
class UserManager(object):
|
||
user_pattern = re.compile(
|
||
r".*/user/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
)
|
||
|
||
@staticmethod
|
||
def is_user_url(url: str):
|
||
return bool(re.findall(UserManager.user_pattern, url))
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_user(
|
||
username, email, password, set_setting=True, add_to_group=True, *args, **kwargs
|
||
):
|
||
# avoid circular import :S
|
||
from .tasks import send_registration_mail
|
||
|
||
extra_fields = {"is_active": not s.ADMIN_APPROVAL_REQUIRED}
|
||
|
||
if "is_active" in kwargs:
|
||
extra_fields["is_active"] = kwargs["is_active"]
|
||
|
||
if "uuid" in kwargs:
|
||
extra_fields["uuid"] = kwargs["uuid"]
|
||
|
||
u = get_user_model().objects.create_user(username, email, password, **extra_fields)
|
||
|
||
# Create package
|
||
package_name = f"{u.username}{'’' if u.username[-1] in 'sxzß' else 's'} Package"
|
||
package_description = "This package was generated during registration."
|
||
p = PackageManager.create_package(u, package_name, package_description)
|
||
u.default_package = p
|
||
u.save()
|
||
|
||
if not u.is_active:
|
||
# send email for verification
|
||
send_registration_mail.delay(u.pk)
|
||
|
||
if set_setting:
|
||
u.default_setting = Setting.objects.get(global_default=True)
|
||
u.save()
|
||
|
||
if add_to_group:
|
||
g = Group.objects.get(public=True, name="enviPath Users")
|
||
g.user_member.add(u)
|
||
g.save()
|
||
u.default_group = g
|
||
u.save()
|
||
|
||
return u
|
||
|
||
@staticmethod
|
||
def get_user(user_url):
|
||
pass
|
||
|
||
@staticmethod
|
||
def get_user_by_id(user, user_uuid: str):
|
||
if str(user.uuid) != user_uuid and not user.is_superuser:
|
||
raise ValueError("Getting user failed!")
|
||
return get_user_model().objects.get(uuid=user_uuid)
|
||
|
||
@staticmethod
|
||
def get_user_lp(user_url: str):
|
||
uuid = user_url.strip().split("/")[-1]
|
||
return get_user_model().objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_users_lp():
|
||
return get_user_model().objects.all()
|
||
|
||
@staticmethod
|
||
def get_users():
|
||
raise ValueError("")
|
||
|
||
@staticmethod
|
||
def writable(current_user, user):
|
||
return (current_user == user) or user.is_superuser
|
||
|
||
|
||
class GroupManager(object):
|
||
group_pattern = re.compile(
|
||
r".*/group/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
)
|
||
|
||
@staticmethod
|
||
def is_group_url(url: str):
|
||
return bool(re.findall(GroupManager.group_pattern, url))
|
||
|
||
@staticmethod
|
||
def create_group(current_user, name, description):
|
||
g = Group()
|
||
g.name = name
|
||
g.description = description
|
||
g.owner = current_user
|
||
g.save()
|
||
|
||
g.user_member.add(current_user)
|
||
g.save()
|
||
|
||
return g
|
||
|
||
@staticmethod
|
||
def get_group_lp(group_url: str):
|
||
uuid = group_url.strip().split("/")[-1]
|
||
return Group.objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_groups_lp():
|
||
return Group.objects.all()
|
||
|
||
@staticmethod
|
||
def get_group_by_url(user, group_url):
|
||
return GroupManager.get_group_by_id(user, group_url.split("/")[-1])
|
||
|
||
@staticmethod
|
||
def get_group_by_id(user, group_id):
|
||
g = Group.objects.get(uuid=group_id)
|
||
|
||
if user in g.user_member.all():
|
||
return g
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_groups(user):
|
||
return Group.objects.filter(user_member=user)
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
|
||
if caller != group.owner and not caller.is_superuser:
|
||
raise ValueError("Only the group Owner is allowed to add members!")
|
||
|
||
if isinstance(member, Group):
|
||
if add_or_remove == "add":
|
||
group.group_member.add(member)
|
||
else:
|
||
group.group_member.remove(member)
|
||
else:
|
||
if add_or_remove == "add":
|
||
group.user_member.add(member)
|
||
else:
|
||
group.user_member.remove(member)
|
||
|
||
group.save()
|
||
|
||
@staticmethod
|
||
def writable(user, group):
|
||
return (user == group.owner) or user.is_superuser
|
||
|
||
|
||
class PackageManager(object):
|
||
package_pattern = re.compile(
|
||
r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
)
|
||
|
||
@staticmethod
|
||
def is_package_url(url: str):
|
||
return bool(re.findall(PackageManager.package_pattern, url))
|
||
|
||
@staticmethod
|
||
def get_reviewed_packages():
|
||
return Package.objects.filter(reviewed=True)
|
||
|
||
@staticmethod
|
||
def readable(user, package):
|
||
if (
|
||
UserPackagePermission.objects.filter(package=package, user=user).exists()
|
||
or GroupPackagePermission.objects.filter(
|
||
package=package, group__in=GroupManager.get_groups(user)
|
||
)
|
||
or package.reviewed is True
|
||
or user.is_superuser
|
||
):
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def writable(user, package):
|
||
if (
|
||
UserPackagePermission.objects.filter(
|
||
package=package, user=user, permission=Permission.WRITE[0]
|
||
).exists()
|
||
or GroupPackagePermission.objects.filter(
|
||
package=package,
|
||
group__in=GroupManager.get_groups(user),
|
||
permission=Permission.WRITE[0],
|
||
).exists()
|
||
or UserPackagePermission.objects.filter(
|
||
package=package, user=user, permission=Permission.ALL[0]
|
||
).exists()
|
||
or user.is_superuser
|
||
):
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def administrable(user, package):
|
||
if (
|
||
UserPackagePermission.objects.filter(
|
||
package=package, user=user, permission=Permission.ALL[0]
|
||
).exists()
|
||
or GroupPackagePermission.objects.filter(
|
||
package=package,
|
||
group__in=GroupManager.get_groups(user),
|
||
permission=Permission.ALL[0],
|
||
).exists()
|
||
or user.is_superuser
|
||
):
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def has_package_permission(user: "User", package: Union[str, UUID, "Package"], permission: str):
|
||
if isinstance(package, str) or isinstance(package, UUID):
|
||
package = Package.objects.get(uuid=package)
|
||
|
||
groups = GroupManager.get_groups(user)
|
||
|
||
perms = {"all": ["all"], "write": ["all", "write"], "read": ["all", "write", "read"]}
|
||
|
||
valid_perms = perms.get(permission)
|
||
|
||
if (
|
||
UserPackagePermission.objects.filter(
|
||
package=package, user=user, permission__in=valid_perms
|
||
).exists()
|
||
or GroupPackagePermission.objects.filter(
|
||
package=package, group__in=groups, permission__in=valid_perms
|
||
).exists()
|
||
or user.is_superuser
|
||
):
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def get_package_lp(package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
if match:
|
||
package_id = match[0].split("/")[-1]
|
||
return Package.objects.get(uuid=package_id)
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_package_by_url(user, package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
|
||
if match:
|
||
package_id = match[0].split("/")[-1]
|
||
return PackageManager.get_package_by_id(user, package_id)
|
||
else:
|
||
raise ValueError(
|
||
"Requested URL {} does not contain a valid package identifier!".format(package_url)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_package_by_id(user, package_id):
|
||
try:
|
||
p = Package.objects.get(uuid=package_id)
|
||
if PackageManager.readable(user, p):
|
||
return p
|
||
else:
|
||
raise ValueError(
|
||
"Insufficient permissions to access Package with ID {}".format(package_id)
|
||
)
|
||
except Package.DoesNotExist:
|
||
raise ValueError("Package with ID {} does not exist!".format(package_id))
|
||
|
||
@staticmethod
|
||
def get_all_readable_packages(user, include_reviewed=False):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
user_package_qs = Package.objects.filter(
|
||
id__in=UserPackagePermission.objects.filter(user=user).values("package").distinct()
|
||
)
|
||
group_package_qs = Package.objects.filter(
|
||
id__in=GroupPackagePermission.objects.filter(
|
||
group__in=GroupManager.get_groups(user)
|
||
)
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
if include_reviewed:
|
||
qs |= Package.objects.filter(reviewed=True)
|
||
else:
|
||
# remove package if user is owner and package is reviewed e.g. admin
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
return qs.distinct()
|
||
|
||
@staticmethod
|
||
def get_all_writeable_packages(user):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
write_user_packs = (
|
||
UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0])
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
owner_user_packs = (
|
||
UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0])
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
|
||
user_packs = write_user_packs | owner_user_packs
|
||
user_package_qs = Package.objects.filter(id__in=user_packs)
|
||
|
||
write_group_packs = (
|
||
GroupPackagePermission.objects.filter(
|
||
group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]
|
||
)
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
owner_group_packs = (
|
||
GroupPackagePermission.objects.filter(
|
||
group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]
|
||
)
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
|
||
group_packs = write_group_packs | owner_group_packs
|
||
group_package_qs = Package.objects.filter(id__in=group_packs)
|
||
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
return qs.distinct()
|
||
|
||
@staticmethod
|
||
def get_packages():
|
||
return Package.objects.all()
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_package(current_user, name: str, description: str = None):
|
||
p = Package()
|
||
p.name = name
|
||
p.description = description
|
||
p.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = current_user
|
||
up.package = p
|
||
up.permission = UserPackagePermission.ALL[0]
|
||
up.save()
|
||
|
||
return p
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_permissions(
|
||
caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]
|
||
):
|
||
caller_perm = None
|
||
if not caller.is_superuser:
|
||
caller_perm = UserPackagePermission.objects.get(user=caller, package=package).permission
|
||
|
||
if caller_perm != Permission.ALL[0] and not caller.is_superuser:
|
||
raise ValueError("Only owner are allowed to modify permissions")
|
||
|
||
data = {
|
||
"package": package,
|
||
}
|
||
|
||
if isinstance(grantee, User):
|
||
perm_cls = UserPackagePermission
|
||
data["user"] = grantee
|
||
else:
|
||
perm_cls = GroupPackagePermission
|
||
data["group"] = grantee
|
||
|
||
if new_perm is None:
|
||
qs = perm_cls.objects.filter(**data)
|
||
if qs.count() > 1:
|
||
raise ValueError("Got more Permission objects than expected!")
|
||
if qs.count() != 0:
|
||
logger.info(f"Deleting Perm {qs.first()}")
|
||
qs.delete()
|
||
else:
|
||
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
|
||
else:
|
||
_ = perm_cls.objects.update_or_create(defaults={"permission": new_perm}, **data)
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def import_legacy_package(
|
||
data: dict, owner: User, keep_ids=False, add_import_timestamp=True, trust_reviewed=False
|
||
):
|
||
from uuid import UUID, uuid4
|
||
from datetime import datetime
|
||
from collections import defaultdict
|
||
from .models import (
|
||
Package,
|
||
Compound,
|
||
CompoundStructure,
|
||
SimpleRule,
|
||
SimpleAmbitRule,
|
||
ParallelRule,
|
||
SequentialRule,
|
||
SequentialRuleOrdering,
|
||
Reaction,
|
||
Pathway,
|
||
Node,
|
||
Edge,
|
||
Scenario,
|
||
)
|
||
from envipy_additional_information import AdditionalInformationConverter
|
||
|
||
pack = Package()
|
||
pack.uuid = UUID(data["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
|
||
if add_import_timestamp:
|
||
pack.name = "{} - {}".format(data["name"], datetime.now().strftime("%Y-%m-%d %H:%M"))
|
||
else:
|
||
pack.name = data["name"]
|
||
|
||
if trust_reviewed:
|
||
pack.reviewed = True if data["reviewStatus"] == "reviewed" else False
|
||
else:
|
||
pack.reviewed = False
|
||
|
||
pack.description = data["description"]
|
||
pack.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = owner
|
||
up.package = pack
|
||
up.permission = up.ALL[0]
|
||
up.save()
|
||
|
||
# Stores old_id to new_id
|
||
mapping = {}
|
||
# Stores new_scen_id to old_parent_scen_id
|
||
parent_mapping = {}
|
||
# Mapping old scen_id to old_obj_id
|
||
scen_mapping = defaultdict(list)
|
||
|
||
# Store Scenarios
|
||
for scenario in data["scenarios"]:
|
||
scen = Scenario()
|
||
scen.package = pack
|
||
scen.uuid = UUID(scenario["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
scen.name = scenario["name"]
|
||
scen.description = scenario["description"]
|
||
scen.scenario_type = scenario["type"]
|
||
scen.scenario_date = scenario["date"]
|
||
scen.additional_information = dict()
|
||
scen.save()
|
||
|
||
mapping[scenario["id"]] = scen.uuid
|
||
|
||
new_add_inf = defaultdict(list)
|
||
# TODO Store AI...
|
||
for ex in scenario.get("additionalInformationCollection", {}).get(
|
||
"additionalInformation", []
|
||
):
|
||
name = ex["name"]
|
||
addinf_data = ex["data"]
|
||
|
||
# park the parent scen id for now and link it later
|
||
if name == "referringscenario":
|
||
parent_mapping[scen.uuid] = addinf_data
|
||
continue
|
||
|
||
# Broken eP Data
|
||
if name == "initialmasssediment" and addinf_data == "missing data":
|
||
continue
|
||
|
||
# TODO Enzymes arent ready yet
|
||
if name == "enzyme":
|
||
continue
|
||
|
||
try:
|
||
res = AdditionalInformationConverter.convert(name, addinf_data)
|
||
res_cls_name = res.__class__.__name__
|
||
ai_data = json.loads(res.model_dump_json())
|
||
ai_data["uuid"] = f"{uuid4()}"
|
||
new_add_inf[res_cls_name].append(ai_data)
|
||
except ValidationError:
|
||
logger.error(f"Failed to convert {name} with {addinf_data}")
|
||
|
||
scen.additional_information = new_add_inf
|
||
scen.save()
|
||
|
||
print("Scenarios imported...")
|
||
|
||
# Store compounds and its structures
|
||
for compound in data["compounds"]:
|
||
comp = Compound()
|
||
comp.package = pack
|
||
comp.uuid = UUID(compound["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
comp.name = compound["name"]
|
||
comp.description = compound["description"]
|
||
comp.aliases = compound["aliases"]
|
||
comp.save()
|
||
|
||
mapping[compound["id"]] = comp.uuid
|
||
|
||
for scen in compound["scenarios"]:
|
||
scen_mapping[scen["id"]].append(comp)
|
||
|
||
default_structure = None
|
||
|
||
for structure in compound["structures"]:
|
||
struc = CompoundStructure()
|
||
# struc.object_url = Command.get_id(structure, keep_ids)
|
||
struc.compound = comp
|
||
struc.uuid = UUID(structure["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
struc.name = structure["name"]
|
||
struc.description = structure["description"]
|
||
struc.smiles = structure["smiles"]
|
||
struc.save()
|
||
|
||
for scen in structure["scenarios"]:
|
||
scen_mapping[scen["id"]].append(struc)
|
||
|
||
mapping[structure["id"]] = struc.uuid
|
||
|
||
if structure["id"] == compound["defaultStructure"]["id"]:
|
||
default_structure = struc
|
||
|
||
struc.save()
|
||
|
||
if default_structure is None:
|
||
raise ValueError("No default structure set")
|
||
|
||
comp.default_structure = default_structure
|
||
comp.save()
|
||
|
||
print("Compounds imported...")
|
||
|
||
# Store simple and parallel-rules
|
||
par_rules = []
|
||
seq_rules = []
|
||
|
||
for rule in data["rules"]:
|
||
if rule["identifier"] == "parallel-rule":
|
||
par_rules.append(rule)
|
||
continue
|
||
if rule["identifier"] == "sequential-rule":
|
||
seq_rules.append(rule)
|
||
continue
|
||
r = SimpleAmbitRule()
|
||
r.uuid = UUID(rule["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.package = pack
|
||
r.name = rule["name"]
|
||
r.description = rule["description"]
|
||
r.smirks = rule["smirks"]
|
||
r.reactant_filter_smarts = rule.get("reactantFilterSmarts", None)
|
||
r.product_filter_smarts = rule.get("productFilterSmarts", None)
|
||
r.save()
|
||
|
||
mapping[rule["id"]] = r.uuid
|
||
|
||
for scen in rule["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
print("Par: ", len(par_rules))
|
||
print("Seq: ", len(seq_rules))
|
||
|
||
for par_rule in par_rules:
|
||
r = ParallelRule()
|
||
r.package = pack
|
||
r.uuid = UUID(par_rule["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.name = par_rule["name"]
|
||
r.description = par_rule["description"]
|
||
r.save()
|
||
|
||
mapping[par_rule["id"]] = r.uuid
|
||
|
||
for scen in par_rule["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for simple_rule in par_rule["simpleRules"]:
|
||
if simple_rule["id"] in mapping:
|
||
r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule["id"]]))
|
||
|
||
r.save()
|
||
|
||
for seq_rule in seq_rules:
|
||
r = SequentialRule()
|
||
r.package = pack
|
||
r.uuid = UUID(seq_rule["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.name = seq_rule["name"]
|
||
r.description = seq_rule["description"]
|
||
r.save()
|
||
|
||
mapping[seq_rule["id"]] = r.uuid
|
||
|
||
for scen in seq_rule["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for i, simple_rule in enumerate(seq_rule["simpleRules"]):
|
||
sro = SequentialRuleOrdering()
|
||
sro.simple_rule = simple_rule
|
||
sro.sequential_rule = r
|
||
sro.order_index = i
|
||
sro.save()
|
||
# r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
|
||
|
||
r.save()
|
||
|
||
print("Rules imported...")
|
||
|
||
for reaction in data["reactions"]:
|
||
r = Reaction()
|
||
r.package = pack
|
||
r.uuid = UUID(reaction["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.name = reaction["name"]
|
||
r.description = reaction["description"]
|
||
r.medlinereferences = (reaction["medlinereferences"],)
|
||
r.multi_step = True if reaction["multistep"] == "true" else False
|
||
r.save()
|
||
|
||
mapping[reaction["id"]] = r.uuid
|
||
|
||
for scen in reaction["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for educt in reaction["educts"]:
|
||
r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt["id"]]))
|
||
|
||
for product in reaction["products"]:
|
||
r.products.add(CompoundStructure.objects.get(uuid=mapping[product["id"]]))
|
||
|
||
if "rules" in reaction:
|
||
for rule in reaction["rules"]:
|
||
try:
|
||
r.rules.add(Rule.objects.get(uuid=mapping[rule["id"]]))
|
||
except Exception as e:
|
||
print(f"Rule with id {rule['id']} not found!")
|
||
print(e)
|
||
|
||
r.save()
|
||
|
||
print("Reactions imported...")
|
||
|
||
for pathway in data["pathways"]:
|
||
pw = Pathway()
|
||
pw.package = pack
|
||
pw.uuid = UUID(pathway["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
pw.name = pathway["name"]
|
||
pw.description = pathway["description"]
|
||
pw.save()
|
||
|
||
mapping[pathway["id"]] = pw.uuid
|
||
for scen in pathway["scenarios"]:
|
||
scen_mapping[scen["id"]].append(pw)
|
||
|
||
out_nodes_mapping = defaultdict(set)
|
||
|
||
for node in pathway["nodes"]:
|
||
n = Node()
|
||
n.uuid = UUID(node["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
n.name = node["name"]
|
||
n.pathway = pw
|
||
n.depth = node["depth"]
|
||
n.default_node_label = CompoundStructure.objects.get(
|
||
uuid=mapping[node["defaultNodeLabel"]["id"]]
|
||
)
|
||
n.save()
|
||
|
||
mapping[node["id"]] = n.uuid
|
||
|
||
for scen in node["scenarios"]:
|
||
scen_mapping[scen["id"]].append(n)
|
||
|
||
for node_label in node["nodeLabels"]:
|
||
n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label["id"]]))
|
||
|
||
n.save()
|
||
|
||
for out_edge in node["outEdges"]:
|
||
out_nodes_mapping[n.uuid].add(out_edge)
|
||
|
||
for edge in pathway["edges"]:
|
||
e = Edge()
|
||
e.uuid = UUID(edge["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
e.name = edge["name"]
|
||
e.pathway = pw
|
||
e.description = edge["description"]
|
||
e.edge_label = Reaction.objects.get(uuid=mapping[edge["edgeLabel"]["id"]])
|
||
e.save()
|
||
|
||
mapping[edge["id"]] = e.uuid
|
||
|
||
for scen in edge["scenarios"]:
|
||
scen_mapping[scen["id"]].append(e)
|
||
|
||
for start_node in edge["startNodes"]:
|
||
e.start_nodes.add(Node.objects.get(uuid=mapping[start_node]))
|
||
|
||
for end_node in edge["endNodes"]:
|
||
e.end_nodes.add(Node.objects.get(uuid=mapping[end_node]))
|
||
|
||
e.save()
|
||
|
||
for k, v in out_nodes_mapping.items():
|
||
n = Node.objects.get(uuid=k)
|
||
for v1 in v:
|
||
n.out_edges.add(Edge.objects.get(uuid=mapping[v1]))
|
||
n.save()
|
||
|
||
print("Pathways imported...")
|
||
|
||
# Linking Phase
|
||
for child, parent in parent_mapping.items():
|
||
child_obj = Scenario.objects.get(uuid=child)
|
||
parent_obj = Scenario.objects.get(uuid=mapping[parent])
|
||
child_obj.parent = parent_obj
|
||
child_obj.save()
|
||
|
||
for scen_id, objects in scen_mapping.items():
|
||
scen = Scenario.objects.get(uuid=mapping[scen_id])
|
||
for o in objects:
|
||
o.scenarios.add(scen)
|
||
o.save()
|
||
|
||
print("Scenarios linked...")
|
||
|
||
print("Import statistics:")
|
||
print("Package {} stored".format(pack.url))
|
||
print("Imported {} compounds".format(Compound.objects.filter(package=pack).count()))
|
||
print("Imported {} rules".format(Rule.objects.filter(package=pack).count()))
|
||
print("Imported {} reactions".format(Reaction.objects.filter(package=pack).count()))
|
||
print("Imported {} pathways".format(Pathway.objects.filter(package=pack).count()))
|
||
print("Imported {} Scenarios".format(Scenario.objects.filter(package=pack).count()))
|
||
|
||
print("Fixing Node depths...")
|
||
total_pws = Pathway.objects.filter(package=pack).count()
|
||
for p, pw in enumerate(Pathway.objects.filter(package=pack)):
|
||
print(pw.url)
|
||
in_count = defaultdict(lambda: 0)
|
||
out_count = defaultdict(lambda: 0)
|
||
|
||
for e in pw.edges:
|
||
# TODO check if this will remain
|
||
for react in e.start_nodes.all():
|
||
out_count[str(react.uuid)] += 1
|
||
|
||
for prod in e.end_nodes.all():
|
||
in_count[str(prod.uuid)] += 1
|
||
|
||
root_nodes = []
|
||
for n in pw.nodes:
|
||
num_parents = in_count[str(n.uuid)]
|
||
if num_parents == 0:
|
||
# must be a root node or unconnected node
|
||
if n.depth != 0:
|
||
n.depth = 0
|
||
n.save()
|
||
|
||
# Only root node may have children
|
||
if out_count[str(n.uuid)] > 0:
|
||
root_nodes.append(n)
|
||
|
||
levels = [root_nodes]
|
||
seen = set()
|
||
# Do a bfs to determine depths starting with level 0 a.k.a. root nodes
|
||
for i, level_nodes in enumerate(levels):
|
||
new_level = []
|
||
for n in level_nodes:
|
||
for e in n.out_edges.all():
|
||
for prod in e.end_nodes.all():
|
||
if str(prod.uuid) not in seen:
|
||
old_depth = prod.depth
|
||
if old_depth != i + 1:
|
||
print(f"updating depth from {old_depth} to {i + 1}")
|
||
prod.depth = i + 1
|
||
prod.save()
|
||
|
||
new_level.append(prod)
|
||
|
||
seen.add(str(n.uuid))
|
||
|
||
if new_level:
|
||
levels.append(new_level)
|
||
|
||
print(f"{p + 1}/{total_pws} fixed.")
|
||
|
||
return pack
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def import_package(
|
||
data: Dict[str, Any],
|
||
owner: User,
|
||
preserve_uuids=False,
|
||
add_import_timestamp=True,
|
||
trust_reviewed=False,
|
||
) -> Package:
|
||
importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed)
|
||
imported_package = importer.do_import()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = owner
|
||
up.package = imported_package
|
||
up.permission = up.ALL[0]
|
||
up.save()
|
||
|
||
return imported_package
|
||
|
||
@staticmethod
|
||
def export_package(
|
||
package: Package, include_models: bool = False, include_external_identifiers: bool = True
|
||
) -> Dict[str, Any]:
|
||
return PackageExporter(package).do_export()
|
||
|
||
|
||
class SettingManager(object):
|
||
setting_pattern = re.compile(
|
||
r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$"
|
||
)
|
||
|
||
@staticmethod
|
||
def get_setting_by_url(user, setting_url):
|
||
match = re.findall(SettingManager.setting_pattern, setting_url)
|
||
|
||
if match:
|
||
setting_id = match[0].split("/")[-1]
|
||
return SettingManager.get_setting_by_id(user, setting_id)
|
||
else:
|
||
raise ValueError(
|
||
"Requested URL {} does not contain a valid setting identifier!".format(setting_url)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_setting_by_id(user, setting_id):
|
||
s = Setting.objects.get(uuid=setting_id)
|
||
|
||
if (
|
||
s.global_default
|
||
or s.public
|
||
or user.is_superuser
|
||
or UserSettingPermission.objects.filter(user=user, setting=s).exists()
|
||
):
|
||
return s
|
||
|
||
raise ValueError("Insufficient permissions to access Setting with ID {}".format(setting_id))
|
||
|
||
@staticmethod
|
||
def get_all_settings(user):
|
||
sp = UserSettingPermission.objects.filter(user=user).values("setting")
|
||
return (
|
||
Setting.objects.filter(id__in=sp)
|
||
| Setting.objects.filter(public=True)
|
||
| Setting.objects.filter(global_default=True)
|
||
).distinct()
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_setting(
|
||
user: User,
|
||
name: str = None,
|
||
description: str = None,
|
||
max_nodes: int = None,
|
||
max_depth: int = None,
|
||
rule_packages: List[Package] = None,
|
||
model: EPModel = None,
|
||
model_threshold: float = None,
|
||
):
|
||
s = Setting()
|
||
s.name = name
|
||
s.description = description
|
||
s.max_nodes = max_nodes
|
||
s.max_depth = max_depth
|
||
s.model = model
|
||
s.model_threshold = model_threshold
|
||
|
||
s.save()
|
||
|
||
if rule_packages is not None:
|
||
for r in rule_packages:
|
||
s.rule_packages.add(r)
|
||
s.save()
|
||
|
||
usp = UserSettingPermission()
|
||
usp.user = user
|
||
usp.setting = s
|
||
usp.permission = Permission.ALL[0]
|
||
usp.save()
|
||
|
||
return s
|
||
|
||
@staticmethod
|
||
def get_default_setting(user: User):
|
||
pass
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def set_default_setting(user: User, setting: Setting):
|
||
pass
|
||
|
||
|
||
class SearchManager(object):
|
||
@staticmethod
|
||
def search(packages: Union[Package, List[Package]], searchterm: str, mode: str):
|
||
match mode:
|
||
case "text":
|
||
return SearchManager._search_text(packages, searchterm)
|
||
case "default":
|
||
return SearchManager._search_default_smiles(packages, searchterm)
|
||
case "exact":
|
||
return SearchManager._search_exact_smiles(packages, searchterm)
|
||
case "canonical":
|
||
return SearchManager._search_canonical_smiles(packages, searchterm)
|
||
case "inchikey":
|
||
return SearchManager._search_inchikey(packages, searchterm)
|
||
case _:
|
||
raise ValueError(f"Unknown search mode {mode}!")
|
||
|
||
@staticmethod
|
||
def _search_inchikey(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(inchikey=searchterm)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__inchikey=searchterm)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (Q(educts__inchikey=searchterm) | Q(products__inchikey=searchterm))
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__inchikey=searchterm)
|
||
| Q(edge__edge_label__products__inchikey=searchterm)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_exact_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(smiles=searchterm)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__smiles=searchterm)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (Q(educts__smiles=searchterm) | Q(products__smiles=searchterm))
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__smiles=searchterm)
|
||
| Q(edge__edge_label__products__smiles=searchterm)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_default_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
inchi_front = FormatConverter.InChIKey(searchterm)[:14]
|
||
|
||
search_cond = Q(inchikey__startswith=inchi_front)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__inchikey__startswith=inchi_front)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(educts__inchikey__startswith=inchi_front)
|
||
| Q(products__inchikey__startswith=inchi_front)
|
||
)
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__inchikey__startswith=inchi_front)
|
||
| Q(edge__edge_label__products__inchikey__startswith=inchi_front)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_canonical_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(canonical_smiles=searchterm)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__canonical_smiles=searchterm)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (Q(educts__canonical_smiles=searchterm) | Q(products__canonical_smiles=searchterm))
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__canonical_smiles=searchterm)
|
||
| Q(edge__edge_label__products__canonical_smiles=searchterm)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_text(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(name__icontains=searchterm) | Q(description__icontains=searchterm)
|
||
cond = Q(package__in=packages) & search_cond
|
||
compound_qs = Compound.objects.filter(cond)
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
rule_qs = Rule.objects.filter(cond)
|
||
reactions_qs = Reaction.objects.filter(cond)
|
||
pathway_qs = Pathway.objects.filter(cond)
|
||
|
||
res = {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Rules": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in rule_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
return res
|
||
|
||
|
||
class SNode(object):
|
||
def __init__(self, smiles: str, depth: int, app_domain_assessment: dict = None):
|
||
self.smiles = smiles
|
||
self.depth = depth
|
||
self.app_domain_assessment = app_domain_assessment
|
||
|
||
def __hash__(self):
|
||
return hash(self.smiles)
|
||
|
||
def __eq__(self, other):
|
||
if isinstance(other, self.__class__):
|
||
return self.smiles == other.smiles
|
||
return False
|
||
|
||
def __repr__(self):
|
||
return f"SNode('{self.smiles}', {self.depth})"
|
||
|
||
|
||
class SEdge(object):
|
||
def __init__(
|
||
self,
|
||
educts: Union[SNode, List[SNode]],
|
||
products: Union[SNode | List[SNode]],
|
||
rule: Optional["Rule"] = None,
|
||
probability: Optional[float] = None,
|
||
):
|
||
if not isinstance(educts, list):
|
||
educts = [educts]
|
||
|
||
self.educts = educts
|
||
self.products = products
|
||
self.rule = rule
|
||
self.probability = probability
|
||
|
||
def __hash__(self):
|
||
full_hash = 0
|
||
|
||
for n in sorted(self.educts, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
for n in sorted(self.products, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
if self.rule is not None:
|
||
full_hash += hash(self.rule)
|
||
|
||
return full_hash
|
||
|
||
def __eq__(self, other):
|
||
if not isinstance(other, SEdge):
|
||
return False
|
||
|
||
if (
|
||
self.rule is not None
|
||
and other.rule is None
|
||
or self.rule is None
|
||
and other.rule is not None
|
||
or self.rule != other.rule
|
||
):
|
||
return False
|
||
|
||
if not (len(self.educts) == len(other.educts)):
|
||
return False
|
||
|
||
for n1, n2 in zip(
|
||
sorted(self.educts, key=lambda x: x.smiles),
|
||
sorted(other.educts, key=lambda x: x.smiles),
|
||
):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
if not (len(self.products) == len(other.products)):
|
||
return False
|
||
|
||
for n1, n2 in zip(
|
||
sorted(self.products, key=lambda x: x.smiles),
|
||
sorted(other.products, key=lambda x: x.smiles),
|
||
):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
return True
|
||
|
||
def __repr__(self):
|
||
return f"SEdge({self.educts}, {self.products}, {self.rule})"
|
||
|
||
|
||
class SPathway(object):
|
||
def __init__(
|
||
self,
|
||
root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
|
||
persist: Optional["Pathway"] = None,
|
||
prediction_setting: Optional[Setting] = None,
|
||
):
|
||
self.root_nodes = []
|
||
|
||
self.persist = persist
|
||
self.snode_persist_lookup: Dict[SNode, Node] = dict()
|
||
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
|
||
self.prediction_setting = prediction_setting
|
||
|
||
if persist:
|
||
for n in persist.root_nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
self.root_nodes.append(snode)
|
||
self.snode_persist_lookup[snode] = n
|
||
else:
|
||
if not isinstance(root_nodes, list):
|
||
root_nodes = [root_nodes]
|
||
|
||
for n in root_nodes:
|
||
if isinstance(n, str):
|
||
self.root_nodes.append(SNode(n, 0))
|
||
elif isinstance(n, SNode):
|
||
self.root_nodes.append(n)
|
||
|
||
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
|
||
self.edges: Set["SEdge"] = set()
|
||
self.done = False
|
||
|
||
@staticmethod
|
||
def from_pathway(pw: "Pathway", persist: bool = True):
|
||
"""Initializes a SPathway with a state given by a Pathway"""
|
||
spw = SPathway(
|
||
root_nodes=pw.root_nodes, persist=pw if persist else None, prediction_setting=pw.setting
|
||
)
|
||
# root_nodes are already added in __init__, add remaining nodes
|
||
for n in pw.nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
if snode.smiles not in spw.smiles_to_node:
|
||
spw.smiles_to_node[snode.smiles] = snode
|
||
spw.snode_persist_lookup[snode] = n
|
||
|
||
for e in pw.edges:
|
||
sub = []
|
||
prod = []
|
||
|
||
for n in e.start_nodes.all():
|
||
sub.append(spw.smiles_to_node[n.default_node_label.smiles])
|
||
|
||
for n in e.end_nodes.all():
|
||
prod.append(spw.smiles_to_node[n.default_node_label.smiles])
|
||
|
||
rule = None
|
||
if e.edge_label.rules.all():
|
||
rule = e.edge_label.rules.all().first()
|
||
|
||
prob = None
|
||
if e.kv.get("probability"):
|
||
prob = float(e.kv["probability"])
|
||
|
||
sedge = SEdge(sub, prod, rule=rule, probability=prob)
|
||
spw.edges.add(sedge)
|
||
spw.sedge_persist_lookup[sedge] = e
|
||
|
||
return spw
|
||
|
||
def num_nodes(self):
|
||
return len(self.smiles_to_node.keys())
|
||
|
||
def depth(self):
|
||
return max([v.depth for v in self.smiles_to_node.values()])
|
||
|
||
def _get_nodes_for_depth(self, depth: int) -> List[SNode]:
|
||
if depth == 0:
|
||
return self.root_nodes
|
||
|
||
res = []
|
||
for n in self.smiles_to_node.values():
|
||
if n.depth == depth:
|
||
res.append(n)
|
||
|
||
return sorted(res, key=lambda x: x.smiles)
|
||
|
||
def _get_edges_for_depth(self, depth: int) -> List[SEdge]:
|
||
res = []
|
||
for e in self.edges:
|
||
for n in e.educts:
|
||
if n.depth == depth:
|
||
res.append(e)
|
||
|
||
return sorted(res, key=lambda x: hash(x))
|
||
|
||
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
|
||
substrates: List[SNode] = []
|
||
|
||
if from_depth is not None:
|
||
substrates = self._get_nodes_for_depth(from_depth)
|
||
elif from_node is not None:
|
||
for k, v in self.snode_persist_lookup.items():
|
||
if from_node == v:
|
||
substrates = [k]
|
||
break
|
||
else:
|
||
raise ValueError("Neither from_depth nor from_node_url specified")
|
||
|
||
new_tp = False
|
||
if substrates:
|
||
for sub in substrates:
|
||
if sub.app_domain_assessment is None:
|
||
if self.prediction_setting.model:
|
||
if self.prediction_setting.model.app_domain:
|
||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
||
sub.smiles
|
||
)[0]
|
||
|
||
if self.persist is not None:
|
||
n = self.snode_persist_lookup[sub]
|
||
|
||
assert n.id is not None, (
|
||
"Node has no id! Should have been saved already... aborting!"
|
||
)
|
||
node_data = n.simple_json()
|
||
node_data["image"] = f"{n.url}?image=svg"
|
||
app_domain_assessment["assessment"]["node"] = node_data
|
||
|
||
n.kv["app_domain_assessment"] = app_domain_assessment
|
||
n.save()
|
||
|
||
sub.app_domain_assessment = app_domain_assessment
|
||
|
||
candidates = self.prediction_setting.expand(self, sub)
|
||
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
|
||
for cand_set in candidates:
|
||
if cand_set:
|
||
new_tp = True
|
||
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
|
||
for cand in cand_set:
|
||
cand_nodes = []
|
||
# candidate reactions can have multiple fragments
|
||
for c in cand:
|
||
if c not in self.smiles_to_node:
|
||
# For new nodes do an AppDomain Assessment if an AppDomain is attached
|
||
app_domain_assessment = None
|
||
if self.prediction_setting.model:
|
||
if self.prediction_setting.model.app_domain:
|
||
app_domain_assessment = (
|
||
self.prediction_setting.model.app_domain.assess(c)[
|
||
0
|
||
]
|
||
)
|
||
|
||
self.smiles_to_node[c] = SNode(
|
||
c, sub.depth + 1, app_domain_assessment
|
||
)
|
||
|
||
node = self.smiles_to_node[c]
|
||
cand_nodes.append(node)
|
||
|
||
edge = SEdge(
|
||
sub,
|
||
cand_nodes,
|
||
rule=cand_set.rule,
|
||
probability=cand_set.probability,
|
||
)
|
||
self.edges.add(edge)
|
||
|
||
# In case no substrates are found, we're done.
|
||
# For "predict from node" we're always done
|
||
if len(substrates) == 0 or from_node is not None:
|
||
self.done = True
|
||
|
||
# Check if we need to write back data to the database
|
||
if new_tp and self.persist:
|
||
self._sync_to_pathway()
|
||
# call save to update the internal modified field
|
||
self.persist.save()
|
||
|
||
def _sync_to_pathway(self) -> None:
|
||
logger.info("Updating Pathway with SPathway")
|
||
|
||
for snode in self.smiles_to_node.values():
|
||
if snode not in self.snode_persist_lookup:
|
||
n = Node.create(self.persist, snode.smiles, snode.depth)
|
||
|
||
if snode.app_domain_assessment is not None:
|
||
app_domain_assessment = snode.app_domain_assessment
|
||
|
||
assert n.id is not None, (
|
||
"Node has no id! Should have been saved already... aborting!"
|
||
)
|
||
node_data = n.simple_json()
|
||
node_data["image"] = f"{n.url}?image=svg"
|
||
app_domain_assessment["assessment"]["node"] = node_data
|
||
|
||
n.kv["app_domain_assessment"] = app_domain_assessment
|
||
n.save()
|
||
|
||
self.snode_persist_lookup[snode] = n
|
||
|
||
for sedge in self.edges:
|
||
if sedge not in self.sedge_persist_lookup:
|
||
educt_nodes = []
|
||
for snode in sedge.educts:
|
||
educt_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
product_nodes = []
|
||
for snode in sedge.products:
|
||
product_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
|
||
|
||
if sedge.probability:
|
||
e.kv.update({"probability": sedge.probability})
|
||
e.save()
|
||
|
||
self.sedge_persist_lookup[sedge] = e
|
||
|
||
logger.info("Update done!")
|
||
|
||
def to_json(self):
|
||
nodes = []
|
||
edges = []
|
||
|
||
idx_lookup = {}
|
||
|
||
for i, smiles in enumerate(self.smiles_to_node):
|
||
n = self.smiles_to_node[smiles]
|
||
idx_lookup[smiles] = i
|
||
|
||
nodes.append({"depth": n.depth, "smiles": n.smiles, "id": i})
|
||
|
||
for edge in self.edges:
|
||
from_idx = idx_lookup[edge.educts[0].smiles]
|
||
to_indices = [idx_lookup[p.smiles] for p in edge.products]
|
||
|
||
e = {
|
||
"from": from_idx,
|
||
"to": to_indices,
|
||
}
|
||
|
||
# if edge.rule:
|
||
# e['rule'] = {
|
||
# 'name': edge.rule.name,
|
||
# 'id': edge.rule.url,
|
||
# }
|
||
edges.append(e)
|
||
|
||
return {
|
||
"nodes": nodes,
|
||
"edges": edges,
|
||
}
|