forked from enviPath/enviPy
Fixes #105 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#106
1403 lines
53 KiB
Python
1403 lines
53 KiB
Python
import re
|
||
import logging
|
||
import json
|
||
from typing import Union, List, Optional, Set, Dict, Any
|
||
|
||
from django.contrib.auth import get_user_model
|
||
from django.db import transaction
|
||
from django.conf import settings as s
|
||
|
||
from epdb.models import User, Package, UserPackagePermission, GroupPackagePermission, Permission, Group, Setting, \
|
||
EPModel, UserSettingPermission, Rule, Pathway, Node, Edge, Compound, Reaction, CompoundStructure
|
||
from utilities.chem import FormatConverter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class EPDBURLParser:
|
||
|
||
UUID_PATTERN = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||
|
||
MODEL_PATTERNS = {
|
||
'epdb.User': re.compile(rf'^.*/user/{UUID_PATTERN}'),
|
||
'epdb.Group': re.compile(rf'^.*/group/{UUID_PATTERN}'),
|
||
'epdb.Package': re.compile(rf'^.*/package/{UUID_PATTERN}'),
|
||
'epdb.Compound': re.compile(rf'^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}'),
|
||
'epdb.CompoundStructure': re.compile(rf'^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}/structure/{UUID_PATTERN}'),
|
||
'epdb.Rule': re.compile(rf'^.*/package/{UUID_PATTERN}/(?:simple-ambit-rule|simple-rdkit-rule|parallel-rule|sequential-rule|rule)/{UUID_PATTERN}'),
|
||
'epdb.Reaction': re.compile(rf'^.*/package/{UUID_PATTERN}/reaction/{UUID_PATTERN}$'),
|
||
'epdb.Pathway': re.compile(rf'^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}'),
|
||
'epdb.Node': re.compile(rf'^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/node/{UUID_PATTERN}'),
|
||
'epdb.Edge': re.compile(rf'^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/edge/{UUID_PATTERN}'),
|
||
'epdb.Scenario': re.compile(rf'^.*/package/{UUID_PATTERN}/scenario/{UUID_PATTERN}'),
|
||
'epdb.EPModel': re.compile(rf'^.*/package/{UUID_PATTERN}/model/{UUID_PATTERN}'),
|
||
'epdb.Setting': re.compile(rf'^.*/setting/{UUID_PATTERN}'),
|
||
}
|
||
|
||
def __init__(self, url: str):
|
||
self.url = url
|
||
self._matches = {}
|
||
self._analyze_url()
|
||
|
||
def _analyze_url(self):
|
||
for model_path, pattern in self.MODEL_PATTERNS.items():
|
||
match = pattern.findall(self.url)
|
||
if match:
|
||
self._matches[model_path] = match[0]
|
||
|
||
def _get_model_class(self, model_path: str):
|
||
try:
|
||
from django.apps import apps
|
||
app_label, model_name = model_path.split('.')[-2:]
|
||
return apps.get_model(app_label, model_name)
|
||
except (ImportError, LookupError, ValueError):
|
||
raise ValueError(f"Model {model_path} does not exist!")
|
||
|
||
def _get_object_by_url(self, model_path: str, url: str):
|
||
model_class = self._get_model_class(model_path)
|
||
return model_class.objects.get(url=url)
|
||
|
||
def is_package_url(self) -> bool:
|
||
return bool(re.compile(rf'^.*/package/{self.UUID_PATTERN}$').findall(self.url))
|
||
|
||
def contains_package_url(self):
|
||
return bool(self.MODEL_PATTERNS['epdb.Package'].findall(self.url)) and not self.is_package_url()
|
||
|
||
def is_user_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS['epdb.User'].findall(self.url))
|
||
|
||
def is_group_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS['epdb.Group'].findall(self.url))
|
||
|
||
def is_setting_url(self) -> bool:
|
||
return bool(self.MODEL_PATTERNS['epdb.Setting'].findall(self.url))
|
||
|
||
def get_object(self) -> Optional[Any]:
|
||
# Define priority order from most specific to least specific
|
||
priority_order = [
|
||
# 3rd level
|
||
'epdb.CompoundStructure',
|
||
'epdb.Node',
|
||
'epdb.Edge',
|
||
# 2nd level
|
||
'epdb.Compound',
|
||
'epdb.Rule',
|
||
'epdb.Reaction',
|
||
'epdb.Scenario',
|
||
'epdb.EPModel',
|
||
'epdb.Pathway',
|
||
# 1st level
|
||
'epdb.Package',
|
||
'epdb.Setting',
|
||
'epdb.Group',
|
||
'epdb.User',
|
||
]
|
||
|
||
for model_path in priority_order:
|
||
if model_path in self._matches:
|
||
url = self._matches[model_path]
|
||
return self._get_object_by_url(model_path, url)
|
||
|
||
raise ValueError(f"No object found for URL {self.url}")
|
||
|
||
def get_objects(self) -> List[Any]:
|
||
"""
|
||
Get all Django model objects along the URL path in hierarchical order.
|
||
Returns objects from parent to child (e.g., Package -> Compound -> Structure).
|
||
"""
|
||
objects = []
|
||
|
||
hierarchy_order = [
|
||
# 1st level
|
||
'epdb.Package',
|
||
'epdb.Setting',
|
||
'epdb.Group',
|
||
'epdb.User',
|
||
# 2nd level
|
||
'epdb.Compound',
|
||
'epdb.Rule',
|
||
'epdb.Reaction',
|
||
'epdb.Scenario',
|
||
'epdb.EPModel',
|
||
'epdb.Pathway',
|
||
# 3rd level
|
||
'epdb.CompoundStructure',
|
||
'epdb.Node',
|
||
'epdb.Edge',
|
||
]
|
||
|
||
for model_path in hierarchy_order:
|
||
if model_path in self._matches:
|
||
url = self._matches[model_path]
|
||
objects.append(self._get_object_by_url(model_path, url))
|
||
|
||
return objects
|
||
|
||
def __str__(self) -> str:
|
||
return f"EPDBURLParser(url='{self.url}')"
|
||
|
||
def __repr__(self) -> str:
|
||
return f"EPDBURLParser(url='{self.url}', matches={list(self._matches.keys())})"
|
||
|
||
|
||
class UserManager(object):
|
||
user_pattern = re.compile(r".*/user/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}")
|
||
|
||
@staticmethod
|
||
def is_user_url(url: str):
|
||
return bool(re.findall(UserManager.user_pattern, url))
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_user(username, email, password, set_setting=True, add_to_group=True, *args, **kwargs):
|
||
# avoid circular import :S
|
||
from .tasks import send_registration_mail
|
||
|
||
extra_fields = {
|
||
'is_active': not s.ADMIN_APPROVAL_REQUIRED
|
||
}
|
||
|
||
if 'is_active' in kwargs:
|
||
extra_fields['is_active'] = kwargs['is_active']
|
||
|
||
if 'uuid' in kwargs:
|
||
extra_fields['uuid'] = kwargs['uuid']
|
||
|
||
u = get_user_model().objects.create_user(username, email, password, **kwargs)
|
||
|
||
# Create package
|
||
package_name = f"{u.username}{'’' if u.username[-1] in 'sxzß' else 's'} Package"
|
||
package_description = f"This package was generated during registration."
|
||
p = PackageManager.create_package(u, package_name, package_description)
|
||
u.default_package = p
|
||
u.save()
|
||
|
||
if not u.is_active:
|
||
# send email for verification
|
||
send_registration_mail.delay(u.pk)
|
||
|
||
if set_setting:
|
||
u.default_setting = Setting.objects.get(global_default=True)
|
||
u.save()
|
||
|
||
if add_to_group:
|
||
g = Group.objects.get(public=True, name='enviPath Users')
|
||
g.user_member.add(u)
|
||
g.save()
|
||
u.default_group = g
|
||
u.save()
|
||
|
||
return u
|
||
|
||
@staticmethod
|
||
def get_user(user_url):
|
||
pass
|
||
|
||
@staticmethod
|
||
def get_user_by_id(user, user_uuid: str):
|
||
if str(user.uuid) != user_uuid and not user.is_superuser:
|
||
raise ValueError("Getting user failed!")
|
||
return get_user_model().objects.get(uuid=user_uuid)
|
||
|
||
@staticmethod
|
||
def get_user_lp(user_url: str):
|
||
uuid = user_url.strip().split('/')[-1]
|
||
return get_user_model().objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_users_lp():
|
||
return get_user_model().objects.all()
|
||
|
||
@staticmethod
|
||
def get_users():
|
||
raise ValueError("")
|
||
|
||
@staticmethod
|
||
def writable(current_user, user):
|
||
return (current_user == user) or user.is_superuser
|
||
|
||
class GroupManager(object):
|
||
group_pattern = re.compile(r".*/group/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}")
|
||
|
||
@staticmethod
|
||
def is_group_url(url: str):
|
||
return bool(re.findall(GroupManager.group_pattern, url))
|
||
|
||
@staticmethod
|
||
def create_group(current_user, name, description):
|
||
g = Group()
|
||
g.name = name
|
||
g.description = description
|
||
g.owner = current_user
|
||
g.save()
|
||
|
||
g.user_member.add(current_user)
|
||
g.save()
|
||
|
||
return g
|
||
|
||
@staticmethod
|
||
def get_group_lp(group_url: str):
|
||
uuid = group_url.strip().split('/')[-1]
|
||
return Group.objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_groups_lp():
|
||
return Group.objects.all()
|
||
|
||
@staticmethod
|
||
def get_group_by_url(user, group_url):
|
||
return GroupManager.get_group_by_id(user, group_url.split('/')[-1])
|
||
|
||
@staticmethod
|
||
def get_group_by_id(user, group_id):
|
||
g = Group.objects.get(uuid=group_id)
|
||
|
||
if user in g.user_member.all():
|
||
return g
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_groups(user):
|
||
return Group.objects.filter(user_member=user)
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
|
||
|
||
if caller != group.owner:
|
||
raise ValueError('Only the group Owner is allowed to add members!')
|
||
|
||
if isinstance(member, Group):
|
||
if add_or_remove == 'add':
|
||
group.group_member.add(member)
|
||
else:
|
||
group.group_member.remove(member)
|
||
else:
|
||
if add_or_remove == 'add':
|
||
group.user_member.add(member)
|
||
else:
|
||
group.user_member.remove(member)
|
||
|
||
group.save()
|
||
|
||
@staticmethod
|
||
def writable(user, group):
|
||
return (user == group.owner) or user.is_superuser
|
||
|
||
|
||
class PackageManager(object):
|
||
package_pattern = re.compile(r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}")
|
||
|
||
@staticmethod
|
||
def is_package_url(url: str):
|
||
return bool(re.findall(PackageManager.package_pattern, url))
|
||
|
||
@staticmethod
|
||
def get_reviewed_packages():
|
||
return Package.objects.filter(reviewed=True)
|
||
|
||
@staticmethod
|
||
def readable(user, package):
|
||
if UserPackagePermission.objects.filter(package=package, user=user).exists() or \
|
||
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user)) or \
|
||
package.reviewed is True or \
|
||
user.is_superuser:
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def writable(user, package):
|
||
if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.WRITE[0]).exists() or \
|
||
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).exists() or \
|
||
UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.ALL[0]).exists() or \
|
||
user.is_superuser:
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def administrable(user, package):
|
||
if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.ALL[0]).exists() or \
|
||
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).exists() or \
|
||
user.is_superuser:
|
||
return True
|
||
return False
|
||
|
||
# @staticmethod
|
||
# def get_package_permission(user: 'User', package: Union[str, 'Package']):
|
||
# if PackageManager.administrable(user, package):
|
||
# return Permission.ALL[0]
|
||
# elif PackageManager.writable(user, package):
|
||
# return Permission.WRITE[0]
|
||
# elif PackageManager.readable(user, package):
|
||
# return Permission.READ[0]
|
||
# else:
|
||
# return None
|
||
|
||
@staticmethod
|
||
def has_package_permission(user: 'User', package: Union[str, 'Package'], permission: str):
|
||
|
||
if isinstance(package, str):
|
||
package = Package.objects.get(uuid=package)
|
||
|
||
groups = GroupManager.get_groups(user)
|
||
|
||
perms = {
|
||
'all': ['all'],
|
||
'write': ['all', 'write'],
|
||
'read': ['all', 'write', 'read']
|
||
}
|
||
|
||
valid_perms = perms.get(permission)
|
||
|
||
if UserPackagePermission.objects.filter(package=package, user=user, permission__in=valid_perms).exists() or \
|
||
GroupPackagePermission.objects.filter(package=package, group__in=groups,
|
||
permission__in=valid_perms).exists() or \
|
||
user.is_superuser:
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def get_package_lp(package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
if match:
|
||
package_id = match[0].split('/')[-1]
|
||
return Package.objects.get(uuid=package_id)
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_package_by_url(user, package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
|
||
if match:
|
||
package_id = match[0].split('/')[-1]
|
||
return PackageManager.get_package_by_id(user, package_id)
|
||
else:
|
||
raise ValueError("Requested URL {} does not contain a valid package identifier!".format(package_url))
|
||
|
||
@staticmethod
|
||
def get_package_by_id(user, package_id):
|
||
try:
|
||
p = Package.objects.get(uuid=package_id)
|
||
if PackageManager.readable(user, p):
|
||
return p
|
||
else:
|
||
raise ValueError(
|
||
"Insufficient permissions to access Package with ID {}".format(package_id))
|
||
except Package.DoesNotExist:
|
||
raise ValueError("Package with ID {} does not exist!".format(package_id))
|
||
|
||
@staticmethod
|
||
def get_all_readable_packages(user, include_reviewed=False):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
user_package_qs = Package.objects.filter(
|
||
id__in=UserPackagePermission.objects.filter(user=user).values('package').distinct())
|
||
group_package_qs = Package.objects.filter(
|
||
id__in=GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user)).values(
|
||
'package').distinct())
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
if include_reviewed:
|
||
qs |= Package.objects.filter(reviewed=True)
|
||
else:
|
||
# remove package if user is owner and package is reviewed e.g. admin
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
return qs.distinct()
|
||
|
||
@staticmethod
|
||
def get_all_writeable_packages(user):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
write_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]).values('package').distinct()
|
||
owner_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]).values('package').distinct()
|
||
|
||
user_packs = write_user_packs | owner_user_packs
|
||
user_package_qs = Package.objects.filter(id__in=user_packs)
|
||
|
||
write_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).values( 'package').distinct()
|
||
owner_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).values( 'package').distinct()
|
||
|
||
group_packs = write_group_packs | owner_group_packs
|
||
group_package_qs = Package.objects.filter(id__in=group_packs)
|
||
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
return qs.distinct()
|
||
|
||
@staticmethod
|
||
def get_packages():
|
||
return Package.objects.all()
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_package(current_user, name: str, description: str = None):
|
||
p = Package()
|
||
p.name = name
|
||
p.description = description
|
||
p.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = current_user
|
||
up.package = p
|
||
up.permission = UserPackagePermission.ALL[0]
|
||
up.save()
|
||
|
||
return p
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_permissions(caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]):
|
||
caller_perm = None
|
||
if not caller.is_superuser:
|
||
caller_perm = UserPackagePermission.objects.get(user=caller, package=package).permission
|
||
|
||
if caller_perm != Permission.ALL[0] and not caller.is_superuser:
|
||
raise ValueError(f"Only owner are allowed to modify permissions")
|
||
|
||
data = {
|
||
'package': package,
|
||
}
|
||
|
||
if isinstance(grantee, User):
|
||
perm_cls = UserPackagePermission
|
||
data['user'] = grantee
|
||
else:
|
||
perm_cls = GroupPackagePermission
|
||
data['group'] = grantee
|
||
|
||
if new_perm is None:
|
||
qs = perm_cls.objects.filter(**data)
|
||
if qs.count() > 1:
|
||
raise ValueError("Got more Permission objects than expected!")
|
||
if qs.count() != 0:
|
||
logger.info(f"Deleting Perm {qs.first()}")
|
||
qs.delete()
|
||
else:
|
||
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
|
||
else:
|
||
_ = perm_cls.objects.update_or_create(defaults={'permission': new_perm}, **data)
|
||
|
||
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def import_package(data: dict, owner: User, keep_ids=False, add_import_timestamp=True, trust_reviewed=False):
|
||
from uuid import UUID, uuid4
|
||
from datetime import datetime
|
||
from collections import defaultdict
|
||
from .models import Package, Compound, CompoundStructure, SimpleRule, SimpleAmbitRule, SimpleRDKitRule, \
|
||
ParallelRule, SequentialRule, SequentialRuleOrdering, Reaction, Pathway, Node, Edge, Scenario
|
||
from envipy_additional_information import AdditionalInformationConverter
|
||
|
||
pack = Package()
|
||
pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
|
||
if add_import_timestamp:
|
||
pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M'))
|
||
else:
|
||
pack.name = data['name']
|
||
|
||
if trust_reviewed:
|
||
pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False
|
||
else:
|
||
pack.reviewed = False
|
||
|
||
pack.description = data['description']
|
||
pack.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = owner
|
||
up.package = pack
|
||
up.permission = up.ALL[0]
|
||
up.save()
|
||
|
||
# Stores old_id to new_id
|
||
mapping = {}
|
||
# Stores new_scen_id to old_parent_scen_id
|
||
parent_mapping = {}
|
||
# Mapping old scen_id to old_obj_id
|
||
scen_mapping = defaultdict(list)
|
||
|
||
# Store Scenarios
|
||
for scenario in data['scenarios']:
|
||
scen = Scenario()
|
||
scen.package = pack
|
||
scen.uuid = UUID(scenario['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
scen.name = scenario['name']
|
||
scen.description = scenario['description']
|
||
scen.scenario_type = scenario['type']
|
||
scen.scenario_date = scenario['date']
|
||
scen.additional_information = dict()
|
||
scen.save()
|
||
|
||
mapping[scenario['id']] = scen.uuid
|
||
|
||
new_add_inf = defaultdict(list)
|
||
# TODO Store AI...
|
||
for ex in scenario.get('additionalInformationCollection', {}).get('additionalInformation', []):
|
||
name = ex['name']
|
||
addinf_data = ex['data']
|
||
|
||
# park the parent scen id for now and link it later
|
||
if name == 'referringscenario':
|
||
parent_mapping[scen.uuid] = addinf_data
|
||
continue
|
||
|
||
# Broken eP Data
|
||
if name == 'initialmasssediment' and addinf_data == 'missing data':
|
||
continue
|
||
|
||
# TODO Enzymes arent ready yet
|
||
if name == 'enzyme':
|
||
continue
|
||
|
||
try:
|
||
res = AdditionalInformationConverter.convert(name, addinf_data)
|
||
res_cls_name = res.__class__.__name__
|
||
ai_data = json.loads(res.model_dump_json())
|
||
ai_data['uuid'] = f"{uuid4()}"
|
||
new_add_inf[res_cls_name].append(ai_data)
|
||
except:
|
||
logger.error(f"Failed to convert {name} with {addinf_data}")
|
||
|
||
scen.additional_information = new_add_inf
|
||
scen.save()
|
||
|
||
print('Scenarios imported...')
|
||
|
||
# Store compounds and its structures
|
||
for compound in data['compounds']:
|
||
comp = Compound()
|
||
comp.package = pack
|
||
comp.uuid = UUID(compound['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
comp.name = compound['name']
|
||
comp.description = compound['description']
|
||
comp.aliases = compound['aliases']
|
||
comp.save()
|
||
|
||
mapping[compound['id']] = comp.uuid
|
||
|
||
for scen in compound['scenarios']:
|
||
scen_mapping[scen['id']].append(comp)
|
||
|
||
default_structure = None
|
||
|
||
for structure in compound['structures']:
|
||
struc = CompoundStructure()
|
||
# struc.object_url = Command.get_id(structure, keep_ids)
|
||
struc.compound = comp
|
||
struc.uuid = UUID(structure['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
struc.name = structure['name']
|
||
struc.description = structure['description']
|
||
struc.smiles = structure['smiles']
|
||
struc.save()
|
||
|
||
for scen in structure['scenarios']:
|
||
scen_mapping[scen['id']].append(struc)
|
||
|
||
mapping[structure['id']] = struc.uuid
|
||
|
||
if structure['id'] == compound['defaultStructure']['id']:
|
||
default_structure = struc
|
||
|
||
struc.save()
|
||
|
||
|
||
if default_structure is None:
|
||
raise ValueError('No default structure set')
|
||
|
||
comp.default_structure = default_structure
|
||
comp.save()
|
||
|
||
print('Compounds imported...')
|
||
|
||
# Store simple and parallel-rules
|
||
par_rules = []
|
||
seq_rules = []
|
||
|
||
for rule in data['rules']:
|
||
if rule['identifier'] == 'parallel-rule':
|
||
par_rules.append(rule)
|
||
continue
|
||
if rule['identifier'] == 'sequential-rule':
|
||
seq_rules.append(rule)
|
||
continue
|
||
r = SimpleAmbitRule()
|
||
r.uuid = UUID(rule['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
r.package = pack
|
||
r.name = rule['name']
|
||
r.description = rule['description']
|
||
r.smirks = rule['smirks']
|
||
r.reactant_filter_smarts = rule.get('reactantFilterSmarts', None)
|
||
r.product_filter_smarts = rule.get('productFilterSmarts', None)
|
||
r.save()
|
||
|
||
mapping[rule['id']] = r.uuid
|
||
|
||
for scen in rule['scenarios']:
|
||
scen_mapping[scen['id']].append(r)
|
||
|
||
print("Par: ", len(par_rules))
|
||
print("Seq: ", len(seq_rules))
|
||
|
||
for par_rule in par_rules:
|
||
r = ParallelRule()
|
||
r.package = pack
|
||
r.uuid = UUID(par_rule['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
r.name = par_rule['name']
|
||
r.description = par_rule['description']
|
||
r.save()
|
||
|
||
mapping[par_rule['id']] = r.uuid
|
||
|
||
for scen in par_rule['scenarios']:
|
||
scen_mapping[scen['id']].append(r)
|
||
|
||
for simple_rule in par_rule['simpleRules']:
|
||
if simple_rule['id'] in mapping:
|
||
r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
|
||
|
||
r.save()
|
||
|
||
for seq_rule in seq_rules:
|
||
r = SequentialRule()
|
||
r.package = pack
|
||
r.uuid = UUID(seq_rule['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
r.name = seq_rule['name']
|
||
r.description = seq_rule['description']
|
||
r.save()
|
||
|
||
mapping[seq_rule['id']] = r.uuid
|
||
|
||
for scen in seq_rule['scenarios']:
|
||
scen_mapping[scen['id']].append(r)
|
||
|
||
for i, simple_rule in enumerate(seq_rule['simpleRules']):
|
||
sro = SequentialRuleOrdering()
|
||
sro.simple_rule = simple_rule
|
||
sro.sequential_rule = r
|
||
sro.order_index = i
|
||
sro.save()
|
||
# r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
|
||
|
||
r.save()
|
||
|
||
print('Rules imported...')
|
||
|
||
for reaction in data['reactions']:
|
||
r = Reaction()
|
||
r.package = pack
|
||
r.uuid = UUID(reaction['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
r.name = reaction['name']
|
||
r.description = reaction['description']
|
||
r.medlinereferences = reaction['medlinereferences'],
|
||
r.multi_step = True if reaction['multistep'] == 'true' else False
|
||
r.save()
|
||
|
||
mapping[reaction['id']] = r.uuid
|
||
|
||
for scen in reaction['scenarios']:
|
||
scen_mapping[scen['id']].append(r)
|
||
|
||
for educt in reaction['educts']:
|
||
r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt['id']]))
|
||
|
||
for product in reaction['products']:
|
||
r.products.add(CompoundStructure.objects.get(uuid=mapping[product['id']]))
|
||
|
||
if 'rules' in reaction:
|
||
for rule in reaction['rules']:
|
||
try:
|
||
r.rules.add(Rule.objects.get(uuid=mapping[rule['id']]))
|
||
except Exception as e:
|
||
print(f"Rule with id {rule['id']} not found!")
|
||
print(e)
|
||
|
||
r.save()
|
||
|
||
print('Reactions imported...')
|
||
|
||
for pathway in data['pathways']:
|
||
pw = Pathway()
|
||
pw.package = pack
|
||
pw.uuid = UUID(pathway['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
pw.name = pathway['name']
|
||
pw.description = pathway['description']
|
||
pw.save()
|
||
|
||
mapping[pathway['id']] = pw.uuid
|
||
for scen in pathway['scenarios']:
|
||
scen_mapping[scen['id']].append(pw)
|
||
|
||
out_nodes_mapping = defaultdict(set)
|
||
|
||
root_node = None
|
||
|
||
for node in pathway['nodes']:
|
||
n = Node()
|
||
n.uuid = UUID(node['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
n.name = node['name']
|
||
n.pathway = pw
|
||
n.depth = node['depth']
|
||
n.default_node_label = CompoundStructure.objects.get(uuid=mapping[node['defaultNodeLabel']['id']])
|
||
n.save()
|
||
|
||
mapping[node['id']] = n.uuid
|
||
|
||
for scen in node['scenarios']:
|
||
scen_mapping[scen['id']].append(n)
|
||
|
||
for node_label in node['nodeLabels']:
|
||
n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label['id']]))
|
||
|
||
n.save()
|
||
|
||
for out_edge in node['outEdges']:
|
||
out_nodes_mapping[n.uuid].add(out_edge)
|
||
|
||
for edge in pathway['edges']:
|
||
e = Edge()
|
||
e.uuid = UUID(edge['id'].split('/')[-1]) if keep_ids else uuid4()
|
||
e.name = edge['name']
|
||
e.pathway = pw
|
||
e.description = edge['description']
|
||
e.edge_label = Reaction.objects.get(uuid=mapping[edge['edgeLabel']['id']])
|
||
e.save()
|
||
|
||
mapping[edge['id']] = e.uuid
|
||
|
||
for scen in edge['scenarios']:
|
||
scen_mapping[scen['id']].append(e)
|
||
|
||
for start_node in edge['startNodes']:
|
||
e.start_nodes.add(Node.objects.get(uuid=mapping[start_node]))
|
||
|
||
for end_node in edge['endNodes']:
|
||
e.end_nodes.add(Node.objects.get(uuid=mapping[end_node]))
|
||
|
||
e.save()
|
||
|
||
for k, v in out_nodes_mapping.items():
|
||
n = Node.objects.get(uuid=k)
|
||
for v1 in v:
|
||
n.out_edges.add(Edge.objects.get(uuid=mapping[v1]))
|
||
n.save()
|
||
|
||
print('Pathways imported...')
|
||
|
||
# Linking Phase
|
||
for child, parent in parent_mapping.items():
|
||
child_obj = Scenario.objects.get(uuid=child)
|
||
parent_obj = Scenario.objects.get(uuid=mapping[parent])
|
||
child_obj.parent = parent_obj
|
||
child_obj.save()
|
||
|
||
for scen_id, objects in scen_mapping.items():
|
||
scen = Scenario.objects.get(uuid=mapping[scen_id])
|
||
for o in objects:
|
||
o.scenarios.add(scen)
|
||
o.save()
|
||
|
||
print("Scenarios linked...")
|
||
|
||
print('Import statistics:')
|
||
print('Package {} stored'.format(pack.url))
|
||
print('Imported {} compounds'.format(Compound.objects.filter(package=pack).count()))
|
||
print('Imported {} rules'.format(Rule.objects.filter(package=pack).count()))
|
||
print('Imported {} reactions'.format(Reaction.objects.filter(package=pack).count()))
|
||
print('Imported {} pathways'.format(Pathway.objects.filter(package=pack).count()))
|
||
print('Imported {} Scenarios'.format(Scenario.objects.filter(package=pack).count()))
|
||
|
||
print("Fixing Node depths...")
|
||
total_pws = Pathway.objects.filter(package=pack).count()
|
||
for p, pw in enumerate(Pathway.objects.filter(package=pack)):
|
||
print(pw.url)
|
||
in_count = defaultdict(lambda: 0)
|
||
out_count = defaultdict(lambda: 0)
|
||
|
||
for e in pw.edges:
|
||
# TODO check if this will remain
|
||
for react in e.start_nodes.all():
|
||
out_count[str(react.uuid)] += 1
|
||
|
||
for prod in e.end_nodes.all():
|
||
in_count[str(prod.uuid)] += 1
|
||
|
||
root_nodes = []
|
||
for n in pw.nodes:
|
||
num_parents = in_count[str(n.uuid)]
|
||
if num_parents == 0:
|
||
# must be a root node or unconnected node
|
||
if n.depth != 0:
|
||
n.depth = 0
|
||
n.save()
|
||
|
||
# Only root node may have children
|
||
if out_count[str(n.uuid)] > 0:
|
||
root_nodes.append(n)
|
||
|
||
levels = [root_nodes]
|
||
seen = set()
|
||
# Do a bfs to determine depths starting with level 0 a.k.a. root nodes
|
||
for i, level_nodes in enumerate(levels):
|
||
new_level = []
|
||
for n in level_nodes:
|
||
for e in n.out_edges.all():
|
||
for prod in e.end_nodes.all():
|
||
if str(prod.uuid) not in seen:
|
||
old_depth = prod.depth
|
||
if old_depth != i + 1:
|
||
print(f'updating depth from {old_depth} to {i + 1}')
|
||
prod.depth = i + 1
|
||
prod.save()
|
||
|
||
new_level.append(prod)
|
||
|
||
seen.add(str(n.uuid))
|
||
|
||
if new_level:
|
||
levels.append(new_level)
|
||
|
||
print(f'{p + 1}/{total_pws} fixed.')
|
||
|
||
return pack
|
||
|
||
class SettingManager(object):
|
||
setting_pattern = re.compile(r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
|
||
|
||
@staticmethod
|
||
def get_setting_by_url(user, setting_url):
|
||
match = re.findall(SettingManager.setting_pattern, setting_url)
|
||
|
||
if match:
|
||
setting_id = match[0].split('/')[-1]
|
||
return SettingManager.get_setting_by_id(user, setting_id)
|
||
else:
|
||
raise ValueError("Requested URL {} does not contain a valid setting identifier!".format(setting_url))
|
||
|
||
@staticmethod
|
||
def get_setting_by_id(user, setting_id):
|
||
s = Setting.objects.get(uuid=setting_id)
|
||
|
||
if s.global_default or s.public or user.is_superuser or \
|
||
UserSettingPermission.objects.filter(user=user, setting=s).exists():
|
||
return s
|
||
|
||
raise ValueError(
|
||
"Insufficient permissions to access Setting with ID {}".format(setting_id))
|
||
|
||
@staticmethod
|
||
def get_all_settings(user):
|
||
sp = UserSettingPermission.objects.filter(user=user).values('setting')
|
||
return (Setting.objects.filter(id__in=sp) | Setting.objects.filter(public=True) | Setting.objects.filter(
|
||
global_default=True)).distinct()
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_setting(user: User, name: str = None, description: str = None, max_nodes: int = None,
|
||
max_depth: int = None, rule_packages: List[Package] = None, model: EPModel = None,
|
||
model_threshold: float = None):
|
||
s = Setting()
|
||
s.name = name
|
||
s.description = description
|
||
s.max_nodes = max_nodes
|
||
s.max_depth = max_depth
|
||
s.model = model
|
||
s.model_threshold = model_threshold
|
||
|
||
s.save()
|
||
|
||
if rule_packages is not None:
|
||
for r in rule_packages:
|
||
s.rule_packages.add(r)
|
||
s.save()
|
||
|
||
usp = UserSettingPermission()
|
||
usp.user = user
|
||
usp.setting = s
|
||
usp.permission = Permission.ALL[0]
|
||
usp.save()
|
||
|
||
return s
|
||
|
||
@staticmethod
|
||
def get_default_setting(user: User):
|
||
|
||
pass
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def set_default_setting(user: User, setting: Setting):
|
||
pass
|
||
|
||
class SearchManager(object):
|
||
|
||
|
||
@staticmethod
|
||
def search(packages: Union[Package, List[Package]], searchterm: str, mode: str):
|
||
match mode:
|
||
case 'text':
|
||
return SearchManager._search_text(packages, searchterm)
|
||
case 'default':
|
||
return SearchManager._search_default_smiles(packages, searchterm)
|
||
case 'exact':
|
||
return SearchManager._search_exact_smiles(packages, searchterm)
|
||
case 'canonical':
|
||
return SearchManager._search_canonical_smiles(packages, searchterm)
|
||
case 'inchikey':
|
||
return SearchManager._search_inchikey(packages, searchterm)
|
||
case _:
|
||
raise ValueError(f"Unknown search mode {mode}!")
|
||
|
||
@staticmethod
|
||
def _search_inchikey(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(inchikey=searchterm)
|
||
compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__inchikey=searchterm)).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond)
|
||
reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__inchikey=searchterm) | Q(products__inchikey=searchterm))).distinct()
|
||
pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__inchikey=searchterm) | Q(edge__edge_label__products__inchikey=searchterm))).distinct()
|
||
|
||
return {
|
||
'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs],
|
||
'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs],
|
||
'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs],
|
||
'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_exact_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(smiles=searchterm)
|
||
compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__smiles=searchterm)).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond)
|
||
reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__smiles=searchterm) | Q(products__smiles=searchterm))).distinct()
|
||
pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__smiles=searchterm) | Q(edge__edge_label__products__smiles=searchterm))).distinct()
|
||
|
||
return {
|
||
'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs],
|
||
'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs],
|
||
'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs],
|
||
'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_default_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
inchi_front = FormatConverter.InChIKey(searchterm)[:14]
|
||
|
||
search_cond = Q(inchikey__startswith=inchi_front)
|
||
compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__inchikey__startswith=inchi_front)).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond)
|
||
reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__inchikey__startswith=inchi_front) | Q(products__inchikey__startswith=inchi_front))).distinct()
|
||
pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__inchikey__startswith=inchi_front) | Q(edge__edge_label__products__inchikey__startswith=inchi_front))).distinct()
|
||
|
||
return {
|
||
'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs],
|
||
'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs],
|
||
'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs],
|
||
'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_canonical_smiles(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = Q(canonical_smiles=searchterm)
|
||
compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__canonical_smiles=searchterm)).distinct()
|
||
compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond)
|
||
reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__canonical_smiles=searchterm) | Q(products__canonical_smiles=searchterm))).distinct()
|
||
pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__canonical_smiles=searchterm) | Q(edge__edge_label__products__canonical_smiles=searchterm))).distinct()
|
||
|
||
return {
|
||
'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs],
|
||
'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs],
|
||
'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs],
|
||
'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs],
|
||
}
|
||
|
||
@staticmethod
|
||
def _search_text(packages: Union[Package, List[Package]], searchterm: str):
|
||
from django.db.models import Q
|
||
|
||
search_cond = (Q(name__icontains=searchterm) | Q(description__icontains=searchterm))
|
||
cond = Q(package__in=packages) & search_cond
|
||
compound_qs = Compound.objects.filter(cond)
|
||
compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond)
|
||
rule_qs = Rule.objects.filter(cond)
|
||
reactions_qs = Reaction.objects.filter(cond)
|
||
pathway_qs = Pathway.objects.filter(cond)
|
||
|
||
res = {
|
||
'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs],
|
||
'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs],
|
||
'Rules': [{'name': r.name, 'description': r.description, 'url': r.url} for r in rule_qs],
|
||
'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs],
|
||
'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs],
|
||
}
|
||
return res
|
||
|
||
|
||
class SNode(object):
|
||
|
||
def __init__(self, smiles: str, depth: int, app_domain_assessment: dict = None):
|
||
self.smiles = smiles
|
||
self.depth = depth
|
||
self.app_domain_assessment = app_domain_assessment
|
||
|
||
def __hash__(self):
|
||
return hash(self.smiles)
|
||
|
||
def __eq__(self, other):
|
||
if isinstance(other, self.__class__):
|
||
return self.smiles == other.smiles
|
||
return False
|
||
|
||
def __repr__(self):
|
||
return f"SNode('{self.smiles}', {self.depth})"
|
||
|
||
|
||
class SEdge(object):
|
||
|
||
def __init__(self, educts: Union[SNode, List[SNode]], products: Union[SNode | List[SNode]],
|
||
rule: Optional['Rule'] = None, probability: Optional[float] = None):
|
||
|
||
if not isinstance(educts, list):
|
||
educts = [educts]
|
||
|
||
self.educts = educts
|
||
self.products = products
|
||
self.rule = rule
|
||
self.probability = probability
|
||
|
||
def __hash__(self):
|
||
full_hash = 0
|
||
|
||
for n in sorted(self.educts, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
for n in sorted(self.products, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
if self.rule is not None:
|
||
full_hash += hash(self.rule)
|
||
|
||
return full_hash
|
||
|
||
def __eq__(self, other):
|
||
if not isinstance(other, SEdge):
|
||
return False
|
||
|
||
if self.rule is not None and other.rule is None or \
|
||
self.rule is None and other.rule is not None or \
|
||
self.rule != other.rule:
|
||
return False
|
||
|
||
if not (len(self.educts) == len(other.educts)):
|
||
return False
|
||
|
||
for n1, n2 in zip(sorted(self.educts, key=lambda x: x.smiles), sorted(other.educts, key=lambda x: x.smiles)):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
if not (len(self.products) == len(other.products)):
|
||
return False
|
||
|
||
for n1, n2 in zip(sorted(self.products, key=lambda x: x.smiles),
|
||
sorted(other.products, key=lambda x: x.smiles)):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
return True
|
||
|
||
def __repr__(self):
|
||
return f"SEdge({self.educts}, {self.products}, {self.rule})"
|
||
|
||
|
||
class SPathway(object):
|
||
|
||
def __init__(self, root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
|
||
persist: Optional['Pathway'] = None, prediction_setting: Optional[Setting] = None
|
||
):
|
||
self.root_nodes = []
|
||
|
||
self.persist = persist
|
||
self.snode_persist_lookup: Dict[SNode, Node] = dict()
|
||
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
|
||
self.prediction_setting = prediction_setting
|
||
|
||
if persist:
|
||
for n in persist.root_nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
self.root_nodes.append(snode)
|
||
self.snode_persist_lookup[snode] = n
|
||
else:
|
||
if not isinstance(root_nodes, list):
|
||
root_nodes = [root_nodes]
|
||
|
||
for n in root_nodes:
|
||
if isinstance(n, str):
|
||
self.root_nodes.append(SNode(n, 0))
|
||
elif isinstance(n, SNode):
|
||
self.root_nodes.append(n)
|
||
|
||
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
|
||
self.edges: Set['SEdge'] = set()
|
||
self.done = False
|
||
|
||
@staticmethod
|
||
def from_pathway(pw: 'Pathway', persist: bool = True):
|
||
""" Initializes a SPathway with a state given by a Pathway """
|
||
spw = SPathway(root_nodes=pw.root_nodes, persist=pw if persist else None, prediction_setting=pw.setting)
|
||
# root_nodes are already added in __init__, add remaining nodes
|
||
for n in pw.nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
if snode.smiles not in spw.smiles_to_node:
|
||
spw.smiles_to_node[snode.smiles] = snode
|
||
spw.snode_persist_lookup[snode] = n
|
||
|
||
for e in pw.edges:
|
||
sub = []
|
||
prod = []
|
||
|
||
for n in e.start_nodes.all():
|
||
sub.append(spw.smiles_to_node[n.default_node_label.smiles])
|
||
|
||
for n in e.end_nodes.all():
|
||
prod.append(spw.smiles_to_node[n.default_node_label.smiles])
|
||
|
||
rule = None
|
||
if e.edge_label.rules.all():
|
||
rule = e.edge_label.rules.all().first()
|
||
|
||
prob = None
|
||
if e.kv.get('probability'):
|
||
prob = float(e.kv['probability'])
|
||
|
||
sedge = SEdge(sub, prod, rule=rule, probability=prob)
|
||
spw.edges.add(sedge)
|
||
spw.sedge_persist_lookup[sedge] = e
|
||
|
||
return spw
|
||
|
||
def num_nodes(self):
|
||
return len(self.smiles_to_node.keys())
|
||
|
||
def depth(self):
|
||
return max([v.depth for v in self.smiles_to_node.values()])
|
||
|
||
def _get_nodes_for_depth(self, depth: int) -> List[SNode]:
|
||
if depth == 0:
|
||
return self.root_nodes
|
||
|
||
res = []
|
||
for n in self.smiles_to_node.values():
|
||
if n.depth == depth:
|
||
res.append(n)
|
||
|
||
return sorted(res, key=lambda x: x.smiles)
|
||
|
||
def _get_edges_for_depth(self, depth: int) -> List[SEdge]:
|
||
res = []
|
||
for e in self.edges:
|
||
for n in e.educts:
|
||
if n.depth == depth:
|
||
res.append(e)
|
||
|
||
return sorted(res, key=lambda x: hash(x))
|
||
|
||
def predict_step(self, from_depth: int = None, from_node: 'Node' = None):
|
||
substrates: List[SNode] = []
|
||
|
||
if from_depth is not None:
|
||
substrates = self._get_nodes_for_depth(from_depth)
|
||
elif from_node is not None:
|
||
for k, v in self.snode_persist_lookup.items():
|
||
if from_node == v:
|
||
substrates = [k]
|
||
break
|
||
else:
|
||
raise ValueError("Neither from_depth nor from_node_url specified")
|
||
|
||
new_tp = False
|
||
if substrates:
|
||
for sub in substrates:
|
||
|
||
if sub.app_domain_assessment is None:
|
||
if self.prediction_setting.model:
|
||
if self.prediction_setting.model.app_domain:
|
||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)[0]
|
||
|
||
if self.persist is not None:
|
||
n = self.snode_persist_lookup[sub]
|
||
|
||
assert n.id is not None, "Node has no id! Should have been saved already... aborting!"
|
||
node_data = n.simple_json()
|
||
node_data['image'] = f"{n.url}?image=svg"
|
||
app_domain_assessment['assessment']['node'] = node_data
|
||
|
||
n.kv['app_domain_assessment'] = app_domain_assessment
|
||
n.save()
|
||
|
||
sub.app_domain_assessment = app_domain_assessment
|
||
|
||
|
||
candidates = self.prediction_setting.expand(self, sub)
|
||
# candidates is a List of PredictionResult. The length of the List is equal to the number of rules
|
||
for cand_set in candidates:
|
||
if cand_set:
|
||
new_tp = True
|
||
# cand_set is a PredictionResult object that can consist of multiple candidate reactions
|
||
for cand in cand_set:
|
||
cand_nodes = []
|
||
# candidate reactions can have multiple fragments
|
||
for c in cand:
|
||
if c not in self.smiles_to_node:
|
||
# For new nodes do an AppDomain Assessment if an AppDomain is attached
|
||
app_domain_assessment = None
|
||
if self.prediction_setting.model:
|
||
if self.prediction_setting.model.app_domain:
|
||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(c)[0]
|
||
|
||
self.smiles_to_node[c] = SNode(c, sub.depth + 1, app_domain_assessment)
|
||
|
||
node = self.smiles_to_node[c]
|
||
cand_nodes.append(node)
|
||
|
||
edge = SEdge(sub, cand_nodes, rule=cand_set.rule, probability=cand_set.probability)
|
||
self.edges.add(edge)
|
||
|
||
# In case no substrates are found, we're done.
|
||
# For "predict from node" we're always done
|
||
if len(substrates) == 0 or from_node is not None:
|
||
self.done = True
|
||
|
||
# Check if we need to write back data to the database
|
||
if new_tp and self.persist:
|
||
self._sync_to_pathway()
|
||
# call save to update the internal modified field
|
||
self.persist.save()
|
||
|
||
def _sync_to_pathway(self) -> None:
|
||
logger.info("Updating Pathway with SPathway")
|
||
|
||
for snode in self.smiles_to_node.values():
|
||
if snode not in self.snode_persist_lookup:
|
||
n = Node.create(self.persist, snode.smiles, snode.depth)
|
||
|
||
if snode.app_domain_assessment is not None:
|
||
app_domain_assessment = snode.app_domain_assessment
|
||
|
||
assert n.id is not None, "Node has no id! Should have been saved already... aborting!"
|
||
node_data = n.simple_json()
|
||
node_data['image'] = f"{n.url}?image=svg"
|
||
app_domain_assessment['assessment']['node'] = node_data
|
||
|
||
n.kv['app_domain_assessment'] = app_domain_assessment
|
||
n.save()
|
||
|
||
self.snode_persist_lookup[snode] = n
|
||
|
||
for sedge in self.edges:
|
||
if sedge not in self.sedge_persist_lookup:
|
||
educt_nodes = []
|
||
for snode in sedge.educts:
|
||
educt_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
product_nodes = []
|
||
for snode in sedge.products:
|
||
product_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
|
||
|
||
if sedge.probability:
|
||
e.kv.update({'probability': sedge.probability})
|
||
e.save()
|
||
|
||
self.sedge_persist_lookup[sedge] = e
|
||
|
||
logger.info("Update done!")
|
||
|
||
def to_json(self):
|
||
nodes = []
|
||
edges = []
|
||
|
||
idx_lookup = {}
|
||
|
||
for i, s in enumerate(self.smiles_to_node):
|
||
n = self.smiles_to_node[s]
|
||
idx_lookup[s] = i
|
||
|
||
nodes.append({'depth': n.depth, 'smiles': n.smiles, 'id': i})
|
||
|
||
for edge in self.edges:
|
||
from_idx = idx_lookup[edge.educts[0].smiles]
|
||
to_indices = [idx_lookup[p.smiles] for p in edge.products]
|
||
|
||
e = {
|
||
'from': from_idx,
|
||
'to': to_indices,
|
||
}
|
||
|
||
# if edge.rule:
|
||
# e['rule'] = {
|
||
# 'name': edge.rule.name,
|
||
# 'id': edge.rule.url,
|
||
# }
|
||
edges.append(e)
|
||
|
||
return {
|
||
'nodes': nodes,
|
||
'edges': edges,
|
||
}
|
||
|
||
def graph_to_tree_string(self):
|
||
graph_json = self.to_json()
|
||
nodes = {node['id']: node for node in graph_json['nodes']}
|
||
edges = graph_json['edges']
|
||
|
||
children_map = {}
|
||
for edge in edges:
|
||
src = edge['from']
|
||
for tgt in edge['to']:
|
||
children_map.setdefault(src, []).append(tgt)
|
||
|
||
visited = set()
|
||
|
||
def recurse(node_id, prefix=''):
|
||
if node_id in visited:
|
||
return prefix + nodes[node_id]['smiles'] + " [loop detected]\n"
|
||
visited.add(node_id)
|
||
|
||
line = prefix + nodes[node_id]['smiles'] + f" [{node_id}]\n"
|
||
kids = children_map.get(node_id, [])
|
||
for i, kid in enumerate(kids):
|
||
if i == len(kids) - 1:
|
||
branch = '└── '
|
||
child_prefix = prefix + ' '
|
||
else:
|
||
branch = '├── '
|
||
child_prefix = prefix + '│ '
|
||
line += recurse(kid, prefix=prefix + branch)
|
||
return line
|
||
|
||
root_nodes = [n['id'] for n in graph_json['nodes'] if n['depth'] == 0]
|
||
result = ''
|
||
for root in root_nodes:
|
||
visited.clear()
|
||
result += recurse(root)
|
||
return result
|