app domain assess and assess_batch. Add threshold check for compatability

This commit is contained in:
Liam Brydon
2025-11-06 10:32:21 +13:00
parent 9f0e396437
commit 09ddd46d69
3 changed files with 17 additions and 23 deletions

View File

@ -2694,7 +2694,7 @@ class MLRelativeReasoning(PackageBasedModel):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
model.fit(X, y)
model.fit(X.to_numpy(), y.to_numpy())
return model
def _model_args(self):
@ -2797,16 +2797,19 @@ class ApplicabilityDomain(EnviPathModel):
joblib.dump(ad, f)
def assess(self, structure: Union[str, "CompoundStructure"]):
return self.assess_batch([structure])[0]
def assess_batch(self, structures: List["CompoundStructure | str"]):
ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure):
smiles = structure.smiles
else:
smiles = structure
smiles = []
for struct in structures:
if isinstance(struct, CompoundStructure):
smiles.append(structures.smiles)
else:
smiles.append(structures)
assessment_ds, assessment_prods = ds.classification_dataset(
[structure], self.model.applicable_rules
)
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
# qualified_neighbours_per_rule is a nested dictionary structured as:
# {
@ -2837,7 +2840,6 @@ class ApplicabilityDomain(EnviPathModel):
)
assessments = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
rule_reliabilities = dict()
@ -2871,7 +2873,7 @@ class ApplicabilityDomain(EnviPathModel):
"local_compatibility_threshold": self.local_compatibilty_threshold,
},
"assessment": {
"smiles": smiles,
"smiles": smiles[i],
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
},
}
@ -2930,13 +2932,11 @@ class ApplicabilityDomain(EnviPathModel):
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
return distances
@staticmethod
def _compute_compatibility(rule_idx: int, neighbours: "RuleBasedDataset"):
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
accuracy = 0.0
import polars as pl
# TODO: Use a threshold to convert prob to boolean, or pass boolean in
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
pred=pl.col(f"prob_{rule_idx}").cast(pl.Boolean))
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
# Compute tp, tn, fp, fn using polars expressions
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height