[Feature] Engineer Pathway (#256)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#256
This commit is contained in:
2025-12-10 07:35:42 +13:00
parent 46b0f1c124
commit 648ec150a9
17 changed files with 990 additions and 127 deletions

View File

@ -9,7 +9,7 @@ from collections import defaultdict
from datetime import datetime
from enum import Enum
from types import NoneType
from typing import Any, Dict, List
from typing import Any, Dict, List, TYPE_CHECKING
from django.conf import settings as s
from django.db import transaction
@ -35,6 +35,7 @@ from epdb.models import (
RuleBasedRelativeReasoning,
Scenario,
SequentialRule,
Setting,
SimpleAmbitRule,
SimpleRDKitRule,
SimpleRule,
@ -44,6 +45,9 @@ from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
Package = s.GET_PACKAGE_MODEL()
if TYPE_CHECKING:
from epdb.logic import SPathway
class HTMLGenerator:
registry = {x.__name__: x for x in NAME_MAPPING.values()}
@ -1260,3 +1264,122 @@ class PathwayUtils:
res[edge.url] = rule_chain
return res
def engineer(self, setting: "Setting"):
from epdb.logic import SPathway
from utilities.chem import FormatConverter
from utilities.ml import graph_from_pathway, get_shortest_path
# get a fresh copy
pw = Pathway.objects.get(id=self.pathway.pk)
root_nodes = [n.default_node_label.smiles for n in pw.root_nodes]
if len(root_nodes) != 1:
logger.warning(f"Pathway {pw.name} has {len(root_nodes)} root nodes")
# spw, mapping, intermediates
return None, {}, []
# Predict the Pathway in memory
spw = SPathway(root_nodes[0], None, setting)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# Generate SNode -> Node mapping
node_mapping = {}
for node in pw.nodes:
for snode in spw.smiles_to_node.values():
data_smiles = node.default_node_label.smiles
pred_smiles = snode.smiles
# "~" denotes any bond remove and use implicit single bond for comparison
data_key = FormatConverter.InChIKey(data_smiles.replace("~", ""))
pred_key = FormatConverter.InChIKey(pred_smiles.replace("~", ""))
if data_key == pred_key:
node_mapping[snode] = node
reverse_mapping = {v: k for k, v in node_mapping.items()}
graph = graph_from_pathway(spw)
intermediate_mapping = []
# loop through each edge and each reactant <-> product pair
# and compute the shortest path on the predicted pathway
for e in pw.edges:
for start in e.start_nodes.all():
if start not in reverse_mapping:
continue
start_snode = reverse_mapping[start]
for end in e.end_nodes.all():
if end not in reverse_mapping:
continue
end_snode = reverse_mapping[end]
# If res is non-empty, we've found intermediates
intermediate_smiles = get_shortest_path(
graph,
FormatConverter.standardize(start_snode.smiles, remove_stereo=True),
FormatConverter.standardize(end_snode.smiles, remove_stereo=True),
)
if intermediate_smiles:
intermediates = []
prev = start_snode.smiles
for smi in intermediate_smiles + [end_snode.smiles]:
for e in spw.get_edge_for_educt_smiles(prev):
if smi in e.product_smiles():
intermediates.append(e)
prev = smi
intermediate_mapping.append(
(start, end, start_snode, end_snode, intermediates)
)
return spw, reverse_mapping, intermediate_mapping
@staticmethod
def spathway_to_pathway(
package: "Package", spw: "SPathway", name: str = None, description: str = None
):
snode_to_node_mapping = dict()
root_nodes = spw.root_nodes
pw = Pathway.create(
package=package,
smiles=root_nodes[0].smiles,
name=name,
description=description,
predicted=True,
)
pw.setting = spw.prediction_setting
pw.save()
snode_to_node_mapping[root_nodes[0]] = pw.root_nodes[0]
if len(root_nodes) > 1:
for rn in root_nodes[1:]:
n = Node.create(pw, rn.smiles, depth=0)
snode_to_node_mapping[rn] = n
for snode, node in snode_to_node_mapping.items():
spw.snode_persist_lookup[snode] = node
spw.persist = pw
spw._sync_to_pathway()
return pw