Files
enviPy-bayer/epdb/logic.py
jebus 4fff78541b Implement Admin approval (#29)
This PR fixes #7

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#29
2025-07-19 06:42:50 +12:00

577 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import logging
from typing import Union, List, Optional, Set, Dict
from django.contrib.auth import get_user_model
from django.db import transaction
from django.conf import settings as s
from epdb.models import User, Package, UserPackagePermission, GroupPackagePermission, Permission, Group, Setting, \
EPModel, UserSettingPermission, Rule, Pathway, Node, Edge
logger = logging.getLogger(__name__)
class UserManager(object):
@staticmethod
def create_user(username, email, password, *args, **kwargs):
# avoid circular import :S
from .tasks import send_registration_mail
is_active = not s.ADMIN_APPROVAL_REQUIRED
if 'is_active' in kwargs:
is_active = kwargs['is_active']
u = get_user_model().objects.create_user(username, email, password, is_active=is_active)
# Create package
package_name = f"{u.username}{'' if u.username[-1] in 'sxzß' else 's'} Package"
package_description = f"This package was generated during registration."
p = PackageManager.create_package(u, package_name, package_description)
u.default_package = p
u.save()
if not u.is_active:
# send email for verification
send_registration_mail.delay(u.pk)
return u
@staticmethod
def get_user(user_url):
pass
@staticmethod
def get_user_lp(user_url: str):
uuid = user_url.strip().split('/')[-1]
return get_user_model().objects.get(uuid=uuid)
@staticmethod
def get_users():
return []
@staticmethod
def get_user_lp(user_url: str):
uuid = user_url.strip().split('/')[-1]
return get_user_model().objects.get(uuid=uuid)
class GroupManager(object):
@staticmethod
def create_group(current_user, name, description):
g = Group()
g.name = name
g.description = description
g.owner = current_user
g.save()
g.user_member.add(current_user)
g.save()
return g
@staticmethod
def get_group_lp(group_url: str):
uuid = group_url.strip().split('/')[-1]
return Group.objects.get(uuid=uuid)
@staticmethod
def get_group_by_url(user, group_url):
return GroupManager.get_group_by_id(user, group_url.split('/')[-1])
@staticmethod
def get_group_by_id(user, group_id):
g = Group.objects.get(uuid=group_id)
if user in g.user_member.all():
return g
return None
@staticmethod
def get_groups(user):
return Group.objects.filter(user_member=user)
@staticmethod
@transaction.atomic
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
if caller != group.owner:
raise ValueError('Only the group Owner is allowed to add members!')
if isinstance(member, Group):
if add_or_remove == 'add':
group.group_member.add(member)
else:
group.group_member.remove(member)
else:
if add_or_remove == 'add':
group.user_member.add(member)
else:
group.user_member.remove(member)
group.save()
class PackageManager(object):
package_pattern = re.compile(r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
@staticmethod
def get_reviewed_packages():
return Package.objects.filter(reviewed=True)
@staticmethod
def readable(user, package):
# TODO Owner!
if UserPackagePermission.objects.filter(package=package, user=user).exists() or \
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user)) or \
package.reviewed is True or \
user.is_superuser:
return True
return False
@staticmethod
def writable(user, package):
# TODO Owner!
if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.WRITE).exists() or \
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user),
permission=Permission.WRITE) or \
user.is_superuser:
return True
return False
@staticmethod
def get_package_by_url(user, package_url):
match = re.findall(PackageManager.package_pattern, package_url)
if match:
package_id = match[0].split('/')[-1]
return PackageManager.get_package_by_id(user, package_id)
else:
raise ValueError("Requested URL {} does not contain a valid package identifier!".format(package_url))
@staticmethod
def get_package_by_id(user, package_id):
try:
p = Package.objects.get(uuid=package_id)
if PackageManager.readable(user, p):
return p
else:
raise ValueError(
"Insufficient permissions to access Package with ID {}".format(package_id))
except Package.DoesNotExist:
raise ValueError("Package with ID {} does not exist!".format(package_id))
@staticmethod
def get_all_readable_packages(user, include_reviewed=False):
# UserPermission only exists if at least read is granted...
if user.is_superuser:
qs = Package.objects.all()
else:
user_package_qs = Package.objects.filter(
id__in=UserPackagePermission.objects.filter(user=user).values('package').distinct())
group_package_qs = Package.objects.filter(
id__in=GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user)).values(
'package').distinct())
qs = user_package_qs | group_package_qs
if include_reviewed:
qs |= Package.objects.filter(reviewed=True)
else:
# remove package if user is owner and package is reviewed e.g. admin
qs = qs.filter(reviewed=False)
return qs.distinct()
@staticmethod
def get_all_writeable_packages(user):
# UserPermission only exists if at least read is granted...
if user.is_superuser:
qs = Package.objects.all()
else:
write_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]).values('package').distinct()
owner_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]).values('package').distinct()
user_packs = write_user_packs | owner_user_packs
user_package_qs = Package.objects.filter(id__in=user_packs)
write_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).values( 'package').distinct()
owner_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).values( 'package').distinct()
group_packs = write_group_packs | owner_group_packs
group_package_qs = Package.objects.filter(id__in=group_packs)
qs = user_package_qs | group_package_qs
qs = qs.filter(reviewed=False)
return qs.distinct()
@staticmethod
def get_packages():
return Package.objects.all()
@staticmethod
@transaction.atomic
def create_package(current_user, name: str, description: str = None):
p = Package()
p.name = name
p.description = description
p.save()
up = UserPackagePermission()
up.user = current_user
up.package = p
up.permission = UserPackagePermission.ALL[0]
up.save()
return p
@staticmethod
@transaction.atomic
def update_permissions(caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]):
if not PackageManager.writable(caller, package):
raise ValueError(f"User {caller} is not allowed to modify permissions on {package}")
data = {
'package': package,
}
if isinstance(grantee, User):
perm_cls = UserPackagePermission
data['user'] = grantee
else:
perm_cls = GroupPackagePermission
data['group'] = grantee
if new_perm is None:
qs = perm_cls.objects.filter(**data)
if qs.count() > 1:
raise ValueError("Got more Permission objects than expected!")
if qs.count() != 0:
logger.info(f"Deleting Perm {qs.first()}")
qs.delete()
else:
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
else:
_ = perm_cls.objects.update_or_create(defaults={'permission': new_perm}, **data)
class SettingManager(object):
setting_pattern = re.compile(r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
@staticmethod
def get_setting_by_url(user, setting_url):
match = re.findall(SettingManager.setting_pattern, setting_url)
if match:
setting_id = match[0].split('/')[-1]
return SettingManager.get_setting_by_id(user, setting_id)
else:
raise ValueError("Requested URL {} does not contain a valid setting identifier!".format(setting_url))
@staticmethod
def get_setting_by_id(user, setting_id):
s = Setting.objects.get(uuid=setting_id)
if s.global_default or s.public or s.owner == user or user.is_superuser:
return s
raise ValueError(
"Insufficient permissions to access Setting with ID {}".format(setting_id))
@staticmethod
def get_all_settings(user):
sp = UserSettingPermission.objects.filter(user=user).values('setting').distinct()
return Setting.objects.filter(id__in=sp)
@staticmethod
@transaction.atomic
def create_setting(user: User, name: str = None, description: str = None, max_nodes: int = None,
max_depth: int = None, rule_packages: List[Package] = None, model: EPModel = None,
model_threshold: float = None):
s = Setting()
s.name = name
s.description = description
s.max_nodes = max_nodes
s.max_depth = max_depth
s.model = model
s.model_threshold = model_threshold
s.save()
if rule_packages is not None:
for r in rule_packages:
s.rule_packages.add(r)
s.save()
usp = UserSettingPermission()
usp.user = user
usp.setting = s
usp.permission = Permission.ALL[0]
usp.save()
return s
@staticmethod
def get_default_setting(user: User):
pass
@staticmethod
@transaction.atomic
def set_default_setting(user: User, setting: Setting):
pass
class SNode(object):
def __init__(self, smiles: str, depth: int):
self.smiles = smiles
self.depth = depth
def __hash__(self):
return hash(self.smiles)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.smiles == other.smiles
return False
def __repr__(self):
return f"SNode('{self.smiles}', {self.depth})"
class SEdge(object):
def __init__(self, educts: Union[SNode, List[SNode]], products: Union[SNode | List[SNode]],
rule: Optional['Rule'] = None):
if not isinstance(educts, list):
educts = [educts]
self.educts = educts
self.products = products
self.rule = rule
def __hash__(self):
full_hash = 0
for n in sorted(self.educts, key=lambda x: x.smiles):
full_hash += hash(n)
for n in sorted(self.products, key=lambda x: x.smiles):
full_hash += hash(n)
if self.rule is not None:
full_hash += hash(self.rule)
return full_hash
def __eq__(self, other):
if not isinstance(other, SEdge):
return False
if self.rule is not None and other.rule is None or \
self.rule is None and other.rule is not None or \
self.rule != other.rule:
return False
if not (len(self.educts) == len(other.educts)):
return False
for n1, n2 in zip(sorted(self.educts, key=lambda x: x.smiles), sorted(other.educts, key=lambda x: x.smiles)):
if n1.smiles != n2.smiles:
return False
if not (len(self.products) == len(other.products)):
return False
for n1, n2 in zip(sorted(self.products, key=lambda x: x.smiles),
sorted(other.products, key=lambda x: x.smiles)):
if n1.smiles != n2.smiles:
return False
return True
def __repr__(self):
return f"SEdge({self.educts}, {self.products}, {self.rule})"
class SPathway(object):
def __init__(self, root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
persist: Optional['Pathway'] = None, prediction_setting: Optional[Setting] = None
):
self.root_nodes = []
self.persist = persist
self.snode_persist_lookup: Dict[SNode, Node] = dict()
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
self.prediction_setting = prediction_setting
if persist:
for n in persist.root_nodes:
snode = SNode(n.default_node_label.smiles, n.depth)
self.root_nodes.append(snode)
self.snode_persist_lookup[snode] = n
else:
if not isinstance(root_nodes, list):
root_nodes = [root_nodes]
for n in root_nodes:
if isinstance(n, str):
self.root_nodes.append(SNode(n, 0))
elif isinstance(n, SNode):
self.root_nodes.append(n)
self.queue = list()
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
self.edges: Set['SEdge'] = set()
self.done = False
def num_nodes(self):
return len(self.smiles_to_node.keys())
def depth(self):
return max([v.depth for v in self.smiles_to_node.values()])
def _get_nodes_for_depth(self, depth: int):
if depth == 0:
return self.root_nodes
res = []
for n in self.smiles_to_node.values():
if n.depth == depth:
res.append(n)
return sorted(res, key=lambda x: x.smiles)
def _get_edges_for_depth(self, depth: int):
res = []
for e in self.edges:
for n in e.educts:
if n.depth == depth:
res.append(e)
return sorted(res, key=lambda x: hash(x))
def predict_step(self, from_depth: int = 0):
substrates = self._get_nodes_for_depth(from_depth)
new_tp = False
if substrates:
for sub in substrates:
candidates = self.prediction_setting.expand(self, sub)
for cand_set in candidates:
if cand_set:
new_tp = True
for cand in cand_set:
cand_nodes = []
for c in cand:
if c not in self.smiles_to_node:
self.smiles_to_node[c] = SNode(c, sub.depth + 1)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(sub, cand_nodes, cand_set.rule)
self.edges.add(edge)
else:
self.done = True
if new_tp and self.persist:
self._sync_to_pathway()
def _sync_to_pathway(self):
logger.info("Updating Pathway with SPathway")
for snode in self.smiles_to_node.values():
if snode not in self.snode_persist_lookup:
n = Node.create(self.persist, snode.smiles, snode.depth)
self.snode_persist_lookup[snode] = n
for sedge in self.edges:
if sedge not in self.sedge_persist_lookup:
educt_nodes = []
for snode in sedge.educts:
educt_nodes.append(self.snode_persist_lookup[snode])
product_nodes = []
for snode in sedge.products:
product_nodes.append(self.snode_persist_lookup[snode])
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
self.sedge_persist_lookup[sedge] = e
logger.info("Update done!")
pass
def to_json(self):
nodes = []
edges = []
idx_lookup = {}
for i, s in enumerate(self.smiles_to_node):
n = self.smiles_to_node[s]
idx_lookup[s] = i
nodes.append({'depth': n.depth, 'smiles': n.smiles, 'id': i})
for edge in self.edges:
from_idx = idx_lookup[edge.educts[0].smiles]
to_indices = [idx_lookup[p.smiles] for p in edge.products]
e = {
'from': from_idx,
'to': to_indices,
}
# if edge.rule:
# e['rule'] = {
# 'name': edge.rule.name,
# 'id': edge.rule.url,
# }
edges.append(e)
return {
'nodes': nodes,
'edges': edges,
}
def graph_to_tree_string(self):
graph_json = self.to_json()
nodes = {node['id']: node for node in graph_json['nodes']}
edges = graph_json['edges']
children_map = {}
for edge in edges:
src = edge['from']
for tgt in edge['to']:
children_map.setdefault(src, []).append(tgt)
visited = set()
def recurse(node_id, prefix=''):
if node_id in visited:
return prefix + nodes[node_id]['smiles'] + " [loop detected]\n"
visited.add(node_id)
line = prefix + nodes[node_id]['smiles'] + f" [{node_id}]\n"
kids = children_map.get(node_id, [])
for i, kid in enumerate(kids):
if i == len(kids) - 1:
branch = '└── '
child_prefix = prefix + ' '
else:
branch = '├── '
child_prefix = prefix + ''
line += recurse(kid, prefix=prefix + branch)
return line
root_nodes = [n['id'] for n in graph_json['nodes'] if n['depth'] == 0]
result = ''
for root in root_nodes:
visited.clear()
result += recurse(root)
return result