forked from enviPath/enviPy
[Feature] Engineer Pathway (#256)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#256
This commit is contained in:
@ -9,7 +9,7 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from types import NoneType
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
from django.conf import settings as s
|
||||
from django.db import transaction
|
||||
@ -35,6 +35,7 @@ from epdb.models import (
|
||||
RuleBasedRelativeReasoning,
|
||||
Scenario,
|
||||
SequentialRule,
|
||||
Setting,
|
||||
SimpleAmbitRule,
|
||||
SimpleRDKitRule,
|
||||
SimpleRule,
|
||||
@ -44,6 +45,9 @@ from utilities.chem import FormatConverter
|
||||
logger = logging.getLogger(__name__)
|
||||
Package = s.GET_PACKAGE_MODEL()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from epdb.logic import SPathway
|
||||
|
||||
|
||||
class HTMLGenerator:
|
||||
registry = {x.__name__: x for x in NAME_MAPPING.values()}
|
||||
@ -1260,3 +1264,122 @@ class PathwayUtils:
|
||||
res[edge.url] = rule_chain
|
||||
|
||||
return res
|
||||
|
||||
def engineer(self, setting: "Setting"):
|
||||
from epdb.logic import SPathway
|
||||
from utilities.chem import FormatConverter
|
||||
from utilities.ml import graph_from_pathway, get_shortest_path
|
||||
|
||||
# get a fresh copy
|
||||
pw = Pathway.objects.get(id=self.pathway.pk)
|
||||
|
||||
root_nodes = [n.default_node_label.smiles for n in pw.root_nodes]
|
||||
|
||||
if len(root_nodes) != 1:
|
||||
logger.warning(f"Pathway {pw.name} has {len(root_nodes)} root nodes")
|
||||
# spw, mapping, intermediates
|
||||
return None, {}, []
|
||||
|
||||
# Predict the Pathway in memory
|
||||
spw = SPathway(root_nodes[0], None, setting)
|
||||
|
||||
level = 0
|
||||
while not spw.done:
|
||||
spw.predict_step(from_depth=level)
|
||||
level += 1
|
||||
|
||||
# Generate SNode -> Node mapping
|
||||
node_mapping = {}
|
||||
|
||||
for node in pw.nodes:
|
||||
for snode in spw.smiles_to_node.values():
|
||||
data_smiles = node.default_node_label.smiles
|
||||
pred_smiles = snode.smiles
|
||||
|
||||
# "~" denotes any bond remove and use implicit single bond for comparison
|
||||
data_key = FormatConverter.InChIKey(data_smiles.replace("~", ""))
|
||||
pred_key = FormatConverter.InChIKey(pred_smiles.replace("~", ""))
|
||||
|
||||
if data_key == pred_key:
|
||||
node_mapping[snode] = node
|
||||
|
||||
reverse_mapping = {v: k for k, v in node_mapping.items()}
|
||||
|
||||
graph = graph_from_pathway(spw)
|
||||
|
||||
intermediate_mapping = []
|
||||
|
||||
# loop through each edge and each reactant <-> product pair
|
||||
# and compute the shortest path on the predicted pathway
|
||||
for e in pw.edges:
|
||||
for start in e.start_nodes.all():
|
||||
if start not in reverse_mapping:
|
||||
continue
|
||||
|
||||
start_snode = reverse_mapping[start]
|
||||
|
||||
for end in e.end_nodes.all():
|
||||
if end not in reverse_mapping:
|
||||
continue
|
||||
|
||||
end_snode = reverse_mapping[end]
|
||||
|
||||
# If res is non-empty, we've found intermediates
|
||||
intermediate_smiles = get_shortest_path(
|
||||
graph,
|
||||
FormatConverter.standardize(start_snode.smiles, remove_stereo=True),
|
||||
FormatConverter.standardize(end_snode.smiles, remove_stereo=True),
|
||||
)
|
||||
|
||||
if intermediate_smiles:
|
||||
intermediates = []
|
||||
|
||||
prev = start_snode.smiles
|
||||
|
||||
for smi in intermediate_smiles + [end_snode.smiles]:
|
||||
for e in spw.get_edge_for_educt_smiles(prev):
|
||||
if smi in e.product_smiles():
|
||||
intermediates.append(e)
|
||||
|
||||
prev = smi
|
||||
|
||||
intermediate_mapping.append(
|
||||
(start, end, start_snode, end_snode, intermediates)
|
||||
)
|
||||
|
||||
return spw, reverse_mapping, intermediate_mapping
|
||||
|
||||
@staticmethod
|
||||
def spathway_to_pathway(
|
||||
package: "Package", spw: "SPathway", name: str = None, description: str = None
|
||||
):
|
||||
snode_to_node_mapping = dict()
|
||||
|
||||
root_nodes = spw.root_nodes
|
||||
|
||||
pw = Pathway.create(
|
||||
package=package,
|
||||
smiles=root_nodes[0].smiles,
|
||||
name=name,
|
||||
description=description,
|
||||
predicted=True,
|
||||
)
|
||||
|
||||
pw.setting = spw.prediction_setting
|
||||
pw.save()
|
||||
|
||||
snode_to_node_mapping[root_nodes[0]] = pw.root_nodes[0]
|
||||
|
||||
if len(root_nodes) > 1:
|
||||
for rn in root_nodes[1:]:
|
||||
n = Node.create(pw, rn.smiles, depth=0)
|
||||
snode_to_node_mapping[rn] = n
|
||||
|
||||
for snode, node in snode_to_node_mapping.items():
|
||||
spw.snode_persist_lookup[snode] = node
|
||||
|
||||
spw.persist = pw
|
||||
|
||||
spw._sync_to_pathway()
|
||||
|
||||
return pw
|
||||
|
||||
Reference in New Issue
Block a user