import re import logging from typing import Union, List, Optional, Set, Dict from django.contrib.auth import get_user_model from django.db import transaction from django.conf import settings as s from epdb.models import User, Package, UserPackagePermission, GroupPackagePermission, Permission, Group, Setting, \ EPModel, UserSettingPermission, Rule, Pathway, Node, Edge logger = logging.getLogger(__name__) class UserManager(object): @staticmethod def create_user(username, email, password, *args, **kwargs): # avoid circular import :S from .tasks import send_registration_mail is_active = not s.ADMIN_APPROVAL_REQUIRED if 'is_active' in kwargs: is_active = kwargs['is_active'] u = get_user_model().objects.create_user(username, email, password, is_active=is_active) # Create package package_name = f"{u.username}{'’' if u.username[-1] in 'sxzß' else 's'} Package" package_description = f"This package was generated during registration." p = PackageManager.create_package(u, package_name, package_description) u.default_package = p u.save() if not u.is_active: # send email for verification send_registration_mail.delay(u.pk) return u @staticmethod def get_user(user_url): pass @staticmethod def get_user_lp(user_url: str): uuid = user_url.strip().split('/')[-1] return get_user_model().objects.get(uuid=uuid) @staticmethod def get_users(): return [] @staticmethod def get_user_lp(user_url: str): uuid = user_url.strip().split('/')[-1] return get_user_model().objects.get(uuid=uuid) class GroupManager(object): @staticmethod def create_group(current_user, name, description): g = Group() g.name = name g.description = description g.owner = current_user g.save() g.user_member.add(current_user) g.save() return g @staticmethod def get_group_lp(group_url: str): uuid = group_url.strip().split('/')[-1] return Group.objects.get(uuid=uuid) @staticmethod def get_group_by_url(user, group_url): return GroupManager.get_group_by_id(user, group_url.split('/')[-1]) @staticmethod def get_group_by_id(user, group_id): g = Group.objects.get(uuid=group_id) if user in g.user_member.all(): return g return None @staticmethod def get_groups(user): return Group.objects.filter(user_member=user) @staticmethod @transaction.atomic def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str): if caller != group.owner: raise ValueError('Only the group Owner is allowed to add members!') if isinstance(member, Group): if add_or_remove == 'add': group.group_member.add(member) else: group.group_member.remove(member) else: if add_or_remove == 'add': group.user_member.add(member) else: group.user_member.remove(member) group.save() class PackageManager(object): package_pattern = re.compile(r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$") @staticmethod def get_reviewed_packages(): return Package.objects.filter(reviewed=True) @staticmethod def readable(user, package): # TODO Owner! if UserPackagePermission.objects.filter(package=package, user=user).exists() or \ GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user)) or \ package.reviewed is True or \ user.is_superuser: return True return False @staticmethod def writable(user, package): # TODO Owner! if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.WRITE).exists() or \ GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user), permission=Permission.WRITE) or \ user.is_superuser: return True return False @staticmethod def get_package_by_url(user, package_url): match = re.findall(PackageManager.package_pattern, package_url) if match: package_id = match[0].split('/')[-1] return PackageManager.get_package_by_id(user, package_id) else: raise ValueError("Requested URL {} does not contain a valid package identifier!".format(package_url)) @staticmethod def get_package_by_id(user, package_id): try: p = Package.objects.get(uuid=package_id) if PackageManager.readable(user, p): return p else: raise ValueError( "Insufficient permissions to access Package with ID {}".format(package_id)) except Package.DoesNotExist: raise ValueError("Package with ID {} does not exist!".format(package_id)) @staticmethod def get_all_readable_packages(user, include_reviewed=False): # UserPermission only exists if at least read is granted... if user.is_superuser: qs = Package.objects.all() else: user_package_qs = Package.objects.filter( id__in=UserPackagePermission.objects.filter(user=user).values('package').distinct()) group_package_qs = Package.objects.filter( id__in=GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user)).values( 'package').distinct()) qs = user_package_qs | group_package_qs if include_reviewed: qs |= Package.objects.filter(reviewed=True) else: # remove package if user is owner and package is reviewed e.g. admin qs = qs.filter(reviewed=False) return qs.distinct() @staticmethod def get_all_writeable_packages(user): # UserPermission only exists if at least read is granted... if user.is_superuser: qs = Package.objects.all() else: write_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]).values('package').distinct() owner_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]).values('package').distinct() user_packs = write_user_packs | owner_user_packs user_package_qs = Package.objects.filter(id__in=user_packs) write_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).values( 'package').distinct() owner_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).values( 'package').distinct() group_packs = write_group_packs | owner_group_packs group_package_qs = Package.objects.filter(id__in=group_packs) qs = user_package_qs | group_package_qs qs = qs.filter(reviewed=False) return qs.distinct() @staticmethod def get_packages(): return Package.objects.all() @staticmethod @transaction.atomic def create_package(current_user, name: str, description: str = None): p = Package() p.name = name p.description = description p.save() up = UserPackagePermission() up.user = current_user up.package = p up.permission = UserPackagePermission.ALL[0] up.save() return p @staticmethod @transaction.atomic def update_permissions(caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]): if not PackageManager.writable(caller, package): raise ValueError(f"User {caller} is not allowed to modify permissions on {package}") data = { 'package': package, } if isinstance(grantee, User): perm_cls = UserPackagePermission data['user'] = grantee else: perm_cls = GroupPackagePermission data['group'] = grantee if new_perm is None: qs = perm_cls.objects.filter(**data) if qs.count() > 1: raise ValueError("Got more Permission objects than expected!") if qs.count() != 0: logger.info(f"Deleting Perm {qs.first()}") qs.delete() else: logger.debug(f"No Permission object for {perm_cls} with filter {data} found!") else: _ = perm_cls.objects.update_or_create(defaults={'permission': new_perm}, **data) class SettingManager(object): setting_pattern = re.compile(r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$") @staticmethod def get_setting_by_url(user, setting_url): match = re.findall(SettingManager.setting_pattern, setting_url) if match: setting_id = match[0].split('/')[-1] return SettingManager.get_setting_by_id(user, setting_id) else: raise ValueError("Requested URL {} does not contain a valid setting identifier!".format(setting_url)) @staticmethod def get_setting_by_id(user, setting_id): s = Setting.objects.get(uuid=setting_id) if s.global_default or s.public or s.owner == user or user.is_superuser: return s raise ValueError( "Insufficient permissions to access Setting with ID {}".format(setting_id)) @staticmethod def get_all_settings(user): sp = UserSettingPermission.objects.filter(user=user).values('setting').distinct() return Setting.objects.filter(id__in=sp) @staticmethod @transaction.atomic def create_setting(user: User, name: str = None, description: str = None, max_nodes: int = None, max_depth: int = None, rule_packages: List[Package] = None, model: EPModel = None, model_threshold: float = None): s = Setting() s.name = name s.description = description s.max_nodes = max_nodes s.max_depth = max_depth s.model = model s.model_threshold = model_threshold s.save() if rule_packages is not None: for r in rule_packages: s.rule_packages.add(r) s.save() usp = UserSettingPermission() usp.user = user usp.setting = s usp.permission = Permission.ALL[0] usp.save() return s @staticmethod def get_default_setting(user: User): pass @staticmethod @transaction.atomic def set_default_setting(user: User, setting: Setting): pass class SNode(object): def __init__(self, smiles: str, depth: int): self.smiles = smiles self.depth = depth def __hash__(self): return hash(self.smiles) def __eq__(self, other): if isinstance(other, self.__class__): return self.smiles == other.smiles return False def __repr__(self): return f"SNode('{self.smiles}', {self.depth})" class SEdge(object): def __init__(self, educts: Union[SNode, List[SNode]], products: Union[SNode | List[SNode]], rule: Optional['Rule'] = None): if not isinstance(educts, list): educts = [educts] self.educts = educts self.products = products self.rule = rule def __hash__(self): full_hash = 0 for n in sorted(self.educts, key=lambda x: x.smiles): full_hash += hash(n) for n in sorted(self.products, key=lambda x: x.smiles): full_hash += hash(n) if self.rule is not None: full_hash += hash(self.rule) return full_hash def __eq__(self, other): if not isinstance(other, SEdge): return False if self.rule is not None and other.rule is None or \ self.rule is None and other.rule is not None or \ self.rule != other.rule: return False if not (len(self.educts) == len(other.educts)): return False for n1, n2 in zip(sorted(self.educts, key=lambda x: x.smiles), sorted(other.educts, key=lambda x: x.smiles)): if n1.smiles != n2.smiles: return False if not (len(self.products) == len(other.products)): return False for n1, n2 in zip(sorted(self.products, key=lambda x: x.smiles), sorted(other.products, key=lambda x: x.smiles)): if n1.smiles != n2.smiles: return False return True def __repr__(self): return f"SEdge({self.educts}, {self.products}, {self.rule})" class SPathway(object): def __init__(self, root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None, persist: Optional['Pathway'] = None, prediction_setting: Optional[Setting] = None ): self.root_nodes = [] self.persist = persist self.snode_persist_lookup: Dict[SNode, Node] = dict() self.sedge_persist_lookup: Dict[SEdge, Edge] = dict() self.prediction_setting = prediction_setting if persist: for n in persist.root_nodes: snode = SNode(n.default_node_label.smiles, n.depth) self.root_nodes.append(snode) self.snode_persist_lookup[snode] = n else: if not isinstance(root_nodes, list): root_nodes = [root_nodes] for n in root_nodes: if isinstance(n, str): self.root_nodes.append(SNode(n, 0)) elif isinstance(n, SNode): self.root_nodes.append(n) self.queue = list() self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes}) self.edges: Set['SEdge'] = set() self.done = False def num_nodes(self): return len(self.smiles_to_node.keys()) def depth(self): return max([v.depth for v in self.smiles_to_node.values()]) def _get_nodes_for_depth(self, depth: int): if depth == 0: return self.root_nodes res = [] for n in self.smiles_to_node.values(): if n.depth == depth: res.append(n) return sorted(res, key=lambda x: x.smiles) def _get_edges_for_depth(self, depth: int): res = [] for e in self.edges: for n in e.educts: if n.depth == depth: res.append(e) return sorted(res, key=lambda x: hash(x)) def predict_step(self, from_depth: int = 0): substrates = self._get_nodes_for_depth(from_depth) new_tp = False if substrates: for sub in substrates: candidates = self.prediction_setting.expand(self, sub) for cand_set in candidates: if cand_set: new_tp = True for cand in cand_set: cand_nodes = [] for c in cand: if c not in self.smiles_to_node: self.smiles_to_node[c] = SNode(c, sub.depth + 1) node = self.smiles_to_node[c] cand_nodes.append(node) edge = SEdge(sub, cand_nodes, cand_set.rule) self.edges.add(edge) else: self.done = True if new_tp and self.persist: self._sync_to_pathway() def _sync_to_pathway(self): logger.info("Updating Pathway with SPathway") for snode in self.smiles_to_node.values(): if snode not in self.snode_persist_lookup: n = Node.create(self.persist, snode.smiles, snode.depth) self.snode_persist_lookup[snode] = n for sedge in self.edges: if sedge not in self.sedge_persist_lookup: educt_nodes = [] for snode in sedge.educts: educt_nodes.append(self.snode_persist_lookup[snode]) product_nodes = [] for snode in sedge.products: product_nodes.append(self.snode_persist_lookup[snode]) e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule) self.sedge_persist_lookup[sedge] = e logger.info("Update done!") pass def to_json(self): nodes = [] edges = [] idx_lookup = {} for i, s in enumerate(self.smiles_to_node): n = self.smiles_to_node[s] idx_lookup[s] = i nodes.append({'depth': n.depth, 'smiles': n.smiles, 'id': i}) for edge in self.edges: from_idx = idx_lookup[edge.educts[0].smiles] to_indices = [idx_lookup[p.smiles] for p in edge.products] e = { 'from': from_idx, 'to': to_indices, } # if edge.rule: # e['rule'] = { # 'name': edge.rule.name, # 'id': edge.rule.url, # } edges.append(e) return { 'nodes': nodes, 'edges': edges, } def graph_to_tree_string(self): graph_json = self.to_json() nodes = {node['id']: node for node in graph_json['nodes']} edges = graph_json['edges'] children_map = {} for edge in edges: src = edge['from'] for tgt in edge['to']: children_map.setdefault(src, []).append(tgt) visited = set() def recurse(node_id, prefix=''): if node_id in visited: return prefix + nodes[node_id]['smiles'] + " [loop detected]\n" visited.add(node_id) line = prefix + nodes[node_id]['smiles'] + f" [{node_id}]\n" kids = children_map.get(node_id, []) for i, kid in enumerate(kids): if i == len(kids) - 1: branch = '└── ' child_prefix = prefix + ' ' else: branch = '├── ' child_prefix = prefix + '│ ' line += recurse(kid, prefix=prefix + branch) return line root_nodes = [n['id'] for n in graph_json['nodes'] if n['depth'] == 0] result = '' for root in root_nodes: visited.clear() result += recurse(root) return result