From 35c342a3e3e711d4a7f5826b5fb43d5bf4b8441a Mon Sep 17 00:00:00 2001 From: Tim Lorsbach Date: Tue, 11 Nov 2025 10:09:22 +0100 Subject: [PATCH] Fixed handling for SMIRKS/SMARTS, adjusted test values as they are now cleaned, refactored logic for object update --- epdb/logic.py | 14 +++++-- epdb/models.py | 88 +++++++++++++++++++++++++++++----------- epdb/views.py | 81 +++++++++++++++++++++++++++--------- tests/test_rule_model.py | 2 +- utilities/chem.py | 24 +++++++++++ 5 files changed, 163 insertions(+), 46 deletions(-) diff --git a/epdb/logic.py b/epdb/logic.py index d8415977..f9e1192a 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -526,9 +526,13 @@ class PackageManager(object): @transaction.atomic def create_package(current_user, name: str, description: str = None): p = Package() + # Clean for potential XSS p.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() - p.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() + + if description is not None and description.strip() != "": + p.description = nh3.clean(description.strip(), tags=s.ALLOWED_HTML_TAGS).strip() + p.save() up = UserPackagePermission() @@ -1552,7 +1556,9 @@ class SPathway(object): if sub.app_domain_assessment is None: if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles) + app_domain_assessment = self.prediction_setting.model.app_domain.assess( + sub.smiles + ) if self.persist is not None: n = self.snode_persist_lookup[sub] @@ -1584,7 +1590,9 @@ class SPathway(object): app_domain_assessment = None if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c)) + app_domain_assessment = ( + self.prediction_setting.model.app_domain.assess(c) + ) self.smiles_to_node[c] = SNode( c, sub.depth + 1, app_domain_assessment diff --git a/epdb/models.py b/epdb/models.py index d48b0476..4b6d7500 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -29,8 +29,14 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \ - EnviFormerDataset, Dataset +from utilities.ml import ( + RuleBasedDataset, + ApplicabilityDomainPCA, + EnsembleClassifierChain, + RelativeReasoning, + EnviFormerDataset, + Dataset, +) logger = logging.getLogger(__name__) @@ -1190,9 +1196,10 @@ class SimpleAmbitRule(SimpleRule): r = SimpleAmbitRule() r.package = package + if name is not None: - # Clean for potential XSS name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() + if name is None or name == "": name = f"Rule {Rule.objects.filter(package=package).count() + 1}" @@ -1200,13 +1207,19 @@ class SimpleAmbitRule(SimpleRule): if description is not None and description.strip() != "": r.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() - r.smirks = nh3.clean(smirks).strip() + r.smirks = smirks if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "": - r.reactant_filter_smarts = nh3.clean(reactant_filter_smarts).strip() + if not FormatConverter.is_valid_smarts(reactant_filter_smarts.strip()): + raise ValueError(f'Reactant Filter SMARTS "{reactant_filter_smarts}" is invalid!') + else: + r.reactant_filter_smarts = reactant_filter_smarts.strip() if product_filter_smarts is not None and product_filter_smarts.strip() != "": - r.product_filter_smarts = nh3.clean(product_filter_smarts).strip() + if not FormatConverter.is_valid_smarts(product_filter_smarts.strip()): + raise ValueError(f'Product Filter SMARTS "{product_filter_smarts}" is invalid!') + else: + r.product_filter_smarts = product_filter_smarts.strip() r.save() return r @@ -2353,7 +2366,9 @@ class PackageBasedModel(EPModel): eval_reactions = list( Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() ) - ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) + ds = RuleBasedDataset.generate_dataset( + eval_reactions, self.applicable_rules, educts_only=True + ) if isinstance(self, RuleBasedRelativeReasoning): X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy() y = ds.y(na_replacement=np.nan).to_numpy() @@ -2818,7 +2833,9 @@ class ApplicabilityDomain(EnviPathModel): else: smiles.append(structures) - assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules) + assessment_ds, assessment_prods = ds.classification_dataset( + structures, self.model.applicable_rules + ) # qualified_neighbours_per_rule is a nested dictionary structured as: # { @@ -2834,12 +2851,16 @@ class ApplicabilityDomain(EnviPathModel): qualified_neighbours_per_rule: Dict = {} import polars as pl + # Select only the triggered columns for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)): # Find the rules the structure triggers. For each rule, filter the training dataset to rows that also # trigger that rule. - train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1)) - for trig_uuid, value in row.items() if value == 1} + train_trig = { + trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1)) + for trig_uuid, value in row.items() + if value == 1 + } qualified_neighbours_per_rule[i] = train_trig rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)} preds = self.model.combine_products_and_probs( @@ -2859,18 +2880,28 @@ class ApplicabilityDomain(EnviPathModel): # loop through rule indices together with the collected neighbours indices from train dataset for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items(): # compute tanimoto distance for all neighbours and add to dataset - dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0], - train_instances[:, train_instances.struct_features()].to_numpy()) + dists = self._compute_distances( + assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0], + train_instances[:, train_instances.struct_features()].to_numpy(), + ) train_instances = train_instances.with_columns(dist=pl.Series(dists)) # sort them in a descending way and take at most `self.num_neighbours` # TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending) - train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours] + train_instances = train_instances.sort("dist", descending=True)[ + : self.num_neighbours + ] # compute average distance - rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item() + rule_reliabilities[rule_uuid] = ( + train_instances.select(pl.mean("dist")).fill_nan(0.0).item() + ) # for local_compatibility we'll need the datasets for the indices having the highest similarity - local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances) - neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"])) + local_compatibilities[rule_uuid] = self._compute_compatibility( + rule_uuid, train_instances + ) + neighbours_per_rule[rule_uuid] = list( + CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]) + ) neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list() ad_res = { @@ -2944,8 +2975,11 @@ class ApplicabilityDomain(EnviPathModel): def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"): accuracy = 0.0 import polars as pl - obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean), - pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold) + + obs_pred = neighbours.select( + obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean), + pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold, + ) # Compute tp, tn, fp, fn using polars expressions tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height @@ -3115,7 +3149,7 @@ class EnviFormer(PackageBasedModel): pred_dict = {} for k, pred in enumerate(predictions): pred_smiles, pred_proba = zip(*pred.items()) - reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"] + reactant, _ = test_ds[k, "educts"], test_ds[k, "products"] pred_dict.setdefault(reactant, {"predict": [], "scores": []}) for smiles, proba in zip(pred_smiles, pred_proba): smiles = set(smiles.split(".")) @@ -3229,8 +3263,9 @@ class EnviFormer(PackageBasedModel): # If there are eval packages perform single generation evaluation on them instead of random splits if self.eval_packages.count() > 0: - ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter( - package__in=self.eval_packages.all()).distinct()) + ds = EnviFormerDataset.generate_dataset( + Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() + ) test_result = self.model.predict_batch(ds.X()) single_gen_result = evaluate_sg(ds, test_result, self.threshold) self.eval_results = self.compute_averages([single_gen_result]) @@ -3248,7 +3283,9 @@ class EnviFormer(PackageBasedModel): train = ds[train_index] test = ds[test_index] start = datetime.now() - model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) + model = fine_tune( + train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE + ) end = datetime.now() logger.debug( f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" @@ -3325,7 +3362,12 @@ class EnviFormer(PackageBasedModel): for pathway in train_pathways: for reaction in pathway.edges: reaction = reaction.edge_label - if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): + if any( + [ + educt in test_educts + for educt in reaction_to_educts[str(reaction.uuid)] + ] + ): overlap += 1 continue train_reactions.append(reaction) diff --git a/epdb/views.py b/epdb/views.py index 9c8f9761..36bb0d6e 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -397,7 +397,9 @@ def packages(request): return HttpResponseBadRequest() else: package_name = request.POST.get("package-name") - package_description = request.POST.get("package-description", s.DEFAULT_VALUES["description"]) + package_description = request.POST.get( + "package-description", s.DEFAULT_VALUES["description"] + ) created_package = PackageManager.create_package( current_user, package_name, package_description @@ -939,8 +941,13 @@ def package_model(request, package_uuid, model_uuid): return HttpResponseBadRequest() else: # TODO: Move cleaning to property updater - name = nh3.clean(request.POST.get("model-name", "").strip(), tags=s.ALLOWED_HTML_TAGS).strip() - description = nh3.clean(request.POST.get("model-description", "").strip(), tags=s.ALLOWED_HTML_TAGS).strip() + name = request.POST.get("model-name") + if name is not None: + name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() + + description = request.POST.get("model-description") + if description is not None: + description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() if any([name, description]): if name: @@ -1043,8 +1050,15 @@ def package(request, package_uuid): return HttpResponseBadRequest() # TODO: Move cleaning to property updater - new_package_name = nh3.clean(request.POST.get("package-name"), tags=s.ALLOWED_HTML_TAGS).strip() - new_package_description = nh3.clean(request.POST.get("package-description"), tags=s.ALLOWED_HTML_TAGS).strip() + new_package_name = request.POST.get("package-name") + if new_package_name is not None: + new_package_name = nh3.clean(new_package_name, tags=s.ALLOWED_HTML_TAGS).strip() + + new_package_description = request.POST.get("package-description") + if new_package_description is not None: + new_package_description = nh3.clean( + new_package_description, tags=s.ALLOWED_HTML_TAGS + ).strip() grantee_url = request.POST.get("grantee") read = request.POST.get("read") == "on" @@ -1205,10 +1219,17 @@ def package_compound(request, package_uuid, compound_uuid): return JsonResponse({"error": str(e)}, status=400) return JsonResponse({"success": current_compound.url}) + # TODO: Move cleaning to property updater - new_compound_name = nh3.clean(request.POST.get("compound-name", ""), tags=s.ALLOWED_HTML_TAGS).strip() - new_compound_description = nh3.clean(request.POST.get("compound-description", ""), - tags=s.ALLOWED_HTML_TAGS).strip() + new_compound_name = request.POST.get("compound-name") + if new_compound_name is not None: + new_compound_name = nh3.clean(new_compound_name, tags=s.ALLOWED_HTML_TAGS).strip() + + new_compound_description = request.POST.get("compound-description") + if new_compound_description is not None: + new_compound_description = nh3.clean( + new_compound_description, tags=s.ALLOWED_HTML_TAGS + ).strip() if new_compound_name: current_compound.name = new_compound_name @@ -1343,11 +1364,17 @@ def package_compound_structure(request, package_uuid, compound_uuid, structure_u return redirect(current_compound.url + "/structure") else: return HttpResponseBadRequest() + # TODO: Move cleaning to property updater - new_structure_name = nh3.clean(request.POST.get("compound-structure-name", ""), - tags=s.ALLOWED_HTML_TAGS).strip() - new_structure_description = nh3.clean(request.POST.get("compound-structure-description", ""), - tags=s.ALLOWED_HTML_TAGS).strip() + new_structure_name = request.POST.get("compound-structure-name") + if new_structure_name is not None: + new_structure_name = nh3.clean(new_structure_name, tags=s.ALLOWED_HTML_TAGS).strip() + + new_structure_description = request.POST.get("compound-structure-description") + if new_structure_description is not None: + new_structure_description = nh3.clean( + new_structure_description, tags=s.ALLOWED_HTML_TAGS + ).strip() if new_structure_name: current_structure.name = new_structure_name @@ -1555,8 +1582,13 @@ def package_rule(request, package_uuid, rule_uuid): return JsonResponse({"success": current_rule.url}) # TODO: Move cleaning to property updater - rule_name = nh3.clean(request.POST.get("rule-name", ""), tags=s.ALLOWED_HTML_TAGS).strip() - rule_description = nh3.clean(request.POST.get("rule-description", "").strip(), tags=s.ALLOWED_HTML_TAGS).strip() + rule_name = request.POST.get("rule-name") + if rule_name is not None: + rule_name = nh3.clean(rule_name, tags=s.ALLOWED_HTML_TAGS).strip() + + rule_description = request.POST.get("rule-description") + if rule_description is not None: + rule_description = nh3.clean(rule_description, tags=s.ALLOWED_HTML_TAGS).strip() if rule_name: current_rule.name = rule_name @@ -1708,9 +1740,15 @@ def package_reaction(request, package_uuid, reaction_uuid): return JsonResponse({"success": current_reaction.url}) # TODO: Move cleaning to property updater - new_reaction_name = nh3.clean(request.POST.get("reaction-name", ""), tags=s.ALLOWED_HTML_TAGS).strip() - new_reaction_description = nh3.clean(request.POST.get("reaction-description", ""), - tags=s.ALLOWED_HTML_TAGS).strip() + new_reaction_name = request.POST.get("reaction-name") + if new_reaction_name is not None: + new_reaction_name = nh3.clean(new_reaction_name, tags=s.ALLOWED_HTML_TAGS).strip() + + new_reaction_description = request.POST.get("reaction-description") + if new_reaction_description is not None: + new_reaction_description = nh3.clean( + new_reaction_description, tags=s.ALLOWED_HTML_TAGS + ).strip() if new_reaction_name: current_reaction.name = new_reaction_name @@ -1957,8 +1995,13 @@ def package_pathway(request, package_uuid, pathway_uuid): return JsonResponse({"success": current_pathway.url}) # TODO: Move cleaning to property updater - pathway_name = nh3.clean(request.POST.get("pathway-name"), tags=s.ALLOWED_HTML_TAGS).strip() - pathway_description = nh3.clean(request.POST.get("pathway-description"), tags=s.ALLOWED_HTML_TAGS).strip() + pathway_name = request.POST.get("pathway-name") + if pathway_name is not None: + pathway_name = nh3.clean(pathway_name, tags=s.ALLOWED_HTML_TAGS).strip() + + pathway_description = request.POST.get("pathway-description") + if pathway_description is not None: + pathway_description = nh3.clean(pathway_description, tags=s.ALLOWED_HTML_TAGS).strip() if any([pathway_name, pathway_description]): if pathway_name is not None and pathway_name.strip() != "": diff --git a/tests/test_rule_model.py b/tests/test_rule_model.py index 520e049f..b50d01f0 100644 --- a/tests/test_rule_model.py +++ b/tests/test_rule_model.py @@ -29,7 +29,7 @@ class RuleTest(TestCase): self.assertEqual(r.name, "bt0022-2833") self.assertEqual( r.description, - "Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative", + "Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative", ) def test_smirks_are_trimmed(self): diff --git a/utilities/chem.py b/utilities/chem.py index d7a68d75..40251911 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -255,6 +255,30 @@ class FormatConverter(object): except Exception: return False + @staticmethod + def is_valid_smarts(smarts: str) -> bool: + """ + Checks whether a given string is a valid SMARTS pattern. + + Parameters + ---------- + smarts : str + The SMARTS string to validate. + + Returns + ------- + bool + True if the SMARTS string is valid, False otherwise. + """ + if not isinstance(smarts, str) or not smarts.strip(): + return False + + try: + mol = Chem.MolFromSmarts(smarts) + return mol is not None + except Exception: + return False + @staticmethod def apply( smiles: str,