forked from enviPath/enviPy
app domain assess and assess_batch. Add threshold check for compatability
This commit is contained in:
@ -1542,9 +1542,7 @@ class SPathway(object):
|
|||||||
if sub.app_domain_assessment is None:
|
if sub.app_domain_assessment is None:
|
||||||
if self.prediction_setting.model:
|
if self.prediction_setting.model:
|
||||||
if self.prediction_setting.model.app_domain:
|
if self.prediction_setting.model.app_domain:
|
||||||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
|
||||||
sub.smiles
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
if self.persist is not None:
|
if self.persist is not None:
|
||||||
n = self.snode_persist_lookup[sub]
|
n = self.snode_persist_lookup[sub]
|
||||||
@ -1576,11 +1574,7 @@ class SPathway(object):
|
|||||||
app_domain_assessment = None
|
app_domain_assessment = None
|
||||||
if self.prediction_setting.model:
|
if self.prediction_setting.model:
|
||||||
if self.prediction_setting.model.app_domain:
|
if self.prediction_setting.model.app_domain:
|
||||||
app_domain_assessment = (
|
app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
|
||||||
self.prediction_setting.model.app_domain.assess(c)[
|
|
||||||
0
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.smiles_to_node[c] = SNode(
|
self.smiles_to_node[c] = SNode(
|
||||||
c, sub.depth + 1, app_domain_assessment
|
c, sub.depth + 1, app_domain_assessment
|
||||||
|
|||||||
@ -2694,7 +2694,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||||
|
|
||||||
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
||||||
model.fit(X, y)
|
model.fit(X.to_numpy(), y.to_numpy())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _model_args(self):
|
def _model_args(self):
|
||||||
@ -2797,16 +2797,19 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
joblib.dump(ad, f)
|
joblib.dump(ad, f)
|
||||||
|
|
||||||
def assess(self, structure: Union[str, "CompoundStructure"]):
|
def assess(self, structure: Union[str, "CompoundStructure"]):
|
||||||
|
return self.assess_batch([structure])[0]
|
||||||
|
|
||||||
|
def assess_batch(self, structures: List["CompoundStructure | str"]):
|
||||||
ds = self.model.load_dataset()
|
ds = self.model.load_dataset()
|
||||||
|
|
||||||
if isinstance(structure, CompoundStructure):
|
smiles = []
|
||||||
smiles = structure.smiles
|
for struct in structures:
|
||||||
else:
|
if isinstance(struct, CompoundStructure):
|
||||||
smiles = structure
|
smiles.append(structures.smiles)
|
||||||
|
else:
|
||||||
|
smiles.append(structures)
|
||||||
|
|
||||||
assessment_ds, assessment_prods = ds.classification_dataset(
|
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
|
||||||
[structure], self.model.applicable_rules
|
|
||||||
)
|
|
||||||
|
|
||||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||||
# {
|
# {
|
||||||
@ -2837,7 +2840,6 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assessments = list()
|
assessments = list()
|
||||||
|
|
||||||
# loop through our assessment dataset
|
# loop through our assessment dataset
|
||||||
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
|
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
|
||||||
rule_reliabilities = dict()
|
rule_reliabilities = dict()
|
||||||
@ -2871,7 +2873,7 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
||||||
},
|
},
|
||||||
"assessment": {
|
"assessment": {
|
||||||
"smiles": smiles,
|
"smiles": smiles[i],
|
||||||
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
|
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -2930,13 +2932,11 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
|
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@staticmethod
|
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
|
||||||
def _compute_compatibility(rule_idx: int, neighbours: "RuleBasedDataset"):
|
|
||||||
accuracy = 0.0
|
accuracy = 0.0
|
||||||
import polars as pl
|
import polars as pl
|
||||||
# TODO: Use a threshold to convert prob to boolean, or pass boolean in
|
|
||||||
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
|
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
|
||||||
pred=pl.col(f"prob_{rule_idx}").cast(pl.Boolean))
|
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
|
||||||
# Compute tp, tn, fp, fn using polars expressions
|
# Compute tp, tn, fp, fn using polars expressions
|
||||||
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
|
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
|
||||||
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
|
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
|
||||||
|
|||||||
@ -888,7 +888,7 @@ def package_model(request, package_uuid, model_uuid):
|
|||||||
return JsonResponse(res, safe=False)
|
return JsonResponse(res, safe=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
|
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
|
||||||
return JsonResponse(app_domain_assessment, safe=False)
|
return JsonResponse(app_domain_assessment, safe=False)
|
||||||
|
|
||||||
context = get_base_context(request)
|
context = get_base_context(request)
|
||||||
|
|||||||
Reference in New Issue
Block a user