Files
enviPy-bayer/epdb/logic.py
jebus 49e02ed97d feature/additional_information (#30)
Fixes #12

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#30
2025-07-19 08:10:40 +12:00

953 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import logging
from typing import Union, List, Optional, Set, Dict
from django.contrib.auth import get_user_model
from django.db import transaction
from django.conf import settings as s
from epdb.models import User, Package, UserPackagePermission, GroupPackagePermission, Permission, Group, Setting, \
EPModel, UserSettingPermission, Rule, Pathway, Node, Edge
logger = logging.getLogger(__name__)
class UserManager(object):
@staticmethod
def create_user(username, email, password, *args, **kwargs):
# avoid circular import :S
from .tasks import send_registration_mail
is_active = not s.ADMIN_APPROVAL_REQUIRED
if 'is_active' in kwargs:
is_active = kwargs['is_active']
u = get_user_model().objects.create_user(username, email, password, is_active=is_active)
# Create package
package_name = f"{u.username}{'' if u.username[-1] in 'sxzß' else 's'} Package"
package_description = f"This package was generated during registration."
p = PackageManager.create_package(u, package_name, package_description)
u.default_package = p
u.save()
if not u.is_active:
# send email for verification
send_registration_mail.delay(u.pk)
return u
@staticmethod
def get_user(user_url):
pass
@staticmethod
def get_user_lp(user_url: str):
uuid = user_url.strip().split('/')[-1]
return get_user_model().objects.get(uuid=uuid)
@staticmethod
def get_users():
return []
@staticmethod
def get_user_lp(user_url: str):
uuid = user_url.strip().split('/')[-1]
return get_user_model().objects.get(uuid=uuid)
class GroupManager(object):
@staticmethod
def create_group(current_user, name, description):
g = Group()
g.name = name
g.description = description
g.owner = current_user
g.save()
g.user_member.add(current_user)
g.save()
return g
@staticmethod
def get_group_lp(group_url: str):
uuid = group_url.strip().split('/')[-1]
return Group.objects.get(uuid=uuid)
@staticmethod
def get_group_by_url(user, group_url):
return GroupManager.get_group_by_id(user, group_url.split('/')[-1])
@staticmethod
def get_group_by_id(user, group_id):
g = Group.objects.get(uuid=group_id)
if user in g.user_member.all():
return g
return None
@staticmethod
def get_groups(user):
return Group.objects.filter(user_member=user)
@staticmethod
@transaction.atomic
def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str):
if caller != group.owner:
raise ValueError('Only the group Owner is allowed to add members!')
if isinstance(member, Group):
if add_or_remove == 'add':
group.group_member.add(member)
else:
group.group_member.remove(member)
else:
if add_or_remove == 'add':
group.user_member.add(member)
else:
group.user_member.remove(member)
group.save()
class PackageManager(object):
package_pattern = re.compile(r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
@staticmethod
def get_reviewed_packages():
return Package.objects.filter(reviewed=True)
@staticmethod
def readable(user, package):
# TODO Owner!
if UserPackagePermission.objects.filter(package=package, user=user).exists() or \
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user)) or \
package.reviewed is True or \
user.is_superuser:
return True
return False
@staticmethod
def writable(user, package):
# TODO Owner!
if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.WRITE).exists() or \
GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user),
permission=Permission.WRITE) or \
user.is_superuser:
return True
return False
@staticmethod
def get_package_by_url(user, package_url):
match = re.findall(PackageManager.package_pattern, package_url)
if match:
package_id = match[0].split('/')[-1]
return PackageManager.get_package_by_id(user, package_id)
else:
raise ValueError("Requested URL {} does not contain a valid package identifier!".format(package_url))
@staticmethod
def get_package_by_id(user, package_id):
try:
p = Package.objects.get(uuid=package_id)
if PackageManager.readable(user, p):
return p
else:
raise ValueError(
"Insufficient permissions to access Package with ID {}".format(package_id))
except Package.DoesNotExist:
raise ValueError("Package with ID {} does not exist!".format(package_id))
@staticmethod
def get_all_readable_packages(user, include_reviewed=False):
# UserPermission only exists if at least read is granted...
if user.is_superuser:
qs = Package.objects.all()
else:
user_package_qs = Package.objects.filter(
id__in=UserPackagePermission.objects.filter(user=user).values('package').distinct())
group_package_qs = Package.objects.filter(
id__in=GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user)).values(
'package').distinct())
qs = user_package_qs | group_package_qs
if include_reviewed:
qs |= Package.objects.filter(reviewed=True)
else:
# remove package if user is owner and package is reviewed e.g. admin
qs = qs.filter(reviewed=False)
return qs.distinct()
@staticmethod
def get_all_writeable_packages(user):
# UserPermission only exists if at least read is granted...
if user.is_superuser:
qs = Package.objects.all()
else:
write_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]).values('package').distinct()
owner_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]).values('package').distinct()
user_packs = write_user_packs | owner_user_packs
user_package_qs = Package.objects.filter(id__in=user_packs)
write_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).values( 'package').distinct()
owner_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).values( 'package').distinct()
group_packs = write_group_packs | owner_group_packs
group_package_qs = Package.objects.filter(id__in=group_packs)
qs = user_package_qs | group_package_qs
qs = qs.filter(reviewed=False)
return qs.distinct()
@staticmethod
def get_packages():
return Package.objects.all()
@staticmethod
@transaction.atomic
def create_package(current_user, name: str, description: str = None):
p = Package()
p.name = name
p.description = description
p.save()
up = UserPackagePermission()
up.user = current_user
up.package = p
up.permission = UserPackagePermission.ALL[0]
up.save()
return p
@staticmethod
@transaction.atomic
def update_permissions(caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]):
if not PackageManager.writable(caller, package):
raise ValueError(f"User {caller} is not allowed to modify permissions on {package}")
data = {
'package': package,
}
if isinstance(grantee, User):
perm_cls = UserPackagePermission
data['user'] = grantee
else:
perm_cls = GroupPackagePermission
data['group'] = grantee
if new_perm is None:
qs = perm_cls.objects.filter(**data)
if qs.count() > 1:
raise ValueError("Got more Permission objects than expected!")
if qs.count() != 0:
logger.info(f"Deleting Perm {qs.first()}")
qs.delete()
else:
logger.debug(f"No Permission object for {perm_cls} with filter {data} found!")
else:
_ = perm_cls.objects.update_or_create(defaults={'permission': new_perm}, **data)
@staticmethod
@transaction.atomic
def import_package(data: dict, owner: User, keep_ids=False):
from uuid import UUID, uuid4
from datetime import datetime
from collections import defaultdict
from .models import Package, Compound, CompoundStructure, SimpleRule, SimpleAmbitRule, SimpleRDKitRule, \
ParallelRule, SequentialRule, SequentialRuleOrdering, Reaction, Pathway, Node, Edge, Scenario
from envipy_additional_information import AdditionalInformationConverter
pack = Package()
pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4()
pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M'))
pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False
pack.description = data['description']
pack.save()
up = UserPackagePermission()
up.user = owner
up.package = pack
up.permission = up.ALL[0]
up.save()
# Stores old_id to new_id
mapping = {}
# Stores new_scen_id to old_parent_scen_id
parent_mapping = {}
# Mapping old scen_id to old_obj_id
scen_mapping = defaultdict(list)
# Store Scenarios
for scenario in data['scenarios']:
scen = Scenario()
scen.package = pack
scen.uuid = UUID(scenario['id'].split('/')[-1]) if keep_ids else uuid4()
scen.name = scenario['name']
scen.description = scenario['description']
scen.scenario_type = scenario['type']
scen.scenario_date = scenario['date']
scen.additional_information = dict()
scen.save()
mapping[scenario['id']] = scen.uuid
new_add_inf = defaultdict(list)
# TODO Store AI...
for ex in scenario.get('additionalInformationCollection', {}).get('additionalInformation', []):
name = ex['name']
addinf_data = ex['data']
# park the parent scen id for now and link it later
if name == 'referringscenario':
parent_mapping[scen.uuid] = addinf_data
continue
# Broken eP Data
if name == 'initialmasssediment' and addinf_data == 'missing data':
continue
# TODO Enzymes arent ready yet
if name == 'enzyme':
continue
try:
res = AdditionalInformationConverter.convert(name, addinf_data)
except:
logger.error(f"Failed to convert {name} with {addinf_data}")
new_add_inf[name].append(res.model_dump_json())
scen.additional_information = new_add_inf
scen.save()
print('Scenarios imported...')
# Store compounds and its structures
for compound in data['compounds']:
comp = Compound()
comp.package = pack
comp.uuid = UUID(compound['id'].split('/')[-1]) if keep_ids else uuid4()
comp.name = compound['name']
comp.description = compound['description']
comp.aliases = compound['aliases']
comp.save()
mapping[compound['id']] = comp.uuid
for scen in compound['scenarios']:
scen_mapping[scen['id']].append(comp)
default_structure = None
for structure in compound['structures']:
struc = CompoundStructure()
# struc.object_url = Command.get_id(structure, keep_ids)
struc.compound = comp
struc.uuid = UUID(structure['id'].split('/')[-1]) if keep_ids else uuid4()
struc.name = structure['name']
struc.description = structure['description']
struc.smiles = structure['smiles']
struc.save()
for scen in structure['scenarios']:
scen_mapping[scen['id']].append(struc)
mapping[structure['id']] = struc.uuid
if structure['id'] == compound['defaultStructure']['id']:
default_structure = struc
struc.save()
if default_structure is None:
raise ValueError('No default structure set')
comp.default_structure = default_structure
comp.save()
print('Compounds imported...')
# Store simple and parallel-rules
par_rules = []
seq_rules = []
for rule in data['rules']:
if rule['identifier'] == 'parallel-rule':
par_rules.append(rule)
continue
if rule['identifier'] == 'sequential-rule':
seq_rules.append(rule)
continue
r = SimpleAmbitRule()
r.uuid = UUID(rule['id'].split('/')[-1]) if keep_ids else uuid4()
r.package = pack
r.name = rule['name']
r.description = rule['description']
r.smirks = rule['smirks']
r.reactant_filter_smarts = rule.get('reactantFilterSmarts', None)
r.product_filter_smarts = rule.get('productFilterSmarts', None)
r.save()
mapping[rule['id']] = r.uuid
for scen in rule['scenarios']:
scen_mapping[scen['id']].append(r)
print("Par: ", len(par_rules))
print("Seq: ", len(seq_rules))
for par_rule in par_rules:
r = ParallelRule()
r.package = pack
r.uuid = UUID(par_rule['id'].split('/')[-1]) if keep_ids else uuid4()
r.name = par_rule['name']
r.description = par_rule['description']
r.save()
mapping[par_rule['id']] = r.uuid
for scen in par_rule['scenarios']:
scen_mapping[scen['id']].append(r)
for simple_rule in par_rule['simpleRules']:
if simple_rule['id'] in mapping:
r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
r.save()
for seq_rule in seq_rules:
r = SequentialRule()
r.package = pack
r.uuid = UUID(seq_rule['id'].split('/')[-1]) if keep_ids else uuid4()
r.name = seq_rule['name']
r.description = seq_rule['description']
r.save()
mapping[seq_rule['id']] = r.uuid
for scen in seq_rule['scenarios']:
scen_mapping[scen['id']].append(r)
for i, simple_rule in enumerate(seq_rule['simpleRules']):
sro = SequentialRuleOrdering()
sro.simple_rule = simple_rule
sro.sequential_rule = r
sro.order_index = i
sro.save()
# r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']]))
r.save()
print('Rules imported...')
for reaction in data['reactions']:
r = Reaction()
r.package = pack
r.uuid = UUID(reaction['id'].split('/')[-1]) if keep_ids else uuid4()
r.name = reaction['name']
r.description = reaction['description']
r.medlinereferences = reaction['medlinereferences'],
r.multi_step = True if reaction['multistep'] == 'true' else False
r.save()
mapping[reaction['id']] = r.uuid
for scen in reaction['scenarios']:
scen_mapping[scen['id']].append(r)
for educt in reaction['educts']:
r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt['id']]))
for product in reaction['products']:
r.products.add(CompoundStructure.objects.get(uuid=mapping[product['id']]))
if 'rules' in reaction:
for rule in reaction['rules']:
try:
r.rules.add(Rule.objects.get(uuid=mapping[rule['id']]))
except Exception as e:
print(f"Rule with id {rule['id']} not found!")
print(e)
r.save()
print('Reactions imported...')
for pathway in data['pathways']:
pw = Pathway()
pw.package = pack
pw.uuid = UUID(pathway['id'].split('/')[-1]) if keep_ids else uuid4()
pw.name = pathway['name']
pw.description = pathway['description']
pw.save()
mapping[pathway['id']] = pw.uuid
for scen in pathway['scenarios']:
scen_mapping[scen['id']].append(pw)
out_nodes_mapping = defaultdict(set)
root_node = None
for node in pathway['nodes']:
n = Node()
n.uuid = UUID(node['id'].split('/')[-1]) if keep_ids else uuid4()
n.name = node['name']
n.pathway = pw
n.depth = node['depth']
n.default_node_label = CompoundStructure.objects.get(uuid=mapping[node['defaultNodeLabel']['id']])
n.save()
mapping[node['id']] = n.uuid
for scen in node['scenarios']:
scen_mapping[scen['id']].append(n)
for node_label in node['nodeLabels']:
n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label['id']]))
n.save()
for out_edge in node['outEdges']:
out_nodes_mapping[n.uuid].add(out_edge)
for edge in pathway['edges']:
e = Edge()
e.uuid = UUID(edge['id'].split('/')[-1]) if keep_ids else uuid4()
e.name = edge['name']
e.pathway = pw
e.description = edge['description']
e.edge_label = Reaction.objects.get(uuid=mapping[edge['edgeLabel']['id']])
e.save()
mapping[edge['id']] = e.uuid
for scen in edge['scenarios']:
scen_mapping[scen['id']].append(e)
for start_node in edge['startNodes']:
e.start_nodes.add(Node.objects.get(uuid=mapping[start_node]))
for end_node in edge['endNodes']:
e.end_nodes.add(Node.objects.get(uuid=mapping[end_node]))
e.save()
for k, v in out_nodes_mapping.items():
n = Node.objects.get(uuid=k)
for v1 in v:
n.out_edges.add(Edge.objects.get(uuid=mapping[v1]))
n.save()
print('Pathways imported...')
# Linking Phase
for child, parent in parent_mapping.items():
child_obj = Scenario.objects.get(uuid=child)
parent_obj = Scenario.objects.get(uuid=mapping[parent])
child_obj.parent = parent_obj
child_obj.save()
for scen_id, objects in scen_mapping.items():
scen = Scenario.objects.get(uuid=mapping[scen_id])
for o in objects:
o.scenarios.add(scen)
o.save()
print("Scenarios linked...")
print('Import statistics:')
print('Package {} stored'.format(pack.url))
print('Imported {} compounds'.format(Compound.objects.filter(package=pack).count()))
print('Imported {} rules'.format(Rule.objects.filter(package=pack).count()))
print('Imported {} reactions'.format(Reaction.objects.filter(package=pack).count()))
print('Imported {} pathways'.format(Pathway.objects.filter(package=pack).count()))
print('Imported {} Scenarios'.format(Scenario.objects.filter(package=pack).count()))
print("Fixing Node depths...")
total_pws = Pathway.objects.filter(package=pack).count()
for p, pw in enumerate(Pathway.objects.filter(package=pack)):
print(pw.url)
in_count = defaultdict(lambda: 0)
out_count = defaultdict(lambda: 0)
for e in pw.edges:
# TODO check if this will remain
for react in e.start_nodes.all():
out_count[str(react.uuid)] += 1
for prod in e.end_nodes.all():
in_count[str(prod.uuid)] += 1
root_nodes = []
for n in pw.nodes:
num_parents = in_count[str(n.uuid)]
if num_parents == 0:
# must be a root node or unconnected node
if n.depth != 0:
n.depth = 0
n.save()
# Only root node may have children
if out_count[str(n.uuid)] > 0:
root_nodes.append(n)
levels = [root_nodes]
seen = set()
# Do a bfs to determine depths starting with level 0 a.k.a. root nodes
for i, level_nodes in enumerate(levels):
new_level = []
for n in level_nodes:
for e in n.out_edges.all():
for prod in e.end_nodes.all():
if str(prod.uuid) not in seen:
old_depth = prod.depth
if old_depth != i + 1:
print(f'updating depth from {old_depth} to {i + 1}')
prod.depth = i + 1
prod.save()
new_level.append(prod)
seen.add(str(n.uuid))
if new_level:
levels.append(new_level)
print(f'{p + 1}/{total_pws} fixed.')
return pack
class SettingManager(object):
setting_pattern = re.compile(r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$")
@staticmethod
def get_setting_by_url(user, setting_url):
match = re.findall(SettingManager.setting_pattern, setting_url)
if match:
setting_id = match[0].split('/')[-1]
return SettingManager.get_setting_by_id(user, setting_id)
else:
raise ValueError("Requested URL {} does not contain a valid setting identifier!".format(setting_url))
@staticmethod
def get_setting_by_id(user, setting_id):
s = Setting.objects.get(uuid=setting_id)
if s.global_default or s.public or s.owner == user or user.is_superuser:
return s
raise ValueError(
"Insufficient permissions to access Setting with ID {}".format(setting_id))
@staticmethod
def get_all_settings(user):
sp = UserSettingPermission.objects.filter(user=user).values('setting').distinct()
return Setting.objects.filter(id__in=sp)
@staticmethod
@transaction.atomic
def create_setting(user: User, name: str = None, description: str = None, max_nodes: int = None,
max_depth: int = None, rule_packages: List[Package] = None, model: EPModel = None,
model_threshold: float = None):
s = Setting()
s.name = name
s.description = description
s.max_nodes = max_nodes
s.max_depth = max_depth
s.model = model
s.model_threshold = model_threshold
s.save()
if rule_packages is not None:
for r in rule_packages:
s.rule_packages.add(r)
s.save()
usp = UserSettingPermission()
usp.user = user
usp.setting = s
usp.permission = Permission.ALL[0]
usp.save()
return s
@staticmethod
def get_default_setting(user: User):
pass
@staticmethod
@transaction.atomic
def set_default_setting(user: User, setting: Setting):
pass
class SNode(object):
def __init__(self, smiles: str, depth: int):
self.smiles = smiles
self.depth = depth
def __hash__(self):
return hash(self.smiles)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.smiles == other.smiles
return False
def __repr__(self):
return f"SNode('{self.smiles}', {self.depth})"
class SEdge(object):
def __init__(self, educts: Union[SNode, List[SNode]], products: Union[SNode | List[SNode]],
rule: Optional['Rule'] = None):
if not isinstance(educts, list):
educts = [educts]
self.educts = educts
self.products = products
self.rule = rule
def __hash__(self):
full_hash = 0
for n in sorted(self.educts, key=lambda x: x.smiles):
full_hash += hash(n)
for n in sorted(self.products, key=lambda x: x.smiles):
full_hash += hash(n)
if self.rule is not None:
full_hash += hash(self.rule)
return full_hash
def __eq__(self, other):
if not isinstance(other, SEdge):
return False
if self.rule is not None and other.rule is None or \
self.rule is None and other.rule is not None or \
self.rule != other.rule:
return False
if not (len(self.educts) == len(other.educts)):
return False
for n1, n2 in zip(sorted(self.educts, key=lambda x: x.smiles), sorted(other.educts, key=lambda x: x.smiles)):
if n1.smiles != n2.smiles:
return False
if not (len(self.products) == len(other.products)):
return False
for n1, n2 in zip(sorted(self.products, key=lambda x: x.smiles),
sorted(other.products, key=lambda x: x.smiles)):
if n1.smiles != n2.smiles:
return False
return True
def __repr__(self):
return f"SEdge({self.educts}, {self.products}, {self.rule})"
class SPathway(object):
def __init__(self, root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None,
persist: Optional['Pathway'] = None, prediction_setting: Optional[Setting] = None
):
self.root_nodes = []
self.persist = persist
self.snode_persist_lookup: Dict[SNode, Node] = dict()
self.sedge_persist_lookup: Dict[SEdge, Edge] = dict()
self.prediction_setting = prediction_setting
if persist:
for n in persist.root_nodes:
snode = SNode(n.default_node_label.smiles, n.depth)
self.root_nodes.append(snode)
self.snode_persist_lookup[snode] = n
else:
if not isinstance(root_nodes, list):
root_nodes = [root_nodes]
for n in root_nodes:
if isinstance(n, str):
self.root_nodes.append(SNode(n, 0))
elif isinstance(n, SNode):
self.root_nodes.append(n)
self.queue = list()
self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes})
self.edges: Set['SEdge'] = set()
self.done = False
def num_nodes(self):
return len(self.smiles_to_node.keys())
def depth(self):
return max([v.depth for v in self.smiles_to_node.values()])
def _get_nodes_for_depth(self, depth: int):
if depth == 0:
return self.root_nodes
res = []
for n in self.smiles_to_node.values():
if n.depth == depth:
res.append(n)
return sorted(res, key=lambda x: x.smiles)
def _get_edges_for_depth(self, depth: int):
res = []
for e in self.edges:
for n in e.educts:
if n.depth == depth:
res.append(e)
return sorted(res, key=lambda x: hash(x))
def predict_step(self, from_depth: int = 0):
substrates = self._get_nodes_for_depth(from_depth)
new_tp = False
if substrates:
for sub in substrates:
candidates = self.prediction_setting.expand(self, sub)
for cand_set in candidates:
if cand_set:
new_tp = True
for cand in cand_set:
cand_nodes = []
for c in cand:
if c not in self.smiles_to_node:
self.smiles_to_node[c] = SNode(c, sub.depth + 1)
node = self.smiles_to_node[c]
cand_nodes.append(node)
edge = SEdge(sub, cand_nodes, cand_set.rule)
self.edges.add(edge)
else:
self.done = True
if new_tp and self.persist:
self._sync_to_pathway()
def _sync_to_pathway(self):
logger.info("Updating Pathway with SPathway")
for snode in self.smiles_to_node.values():
if snode not in self.snode_persist_lookup:
n = Node.create(self.persist, snode.smiles, snode.depth)
self.snode_persist_lookup[snode] = n
for sedge in self.edges:
if sedge not in self.sedge_persist_lookup:
educt_nodes = []
for snode in sedge.educts:
educt_nodes.append(self.snode_persist_lookup[snode])
product_nodes = []
for snode in sedge.products:
product_nodes.append(self.snode_persist_lookup[snode])
e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule)
self.sedge_persist_lookup[sedge] = e
logger.info("Update done!")
pass
def to_json(self):
nodes = []
edges = []
idx_lookup = {}
for i, s in enumerate(self.smiles_to_node):
n = self.smiles_to_node[s]
idx_lookup[s] = i
nodes.append({'depth': n.depth, 'smiles': n.smiles, 'id': i})
for edge in self.edges:
from_idx = idx_lookup[edge.educts[0].smiles]
to_indices = [idx_lookup[p.smiles] for p in edge.products]
e = {
'from': from_idx,
'to': to_indices,
}
# if edge.rule:
# e['rule'] = {
# 'name': edge.rule.name,
# 'id': edge.rule.url,
# }
edges.append(e)
return {
'nodes': nodes,
'edges': edges,
}
def graph_to_tree_string(self):
graph_json = self.to_json()
nodes = {node['id']: node for node in graph_json['nodes']}
edges = graph_json['edges']
children_map = {}
for edge in edges:
src = edge['from']
for tgt in edge['to']:
children_map.setdefault(src, []).append(tgt)
visited = set()
def recurse(node_id, prefix=''):
if node_id in visited:
return prefix + nodes[node_id]['smiles'] + " [loop detected]\n"
visited.add(node_id)
line = prefix + nodes[node_id]['smiles'] + f" [{node_id}]\n"
kids = children_map.get(node_id, [])
for i, kid in enumerate(kids):
if i == len(kids) - 1:
branch = '└── '
child_prefix = prefix + ' '
else:
branch = '├── '
child_prefix = prefix + ''
line += recurse(kid, prefix=prefix + branch)
return line
root_nodes = [n['id'] for n in graph_json['nodes'] if n['depth'] == 0]
result = ''
for root in root_nodes:
visited.clear()
result += recurse(root)
return result