app domain assess and assess_batch. Add threshold check for compatability

This commit is contained in:
Liam Brydon
2025-11-06 10:32:21 +13:00
parent 9f0e396437
commit 09ddd46d69
3 changed files with 17 additions and 23 deletions

View File

@ -1542,9 +1542,7 @@ class SPathway(object):
if sub.app_domain_assessment is None: if sub.app_domain_assessment is None:
if self.prediction_setting.model: if self.prediction_setting.model:
if self.prediction_setting.model.app_domain: if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess( app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
sub.smiles
)[0]
if self.persist is not None: if self.persist is not None:
n = self.snode_persist_lookup[sub] n = self.snode_persist_lookup[sub]
@ -1576,11 +1574,7 @@ class SPathway(object):
app_domain_assessment = None app_domain_assessment = None
if self.prediction_setting.model: if self.prediction_setting.model:
if self.prediction_setting.model.app_domain: if self.prediction_setting.model.app_domain:
app_domain_assessment = ( app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
self.prediction_setting.model.app_domain.assess(c)[
0
]
)
self.smiles_to_node[c] = SNode( self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment c, sub.depth + 1, app_domain_assessment

View File

@ -2694,7 +2694,7 @@ class MLRelativeReasoning(PackageBasedModel):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
model.fit(X, y) model.fit(X.to_numpy(), y.to_numpy())
return model return model
def _model_args(self): def _model_args(self):
@ -2797,16 +2797,19 @@ class ApplicabilityDomain(EnviPathModel):
joblib.dump(ad, f) joblib.dump(ad, f)
def assess(self, structure: Union[str, "CompoundStructure"]): def assess(self, structure: Union[str, "CompoundStructure"]):
return self.assess_batch([structure])[0]
def assess_batch(self, structures: List["CompoundStructure | str"]):
ds = self.model.load_dataset() ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure): smiles = []
smiles = structure.smiles for struct in structures:
else: if isinstance(struct, CompoundStructure):
smiles = structure smiles.append(structures.smiles)
else:
smiles.append(structures)
assessment_ds, assessment_prods = ds.classification_dataset( assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
[structure], self.model.applicable_rules
)
# qualified_neighbours_per_rule is a nested dictionary structured as: # qualified_neighbours_per_rule is a nested dictionary structured as:
# { # {
@ -2837,7 +2840,6 @@ class ApplicabilityDomain(EnviPathModel):
) )
assessments = list() assessments = list()
# loop through our assessment dataset # loop through our assessment dataset
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]): for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
rule_reliabilities = dict() rule_reliabilities = dict()
@ -2871,7 +2873,7 @@ class ApplicabilityDomain(EnviPathModel):
"local_compatibility_threshold": self.local_compatibilty_threshold, "local_compatibility_threshold": self.local_compatibilty_threshold,
}, },
"assessment": { "assessment": {
"smiles": smiles, "smiles": smiles[i],
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0], "inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
}, },
} }
@ -2930,13 +2932,11 @@ class ApplicabilityDomain(EnviPathModel):
distances = [tanimoto_distance(classify_instance, train) for train in train_instances] distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
return distances return distances
@staticmethod def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
def _compute_compatibility(rule_idx: int, neighbours: "RuleBasedDataset"):
accuracy = 0.0 accuracy = 0.0
import polars as pl import polars as pl
# TODO: Use a threshold to convert prob to boolean, or pass boolean in
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean), obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
pred=pl.col(f"prob_{rule_idx}").cast(pl.Boolean)) pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
# Compute tp, tn, fp, fn using polars expressions # Compute tp, tn, fp, fn using polars expressions
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height

View File

@ -888,7 +888,7 @@ def package_model(request, package_uuid, model_uuid):
return JsonResponse(res, safe=False) return JsonResponse(res, safe=False)
else: else:
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0] app_domain_assessment = current_model.app_domain.assess(stand_smiles)
return JsonResponse(app_domain_assessment, safe=False) return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request) context = get_base_context(request)