[Feature] Engineer Pathway (#256)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#256
This commit is contained in:
2025-12-10 07:35:42 +13:00
parent 46b0f1c124
commit 648ec150a9
17 changed files with 990 additions and 127 deletions

View File

@ -2,12 +2,13 @@ import logging
import re
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Dict, TYPE_CHECKING
from typing import List, Optional, Dict, TYPE_CHECKING, Union
from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer
from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
from rdkit.Chem import rdchem
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MolStandardize import rdMolStandardize
@ -94,8 +95,15 @@ class FormatConverter(object):
return Chem.MolToSmiles(mol, canonical=canonical)
@staticmethod
def InChIKey(smiles):
return Chem.MolToInchiKey(FormatConverter.from_smiles(smiles))
def InChIKey(mol_or_smiles: Union[rdchem.Mol | str]):
if isinstance(mol_or_smiles, str):
mol_or_smiles = mol_or_smiles.replace("~", "")
mol_or_smiles = FormatConverter.from_smiles(mol_or_smiles)
if mol_or_smiles is None:
return None
return Chem.MolToInchiKey(mol_or_smiles)
@staticmethod
def InChI(smiles):
@ -352,7 +360,8 @@ class FormatConverter(object):
product = GetMolFrags(product, asMols=True)
for p in product:
p = FormatConverter.standardize(
Chem.MolToSmiles(p), remove_stereo=remove_stereo
Chem.MolToSmiles(p).replace("~", ""),
remove_stereo=remove_stereo,
)
if product_filter_smarts and FormatConverter.smarts_matches(

View File

@ -9,7 +9,7 @@ from collections import defaultdict
from datetime import datetime
from enum import Enum
from types import NoneType
from typing import Any, Dict, List
from typing import Any, Dict, List, TYPE_CHECKING
from django.conf import settings as s
from django.db import transaction
@ -35,6 +35,7 @@ from epdb.models import (
RuleBasedRelativeReasoning,
Scenario,
SequentialRule,
Setting,
SimpleAmbitRule,
SimpleRDKitRule,
SimpleRule,
@ -44,6 +45,9 @@ from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
Package = s.GET_PACKAGE_MODEL()
if TYPE_CHECKING:
from epdb.logic import SPathway
class HTMLGenerator:
registry = {x.__name__: x for x in NAME_MAPPING.values()}
@ -1260,3 +1264,122 @@ class PathwayUtils:
res[edge.url] = rule_chain
return res
def engineer(self, setting: "Setting"):
from epdb.logic import SPathway
from utilities.chem import FormatConverter
from utilities.ml import graph_from_pathway, get_shortest_path
# get a fresh copy
pw = Pathway.objects.get(id=self.pathway.pk)
root_nodes = [n.default_node_label.smiles for n in pw.root_nodes]
if len(root_nodes) != 1:
logger.warning(f"Pathway {pw.name} has {len(root_nodes)} root nodes")
# spw, mapping, intermediates
return None, {}, []
# Predict the Pathway in memory
spw = SPathway(root_nodes[0], None, setting)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# Generate SNode -> Node mapping
node_mapping = {}
for node in pw.nodes:
for snode in spw.smiles_to_node.values():
data_smiles = node.default_node_label.smiles
pred_smiles = snode.smiles
# "~" denotes any bond remove and use implicit single bond for comparison
data_key = FormatConverter.InChIKey(data_smiles.replace("~", ""))
pred_key = FormatConverter.InChIKey(pred_smiles.replace("~", ""))
if data_key == pred_key:
node_mapping[snode] = node
reverse_mapping = {v: k for k, v in node_mapping.items()}
graph = graph_from_pathway(spw)
intermediate_mapping = []
# loop through each edge and each reactant <-> product pair
# and compute the shortest path on the predicted pathway
for e in pw.edges:
for start in e.start_nodes.all():
if start not in reverse_mapping:
continue
start_snode = reverse_mapping[start]
for end in e.end_nodes.all():
if end not in reverse_mapping:
continue
end_snode = reverse_mapping[end]
# If res is non-empty, we've found intermediates
intermediate_smiles = get_shortest_path(
graph,
FormatConverter.standardize(start_snode.smiles, remove_stereo=True),
FormatConverter.standardize(end_snode.smiles, remove_stereo=True),
)
if intermediate_smiles:
intermediates = []
prev = start_snode.smiles
for smi in intermediate_smiles + [end_snode.smiles]:
for e in spw.get_edge_for_educt_smiles(prev):
if smi in e.product_smiles():
intermediates.append(e)
prev = smi
intermediate_mapping.append(
(start, end, start_snode, end_snode, intermediates)
)
return spw, reverse_mapping, intermediate_mapping
@staticmethod
def spathway_to_pathway(
package: "Package", spw: "SPathway", name: str = None, description: str = None
):
snode_to_node_mapping = dict()
root_nodes = spw.root_nodes
pw = Pathway.create(
package=package,
smiles=root_nodes[0].smiles,
name=name,
description=description,
predicted=True,
)
pw.setting = spw.prediction_setting
pw.save()
snode_to_node_mapping[root_nodes[0]] = pw.root_nodes[0]
if len(root_nodes) > 1:
for rn in root_nodes[1:]:
n = Node.create(pw, rn.smiles, depth=0)
snode_to_node_mapping[rn] = n
for snode, node in snode_to_node_mapping.items():
spw.snode_persist_lookup[snode] = node
spw.persist = pw
spw._sync_to_pathway()
return pw