From 13ed86a780d8938b4e7d082fbb612470907c350f Mon Sep 17 00:00:00 2001 From: jebus Date: Thu, 30 Oct 2025 00:47:45 +1300 Subject: [PATCH] [Feature] Identify Missing Rules (#177) Fixes #97 Co-authored-by: Tim Lorsbach Reviewed-on: https://git.envipath.com/enviPath/enviPy/pulls/177 --- epdb/tasks.py | 99 +++++++++++++- epdb/views.py | 19 +++ templates/actions/objects/pathway.html | 4 + .../objects/identify_missing_rules_modal.html | 54 ++++++++ templates/objects/pathway.html | 1 + utilities/chem.py | 84 +++++++++++- utilities/misc.py | 129 +++++++++++++++--- 7 files changed, 361 insertions(+), 29 deletions(-) create mode 100644 templates/modals/objects/identify_missing_rules_modal.html diff --git a/epdb/tasks.py b/epdb/tasks.py index b6f4e6b0..b872d4a9 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -1,13 +1,15 @@ +import csv +import io import logging from datetime import datetime -from typing import Callable, Optional +from typing import Any, Callable, List, Optional from uuid import uuid4 from celery import shared_task from celery.utils.functional import LRUCache from epdb.logic import SPathway -from epdb.models import EPModel, JobLog, Node, Package, Pathway, Setting, User +from epdb.models import EPModel, JobLog, Node, Package, Pathway, Rule, Setting, User, Edge logger = logging.getLogger(__name__) ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. @@ -186,3 +188,96 @@ def predict( JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url) return pw.url + + +@shared_task(bind=True, queue="background") +def identify_missing_rules( + self, + pw_pks: List[int], + rule_package_pk: int, +): + from utilities.misc import PathwayUtils + + rules = Package.objects.get(pk=rule_package_pk).get_applicable_rules() + + rows: List[Any] = [] + header = [ + "Package Name", + "Pathway Name", + "Educt Name", + "Educt SMILES", + "Reaction Name", + "Reaction SMIRKS", + "Triggered Rules", + "Reactant SMARTS", + "Product SMARTS", + "Product Names", + "Product SMILES", + ] + + rows.append(header) + + for pw in Pathway.objects.filter(pk__in=pw_pks): + pu = PathwayUtils(pw) + + missing_rules = pu.find_missing_rules(rules) + + package_name = pw.package.name + pathway_name = pw.name + + for edge_url, rule_chain in missing_rules.items(): + row: List[Any] = [package_name, pathway_name] + edge = Edge.objects.get(url=edge_url) + educts = edge.start_nodes.all() + + for educt in educts: + row.append(educt.default_node_label.name) + row.append(educt.default_node_label.smiles) + + row.append(edge.edge_label.name) + row.append(edge.edge_label.smirks()) + + rule_names = [] + reactant_smarts = [] + product_smarts = [] + + for r in rule_chain: + r = Rule.objects.get(url=r[0]) + rule_names.append(r.name) + + rs = r.reactants_smarts + if isinstance(rs, set): + rs = list(rs) + + ps = r.products_smarts + if isinstance(ps, set): + ps = list(ps) + + reactant_smarts.append(rs) + product_smarts.append(ps) + + row.append(rule_names) + row.append(reactant_smarts) + row.append(product_smarts) + + products = edge.end_nodes.all() + product_names = [] + product_smiles = [] + + for product in products: + product_names.append(product.default_node_label.name) + product_smiles.append(product.default_node_label.smiles) + + row.append(product_names) + row.append(product_smiles) + + rows.append(row) + + buffer = io.StringIO() + + writer = csv.writer(buffer) + writer.writerows(rows) + + buffer.seek(0) + + return buffer.getvalue() diff --git a/epdb/views.py b/epdb/views.py index 64f68a76..6778a221 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -1866,6 +1866,25 @@ def package_pathway(request, package_uuid, pathway_uuid): return response + if ( + request.GET.get("identify-missing-rules", False) == "true" + and request.GET.get("rule-package") is not None + ): + from .tasks import dispatch_eager, identify_missing_rules + + rule_package = PackageManager.get_package_by_url( + current_user, request.GET.get("rule-package") + ) + res = dispatch_eager( + current_user, identify_missing_rules, [current_pathway.pk], rule_package.pk + ) + + filename = f"{current_pathway.name.replace(' ', '_')}_{current_pathway.uuid}.csv" + response = HttpResponse(res, content_type="text/csv") + response["Content-Disposition"] = f'attachment; filename="{filename}"' + + return response + # Pathway d3_json() relies on a lot of related objects (Nodes, Structures, Edges, Reaction, Rules, ...) # we will again fetch the current pathway identified by this url, but this time together with nearly all # related objects diff --git a/templates/actions/objects/pathway.html b/templates/actions/objects/pathway.html index 28f74443..785f6213 100644 --- a/templates/actions/objects/pathway.html +++ b/templates/actions/objects/pathway.html @@ -22,6 +22,10 @@ Download Pathway as Image {% if meta.can_edit %} +
  • + + Identify Missing Rules +
  • diff --git a/templates/modals/objects/identify_missing_rules_modal.html b/templates/modals/objects/identify_missing_rules_modal.html new file mode 100644 index 00000000..23f2a953 --- /dev/null +++ b/templates/modals/objects/identify_missing_rules_modal.html @@ -0,0 +1,54 @@ +{% load static %} + + + diff --git a/templates/objects/pathway.html b/templates/objects/pathway.html index 4e4cc27d..faa38686 100644 --- a/templates/objects/pathway.html +++ b/templates/objects/pathway.html @@ -83,6 +83,7 @@ {% include "modals/objects/add_pathway_edge_modal.html" %} {% include "modals/objects/download_pathway_csv_modal.html" %} {% include "modals/objects/download_pathway_image_modal.html" %} + {% include "modals/objects/identify_missing_rules_modal.html" %} {% include "modals/objects/generic_copy_object_modal.html" %} {% include "modals/objects/edit_pathway_modal.html" %} {% include "modals/objects/generic_set_aliases_modal.html" %} diff --git a/utilities/chem.py b/utilities/chem.py index 279de26f..250ccfb6 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -185,7 +185,7 @@ class FormatConverter(object): return smiles @staticmethod - def standardize(smiles, remove_stereo=False): + def standardize(smiles, remove_stereo=False, canonicalize_tautomers=False): # Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/ # follows the steps in # https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/MolStandardize%20pieces.ipynb @@ -203,19 +203,21 @@ class FormatConverter(object): uncharger = ( rdMolStandardize.Uncharger() ) # annoying, but necessary as no convenience method exists - uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) + res_mol = uncharger.uncharge(parent_clean_mol) # note that no attempt is made at reionization at this step # nor at ionization at some pH (rdkit has no pKa caculator) # the main aim to to represent all molecules from different sources # in a (single) standard way, for use in ML, catalogue, etc. - # te = rdMolStandardize.TautomerEnumerator() # idem - # taut_uncharged_parent_clean_mol = te.Canonicalize(uncharged_parent_clean_mol) if remove_stereo: - Chem.RemoveStereochemistry(uncharged_parent_clean_mol) + Chem.RemoveStereochemistry(res_mol) - return Chem.MolToSmiles(uncharged_parent_clean_mol, kekuleSmiles=True) + if canonicalize_tautomers: + te = rdMolStandardize.TautomerEnumerator() # idem + res_mol = te.Canonicalize(res_mol) + + return Chem.MolToSmiles(res_mol, kekuleSmiles=True) @staticmethod def neutralize_smiles(smiles): @@ -363,6 +365,76 @@ class FormatConverter(object): return parsed_smiles, errors + @staticmethod + def smiles_covered_by( + l_smiles: List[str], + r_smiles: List[str], + standardize: bool = True, + canonicalize_tautomers: bool = True, + ) -> bool: + """ + Check if all SMILES in the left list are covered by (contained in) the right list. + + This function performs a subset check to determine if every chemical structure + represented in l_smiles has a corresponding representation in r_smiles. + + Args: + l_smiles (List[str]): List of SMILES strings to check for coverage. + r_smiles (List[str]): List of SMILES strings that should contain all l_smiles. + standardize (bool, optional): Whether to standardize SMILES before comparison. + Defaults to True. When True, applies FormatConverter.standardize() to + normalize representations for accurate comparison. + canonicalize_tautomers (bool, optional): Whether to canonicalize tautomers + Defaults to False. When True, applies rdMolStandardize.TautomerEnumerator().Canonicalize(res_mol) + to the compounds before comparison. + Returns: + bool: True if all SMILES in l_smiles are found in r_smiles (i.e., l_smiles + is a subset of r_smiles), False otherwise. + + Note: + - Comparison treats lists as sets, ignoring duplicates and order + - Failed standardization attempts are silently ignored (original SMILES used) + - This is a one-directional check: l_smiles ⊆ r_smiles + - For bidirectional equality, both directions must be checked separately + + Example: + >>> FormatConverter.smiles_covered_by(["CCO", "CC"], ["CCO", "CC", "CCC"]) + True + >>> FormatConverter.smiles_covered_by(["CCO", "CCCC"], ["CCO", "CC", "CCC"]) + False + """ + + standardized_l_smiles = [] + + if standardize: + for smi in l_smiles: + try: + smi = FormatConverter.standardize( + smi, canonicalize_tautomers=canonicalize_tautomers + ) + except Exception: + # :shrug: + # logger.debug(f'Standardizing SMILES failed for {smi}') + pass + standardized_l_smiles.append(smi) + else: + standardized_l_smiles = l_smiles + + standardized_r_smiles = [] + if standardize: + for smi in r_smiles: + try: + smi = FormatConverter.standardize(smi) + except Exception: + # :shrug: + # logger.debug(f'Standardizing SMILES failed for {smi}') + pass + standardized_r_smiles.append(smi) + else: + standardized_r_smiles = r_smiles + + return len(set(standardized_l_smiles).difference(set(standardized_r_smiles))) == 0 + class Standardizer(ABC): def __init__(self, name): diff --git a/utilities/misc.py b/utilities/misc.py index 3e4eeb59..0b7222f7 100644 --- a/utilities/misc.py +++ b/utilities/misc.py @@ -9,36 +9,37 @@ from collections import defaultdict from datetime import datetime from enum import Enum from types import NoneType -from typing import Dict, Any, List +from typing import Any, Dict, List from django.db import transaction -from envipy_additional_information import Interval, EnviPyModel -from envipy_additional_information import NAME_MAPPING +from envipy_additional_information import NAME_MAPPING, EnviPyModel, Interval from pydantic import BaseModel, HttpUrl from epdb.models import ( - Package, Compound, CompoundStructure, - SimpleRule, + Edge, + EnviFormer, + EPModel, + ExternalDatabase, + ExternalIdentifier, + License, + MLRelativeReasoning, + Node, + Package, + ParallelRule, + Pathway, + PluginModel, + Reaction, + Rule, + RuleBasedRelativeReasoning, + Scenario, + SequentialRule, SimpleAmbitRule, SimpleRDKitRule, - ParallelRule, - SequentialRule, - Reaction, - Pathway, - Node, - Edge, - Scenario, - EPModel, - MLRelativeReasoning, - RuleBasedRelativeReasoning, - EnviFormer, - PluginModel, - ExternalIdentifier, - ExternalDatabase, - License, + SimpleRule, ) +from utilities.chem import FormatConverter logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ class HTMLGenerator: @staticmethod def generate_html(additional_information: "EnviPyModel", prefix="") -> str: - from typing import get_origin, get_args, Union + from typing import Union, get_args, get_origin if isinstance(additional_information, type): clz_name = additional_information.__name__ @@ -1171,3 +1172,89 @@ class PackageImporter: url=identifier_data.get("url", ""), is_primary=identifier_data.get("is_primary", False), ) + + +class PathwayUtils: + def __init__(self, pathway: "Pathway"): + self.pathway = pathway + + @staticmethod + def _get_products(smiles: str, rules: List["Rule"]): + educt_rule_products: Dict[str, Dict[str, List[str]]] = defaultdict( + lambda: defaultdict(list) + ) + + for r in rules: + product_sets = r.apply(smiles) + for product_set in product_sets: + for product in product_set: + educt_rule_products[smiles][r.url].append(product) + + return educt_rule_products + + def find_missing_rules(self, rules: List["Rule"]): + print(f"Processing {self.pathway.name}") + # compute products for each node / rule combination in the pathway + educt_rule_products = defaultdict(lambda: defaultdict(list)) + + for node in self.pathway.nodes: + educt_rule_products.update(**self._get_products(node.default_node_label.smiles, rules)) + + # loop through edges and determine reactions that can't be constructed by + # any of the rules or a combination of two rules in a chained fashion + + res: Dict[str, List["Rule"]] = dict() + + for edge in self.pathway.edges: + found = False + reaction = edge.edge_label + + educts = [cs for cs in reaction.educts.all()] + products = [cs.smiles for cs in reaction.products.all()] + rule_chain = [] + + for educt in educts: + educt = educt.smiles + triggered_rules = list(educt_rule_products.get(educt, {}).keys()) + for triggered_rule in triggered_rules: + if rule_products := educt_rule_products[educt][triggered_rule]: + # check if this rule covers the reaction + if FormatConverter.smiles_covered_by( + products, rule_products, standardize=True, canonicalize_tautomers=True + ): + found = True + else: + # Check if another prediction step would cover the reaction + for product in rule_products: + prod_rule_products = self._get_products(product, rules) + prod_triggered_rules = list( + prod_rule_products.get(product, {}).keys() + ) + for prod_triggered_rule in prod_triggered_rules: + if second_step_products := prod_rule_products[product][ + prod_triggered_rule + ]: + if FormatConverter.smiles_covered_by( + products, + second_step_products, + standardize=True, + canonicalize_tautomers=True, + ): + rule_chain.append( + ( + triggered_rule, + Rule.objects.get(url=triggered_rule).name, + ) + ) + rule_chain.append( + ( + prod_triggered_rule, + Rule.objects.get(url=prod_triggered_rule).name, + ) + ) + res[edge.url] = rule_chain + + if not found: + res[edge.url] = rule_chain + + return res