Enable App Domain Assessment on Model Page (#45)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#45
This commit is contained in:
2025-08-12 09:02:11 +12:00
parent ec52b8872d
commit 1267ca8ace
5 changed files with 264 additions and 18 deletions

View File

@ -4,7 +4,7 @@ import logging
import os
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Union, List, Optional, Dict, Tuple
from typing import Union, List, Optional, Dict, Tuple, Set
from uuid import uuid4
import joblib
@ -164,6 +164,18 @@ class EnviPathModel(TimeStampedModel):
def url(self):
pass
def simple_json(self, include_description=False):
res = {
'url': self.url,
'uuid': str(self.uuid),
'name': self.name,
}
if include_description:
res['description'] = self.description
return res
def get_v(self, k, default=None):
if self.kv:
return self.kv.get(k, default)
@ -618,7 +630,7 @@ class ParallelRule(Rule):
return '{}/parallel-rule/{}'.format(self.package.url, self.uuid)
@property
def srs(self):
def srs(self) -> QuerySet:
return self.simple_rules.all()
def apply(self, structure):
@ -628,6 +640,26 @@ class ParallelRule(Rule):
return list(set(res))
@property
def reactants_smarts(self) -> Set[str]:
res = set()
for sr in self.srs:
for part in sr.reactants_smarts.split('.'):
res.add(part)
return res
@property
def products_smarts(self) -> Set[str]:
res = set()
for sr in self.srs:
for part in sr.products_smarts.split('.'):
res.add(part)
return res
class SequentialRule(Rule):
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules',
@ -1494,6 +1526,11 @@ class ApplicabilityDomain(EnviPathModel):
def assess(self, structure: Union[str, 'CompoundStructure']):
ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure):
smiles = structure.smiles
else:
smiles = structure
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
# qualified_neighbours_per_rule is a nested dictionary structured as:
@ -1525,7 +1562,7 @@ class ApplicabilityDomain(EnviPathModel):
self.model.model.predict_proba(assessment_ds.X())[0],
assessment_prods[0])
res = list()
assessments = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds):
@ -1533,6 +1570,7 @@ class ApplicabilityDomain(EnviPathModel):
rule_reliabilities = dict()
local_compatibilities = dict()
neighbours_per_rule = dict()
neighbor_probs_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
@ -1568,19 +1606,63 @@ class ApplicabilityDomain(EnviPathModel):
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index]
# Assemble result for instance
res.append({
'in_ad': self.pca.is_applicable(instance)[0],
'rule_reliabilities': rule_reliabilities,
'local_compatibilities': local_compatibilities,
'neighbours': neighbours_per_rule,
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
'prob': preds
})
ad_res = {
'ad_params': {
'uuid': str(self.uuid),
'model': self.model.simple_json(),
'num_neighbours': self.num_neighbours,
'reliability_threshold': self.reliability_threshold,
'local_compatibilty_threshold': self.local_compatibilty_threshold,
},
'assessment': {
'smiles': smiles,
'inside_app_domain': self.pca.is_applicable(instance)[0],
}
}
transformations = list()
for rule_idx in rule_reliabilities.keys():
rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', ''))
return res
rule_data = rule.simple_json()
rule_data['image'] = f"{rule.url}?image=svg"
neighbors = []
for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]):
neighbor = n.simple_json()
neighbor['image'] = f"{n.url}?image=svg"
neighbor['smiles'] = n.smiles
neighbor['related_pathways'] = [
pw.simple_json() for pw in Pathway.objects.filter(
node__default_node_label=n,
package__in=self.model.data_packages.all()
).distinct()
]
neighbor['probability'] = n_prob
neighbors.append(neighbor)
transformation = {
'rule': rule_data,
'reliability': rule_reliabilities[rule_idx],
# TODO
'is_predicted': False,
'local_compatibility': local_compatibilities[rule_idx],
'probability': preds[rule_idx].probability,
'transformation_products': [x.product_set for x in preds[rule_idx].product_sets],
'times_triggered': ds.times_triggered(str(rule.uuid)),
'neighbors': neighbors,
}
transformations.append(transformation)
ad_res['assessment']['transformations'] = transformations
assessments.append(ad_res)
return assessments
@staticmethod
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
@ -1607,7 +1689,7 @@ class ApplicabilityDomain(EnviPathModel):
tn += 1
# Jaccard Index
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn);
accuracy = (tp + tn) / (tp + tn + fp + fn)
return accuracy