From d2c2e643cb2cd0aef51227ff2295cbc5818ac15f Mon Sep 17 00:00:00 2001 From: jebus Date: Fri, 20 Feb 2026 10:14:28 +1300 Subject: [PATCH] [Fix] Compound Grouping, Identity prediction of enviFormer, Setting params (#337) Co-authored-by: Tim Lorsbach Reviewed-on: https://git.envipath.com/enviPath/enviPy/pulls/337 --- envipath/settings.py | 4 ++-- epdb/legacy_api.py | 11 ++++++++--- epdb/logic.py | 3 +++ epdb/models.py | 21 ++++++++++++++++----- epdb/tasks.py | 2 +- epdb/views.py | 14 ++++++++------ tests/test_formatconverter.py | 2 +- utilities/chem.py | 6 ++++-- 8 files changed, 43 insertions(+), 20 deletions(-) diff --git a/envipath/settings.py b/envipath/settings.py index 8d7ff657..a23cb85c 100644 --- a/envipath/settings.py +++ b/envipath/settings.py @@ -311,8 +311,8 @@ DEFAULT_MODEL_PARAMS = { "num_chains": 10, } -DEFAULT_MAX_NUMBER_OF_NODES = 30 -DEFAULT_MAX_DEPTH = 5 +DEFAULT_MAX_NUMBER_OF_NODES = 50 +DEFAULT_MAX_DEPTH = 8 DEFAULT_MODEL_THRESHOLD = 0.25 # Loading Plugins diff --git a/epdb/legacy_api.py b/epdb/legacy_api.py index b6aab0bf..ddf3b802 100644 --- a/epdb/legacy_api.py +++ b/epdb/legacy_api.py @@ -1332,8 +1332,13 @@ def get_package_scenario(request, package_uuid, scenario_uuid): } +@router.post("/package/{uuid:package_uuid}/scenario") +def create_package_scenario(request, package_uuid): + pass + + @router.delete("/package/{uuid:package_uuid}/scenario") -def delete_scenarios(request, package_uuid, scenario_uuid): +def delete_scenarios(request, package_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) @@ -1520,7 +1525,7 @@ def create_pathway( try: p = PackageManager.get_package_by_id(request.user, package_uuid) - stand_smiles = FormatConverter.standardize(pw.smilesinput.strip()) + stand_smiles = FormatConverter.standardize(pw.smilesinput.strip(), remove_stereo=True) new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description) @@ -1937,7 +1942,7 @@ def get_model(request, package_uuid, model_uuid, c: Query[Classify]): return 400, {"message": "Received empty SMILES"} try: - stand_smiles = FormatConverter.standardize(c.smiles) + stand_smiles = FormatConverter.standardize(c.smiles, remove_stereo=True) except ValueError: return 400, {"message": f'"{c.smiles}" is not a valid SMILES'} diff --git a/epdb/logic.py b/epdb/logic.py index 9718309b..42382e1c 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -1115,13 +1115,16 @@ class SettingManager(object): expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS, ): new_s = Setting() + # Clean for potential XSS new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() + new_s.max_nodes = max_nodes new_s.max_depth = max_depth new_s.model = model new_s.model_threshold = model_threshold + new_s.expansion_scheme = expansion_scheme new_s.save() diff --git a/epdb/models.py b/epdb/models.py index d15b03cf..79b55f14 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -769,7 +769,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin num_structs = self.structures.count() stand_smiles = set() for structure in self.structures.all(): - stand_smiles.add(FormatConverter.standardize(structure.smiles)) + stand_smiles.add(FormatConverter.standardize(structure.smiles, remove_stereo=True)) if len(stand_smiles) != 1: logger.debug( @@ -838,7 +838,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin if parsed is None: raise ValueError("Given SMILES is invalid") - standardized_smiles = FormatConverter.standardize(smiles) + standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True) # Check if we find a direct match for a given SMILES if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists(): @@ -911,7 +911,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin if parsed is None: raise ValueError("Given SMILES is invalid") - standardized_smiles = FormatConverter.standardize(smiles) + standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True) is_standardized = standardized_smiles == smiles @@ -2011,19 +2011,23 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): # Clean for potential XSS name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() + is_generic_name = False if name is None or name == "": name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}" + is_generic_name = True pw.name = name + if description is not None and description.strip() != "": pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() pw.predicted = predicted pw.save() + try: # create root node - Node.create(pw, smiles, 0) + Node.create(pw, smiles, 0, name=name if not is_generic_name else None) except ValueError as e: # Node creation failed, most likely due to an invalid smiles # delete this pathway... @@ -3445,10 +3449,17 @@ class EnviFormer(PackageBasedModel): for smile in smi.split(".") ] ) + + if smi in canon_smiles: + logger.debug(f"Found input SMILES={smi} in prediction results. Skipping...") + continue + except ValueError: # This occurs when the predicted string is an invalid SMILES logging.info(f"EnviFormer predicted an invalid SMILES: {smi}") continue + res.append(PredictionResult([ProductSet([smi])], prob, None)) + results.append(res) return results @@ -3587,7 +3598,7 @@ class EnviFormer(PackageBasedModel): ) root_node = ".".join( [ - FormatConverter.standardize(smile) + FormatConverter.standardize(smile, remove_stereo=True) for smile in root_node[0].default_node_label.smiles.split(".") ] ) diff --git a/epdb/tasks.py b/epdb/tasks.py index b3aaa5e1..be4806b2 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -420,7 +420,7 @@ def batch_predict( standardized_substrates_and_smiles = [] for substrate in substrate_and_names: try: - stand_smiles = FormatConverter.standardize(substrate[0]) + stand_smiles = FormatConverter.standardize(substrate[0], remove_stereo=True) standardized_substrates_and_smiles.append([stand_smiles, substrate[1]]) except ValueError: raise ValueError( diff --git a/epdb/views.py b/epdb/views.py index 3d569082..56884492 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -1591,7 +1591,7 @@ def package_rule(request, package_uuid, rule_uuid): context = get_base_context(request) if smiles := request.GET.get("smiles", False): - stand_smiles = FormatConverter.standardize(smiles) + stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True) res = current_rule.apply(stand_smiles) if len(res) > 1: logger.info( @@ -1919,7 +1919,7 @@ def package_pathways(request, package_uuid): "Pathway prediction failed due to missing or empty SMILES", ) try: - stand_smiles = FormatConverter.standardize(smiles) + stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True) except ValueError: return error( request, @@ -2793,7 +2793,6 @@ def settings(request): return render(request, "collections/settings_paginated.html", context) - return render(request, "collections/objects_list.html", context) elif request.method == "POST": if s.DEBUG: for k, v in request.POST.items(): @@ -2805,15 +2804,18 @@ def settings(request): new_default = request.POST.get("prediction-setting-new-default", "off") == "on" + # min 2, max s.DEFAULT_MAX_NUMBER_OF_NODES max_nodes = min( max( int(request.POST.get("prediction-setting-max-nodes", 1)), - s.DEFAULT_MAX_NUMBER_OF_NODES, + 2, ), s.DEFAULT_MAX_NUMBER_OF_NODES, ) + + # min 1, max s.DEFAULT_MAX_DEPTH max_depth = min( - max(int(request.POST.get("prediction-setting-max-depth", 1)), s.DEFAULT_MAX_DEPTH), + max(int(request.POST.get("prediction-setting-max-depth", 1)), 1), s.DEFAULT_MAX_DEPTH, ) @@ -2960,7 +2962,7 @@ def jobs(request): parts = pair.split(",") try: - smiles = FormatConverter.standardize(parts[0]) + smiles = FormatConverter.standardize(parts[0], remove_stereo=True) except ValueError: raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!") diff --git a/tests/test_formatconverter.py b/tests/test_formatconverter.py index 006d267a..f6d17148 100644 --- a/tests/test_formatconverter.py +++ b/tests/test_formatconverter.py @@ -6,5 +6,5 @@ from utilities.chem import FormatConverter class FormatConverterTestCase(TestCase): def test_standardization(self): smiles = "C[n+]1c([N-](C))cccc1" - standardized_smiles = FormatConverter.standardize(smiles) + standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True) self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C") diff --git a/utilities/chem.py b/utilities/chem.py index 661d06a1..6e4e48bb 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -474,7 +474,7 @@ class FormatConverter(object): for smi in l_smiles: try: smi = FormatConverter.standardize( - smi, canonicalize_tautomers=canonicalize_tautomers + smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers ) except Exception: # :shrug: @@ -488,7 +488,9 @@ class FormatConverter(object): if standardize: for smi in r_smiles: try: - smi = FormatConverter.standardize(smi) + smi = FormatConverter.standardize( + smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers + ) except Exception: # :shrug: # logger.debug(f'Standardizing SMILES failed for {smi}')