forked from enviPath/enviPy
572 lines
18 KiB
Python
572 lines
18 KiB
Python
import re
|
||
import logging
|
||
from typing import Union, List, Optional, Set, Dict
|
||
|
||
from django.contrib.auth import get_user_model
|
||
from django.db import transaction
|
||
|
||
from epdb.models import User, Package, UserPackagePermission, GroupPackagePermission, Permission, Group, Setting, \
|
||
EPModel, UserSettingPermission, Rule, Pathway, Node, Edge
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class UserManager(object):
|
||
@staticmethod
|
||
def create_user(username, email, password):
|
||
# avoid circular import :S
|
||
from .tasks import send_registration_mail
|
||
# TODO flip to False
|
||
u = get_user_model().objects.create_user(username, email, password, is_active=True)
|
||
|
||
# Create package
|
||
package_name = f"{u.username}{'’' if u.username[-1] in 'sxzß' else 's'} Package"
|
||
package_description = f"This package was generated during registration."
|
||
p = PackageManager.create_package(u, package_name, package_description)
|
||
u.default_package = p
|
||
u.save()
|
||
|
||
if not u.is_active:
|
||
# send email for verification
|
||
send_registration_mail.delay(u.pk)
|
||
|
||
return u
|
||
|
||
@staticmethod
|
||
def get_user(user_url):
|
||
pass
|
||
|
||
@staticmethod
|
||
def get_user_lp(user_url: str):
|
||
uuid = user_url.strip().split('/')[-1]
|
||
return get_user_model().objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_users():
|
||
return []
|
||
|
||
@staticmethod
|
||
def get_user_lp(user_url: str):
|
||
uuid = user_url.strip().split('/')[-1]
|
||
return get_user_model().objects.get(uuid=uuid)
|
||
|
||
class GroupManager(object):
|
||
|
||
@staticmethod
|
||
def create_group(current_user, name, description):
|
||
g = Group()
|
||
g.name = name
|
||
g.description = description
|
||
g.owner = current_user
|
||
g.save()
|
||
|
||
g.user_member.add(current_user)
|
||
g.save()
|
||
|
||
return g
|
||
|
||
@staticmethod
|
||
def get_group_lp(group_url: str):
|
||
uuid = group_url.strip().split('/')[-1]
|
||
return Group.objects.get(uuid=uuid)
|
||
|
||
@staticmethod
|
||
def get_group_by_url(user, group_url):
|
||
return GroupManager.get_group_by_id(user, group_url.split('/')[-1])
|
||
|
||
@staticmethod
|
||
def get_group_by_id(user, group_id):
|
||
g = Group.objects.get(uuid=group_id)
|
||
|
||
if user in g.user_member.all():
|
||
return g
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_groups(user):
|
||
return Group.objects.filter(user_member=user)
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
|
||
|
||
if caller != group.owner:
|
||
raise ValueError('Only the group Owner is allowed to add members!')
|
||
|
||
if isinstance(member, Group):
|
||
if add_or_remove == 'add':
|
||
group.group_member.add(member)
|
||
else:
|
||
group.group_member.remove(member)
|
||
else:
|
||
if add_or_remove == 'add':
|
||
group.user_member.add(member)
|
||
else:
|
||
group.user_member.remove(member)
|
||
|
||
group.save()
|
||
|
||
|
||
class PackageManager(object):
|
||
package_pattern = re.compile(r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
|
||
|
||
@staticmethod
|
||
def get_reviewed_packages():
|
||
return Package.objects.filter(reviewed=True)
|
||
|
||
@staticmethod
|
||
def readable(user, package):
|
||
# TODO Owner!
|
||
if UserPackagePermission.objects.filter(package=package, user=user).exists() or \
|
||
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user)) or \
|
||
package.reviewed is True or \
|
||
user.is_superuser:
|
||
return True
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def writable(user, package):
|
||
# TODO Owner!
|
||
if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.WRITE).exists() or \
|
||
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user),
|
||
permission=Permission.WRITE) or \
|
||
user.is_superuser:
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def get_package_by_url(user, package_url):
|
||
match = re.findall(PackageManager.package_pattern, package_url)
|
||
|
||
if match:
|
||
package_id = match[0].split('/')[-1]
|
||
return PackageManager.get_package_by_id(user, package_id)
|
||
else:
|
||
raise ValueError("Requested URL {} does not contain a valid package identifier!".format(package_url))
|
||
|
||
@staticmethod
|
||
def get_package_by_id(user, package_id):
|
||
try:
|
||
p = Package.objects.get(uuid=package_id)
|
||
if PackageManager.readable(user, p):
|
||
return p
|
||
else:
|
||
raise ValueError(
|
||
"Insufficient permissions to access Package with ID {}".format(package_id))
|
||
except Package.DoesNotExist:
|
||
raise ValueError("Package with ID {} does not exist!".format(package_id))
|
||
|
||
@staticmethod
|
||
def get_all_readable_packages(user, include_reviewed=False):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
user_package_qs = Package.objects.filter(
|
||
id__in=UserPackagePermission.objects.filter(user=user).values('package').distinct())
|
||
group_package_qs = Package.objects.filter(
|
||
id__in=GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user)).values(
|
||
'package').distinct())
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
if include_reviewed:
|
||
qs |= Package.objects.filter(reviewed=True)
|
||
else:
|
||
# remove package if user is owner and package is reviewed e.g. admin
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
return qs.distinct()
|
||
|
||
@staticmethod
|
||
def get_all_writeable_packages(user):
|
||
# UserPermission only exists if at least read is granted...
|
||
if user.is_superuser:
|
||
qs = Package.objects.all()
|
||
else:
|
||
write_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]).values('package').distinct()
|
||
owner_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]).values('package').distinct()
|
||
|
||
user_packs = write_user_packs | owner_user_packs
|
||
user_package_qs = Package.objects.filter(id__in=user_packs)
|
||
|
||
write_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).values( 'package').distinct()
|
||
owner_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).values( 'package').distinct()
|
||
|
||
group_packs = write_group_packs | owner_group_packs
|
||
group_package_qs = Package.objects.filter(id__in=group_packs)
|
||
|
||
qs = user_package_qs | group_package_qs
|
||
|
||
qs = qs.filter(reviewed=False)
|
||
|
||
return qs.distinct()
|
||
|
||
@staticmethod
|
||
def get_packages():
|
||
return Package.objects.all()
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_package(current_user, name: str, description: str = None):
|
||
p = Package()
|
||
p.name = name
|
||
p.description = description
|
||
p.save()
|
||
|
||
up = UserPackagePermission()
|
||
up.user = current_user
|
||
up.package = p
|
||
up.permission = UserPackagePermission.ALL[0]
|
||
up.save()
|
||
|
||
return p
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def update_permissions(caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]):
|
||
if not PackageManager.writable(caller, package):
|
||
raise ValueError(f"User {caller} is not allowed to modify permissions on {package}")
|
||
|
||
data = {
|
||
'package': package,
|
||
}
|
||
|
||
if isinstance(grantee, User):
|
||
perm_cls = UserPackagePermission
|
||
data['user'] = grantee
|
||
else:
|
||
perm_cls = GroupPackagePermission
|
||
data['group'] = grantee
|
||
|
||
if new_perm is None:
|
||
qs = perm_cls.objects.filter(**data)
|
||
if qs.count() > 1:
|
||
raise ValueError("Got more Permission objects than expected!")
|
||
if qs.count() != 0:
|
||
logger.info(f"Deleting Perm {qs.first()}")
|
||
qs.delete()
|
||
else:
|
||
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
|
||
else:
|
||
_ = perm_cls.objects.update_or_create(defaults={'permission': new_perm}, **data)
|
||
|
||
class SettingManager(object):
|
||
setting_pattern = re.compile(r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
|
||
|
||
@staticmethod
|
||
def get_setting_by_url(user, setting_url):
|
||
match = re.findall(SettingManager.setting_pattern, setting_url)
|
||
|
||
if match:
|
||
setting_id = match[0].split('/')[-1]
|
||
return SettingManager.get_setting_by_id(user, setting_id)
|
||
else:
|
||
raise ValueError("Requested URL {} does not contain a valid setting identifier!".format(setting_url))
|
||
|
||
@staticmethod
|
||
def get_setting_by_id(user, setting_id):
|
||
s = Setting.objects.get(uuid=setting_id)
|
||
|
||
if s.global_default or s.public or s.owner == user or user.is_superuser:
|
||
return s
|
||
|
||
raise ValueError(
|
||
"Insufficient permissions to access Setting with ID {}".format(setting_id))
|
||
|
||
@staticmethod
|
||
def get_all_settings(user):
|
||
sp = UserSettingPermission.objects.filter(user=user).values('setting').distinct()
|
||
return Setting.objects.filter(id__in=sp)
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def create_setting(user: User, name: str = None, description: str = None, max_nodes: int = None,
|
||
max_depth: int = None, rule_packages: List[Package] = None, model: EPModel = None,
|
||
model_threshold: float = None):
|
||
s = Setting()
|
||
s.name = name
|
||
s.description = description
|
||
s.max_nodes = max_nodes
|
||
s.max_depth = max_depth
|
||
s.model = model
|
||
s.model_threshold = model_threshold
|
||
|
||
s.save()
|
||
|
||
if rule_packages is not None:
|
||
for r in rule_packages:
|
||
s.rule_packages.add(r)
|
||
s.save()
|
||
|
||
usp = UserSettingPermission()
|
||
usp.user = user
|
||
usp.setting = s
|
||
usp.permission = Permission.ALL[0]
|
||
usp.save()
|
||
|
||
return s
|
||
|
||
@staticmethod
|
||
def get_default_setting(user: User):
|
||
|
||
pass
|
||
|
||
@staticmethod
|
||
@transaction.atomic
|
||
def set_default_setting(user: User, setting: Setting):
|
||
pass
|
||
|
||
|
||
class SNode(object):
|
||
|
||
def __init__(self, smiles: str, depth: int):
|
||
self.smiles = smiles
|
||
self.depth = depth
|
||
|
||
def __hash__(self):
|
||
return hash(self.smiles)
|
||
|
||
def __eq__(self, other):
|
||
if isinstance(other, self.__class__):
|
||
return self.smiles == other.smiles
|
||
return False
|
||
|
||
def __repr__(self):
|
||
return f"SNode('{self.smiles}', {self.depth})"
|
||
|
||
|
||
class SEdge(object):
|
||
|
||
def __init__(self, educts: Union[SNode, List[SNode]], products: Union[SNode | List[SNode]],
|
||
rule: Optional['Rule'] = None):
|
||
|
||
if not isinstance(educts, list):
|
||
educts = [educts]
|
||
|
||
self.educts = educts
|
||
self.products = products
|
||
self.rule = rule
|
||
|
||
def __hash__(self):
|
||
full_hash = 0
|
||
|
||
for n in sorted(self.educts, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
for n in sorted(self.products, key=lambda x: x.smiles):
|
||
full_hash += hash(n)
|
||
|
||
if self.rule is not None:
|
||
full_hash += hash(self.rule)
|
||
|
||
return full_hash
|
||
|
||
def __eq__(self, other):
|
||
if not isinstance(other, SEdge):
|
||
return False
|
||
|
||
if self.rule is not None and other.rule is None or \
|
||
self.rule is None and other.rule is not None or \
|
||
self.rule != other.rule:
|
||
return False
|
||
|
||
if not (len(self.educts) == len(other.educts)):
|
||
return False
|
||
|
||
for n1, n2 in zip(sorted(self.educts, key=lambda x: x.smiles), sorted(other.educts, key=lambda x: x.smiles)):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
if not (len(self.products) == len(other.products)):
|
||
return False
|
||
|
||
for n1, n2 in zip(sorted(self.products, key=lambda x: x.smiles),
|
||
sorted(other.products, key=lambda x: x.smiles)):
|
||
if n1.smiles != n2.smiles:
|
||
return False
|
||
|
||
return True
|
||
|
||
def __repr__(self):
|
||
return f"SEdge({self.educts}, {self.products}, {self.rule})"
|
||
|
||
|
||
class SPathway(object):
|
||
|
||
def __init__(self, root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
|
||
persist: Optional['Pathway'] = None, prediction_setting: Optional[Setting] = None
|
||
):
|
||
self.root_nodes = []
|
||
|
||
self.persist = persist
|
||
self.snode_persist_lookup: Dict[SNode, Node] = dict()
|
||
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
|
||
self.prediction_setting = prediction_setting
|
||
|
||
if persist:
|
||
for n in persist.root_nodes:
|
||
snode = SNode(n.default_node_label.smiles, n.depth)
|
||
self.root_nodes.append(snode)
|
||
self.snode_persist_lookup[snode] = n
|
||
else:
|
||
if not isinstance(root_nodes, list):
|
||
root_nodes = [root_nodes]
|
||
|
||
for n in root_nodes:
|
||
if isinstance(n, str):
|
||
self.root_nodes.append(SNode(n, 0))
|
||
elif isinstance(n, SNode):
|
||
self.root_nodes.append(n)
|
||
|
||
self.queue = list()
|
||
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
|
||
self.edges: Set['SEdge'] = set()
|
||
self.done = False
|
||
|
||
def num_nodes(self):
|
||
return len(self.smiles_to_node.keys())
|
||
|
||
def depth(self):
|
||
return max([v.depth for v in self.smiles_to_node.values()])
|
||
|
||
def _get_nodes_for_depth(self, depth: int):
|
||
if depth == 0:
|
||
return self.root_nodes
|
||
|
||
res = []
|
||
for n in self.smiles_to_node.values():
|
||
if n.depth == depth:
|
||
res.append(n)
|
||
|
||
return sorted(res, key=lambda x: x.smiles)
|
||
|
||
def _get_edges_for_depth(self, depth: int):
|
||
res = []
|
||
for e in self.edges:
|
||
for n in e.educts:
|
||
if n.depth == depth:
|
||
res.append(e)
|
||
|
||
return sorted(res, key=lambda x: hash(x))
|
||
|
||
def predict_step(self, from_depth: int = 0):
|
||
substrates = self._get_nodes_for_depth(from_depth)
|
||
|
||
new_tp = False
|
||
if substrates:
|
||
for sub in substrates:
|
||
candidates = self.prediction_setting.expand(self, sub)
|
||
for cand_set in candidates:
|
||
if cand_set:
|
||
new_tp = True
|
||
for cand in cand_set:
|
||
cand_nodes = []
|
||
for c in cand:
|
||
if c not in self.smiles_to_node:
|
||
self.smiles_to_node[c] = SNode(c, sub.depth + 1)
|
||
|
||
node = self.smiles_to_node[c]
|
||
cand_nodes.append(node)
|
||
|
||
edge = SEdge(sub, cand_nodes, cand_set.rule)
|
||
self.edges.add(edge)
|
||
else:
|
||
self.done = True
|
||
|
||
if new_tp and self.persist:
|
||
self._sync_to_pathway()
|
||
|
||
def _sync_to_pathway(self):
|
||
logger.info("Updating Pathway with SPathway")
|
||
|
||
for snode in self.smiles_to_node.values():
|
||
if snode not in self.snode_persist_lookup:
|
||
n = Node.create(self.persist, snode.smiles, snode.depth)
|
||
self.snode_persist_lookup[snode] = n
|
||
|
||
for sedge in self.edges:
|
||
if sedge not in self.sedge_persist_lookup:
|
||
educt_nodes = []
|
||
for snode in sedge.educts:
|
||
educt_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
product_nodes = []
|
||
for snode in sedge.products:
|
||
product_nodes.append(self.snode_persist_lookup[snode])
|
||
|
||
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
|
||
self.sedge_persist_lookup[sedge] = e
|
||
|
||
logger.info("Update done!")
|
||
pass
|
||
|
||
def to_json(self):
|
||
nodes = []
|
||
edges = []
|
||
|
||
idx_lookup = {}
|
||
|
||
for i, s in enumerate(self.smiles_to_node):
|
||
n = self.smiles_to_node[s]
|
||
idx_lookup[s] = i
|
||
|
||
nodes.append({'depth': n.depth, 'smiles': n.smiles, 'id': i})
|
||
|
||
for edge in self.edges:
|
||
from_idx = idx_lookup[edge.educts[0].smiles]
|
||
to_indices = [idx_lookup[p.smiles] for p in edge.products]
|
||
|
||
e = {
|
||
'from': from_idx,
|
||
'to': to_indices,
|
||
}
|
||
|
||
# if edge.rule:
|
||
# e['rule'] = {
|
||
# 'name': edge.rule.name,
|
||
# 'id': edge.rule.url,
|
||
# }
|
||
edges.append(e)
|
||
|
||
return {
|
||
'nodes': nodes,
|
||
'edges': edges,
|
||
}
|
||
|
||
def graph_to_tree_string(self):
|
||
graph_json = self.to_json()
|
||
nodes = {node['id']: node for node in graph_json['nodes']}
|
||
edges = graph_json['edges']
|
||
|
||
children_map = {}
|
||
for edge in edges:
|
||
src = edge['from']
|
||
for tgt in edge['to']:
|
||
children_map.setdefault(src, []).append(tgt)
|
||
|
||
visited = set()
|
||
|
||
def recurse(node_id, prefix=''):
|
||
if node_id in visited:
|
||
return prefix + nodes[node_id]['smiles'] + " [loop detected]\n"
|
||
visited.add(node_id)
|
||
|
||
line = prefix + nodes[node_id]['smiles'] + f" [{node_id}]\n"
|
||
kids = children_map.get(node_id, [])
|
||
for i, kid in enumerate(kids):
|
||
if i == len(kids) - 1:
|
||
branch = '└── '
|
||
child_prefix = prefix + ' '
|
||
else:
|
||
branch = '├── '
|
||
child_prefix = prefix + '│ '
|
||
line += recurse(kid, prefix=prefix + branch)
|
||
return line
|
||
|
||
root_nodes = [n['id'] for n in graph_json['nodes'] if n['depth'] == 0]
|
||
result = ''
|
||
for root in root_nodes:
|
||
visited.clear()
|
||
result += recurse(root)
|
||
return result
|