forked from enviPath/enviPy
Initial bayer app Show Pack Classification Adjusted docker compose to bayer specifics Adjusted Dockerfile for Bayer Adding secret flags to group, add secret pools to packages Adjusted View for Package creation Prep configs, added Package Create Modal wip More on PES wip wip
1907 lines
67 KiB
Python
1907 lines
67 KiB
Python
import logging
|
||
import re
|
||
from typing import Any, Dict, List, Optional, Set, Union, Tuple
|
||
from uuid import UUID
|
||
|
||
import nh3
|
||
from django.conf import settings as s
|
||
from django.contrib.auth import get_user_model
|
||
from django.db import transaction
|
||
from django.db.models import QuerySet
|
||
from pydantic import ValidationError
|
||
|
||
from epdb.models import (
|
||
AdditionalInformation,
|
||
Compound,
|
||
CompoundStructure,
|
||
Edge,
|
||
EnzymeLink,
|
||
EPModel,
|
||
ExpansionSchemeChoice,
|
||
Group,
|
||
GroupPackagePermission,
|
||
Node,
|
||
Pathway,
|
||
Permission,
|
||
PropertyPluginModel,
|
||
Reaction,
|
||
Rule,
|
||
Setting,
|
||
User,
|
||
UserPackagePermission,
|
||
UserSettingPermission,
|
||
)
|
||
from utilities.chem import FormatConverter
|
||
from utilities.misc import PackageExporter, PackageImporter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
Package = s.GET_PACKAGE_MODEL()
|
||
|
||
|
||
class EPDBURLParser:
|
||
UUID_PATTERN = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
|
||
MODEL_PATTERNS = {
|
||
"epdb.User": re.compile(rf"^.*/user/{UUID_PATTERN}"),
|
||
"epdb.Group": re.compile(rf"^.*/group/{UUID_PATTERN}"),
|
||
"epdb.Package": re.compile(rf"^.*/package/{UUID_PATTERN}"),
|
||
"epdb.Compound": re.compile(rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}"),
|
||
"epdb.CompoundStructure": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}/structure/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Rule": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/(?:simple-ambit-rule|simple-rdkit-rule|parallel-rule|sequential-rule|rule)/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Reaction": re.compile(rf"^.*/package/{UUID_PATTERN}/reaction/{UUID_PATTERN}$"),
|
||
"epdb.Pathway": re.compile(rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}"),
|
||
"epdb.Node": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/node/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Edge": re.compile(
|
||
rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/edge/{UUID_PATTERN}"
|
||
),
|
||
"epdb.Scenario": re.compile(rf"^.*/package/{UUID_PATTERN}/scenario/{UUID_PATTERN}"),
|
||
"epdb.EPModel": re.compile(rf"^.*/package/{UUID_PATTERN}/model/{UUID_PATTERN}"),
|
||
"epdb.Setting": re.compile(rf"^.*/setting/{UUID_PATTERN}"),
|
||
}
|
||
|
||
def __init__(self, url: str):
|
||
self.url = url
|
||
self._matches = {}
|
||
self._analyze_url()
|
||
|
||
def _analyze_url(self):
|
||
for model_path, pattern in self.MODEL_PATTERNS.items():
|
||
match = pattern.findall(self.url)
|
||
if match:
|
||
self._matches[model_path] = match[0]
|
||
|
||
def _get_model_class(self, model_path: str):
|
||
try:
|
||
from django.apps import apps
|
||
|
||
app_label, model_name = model_path.split(".")[-2:]
|
||
return apps.get_model(app_label, model_name)
|
||
except (ImportError, LookupError, ValueError):
|
||
raise ValueError(f"Model {model_path} does not exist!")
|
||
|
||
def _get_object_by_url(self, model_path: str, url: str):
|
||
model_class = self._get_model_class(model_path)
|
||
return model_class.objects.get(url=url)
|
||
|
||
def is_package_url(self) -> bool:
|
||
return bool(re.compile(rf"^.*/package/{self.UUID_PATTERN}$").findall(self.url))
|
||
|
||
def contains_package_url(self):
|
||
return (
|
||
bool(self.MODEL_PATTERNS["epdb.Package"].findall(self.url))
|
||
and not self.is_package_url()
|
||
)
|
||
|
||
def is_user_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS["epdb.User"].findall(self.url))
|
||
|
||
def is_group_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS["epdb.Group"].findall(self.url))
|
||
|
||
def is_setting_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS["epdb.Setting"].findall(self.url))
|
||
|
||
def get_object(self) -> Optional[Any]:
|
||
# Define priority order from most specific to least specific
|
||
priority_order = [
|
||
# 3rd level
|
||
"epdb.CompoundStructure",
|
||
"epdb.Node",
|
||
"epdb.Edge",
|
||
# 2nd level
|
||
"epdb.Compound",
|
||
"epdb.Rule",
|
||
"epdb.Reaction",
|
||
"epdb.Scenario",
|
||
"epdb.EPModel",
|
||
"epdb.Pathway",
|
||
# 1st level
|
||
"epdb.Package",
|
||
"epdb.Setting",
|
||
"epdb.Group",
|
||
"epdb.User",
|
||
]
|
||
|
||
for model_path in priority_order:
|
||
if model_path in self._matches:
|
||
url = self._matches[model_path]
|
||
return self._get_object_by_url(model_path, url)
|
||
|
||
raise ValueError(f"No object found for URL {self.url}")
|
||
|
||
def get_objects(self) -> List[Any]:
|
||
"""
|
||
Get all Django model objects along the URL path in hierarchical order.
|
||
Returns objects from parent to child (e.g., Package -> Compound -> Structure).
|
||
"""
|
||
objects = []
|
||
|
||
hierarchy_order = [
|
||
# 1st level
|
||
"epdb.Package",
|
||
"epdb.Setting",
|
||
"epdb.Group",
|
||
"epdb.User",
|
||
# 2nd level
|
||
"epdb.Compound",
|
||
"epdb.Rule",
|
||
"epdb.Reaction",
|
||
"epdb.Scenario",
|
||
"epdb.EPModel",
|
||
"epdb.Pathway",
|
||
# 3rd level
|
||
"epdb.CompoundStructure",
|
||
"epdb.Node",
|
||
"epdb.Edge",
|
||
]
|
||
|
||
for model_path in hierarchy_order:
|
||
if model_path in self._matches:
|
||
url = self._matches[model_path]
|
||
objects.append(self._get_object_by_url(model_path, url))
|
||
|
||
return objects
|
||
|
||
def __str__(self) -> str:
|
||
return f"EPDBURLParser(url='{self.url}')"
|
||
|
||
def __repr__(self) -> str:
|
||
return f"EPDBURLParser(url='{self.url}', matches={list(self._matches.keys())})"
|
||
|
||
|
||
class UserManager(object):
|
||
user_pattern = re.compile(
|
||
r".*/user/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
)
|
||
|
||
@staticmethod
|
||
def is_user_url(url: str):
|
||
return bool(re.findall(UserManager.user_pattern, url))
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_user(
|
||
username, email, password, set_setting=True, add_to_group=True, *args, **kwargs
|
||
):
|
||
# Clean for potential XSS
|
||
clean_username = nh3.clean(username).strip()
|
||
clean_email = nh3.clean(email).strip()
|
||
if clean_username != username or clean_email != email:
|
||
# This will be caught by the try in view.py/register
|
||
raise ValueError("Invalid username or password")
|
||
|
||
extra_fields = {"is_active": not s.ADMIN_APPROVAL_REQUIRED}
|
||
|
||
if "is_active" in kwargs:
|
||
extra_fields["is_active"] = kwargs["is_active"]
|
||
|
||
if "uuid" in kwargs:
|
||
extra_fields["uuid"] = kwargs["uuid"]
|
||
|
||
u = get_user_model().objects.create_user(username, email, password, **extra_fields)
|
||
|
||
# Create package
|
||
package_name = f"{u.username}{'’' if u.username[-1] in 'sxzß' else 's'} Package"
|
||
package_description = "This package was generated during registration."
|
||
p = PackageManager.create_package(u, package_name, package_description)
|
||
u.default_package = p
|
||
u.save()
|
||
|
||
if set_setting:
|
||
u.default_setting = Setting.objects.get(global_default=True)
|
||
u.save()
|
||
|
||
if add_to_group:
|
||
g = Group.objects.get(public=True, name="enviPath Users")
|
||
g.user_member.add(u)
|
||
g.save()
|
||
u.default_group = g
|
||
u.save()
|
||
|
||
return u
|
||
|
||
@staticmethod
|
||
def get_user(user_url):
|
||
pass
|
||
|
||
@staticmethod
|
||
def get_user_by_id(user, user_uuid: str):
|
||
if str(user.uuid) != user_uuid and not user.is_superuser:
|
||
raise ValueError("Getting user failed!")
|
||
return get_user_model().objects.get(uuid=user_uuid)
|
||
|
||
@staticmethod
|
||
def get_user_lp(user_url: str):
|
||
uuid = user_url.strip().split("/")[-1]
|
||
return get_user_model().objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_users_lp():
|
||
return get_user_model().objects.all()
|
||
|
||
@staticmethod
|
||
def get_users():
|
||
raise ValueError("")
|
||
|
||
@staticmethod
|
||
def writable(current_user, user):
|
||
return (current_user == user) or user.is_superuser
|
||
|
||
|
||
class GroupManager(object):
|
||
group_pattern = re.compile(
|
||
r".*/group/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
)
|
||
|
||
@staticmethod
|
||
def is_group_url(url: str):
|
||
return bool(re.findall(GroupManager.group_pattern, url))
|
||
|
||
@staticmethod
|
||
def create_group(current_user, name, description, *args, **kwargs):
|
||
g = Group()
|
||
|
||
if "uuid" in kwargs:
|
||
g.uuid = kwargs["uuid"]
|
||
|
||
# Clean for potential XSS
|
||
g.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||
g.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||
g.owner = current_user
|
||
g.save()
|
||
|
||
g.user_member.add(current_user)
|
||
g.save()
|
||
|
||
return g
|
||
|
||
@staticmethod
|
||
def get_group_lp(group_url: str):
|
||
uuid = group_url.strip().split("/")[-1]
|
||
return Group.objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_groups_lp():
|
||
return Group.objects.all()
|
||
|
||
@staticmethod
|
||
def get_group_by_url(user, group_url):
|
||
return GroupManager.get_group_by_id(user, group_url.split("/")[-1])
|
||
|
||
@staticmethod
|
||
def get_group_by_id(user, group_id):
|
||
g = Group.objects.get(uuid=group_id)
|
||
|
||
if user in g.user_member.all():
|
||
return g
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_groups(user):
|
||
return Group.objects.filter(user_member=user)
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
|
||
if caller != group.owner and not caller.is_superuser:
|
||
raise ValueError("Only the group Owner is allowed to add members!")
|
||
|
||
if isinstance(member, Group):
|
||
if add_or_remove == "add":
|
||
group.group_member.add(member)
|
||
else:
|
||
group.group_member.remove(member)
|
||
else:
|
||
if add_or_remove == "add":
|
||
group.user_member.add(member)
|
||
else:
|
||
group.user_member.remove(member)
|
||
|
||
group.save()
|
||
|
||
@staticmethod
|
||
def writable(user, group):
|
||
return (user == group.owner) or user.is_superuser
|
||
|
||
|
||
class PackageManager(object):
|
||
package_pattern = re.compile(
|
||
r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||
)
|
||
|
||
@staticmethod
|
||
def is_package_url(url: str):
|
||
return bool(re.findall(PackageManager.package_pattern, url))
|
||
|
||
@staticmethod
|
||
def get_reviewed_packages():
|
||
return Package.objects.filter(reviewed=True)
|
||
|
||
@staticmethod
|
||
def readable(user, package):
|
||
return (
|
||
PackageManager.has_package_permission(user, package, "read") | package.reviewed is True
|
||
)
|
||
|
||
@staticmethod
|
||
def writable(user, package):
|
||
return PackageManager.has_package_permission(user, package, "write")
|
||
|
||
@staticmethod
|
||
def administrable(user, package):
|
||
return PackageManager.has_package_permission(user, package, "all")
|
||
|
||
@staticmethod
|
||
def has_package_permission(user: "User", package: Union[str, UUID, "Package"], permission: str):
|
||
if isinstance(package, str) or isinstance(package, UUID):
|
||
package = Package.objects.get(uuid=package)
|
||
|
||
groups = GroupManager.get_groups(user)
|
||
|
||
# EDIT START
|
||
|
||
if package.classification_level == Package.Classification.SECRET:
|
||
if package.data_pool not in groups:
|
||
return False
|
||
|
||
# EDIT END
|
||
|
||
perms = {"all": ["all"], "write": ["all", "write"], "read": ["all", "write", "read"]}
|
||
|
||
valid_perms = perms.get(permission)
|
||
|
||
if (
|
||
UserPackagePermission.objects.filter(
|
||
package=package, user=user, permission__in=valid_perms
|
||
).exists()
|
||
or GroupPackagePermission.objects.filter(
|
||
package=package, group__in=groups, permission__in=valid_perms
|
||
).exists()
|
||
or user.is_superuser
|
||
):
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def get_package_lp(package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
if match:
|
||
package_id = match[0].split("/")[-1]
|
||
return Package.objects.get(uuid=package_id)
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_package_by_url(user, package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
|
||
if match:
|
||
package_id = match[0].split("/")[-1]
|
||
return PackageManager.get_package_by_id(user, package_id)
|
||
else:
|
||
raise ValueError(
|
||
"Requested URL {} does not contain a valid package identifier!".format(package_url)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_package_by_id(user, package_id):
|
||
try:
|
||
p = Package.objects.get(uuid=package_id)
|
||
if PackageManager.readable(user, p):
|
||
p = PackageManager.check_package_classification(user, p)
|
||
return p
|
||
else:
|
||
# FIXME: use custom exception to be translatable to 403 in API
|
||
raise ValueError(
|
||
"Insufficient permissions to access Package with ID {}".format(package_id)
|
||
)
|
||
except Package.DoesNotExist:
|
||
raise ValueError("Package with ID {} does not exist!".format(package_id))
|
||
|
||
# EDIT START
|
||
|
||
@staticmethod
|
||
def check_package_classification(user, pack: Package):
|
||
if pack.classification_level == Package.Classification.SECRET:
|
||
if pack.data_pool.user_member.filter(id=user.id).exists():
|
||
return pack
|
||
|
||
raise ValueError("Package is secret and not accessible to user!")
|
||
|
||
else:
|
||
return pack
|
||
|
||
|
||
@staticmethod
|
||
def check_package_classifications(user, package_qs: QuerySet[Package]):
|
||
non_secret = package_qs.exclude(classification_level=Package.Classification.SECRET)
|
||
secret = package_qs.filter(classification_level=Package.Classification.SECRET)
|
||
|
||
# TODO we should be able to do via the db
|
||
accessible_secret = []
|
||
|
||
for s_package in secret:
|
||
if s_package.data_pool.user_member.filter(id=user.id).exists():
|
||
accessible_secret.append(s_package.pk)
|
||
|
||
# Cannot combine a unique query with a non-unique query -> we have to call distinct
|
||
return Package.objects.filter(pk__in=accessible_secret).distinct() | non_secret.distinct()
|
||
|
||
# EDIT END
|
||
|
||
@staticmethod
|
||
def get_all_readable_packages(user, include_reviewed=False):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
user_package_qs = Package.objects.filter(
|
||
id__in=UserPackagePermission.objects.filter(user=user).values("package").distinct()
|
||
)
|
||
group_package_qs = Package.objects.filter(
|
||
id__in=GroupPackagePermission.objects.filter(
|
||
group__in=GroupManager.get_groups(user)
|
||
)
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
if include_reviewed:
|
||
qs |= Package.objects.filter(reviewed=True)
|
||
else:
|
||
# remove package if user is owner and package is reviewed e.g. admin
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
qs = qs.distinct()
|
||
|
||
# EDIT START
|
||
qs = PackageManager.check_package_classifications(user, qs)
|
||
# EDIT END
|
||
|
||
return qs
|
||
|
||
@staticmethod
|
||
def get_all_writeable_packages(user):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
write_user_packs = (
|
||
UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0])
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
owner_user_packs = (
|
||
UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0])
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
|
||
user_packs = write_user_packs | owner_user_packs
|
||
user_package_qs = Package.objects.filter(id__in=user_packs)
|
||
|
||
write_group_packs = (
|
||
GroupPackagePermission.objects.filter(
|
||
group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]
|
||
)
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
owner_group_packs = (
|
||
GroupPackagePermission.objects.filter(
|
||
group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]
|
||
)
|
||
.values("package")
|
||
.distinct()
|
||
)
|
||
|
||
group_packs = write_group_packs | owner_group_packs
|
||
group_package_qs = Package.objects.filter(id__in=group_packs)
|
||
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
qs = qs.distinct()
|
||
|
||
# EDIT START
|
||
qs = PackageManager.check_package_classifications(user, qs)
|
||
# EDIT END
|
||
|
||
return qs
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_package(current_user, name: str, description: str = None):
|
||
p = Package()
|
||
|
||
# Clean for potential XSS
|
||
p.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||
|
||
if description is not None and description.strip() != "":
|
||
p.description = nh3.clean(description.strip(), tags=s.ALLOWED_HTML_TAGS).strip()
|
||
|
||
p.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = current_user
|
||
up.package = p
|
||
up.permission = UserPackagePermission.ALL[0]
|
||
up.save()
|
||
|
||
return p
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_permissions(
|
||
caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]
|
||
):
|
||
caller_perm = None
|
||
if not caller.is_superuser:
|
||
caller_perm = UserPackagePermission.objects.get(user=caller, package=package).permission
|
||
|
||
if caller_perm != Permission.ALL[0] and not caller.is_superuser:
|
||
raise ValueError("Only owner are allowed to modify permissions")
|
||
|
||
data = {
|
||
"package": package,
|
||
}
|
||
|
||
if isinstance(grantee, User):
|
||
perm_cls = UserPackagePermission
|
||
data["user"] = grantee
|
||
else:
|
||
perm_cls = GroupPackagePermission
|
||
data["group"] = grantee
|
||
|
||
if new_perm is None:
|
||
qs = perm_cls.objects.filter(**data)
|
||
if qs.count() > 1:
|
||
raise ValueError("Got more Permission objects than expected!")
|
||
if qs.count() != 0:
|
||
logger.info(f"Deleting Perm {qs.first()}")
|
||
qs.delete()
|
||
else:
|
||
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
|
||
else:
|
||
_ = perm_cls.objects.update_or_create(defaults={"permission": new_perm}, **data)
|
||
|
||
@staticmethod
|
||
def grant_read(caller: User, package: Package, grantee: Union[User, Group]):
|
||
PackageManager.update_permissions(caller, package, grantee, Permission.READ[0])
|
||
|
||
@staticmethod
|
||
def grant_write(caller: User, package: Package, grantee: Union[User, Group]):
|
||
PackageManager.update_permissions(caller, package, grantee, Permission.WRITE[0])
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def import_legacy_package(
|
||
data: dict, owner: User, keep_ids=False, add_import_timestamp=True, trust_reviewed=False
|
||
):
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
from uuid import UUID, uuid4
|
||
|
||
from envipy_additional_information import AdditionalInformationConverter
|
||
|
||
from .models import (
|
||
Compound,
|
||
CompoundStructure,
|
||
Edge,
|
||
Node,
|
||
ParallelRule,
|
||
Pathway,
|
||
Reaction,
|
||
Scenario,
|
||
SequentialRule,
|
||
SequentialRuleOrdering,
|
||
SimpleAmbitRule,
|
||
SimpleRule,
|
||
)
|
||
|
||
pack = Package()
|
||
pack.uuid = UUID(data["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
|
||
if add_import_timestamp:
|
||
pack.name = "{} - {}".format(data["name"], datetime.now().strftime("%Y-%m-%d %H:%M"))
|
||
else:
|
||
pack.name = data["name"]
|
||
|
||
if trust_reviewed:
|
||
pack.reviewed = True if data["reviewStatus"] == "reviewed" else False
|
||
else:
|
||
pack.reviewed = False
|
||
|
||
# EDIT START
|
||
if data.get("classification"):
|
||
if data["classification"] == "INTERNAL":
|
||
pack.classification = Package.Classification.RESTRICTED
|
||
elif data["classification"] == "RESTRICTED":
|
||
pack.classification = Package.Classification.RESTRICTED
|
||
elif data["classification"] == "SECRET":
|
||
pack.classification = Package.Classification.SECRET
|
||
|
||
if not "datapool" in data:
|
||
raise ValueError("Missing datapool in package")
|
||
|
||
g = Group.objects.get(uuid=data["datapool"].split('/')[-1])
|
||
pack.data_pool = g
|
||
else:
|
||
raise ValueError(f"Invalid classification {data['classification']}")
|
||
|
||
# EDIT END
|
||
|
||
pack.description = data["description"]
|
||
pack.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = owner
|
||
up.package = pack
|
||
up.permission = up.ALL[0]
|
||
up.save()
|
||
|
||
# Stores old_id to new_id
|
||
mapping = {}
|
||
# Mapping old scen_id to old_obj_id
|
||
scen_mapping = defaultdict(list)
|
||
# Enzymelink Mapping rule_id to enzymelink objects
|
||
enzyme_mapping = defaultdict(list)
|
||
|
||
# old_parent_id to child
|
||
postponed_scens = defaultdict(list)
|
||
|
||
# Store Scenarios
|
||
for scenario in data["scenarios"]:
|
||
skip_scen = False
|
||
# Check if parent exists and park this Scenario to convert it later into an
|
||
# AdditionalInformation object
|
||
for ex in scenario.get("additionalInformationCollection", {}).get(
|
||
"additionalInformation", []
|
||
):
|
||
if ex["name"] == "referringscenario":
|
||
postponed_scens[ex["data"]].append(scenario)
|
||
skip_scen = True
|
||
break
|
||
|
||
if skip_scen:
|
||
continue
|
||
|
||
scen = Scenario()
|
||
scen.package = pack
|
||
scen.uuid = UUID(scenario["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
scen.name = scenario["name"]
|
||
scen.description = scenario["description"]
|
||
scen.scenario_type = scenario["type"]
|
||
scen.scenario_date = scenario["date"]
|
||
scen.additional_information = dict()
|
||
scen.save()
|
||
|
||
mapping[scenario["id"]] = scen.uuid
|
||
|
||
for ex in scenario.get("additionalInformationCollection", {}).get(
|
||
"additionalInformation", []
|
||
):
|
||
name = ex["name"]
|
||
addinf_data = ex["data"]
|
||
|
||
# Broken eP Data
|
||
if name == "initialmasssediment" and addinf_data == "missing data":
|
||
continue
|
||
if name == "columnheight" and addinf_data == "(2)-(2.5);(6)-(8)":
|
||
continue
|
||
|
||
try:
|
||
ai = AdditionalInformationConverter.convert(name, addinf_data)
|
||
AdditionalInformation.create(pack, ai, scenario=scen)
|
||
except (ValidationError, ValueError):
|
||
logger.error(f"Failed to convert {name} with {addinf_data}")
|
||
|
||
print("Scenarios imported...")
|
||
|
||
# Store compounds and its structures
|
||
for compound in data["compounds"]:
|
||
comp = Compound()
|
||
comp.package = pack
|
||
comp.uuid = UUID(compound["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
comp.name = compound["name"]
|
||
comp.description = compound["description"]
|
||
comp.aliases = compound["aliases"]
|
||
comp.save()
|
||
|
||
mapping[compound["id"]] = comp.uuid
|
||
|
||
for scen in compound["scenarios"]:
|
||
scen_mapping[scen["id"]].append(comp)
|
||
|
||
default_structure = None
|
||
|
||
for structure in compound["structures"]:
|
||
if structure.get("pesLink"):
|
||
from bayer.models import PESStructure
|
||
struc = PESStructure()
|
||
struc.pes_link = structure["pesLink"]
|
||
else:
|
||
struc = CompoundStructure()
|
||
|
||
# struc.object_url = Command.get_id(structure, keep_ids)
|
||
struc.compound = comp
|
||
struc.uuid = UUID(structure["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
struc.name = structure["name"]
|
||
struc.description = structure["description"]
|
||
struc.aliases = structure.get("aliases", [])
|
||
struc.smiles = structure["smiles"]
|
||
|
||
if structure.get("molfile"):
|
||
struc.molfile = structure["molfile"]
|
||
|
||
struc.save()
|
||
|
||
for scen in structure["scenarios"]:
|
||
scen_mapping[scen["id"]].append(struc)
|
||
|
||
mapping[structure["id"]] = struc.uuid
|
||
|
||
if structure["id"] == compound["defaultStructure"]["id"]:
|
||
default_structure = struc
|
||
|
||
struc.save()
|
||
|
||
if default_structure is None:
|
||
raise ValueError("No default structure set")
|
||
|
||
comp.default_structure = default_structure
|
||
comp.save()
|
||
|
||
print("Compounds imported...")
|
||
|
||
# Store simple and parallel-rules
|
||
par_rules = []
|
||
seq_rules = []
|
||
|
||
for rule in data["rules"]:
|
||
if rule["identifier"] == "parallel-rule":
|
||
par_rules.append(rule)
|
||
continue
|
||
if rule["identifier"] == "sequential-rule":
|
||
seq_rules.append(rule)
|
||
continue
|
||
r = SimpleAmbitRule()
|
||
r.uuid = UUID(rule["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.package = pack
|
||
r.name = rule["name"]
|
||
r.description = rule["description"]
|
||
r.aliases = rule.get("aliases", [])
|
||
r.smirks = rule["smirks"]
|
||
r.reactant_filter_smarts = rule.get("reactantFilterSmarts", None)
|
||
r.product_filter_smarts = rule.get("productFilterSmarts", None)
|
||
r.save()
|
||
|
||
mapping[rule["id"]] = r.uuid
|
||
|
||
for scen in rule["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for enzyme_link in rule.get("enzymeLinks", []):
|
||
enzyme_mapping[r.uuid].append(enzyme_link)
|
||
|
||
print("Par: ", len(par_rules))
|
||
print("Seq: ", len(seq_rules))
|
||
|
||
for par_rule in par_rules:
|
||
r = ParallelRule()
|
||
r.package = pack
|
||
r.uuid = UUID(par_rule["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.name = par_rule["name"]
|
||
r.description = par_rule["description"]
|
||
r.aliases = par_rule.get("aliases", [])
|
||
r.save()
|
||
|
||
mapping[par_rule["id"]] = r.uuid
|
||
|
||
for scen in par_rule["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for enzyme_link in par_rule.get("enzymeLinks", []):
|
||
enzyme_mapping[r.uuid].append(enzyme_link)
|
||
|
||
for simple_rule in par_rule["simpleRules"]:
|
||
if simple_rule["id"] in mapping:
|
||
r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule["id"]]))
|
||
|
||
r.save()
|
||
|
||
for seq_rule in seq_rules:
|
||
r = SequentialRule()
|
||
r.package = pack
|
||
r.uuid = UUID(seq_rule["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.name = seq_rule["name"]
|
||
r.description = seq_rule["description"]
|
||
r.aliases = seq_rule.get("aliases", [])
|
||
r.save()
|
||
|
||
mapping[seq_rule["id"]] = r.uuid
|
||
|
||
for scen in seq_rule["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for enzyme_link in seq_rule.get("enzymeLinks", []):
|
||
enzyme_mapping[r.uuid].append(enzyme_link)
|
||
|
||
for i, simple_rule in enumerate(seq_rule["simpleRules"]):
|
||
sro = SequentialRuleOrdering()
|
||
sro.simple_rule = simple_rule
|
||
sro.sequential_rule = r
|
||
sro.order_index = i
|
||
sro.save()
|
||
# r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
|
||
|
||
r.save()
|
||
|
||
print("Rules imported...")
|
||
|
||
for reaction in data["reactions"]:
|
||
r = Reaction()
|
||
r.package = pack
|
||
r.uuid = UUID(reaction["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
r.name = reaction["name"]
|
||
r.description = reaction["description"]
|
||
r.aliases = reaction.get("aliases", [])
|
||
r.medlinereferences = (reaction["medlinereferences"],)
|
||
r.multi_step = True if reaction["multistep"] == "true" else False
|
||
r.save()
|
||
|
||
mapping[reaction["id"]] = r.uuid
|
||
|
||
for scen in reaction["scenarios"]:
|
||
scen_mapping[scen["id"]].append(r)
|
||
|
||
for educt in reaction["educts"]:
|
||
r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt["id"]]))
|
||
|
||
for product in reaction["products"]:
|
||
r.products.add(CompoundStructure.objects.get(uuid=mapping[product["id"]]))
|
||
|
||
if "rules" in reaction:
|
||
for rule in reaction["rules"]:
|
||
try:
|
||
r.rules.add(Rule.objects.get(uuid=mapping[rule["id"]]))
|
||
except Exception as e:
|
||
print(f"Rule with id {rule['id']} not found!")
|
||
print(e)
|
||
|
||
r.save()
|
||
|
||
print("Reactions imported...")
|
||
|
||
for pathway in data["pathways"]:
|
||
pw = Pathway()
|
||
pw.package = pack
|
||
pw.uuid = UUID(pathway["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
pw.name = pathway["name"]
|
||
pw.description = pathway["description"]
|
||
pw.aliases = pathway.get("aliases", [])
|
||
pw.save()
|
||
|
||
mapping[pathway["id"]] = pw.uuid
|
||
for scen in pathway["scenarios"]:
|
||
scen_mapping[scen["id"]].append(pw)
|
||
|
||
out_nodes_mapping = defaultdict(set)
|
||
|
||
for node in pathway["nodes"]:
|
||
n = Node()
|
||
n.uuid = UUID(node["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
n.name = node["name"]
|
||
n.description = node.get("description")
|
||
n.aliases = node.get("aliases", [])
|
||
n.pathway = pw
|
||
n.depth = node["depth"]
|
||
n.default_node_label = CompoundStructure.objects.get(
|
||
uuid=mapping[node["defaultNodeLabel"]["id"]]
|
||
)
|
||
n.save()
|
||
|
||
mapping[node["id"]] = n.uuid
|
||
|
||
for scen in node["scenarios"]:
|
||
scen_mapping[scen["id"]].append(n)
|
||
|
||
for node_label in node["nodeLabels"]:
|
||
n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label["id"]]))
|
||
|
||
n.save()
|
||
|
||
for out_edge in node["outEdges"]:
|
||
out_nodes_mapping[n.uuid].add(out_edge)
|
||
|
||
for edge in pathway["edges"]:
|
||
e = Edge()
|
||
e.uuid = UUID(edge["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
e.name = edge["name"]
|
||
e.pathway = pw
|
||
e.description = edge["description"]
|
||
e.aliases = edge.get("aliases", [])
|
||
e.edge_label = Reaction.objects.get(uuid=mapping[edge["edgeLabel"]["id"]])
|
||
e.save()
|
||
|
||
mapping[edge["id"]] = e.uuid
|
||
|
||
for scen in edge["scenarios"]:
|
||
scen_mapping[scen["id"]].append(e)
|
||
|
||
for start_node in edge["startNodes"]:
|
||
e.start_nodes.add(Node.objects.get(uuid=mapping[start_node]))
|
||
|
||
for end_node in edge["endNodes"]:
|
||
e.end_nodes.add(Node.objects.get(uuid=mapping[end_node]))
|
||
|
||
e.save()
|
||
|
||
for k, v in out_nodes_mapping.items():
|
||
n = Node.objects.get(uuid=k)
|
||
for v1 in v:
|
||
n.out_edges.add(Edge.objects.get(uuid=mapping[v1]))
|
||
n.save()
|
||
|
||
print("Pathways imported...")
|
||
|
||
for parent, children in postponed_scens.items():
|
||
for child in children:
|
||
for ex in child.get("additionalInformationCollection", {}).get(
|
||
"additionalInformation", []
|
||
):
|
||
child_id = child["id"]
|
||
name = ex["name"]
|
||
addinf_data = ex["data"]
|
||
|
||
if name == "referringscenario":
|
||
continue
|
||
# Broken eP Data
|
||
if name == "initialmasssediment" and addinf_data == "missing data":
|
||
continue
|
||
if name == "columnheight" and addinf_data == "(2)-(2.5);(6)-(8)":
|
||
continue
|
||
|
||
ai = AdditionalInformationConverter.convert(name, addinf_data)
|
||
|
||
if child_id not in scen_mapping:
|
||
logger.info(
|
||
f"{child_id} not found in scen_mapping. Seems like its not attached to any object"
|
||
)
|
||
print(
|
||
f"{child_id} not found in scen_mapping. Seems like its not attached to any object"
|
||
)
|
||
|
||
scen = Scenario.objects.get(uuid=mapping[parent])
|
||
mapping[child_id] = scen.uuid
|
||
for obj in scen_mapping[child_id]:
|
||
_ = AdditionalInformation.create(pack, ai, scen, content_object=obj)
|
||
|
||
for scen_id, objects in scen_mapping.items():
|
||
new_id = mapping.get(scen_id)
|
||
|
||
if new_id is None:
|
||
logger.warning(f"Could not find mapping for {scen_id}")
|
||
print(f"Could not find mapping for {scen_id}")
|
||
continue
|
||
|
||
scen = Scenario.objects.get(uuid=mapping[scen_id])
|
||
for o in objects:
|
||
o.scenarios.add(scen)
|
||
o.save()
|
||
|
||
print("Scenarios linked...")
|
||
|
||
# Import Enzyme Links
|
||
for rule_uuid, enzyme_links in enzyme_mapping.items():
|
||
r = Rule.objects.get(uuid=rule_uuid)
|
||
for enzyme in enzyme_links:
|
||
e = EnzymeLink()
|
||
e.uuid = UUID(enzyme["id"].split("/")[-1]) if keep_ids else uuid4()
|
||
e.rule = r
|
||
e.name = enzyme["name"]
|
||
e.ec_number = enzyme["ecNumber"]
|
||
e.classification_level = enzyme["classificationLevel"]
|
||
e.linking_method = enzyme["linkingMethod"]
|
||
e.save()
|
||
|
||
for reaction in enzyme["reactionLinkEvidence"]:
|
||
reaction = Reaction.objects.get(uuid=mapping[reaction["id"]])
|
||
e.reaction_evidence.add(reaction)
|
||
|
||
for edge in enzyme["edgeLinkEvidence"]:
|
||
edge = Edge.objects.get(uuid=mapping[edge["id"]])
|
||
e.reaction_evidence.add(edge)
|
||
|
||
for evidence in enzyme["linkEvidence"]:
|
||
matches = re.findall(r">(R[0-9]+)<", evidence["evidence"])
|
||
if not matches or len(matches) != 1:
|
||
logger.warning(f"Could not find reaction id in {evidence['evidence']}")
|
||
print(f"Could not find reaction id in {evidence['evidence']}")
|
||
continue
|
||
|
||
e.add_kegg_reaction_id(matches[0])
|
||
|
||
e.save()
|
||
|
||
print("Enzyme links imported...")
|
||
|
||
print("Import statistics:")
|
||
print("Package {} stored".format(pack.url))
|
||
print("Imported {} compounds".format(Compound.objects.filter(package=pack).count()))
|
||
print("Imported {} rules".format(Rule.objects.filter(package=pack).count()))
|
||
print("Imported {} reactions".format(Reaction.objects.filter(package=pack).count()))
|
||
print("Imported {} pathways".format(Pathway.objects.filter(package=pack).count()))
|
||
print("Imported {} Scenarios".format(Scenario.objects.filter(package=pack).count()))
|
||
|
||
print("Fixing Node depths...")
|
||
total_pws = Pathway.objects.filter(package=pack).count()
|
||
|
||
for p, pw in enumerate(Pathway.objects.filter(package=pack)):
|
||
pw.update_depths()
|
||
print(f"{p + 1}/{total_pws} fixed.", end="\r")
|
||
|
||
return pack
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def import_package(
|
||
data: Dict[str, Any],
|
||
owner: User,
|
||
preserve_uuids=False,
|
||
add_import_timestamp=True,
|
||
trust_reviewed=False,
|
||
) -> Package:
|
||
importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed)
|
||
imported_package = importer.do_import()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = owner
|
||
up.package = imported_package
|
||
up.permission = up.ALL[0]
|
||
up.save()
|
||
|
||
return imported_package
|
||
|
||
@staticmethod
|
||
def export_package(
|
||
package: Package, include_models: bool = False, include_external_identifiers: bool = True
|
||
) -> Dict[str, Any]:
|
||
return PackageExporter(package).do_export()
|
||
|
||
|
||
class SettingManager(object):
|
||
setting_pattern = re.compile(
|
||
r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$"
|
||
)
|
||
|
||
@staticmethod
|
||
def get_setting_by_url(user, setting_url):
|
||
match = re.findall(SettingManager.setting_pattern, setting_url)
|
||
|
||
if match:
|
||
setting_id = match[0].split("/")[-1]
|
||
return SettingManager.get_setting_by_id(user, setting_id)
|
||
else:
|
||
raise ValueError(
|
||
"Requested URL {} does not contain a valid setting identifier!".format(setting_url)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_setting_by_id(user, setting_id):
|
||
s = Setting.objects.get(uuid=setting_id)
|
||
|
||
if (
|
||
s.global_default
|
||
or s.public
|
||
or user.is_superuser
|
||
or UserSettingPermission.objects.filter(user=user, setting=s).exists()
|
||
):
|
||
return s
|
||
|
||
raise ValueError("Insufficient permissions to access Setting with ID {}".format(setting_id))
|
||
|
||
@staticmethod
|
||
def get_all_settings(user):
|
||
sp = UserSettingPermission.objects.filter(user=user).values("setting")
|
||
return (
|
||
Setting.objects.filter(id__in=sp)
|
||
| Setting.objects.filter(public=True)
|
||
| Setting.objects.filter(global_default=True)
|
||
).distinct()
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_setting(
|
||
user: User,
|
||
name: str = None,
|
||
description: str = None,
|
||
max_nodes: int = None,
|
||
max_depth: int = None,
|
||
rule_packages: List[Package] | None = None,
|
||
model: EPModel = None,
|
||
model_threshold: float = None,
|
||
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
|
||
property_models: List["PropertyPluginModel"] | None = None,
|
||
):
|
||
new_s = Setting()
|
||
|
||
# Clean for potential XSS
|
||
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||
|
||
new_s.max_nodes = max_nodes
|
||
new_s.max_depth = max_depth
|
||
new_s.model = model
|
||
new_s.model_threshold = model_threshold
|
||
new_s.expansion_scheme = expansion_scheme
|
||
|
||
new_s.save()
|
||
|
||
if rule_packages is not None:
|
||
for r in rule_packages:
|
||
new_s.rule_packages.add(r)
|
||
new_s.save()
|
||
|
||
if property_models is not None:
|
||
for pm in property_models:
|
||
new_s.property_models.add(pm)
|
||
new_s.save()
|
||
|
||
usp = UserSettingPermission()
|
||
usp.user = user
|
||
usp.setting = new_s
|
||
usp.permission = Permission.ALL[0]
|
||
usp.save()
|
||
|
||
return new_s
|
||
|
||
@staticmethod
|
||
def get_default_setting(user: User):
|
||
pass
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def set_default_setting(user: User, setting: Setting):
|
||
pass
|
||
|
||
|
||
class SearchManager(object):
|
||
@staticmethod
|
||
def search(packages: Union[Package, List[Package]], searchterm: str, mode: str):
|
||
match mode:
|
||
case "text":
|
||
return SearchManager._search_text(packages, searchterm)
|
||
case "default":
|
||
return SearchManager._search_default_smiles(packages, searchterm)
|
||
case "exact":
|
||
return SearchManager._search_exact_smiles(packages, searchterm)
|
||
case "canonical":
|
||
return SearchManager._search_canonical_smiles(packages, searchterm)
|
||
case "inchikey":
|
||
return SearchManager._search_inchikey(packages, searchterm)
|
||
case _:
|
||
raise ValueError(f"Unknown search mode {mode}!")
|
||
|
||
@staticmethod
|
||
def _search_inchikey(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(inchikey=searchterm)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__inchikey=searchterm)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (Q(educts__inchikey=searchterm) | Q(products__inchikey=searchterm))
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__inchikey=searchterm)
|
||
| Q(edge__edge_label__products__inchikey=searchterm)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_exact_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(smiles=searchterm)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__smiles=searchterm)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (Q(educts__smiles=searchterm) | Q(products__smiles=searchterm))
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__smiles=searchterm)
|
||
| Q(edge__edge_label__products__smiles=searchterm)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_default_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
inchi_front = FormatConverter.InChIKey(searchterm)[:14]
|
||
|
||
search_cond = Q(inchikey__startswith=inchi_front)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__inchikey__startswith=inchi_front)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(educts__inchikey__startswith=inchi_front)
|
||
| Q(products__inchikey__startswith=inchi_front)
|
||
)
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__inchikey__startswith=inchi_front)
|
||
| Q(edge__edge_label__products__inchikey__startswith=inchi_front)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_canonical_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(canonical_smiles=searchterm)
|
||
compound_qs = Compound.objects.filter(
|
||
Q(package__in=packages) & Q(compoundstructure__canonical_smiles=searchterm)
|
||
).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
reactions_qs = Reaction.objects.filter(
|
||
Q(package__in=packages)
|
||
& (Q(educts__canonical_smiles=searchterm) | Q(products__canonical_smiles=searchterm))
|
||
).distinct()
|
||
pathway_qs = Pathway.objects.filter(
|
||
Q(package__in=packages)
|
||
& (
|
||
Q(edge__edge_label__educts__canonical_smiles=searchterm)
|
||
| Q(edge__edge_label__products__canonical_smiles=searchterm)
|
||
)
|
||
).distinct()
|
||
|
||
return {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_text(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(name__icontains=searchterm) | Q(description__icontains=searchterm)
|
||
cond = Q(package__in=packages) & search_cond
|
||
compound_qs = Compound.objects.filter(cond)
|
||
compound_structure_qs = CompoundStructure.objects.filter(
|
||
Q(compound__package__in=packages) & search_cond
|
||
)
|
||
rule_qs = Rule.objects.filter(cond)
|
||
reactions_qs = Reaction.objects.filter(cond)
|
||
pathway_qs = Pathway.objects.filter(cond)
|
||
|
||
res = {
|
||
"Compounds": [
|
||
{"name": c.name, "description": c.description, "url": c.url} for c in compound_qs
|
||
],
|
||
"Compound Structures": [
|
||
{"name": c.name, "description": c.description, "url": c.url}
|
||
for c in compound_structure_qs
|
||
],
|
||
"Rules": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in rule_qs
|
||
],
|
||
"Reactions": [
|
||
{"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs
|
||
],
|
||
"Pathways": [
|
||
{"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs
|
||
],
|
||
}
|
||
return res
|
||
|
||
|
||
class SNode(object):
|
||
def __init__(self, smiles: str, depth: int, app_domain_assessment: dict = None):
|
||
self.smiles = smiles
|
||
self.depth = depth
|
||
self.app_domain_assessment = app_domain_assessment
|
||
|
||
def __hash__(self):
|
||
return hash(self.smiles)
|
||
|
||
def __eq__(self, other):
|
||
if isinstance(other, self.__class__):
|
||
return self.smiles == other.smiles
|
||
return False
|
||
|
||
def __repr__(self):
|
||
return f"SNode('{self.smiles}', {self.depth})"
|
||
|
||
|
||
class SEdge(object):
|
||
def __init__(
|
||
self,
|
||
educts: Union[SNode, List[SNode]],
|
||
products: Union[SNode | List[SNode]],
|
||
rule: Optional["Rule"] = None,
|
||
probability: Optional[float] = None,
|
||
):
|
||
if not isinstance(educts, list):
|
||
educts = [educts]
|
||
|
||
self.educts = educts
|
||
self.products = products
|
||
self.rule = rule
|
||
self.probability = probability
|
||
|
||
def product_smiles(self):
|
||
return [p.smiles for p in self.products]
|
||
|
||
def __hash__(self):
|
||
full_hash = 0
|
||
|
||
for n in sorted(self.educts, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
for n in sorted(self.products, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
if self.rule is not None:
|
||
full_hash += hash(self.rule)
|
||
|
||
return full_hash
|
||
|
||
def __eq__(self, other):
|
||
if not isinstance(other, SEdge):
|
||
return False
|
||
|
||
if (
|
||
self.rule is not None
|
||
and other.rule is None
|
||
or self.rule is None
|
||
and other.rule is not None
|
||
or self.rule != other.rule
|
||
):
|
||
return False
|
||
|
||
if not (len(self.educts) == len(other.educts)):
|
||
return False
|
||
|
||
for n1, n2 in zip(
|
||
sorted(self.educts, key=lambda x: x.smiles),
|
||
sorted(other.educts, key=lambda x: x.smiles),
|
||
):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
if not (len(self.products) == len(other.products)):
|
||
return False
|
||
|
||
for n1, n2 in zip(
|
||
sorted(self.products, key=lambda x: x.smiles),
|
||
sorted(other.products, key=lambda x: x.smiles),
|
||
):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
return True
|
||
|
||
def __repr__(self):
|
||
return f"SEdge({self.educts}, {self.products}, {self.rule})"
|
||
|
||
|
||
class SPathway(object):
|
||
def __init__(
|
||
self,
|
||
root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
|
||
persist: Optional["Pathway"] = None,
|
||
prediction_setting: Optional[Setting] = None,
|
||
):
|
||
self.root_nodes = []
|
||
|
||
self.persist = persist
|
||
self.snode_persist_lookup: Dict[SNode, Node] = dict()
|
||
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
|
||
self.prediction_setting = prediction_setting
|
||
|
||
if persist:
|
||
for n in persist.root_nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
self.root_nodes.append(snode)
|
||
self.snode_persist_lookup[snode] = n
|
||
else:
|
||
if not isinstance(root_nodes, list):
|
||
root_nodes = [root_nodes]
|
||
|
||
for n in root_nodes:
|
||
if isinstance(n, str):
|
||
self.root_nodes.append(SNode(n, 0))
|
||
elif isinstance(n, SNode):
|
||
self.root_nodes.append(n)
|
||
|
||
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
|
||
self.edges: Set["SEdge"] = set()
|
||
self.done = False
|
||
self.empty_due_to_threshold = False
|
||
|
||
@staticmethod
|
||
def from_pathway(pw: "Pathway", persist: bool = True):
|
||
"""Initializes a SPathway with a state given by a Pathway"""
|
||
spw = SPathway(
|
||
root_nodes=pw.root_nodes, persist=pw if persist else None, prediction_setting=pw.setting
|
||
)
|
||
# root_nodes are already added in __init__, add remaining nodes
|
||
for n in pw.nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
if snode.smiles not in spw.smiles_to_node:
|
||
spw.smiles_to_node[snode.smiles] = snode
|
||
spw.snode_persist_lookup[snode] = n
|
||
|
||
for e in pw.edges:
|
||
sub = []
|
||
prod = []
|
||
|
||
for n in e.start_nodes.all():
|
||
sub.append(spw.smiles_to_node[n.default_node_label.smiles])
|
||
|
||
for n in e.end_nodes.all():
|
||
prod.append(spw.smiles_to_node[n.default_node_label.smiles])
|
||
|
||
rule = None
|
||
if e.edge_label.rules.all():
|
||
rule = e.edge_label.rules.all().first()
|
||
|
||
prob = None
|
||
if e.kv.get("probability"):
|
||
prob = float(e.kv["probability"])
|
||
|
||
sedge = SEdge(sub, prod, rule=rule, probability=prob)
|
||
spw.edges.add(sedge)
|
||
spw.sedge_persist_lookup[sedge] = e
|
||
|
||
return spw
|
||
|
||
def num_nodes(self):
|
||
return len(self.smiles_to_node.keys())
|
||
|
||
def depth(self):
|
||
return max([v.depth for v in self.smiles_to_node.values()])
|
||
|
||
def _get_nodes_for_depth(self, depth: int) -> List[SNode]:
|
||
if depth == 0:
|
||
return self.root_nodes
|
||
|
||
res = []
|
||
for n in self.smiles_to_node.values():
|
||
if n.depth == depth:
|
||
res.append(n)
|
||
|
||
return sorted(res, key=lambda x: x.smiles)
|
||
|
||
def _get_edges_for_depth(self, depth: int) -> List[SEdge]:
|
||
res = []
|
||
for e in self.edges:
|
||
for n in e.educts:
|
||
if n.depth == depth:
|
||
res.append(e)
|
||
|
||
return sorted(res, key=lambda x: hash(x))
|
||
|
||
def _expand(self, substrates: List[SNode]) -> Tuple[List[SNode], List[SEdge]]:
|
||
"""
|
||
Expands the given substrates by generating new nodes and edges based on prediction settings.
|
||
|
||
This method processes a list of substrates and expands them into new nodes and edges using defined
|
||
rules and settings. It evaluates each substrate to determine its applicability domain, persists
|
||
domain assessments, and generates candidates for further processing. Newly created nodes and edges
|
||
are returned, and any applicable information is stored or updated internally during the process.
|
||
|
||
Parameters:
|
||
substrates (List[SNode]): A list of substrate nodes to be expanded.
|
||
|
||
Returns:
|
||
Tuple[List[SNode], List[SEdge]]:
|
||
A tuple containing:
|
||
- A list of new nodes generated during the expansion.
|
||
- A list of new edges representing connections between nodes based on candidate reactions.
|
||
|
||
Raises:
|
||
ValueError: If a node does not have an ID when it should have been saved already.
|
||
"""
|
||
new_nodes: List[SNode] = []
|
||
new_edges: List[SEdge] = []
|
||
|
||
for sub in substrates:
|
||
# For App Domain we have to ensure that each Node is evaluated
|
||
if sub.app_domain_assessment is None:
|
||
if self.prediction_setting.model:
|
||
if self.prediction_setting.model.app_domain:
|
||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
||
sub.smiles
|
||
)
|
||
|
||
if self.persist is not None:
|
||
n = self.snode_persist_lookup[sub]
|
||
|
||
if n.id is None:
|
||
raise ValueError(f"Node {n} has no ID... aborting!")
|
||
|
||
node_data = n.simple_json()
|
||
node_data["image"] = f"{n.url}?image=svg"
|
||
app_domain_assessment["assessment"]["node"] = node_data
|
||
|
||
n.kv["app_domain_assessment"] = app_domain_assessment
|
||
n.save()
|
||
|
||
sub.app_domain_assessment = app_domain_assessment
|
||
|
||
expansion_result = self.prediction_setting.expand(self, sub)
|
||
|
||
# We don't have any substrate, but technically we have at least one rule that triggered.
|
||
# If our substrate is a root node a.k.a. depth == 0 store that info in SPathway
|
||
if (
|
||
len(expansion_result["transformations"]) == 0
|
||
and expansion_result["rule_triggered"]
|
||
and sub.depth == 0
|
||
):
|
||
self.empty_due_to_threshold = True
|
||
|
||
# Emit directly
|
||
if self.persist is not None:
|
||
self.persist.kv["empty_due_to_threshold"] = True
|
||
self.persist.save()
|
||
|
||
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
|
||
for cand_set in expansion_result["transformations"]:
|
||
if cand_set:
|
||
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
|
||
for cand in cand_set:
|
||
cand_nodes = []
|
||
# candidate reactions can have multiple fragments
|
||
for c in cand:
|
||
if c not in self.smiles_to_node:
|
||
# For new nodes do an AppDomain Assessment if an AppDomain is attached
|
||
app_domain_assessment = None
|
||
if self.prediction_setting.model:
|
||
if self.prediction_setting.model.app_domain:
|
||
app_domain_assessment = (
|
||
self.prediction_setting.model.app_domain.assess(c)
|
||
)
|
||
snode = SNode(c, sub.depth + 1, app_domain_assessment)
|
||
self.smiles_to_node[c] = snode
|
||
new_nodes.append(snode)
|
||
|
||
node = self.smiles_to_node[c]
|
||
cand_nodes.append(node)
|
||
|
||
edge = SEdge(
|
||
sub,
|
||
cand_nodes,
|
||
rule=cand_set.rule,
|
||
probability=cand_set.probability,
|
||
)
|
||
self.edges.add(edge)
|
||
new_edges.append(edge)
|
||
|
||
return new_nodes, new_edges
|
||
|
||
def predict(self):
|
||
"""
|
||
Predicts outcomes based on a graph traversal algorithm using the specified expansion schema.
|
||
|
||
This method iteratively explores the nodes of a graph starting from the root nodes, propagating
|
||
probabilities through edges, and updating the probabilities of the connected nodes. The traversal
|
||
can follow one of three predefined expansion schemas: Depth-First Search (DFS), Breadth-First Search
|
||
(BFS), or a Greedy approach based on node probabilities. The methodology ensures that all reachable
|
||
nodes are processed systematically according to the specified schema.
|
||
|
||
Errors will be raised if the expansion schema is undefined or invalid. Additionally, this method
|
||
supports persisting changes by writing back data to the database when configured to do so.
|
||
|
||
Attributes
|
||
----------
|
||
done : bool
|
||
A flag indicating whether the prediction process is completed.
|
||
persist : Any
|
||
An optional object that manages persistence operations for saving modifications.
|
||
root_nodes : List[SNode]
|
||
A collection of initial nodes in the graph from which traversal begins.
|
||
prediction_setting : Any
|
||
Configuration object specifying settings for graph traversal, such as the choice of
|
||
expansion schema.
|
||
|
||
Raises
|
||
------
|
||
ValueError
|
||
If an invalid or unknown expansion schema is provided in `prediction_setting`.
|
||
"""
|
||
# populate initial queue
|
||
queue = list(self.root_nodes)
|
||
processed = set()
|
||
|
||
# initial nodes have prob 1.0
|
||
node_probs: Dict[SNode, float] = {}
|
||
node_probs.update({n: 1.0 for n in queue})
|
||
|
||
while queue:
|
||
current = queue.pop(0)
|
||
|
||
if current in processed:
|
||
continue
|
||
|
||
processed.add(current)
|
||
|
||
new_nodes, new_edges = self._expand([current])
|
||
|
||
if new_nodes or new_edges:
|
||
# Check if we need to write back data to the database
|
||
if self.persist:
|
||
self._sync_to_pathway()
|
||
# call save to update the internal modified field
|
||
self.persist.save()
|
||
|
||
if new_nodes:
|
||
for edge in new_edges:
|
||
# All edge have `current` as educt
|
||
# Use `current` and adjust probs
|
||
current_prob = node_probs[current]
|
||
|
||
for prod in edge.products:
|
||
# Either is a new product or a product and we found a path with a higher prob
|
||
if (
|
||
prod not in node_probs
|
||
or current_prob * edge.probability > node_probs[prod]
|
||
):
|
||
node_probs[prod] = current_prob * edge.probability
|
||
|
||
# Update Queue to proceed
|
||
if self.prediction_setting.expansion_scheme == "DFS":
|
||
for n in new_nodes:
|
||
if n not in processed:
|
||
# We want to follow this path -> prepend queue
|
||
queue.insert(0, n)
|
||
elif self.prediction_setting.expansion_scheme == "BFS":
|
||
for n in new_nodes:
|
||
if n not in processed:
|
||
# Add at the end, everything queued before will be processed
|
||
# before new_nodese
|
||
queue.append(n)
|
||
elif self.prediction_setting.expansion_scheme == "GREEDY":
|
||
# Simply add them, as we will re-order the queue later
|
||
for n in new_nodes:
|
||
if n not in processed:
|
||
queue.append(n)
|
||
|
||
node_and_probs = []
|
||
for queued_val in queue:
|
||
node_and_probs.append((queued_val, node_probs[queued_val]))
|
||
|
||
# re-order the queue and only pick smiles
|
||
queue = [
|
||
n[0] for n in sorted(node_and_probs, key=lambda x: x[1], reverse=True)
|
||
]
|
||
else:
|
||
raise ValueError(
|
||
f"Unknown expansion schema: {self.prediction_setting.expansion_scheme}"
|
||
)
|
||
|
||
# Queue exhausted, we're done
|
||
self.done = True
|
||
|
||
def predict_step(self, from_depth: int = None, from_node: "Node" = None):
|
||
substrates: List[SNode] = []
|
||
|
||
if from_depth is not None:
|
||
substrates = self._get_nodes_for_depth(from_depth)
|
||
elif from_node is not None:
|
||
for k, v in self.snode_persist_lookup.items():
|
||
if from_node == v:
|
||
substrates = [k]
|
||
break
|
||
else:
|
||
raise ValueError(f"Node {from_node} not found in SPathway!")
|
||
else:
|
||
raise ValueError("Neither from_depth nor from_node_url specified")
|
||
|
||
new_tp = False
|
||
if substrates:
|
||
new_nodes, _ = self._expand(substrates)
|
||
new_tp = len(new_nodes) > 0
|
||
|
||
# In case no substrates are found, we're done.
|
||
# For "predict from node" we're always done
|
||
if len(substrates) == 0 or from_node is not None:
|
||
self.done = True
|
||
|
||
# Check if we need to write back data to the database
|
||
if new_tp and self.persist:
|
||
self._sync_to_pathway()
|
||
# call save to update the internal modified field
|
||
self.persist.save()
|
||
|
||
def get_edge_for_educt_smiles(self, smiles: str) -> List[SEdge]:
|
||
res = []
|
||
for e in self.edges:
|
||
for n in e.educts:
|
||
if n.smiles == smiles:
|
||
res.append(e)
|
||
return res
|
||
|
||
def _sync_to_pathway(self) -> None:
|
||
logger.info("Updating Pathway with SPathway")
|
||
|
||
for snode in self.smiles_to_node.values():
|
||
if snode not in self.snode_persist_lookup:
|
||
n = Node.create(self.persist, snode.smiles, snode.depth)
|
||
|
||
if snode.app_domain_assessment is not None:
|
||
app_domain_assessment = snode.app_domain_assessment
|
||
|
||
assert n.id is not None, (
|
||
"Node has no id! Should have been saved already... aborting!"
|
||
)
|
||
node_data = n.simple_json()
|
||
node_data["image"] = f"{n.url}?image=svg"
|
||
app_domain_assessment["assessment"]["node"] = node_data
|
||
|
||
n.kv["app_domain_assessment"] = app_domain_assessment
|
||
n.save()
|
||
|
||
self.snode_persist_lookup[snode] = n
|
||
|
||
for sedge in self.edges:
|
||
if sedge not in self.sedge_persist_lookup:
|
||
educt_nodes = []
|
||
for snode in sedge.educts:
|
||
educt_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
product_nodes = []
|
||
for snode in sedge.products:
|
||
product_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
|
||
|
||
if sedge.probability:
|
||
e.kv.update({"probability": sedge.probability})
|
||
e.save()
|
||
|
||
self.sedge_persist_lookup[sedge] = e
|
||
|
||
logger.info("Update done!")
|
||
|
||
def to_json(self):
|
||
nodes = []
|
||
edges = []
|
||
|
||
idx_lookup = {}
|
||
|
||
for i, smiles in enumerate(self.smiles_to_node):
|
||
n = self.smiles_to_node[smiles]
|
||
idx_lookup[smiles] = i
|
||
|
||
nodes.append({"depth": n.depth, "smiles": n.smiles, "id": i})
|
||
|
||
for edge in self.edges:
|
||
from_idx = idx_lookup[edge.educts[0].smiles]
|
||
to_indices = [idx_lookup[p.smiles] for p in edge.products]
|
||
|
||
e = {
|
||
"from": from_idx,
|
||
"to": to_indices,
|
||
}
|
||
|
||
edges.append(e)
|
||
|
||
return {
|
||
"nodes": nodes,
|
||
"edges": edges,
|
||
}
|