From 09ddd46d69dd3040bdafbda8bb7d586873bab1db Mon Sep 17 00:00:00 2001 From: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Date: Thu, 6 Nov 2025 10:32:21 +1300 Subject: [PATCH] app domain assess and assess_batch. Add threshold check for compatability --- epdb/logic.py | 10 ++-------- epdb/models.py | 28 ++++++++++++++-------------- epdb/views.py | 2 +- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/epdb/logic.py b/epdb/logic.py index 19f03ae2..0aaebf32 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -1542,9 +1542,7 @@ class SPathway(object): if sub.app_domain_assessment is None: if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess( - sub.smiles - )[0] + app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles) if self.persist is not None: n = self.snode_persist_lookup[sub] @@ -1576,11 +1574,7 @@ class SPathway(object): app_domain_assessment = None if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = ( - self.prediction_setting.model.app_domain.assess(c)[ - 0 - ] - ) + app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c)) self.smiles_to_node[c] = SNode( c, sub.depth + 1, app_domain_assessment diff --git a/epdb/models.py b/epdb/models.py index 8643ae7e..61cd3f1a 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -2694,7 +2694,7 @@ class MLRelativeReasoning(PackageBasedModel): X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) - model.fit(X, y) + model.fit(X.to_numpy(), y.to_numpy()) return model def _model_args(self): @@ -2797,16 +2797,19 @@ class ApplicabilityDomain(EnviPathModel): joblib.dump(ad, f) def assess(self, structure: Union[str, "CompoundStructure"]): + return self.assess_batch([structure])[0] + + def assess_batch(self, structures: List["CompoundStructure | str"]): ds = self.model.load_dataset() - if isinstance(structure, CompoundStructure): - smiles = structure.smiles - else: - smiles = structure + smiles = [] + for struct in structures: + if isinstance(struct, CompoundStructure): + smiles.append(structures.smiles) + else: + smiles.append(structures) - assessment_ds, assessment_prods = ds.classification_dataset( - [structure], self.model.applicable_rules - ) + assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules) # qualified_neighbours_per_rule is a nested dictionary structured as: # { @@ -2837,7 +2840,6 @@ class ApplicabilityDomain(EnviPathModel): ) assessments = list() - # loop through our assessment dataset for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]): rule_reliabilities = dict() @@ -2871,7 +2873,7 @@ class ApplicabilityDomain(EnviPathModel): "local_compatibility_threshold": self.local_compatibilty_threshold, }, "assessment": { - "smiles": smiles, + "smiles": smiles[i], "inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0], }, } @@ -2930,13 +2932,11 @@ class ApplicabilityDomain(EnviPathModel): distances = [tanimoto_distance(classify_instance, train) for train in train_instances] return distances - @staticmethod - def _compute_compatibility(rule_idx: int, neighbours: "RuleBasedDataset"): + def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"): accuracy = 0.0 import polars as pl - # TODO: Use a threshold to convert prob to boolean, or pass boolean in obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean), - pred=pl.col(f"prob_{rule_idx}").cast(pl.Boolean)) + pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold) # Compute tp, tn, fp, fn using polars expressions tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height diff --git a/epdb/views.py b/epdb/views.py index 4844d3be..bf545bf3 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -888,7 +888,7 @@ def package_model(request, package_uuid, model_uuid): return JsonResponse(res, safe=False) else: - app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0] + app_domain_assessment = current_model.app_domain.assess(stand_smiles) return JsonResponse(app_domain_assessment, safe=False) context = get_base_context(request)