Enable App Domain Assessment on Model Page (#45)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#45
This commit is contained in:
2025-08-12 09:02:11 +12:00
parent ec52b8872d
commit 1267ca8ace
5 changed files with 264 additions and 18 deletions

View File

@ -1066,7 +1066,7 @@ class SPathway(object):
if from_depth is not None:
substrates = self._get_nodes_for_depth(from_depth)
elif from_node is not None:
for k,v in self.snode_persist_lookup.items():
for k, v in self.snode_persist_lookup.items():
if from_node == v:
substrates = [k]
break

View File

@ -4,7 +4,7 @@ import logging
import os
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Union, List, Optional, Dict, Tuple
from typing import Union, List, Optional, Dict, Tuple, Set
from uuid import uuid4
import joblib
@ -164,6 +164,18 @@ class EnviPathModel(TimeStampedModel):
def url(self):
pass
def simple_json(self, include_description=False):
res = {
'url': self.url,
'uuid': str(self.uuid),
'name': self.name,
}
if include_description:
res['description'] = self.description
return res
def get_v(self, k, default=None):
if self.kv:
return self.kv.get(k, default)
@ -618,7 +630,7 @@ class ParallelRule(Rule):
return '{}/parallel-rule/{}'.format(self.package.url, self.uuid)
@property
def srs(self):
def srs(self) -> QuerySet:
return self.simple_rules.all()
def apply(self, structure):
@ -628,6 +640,26 @@ class ParallelRule(Rule):
return list(set(res))
@property
def reactants_smarts(self) -> Set[str]:
res = set()
for sr in self.srs:
for part in sr.reactants_smarts.split('.'):
res.add(part)
return res
@property
def products_smarts(self) -> Set[str]:
res = set()
for sr in self.srs:
for part in sr.products_smarts.split('.'):
res.add(part)
return res
class SequentialRule(Rule):
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules',
@ -1494,6 +1526,11 @@ class ApplicabilityDomain(EnviPathModel):
def assess(self, structure: Union[str, 'CompoundStructure']):
ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure):
smiles = structure.smiles
else:
smiles = structure
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
# qualified_neighbours_per_rule is a nested dictionary structured as:
@ -1525,7 +1562,7 @@ class ApplicabilityDomain(EnviPathModel):
self.model.model.predict_proba(assessment_ds.X())[0],
assessment_prods[0])
res = list()
assessments = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds):
@ -1533,6 +1570,7 @@ class ApplicabilityDomain(EnviPathModel):
rule_reliabilities = dict()
local_compatibilities = dict()
neighbours_per_rule = dict()
neighbor_probs_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
@ -1568,19 +1606,63 @@ class ApplicabilityDomain(EnviPathModel):
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index]
# Assemble result for instance
res.append({
'in_ad': self.pca.is_applicable(instance)[0],
'rule_reliabilities': rule_reliabilities,
'local_compatibilities': local_compatibilities,
'neighbours': neighbours_per_rule,
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
'prob': preds
})
ad_res = {
'ad_params': {
'uuid': str(self.uuid),
'model': self.model.simple_json(),
'num_neighbours': self.num_neighbours,
'reliability_threshold': self.reliability_threshold,
'local_compatibilty_threshold': self.local_compatibilty_threshold,
},
'assessment': {
'smiles': smiles,
'inside_app_domain': self.pca.is_applicable(instance)[0],
}
}
transformations = list()
for rule_idx in rule_reliabilities.keys():
rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', ''))
return res
rule_data = rule.simple_json()
rule_data['image'] = f"{rule.url}?image=svg"
neighbors = []
for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]):
neighbor = n.simple_json()
neighbor['image'] = f"{n.url}?image=svg"
neighbor['smiles'] = n.smiles
neighbor['related_pathways'] = [
pw.simple_json() for pw in Pathway.objects.filter(
node__default_node_label=n,
package__in=self.model.data_packages.all()
).distinct()
]
neighbor['probability'] = n_prob
neighbors.append(neighbor)
transformation = {
'rule': rule_data,
'reliability': rule_reliabilities[rule_idx],
# TODO
'is_predicted': False,
'local_compatibility': local_compatibilities[rule_idx],
'probability': preds[rule_idx].probability,
'transformation_products': [x.product_set for x in preds[rule_idx].product_sets],
'times_triggered': ds.times_triggered(str(rule.uuid)),
'neighbors': neighbors,
}
transformations.append(transformation)
ad_res['assessment']['transformations'] = transformations
assessments.append(ad_res)
return assessments
@staticmethod
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
@ -1607,7 +1689,7 @@ class ApplicabilityDomain(EnviPathModel):
tn += 1
# Jaccard Index
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn);
accuracy = (tp + tn) / (tp + tn + fp + fn)
return accuracy

View File

@ -680,7 +680,7 @@ def package_model(request, package_uuid, model_uuid):
elif request.GET.get('app-domain-assessment', False):
smiles = request.GET['smiles']
stand_smiles = FormatConverter.standardize(smiles)
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request)
@ -1048,6 +1048,24 @@ def package_rule(request, package_uuid, rule_uuid):
if request.method == 'GET':
context = get_base_context(request)
if smiles := request.GET.get('smiles', False):
stand_smiles = FormatConverter.standardize(smiles)
res = current_rule.apply(stand_smiles)
if len(res) > 1:
logger.info(f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one.")
smirks = f"{stand_smiles}>>{'.'.join(sorted(res[0]))}"
# Usually the functional groups are a mapping of fg -> count
# As we are doing it on the fly here fake a high count to ensure that its properly highlighted
educt_functional_groups = {x: 1000 for x in current_rule.reactants_smarts}
product_functional_groups = {x: 1000 for x in current_rule.products_smarts}
return HttpResponse(
IndigoUtils.smirks_to_svg(smirks, False, 0, 0,
educt_functional_groups=educt_functional_groups,
product_functional_groups=product_functional_groups),
content_type='image/svg+xml')
context['title'] = f'enviPath - {current_package.name} - {current_rule.name}'
context['meta']['current_package'] = current_package