Enable App Domain Assessment on Model Page (#45)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#45
This commit is contained in:
2025-08-12 09:02:11 +12:00
parent ec52b8872d
commit 1267ca8ace
5 changed files with 264 additions and 18 deletions

View File

@ -1066,7 +1066,7 @@ class SPathway(object):
if from_depth is not None:
substrates = self._get_nodes_for_depth(from_depth)
elif from_node is not None:
for k,v in self.snode_persist_lookup.items():
for k, v in self.snode_persist_lookup.items():
if from_node == v:
substrates = [k]
break

View File

@ -4,7 +4,7 @@ import logging
import os
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Union, List, Optional, Dict, Tuple
from typing import Union, List, Optional, Dict, Tuple, Set
from uuid import uuid4
import joblib
@ -164,6 +164,18 @@ class EnviPathModel(TimeStampedModel):
def url(self):
pass
def simple_json(self, include_description=False):
res = {
'url': self.url,
'uuid': str(self.uuid),
'name': self.name,
}
if include_description:
res['description'] = self.description
return res
def get_v(self, k, default=None):
if self.kv:
return self.kv.get(k, default)
@ -618,7 +630,7 @@ class ParallelRule(Rule):
return '{}/parallel-rule/{}'.format(self.package.url, self.uuid)
@property
def srs(self):
def srs(self) -> QuerySet:
return self.simple_rules.all()
def apply(self, structure):
@ -628,6 +640,26 @@ class ParallelRule(Rule):
return list(set(res))
@property
def reactants_smarts(self) -> Set[str]:
res = set()
for sr in self.srs:
for part in sr.reactants_smarts.split('.'):
res.add(part)
return res
@property
def products_smarts(self) -> Set[str]:
res = set()
for sr in self.srs:
for part in sr.products_smarts.split('.'):
res.add(part)
return res
class SequentialRule(Rule):
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules',
@ -1494,6 +1526,11 @@ class ApplicabilityDomain(EnviPathModel):
def assess(self, structure: Union[str, 'CompoundStructure']):
ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure):
smiles = structure.smiles
else:
smiles = structure
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
# qualified_neighbours_per_rule is a nested dictionary structured as:
@ -1525,7 +1562,7 @@ class ApplicabilityDomain(EnviPathModel):
self.model.model.predict_proba(assessment_ds.X())[0],
assessment_prods[0])
res = list()
assessments = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds):
@ -1533,6 +1570,7 @@ class ApplicabilityDomain(EnviPathModel):
rule_reliabilities = dict()
local_compatibilities = dict()
neighbours_per_rule = dict()
neighbor_probs_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
@ -1568,19 +1606,63 @@ class ApplicabilityDomain(EnviPathModel):
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index]
# Assemble result for instance
res.append({
'in_ad': self.pca.is_applicable(instance)[0],
'rule_reliabilities': rule_reliabilities,
'local_compatibilities': local_compatibilities,
'neighbours': neighbours_per_rule,
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
'prob': preds
})
ad_res = {
'ad_params': {
'uuid': str(self.uuid),
'model': self.model.simple_json(),
'num_neighbours': self.num_neighbours,
'reliability_threshold': self.reliability_threshold,
'local_compatibilty_threshold': self.local_compatibilty_threshold,
},
'assessment': {
'smiles': smiles,
'inside_app_domain': self.pca.is_applicable(instance)[0],
}
}
transformations = list()
for rule_idx in rule_reliabilities.keys():
rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', ''))
return res
rule_data = rule.simple_json()
rule_data['image'] = f"{rule.url}?image=svg"
neighbors = []
for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]):
neighbor = n.simple_json()
neighbor['image'] = f"{n.url}?image=svg"
neighbor['smiles'] = n.smiles
neighbor['related_pathways'] = [
pw.simple_json() for pw in Pathway.objects.filter(
node__default_node_label=n,
package__in=self.model.data_packages.all()
).distinct()
]
neighbor['probability'] = n_prob
neighbors.append(neighbor)
transformation = {
'rule': rule_data,
'reliability': rule_reliabilities[rule_idx],
# TODO
'is_predicted': False,
'local_compatibility': local_compatibilities[rule_idx],
'probability': preds[rule_idx].probability,
'transformation_products': [x.product_set for x in preds[rule_idx].product_sets],
'times_triggered': ds.times_triggered(str(rule.uuid)),
'neighbors': neighbors,
}
transformations.append(transformation)
ad_res['assessment']['transformations'] = transformations
assessments.append(ad_res)
return assessments
@staticmethod
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
@ -1607,7 +1689,7 @@ class ApplicabilityDomain(EnviPathModel):
tn += 1
# Jaccard Index
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn);
accuracy = (tp + tn) / (tp + tn + fp + fn)
return accuracy

View File

@ -680,7 +680,7 @@ def package_model(request, package_uuid, model_uuid):
elif request.GET.get('app-domain-assessment', False):
smiles = request.GET['smiles']
stand_smiles = FormatConverter.standardize(smiles)
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request)
@ -1048,6 +1048,24 @@ def package_rule(request, package_uuid, rule_uuid):
if request.method == 'GET':
context = get_base_context(request)
if smiles := request.GET.get('smiles', False):
stand_smiles = FormatConverter.standardize(smiles)
res = current_rule.apply(stand_smiles)
if len(res) > 1:
logger.info(f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one.")
smirks = f"{stand_smiles}>>{'.'.join(sorted(res[0]))}"
# Usually the functional groups are a mapping of fg -> count
# As we are doing it on the fly here fake a high count to ensure that its properly highlighted
educt_functional_groups = {x: 1000 for x in current_rule.reactants_smarts}
product_functional_groups = {x: 1000 for x in current_rule.products_smarts}
return HttpResponse(
IndigoUtils.smirks_to_svg(smirks, False, 0, 0,
educt_functional_groups=educt_functional_groups,
product_functional_groups=product_functional_groups),
content_type='image/svg+xml')
context['title'] = f'enviPath - {current_package.name} - {current_rule.name}'
context['meta']['current_package'] = current_package

View File

@ -269,7 +269,142 @@
<script>
function handleResponse(data) {
function handleAssessmentResponse(data) {
var inside_app_domain = "<a class='list-group-item'>This compound is " + (data["assessment"]["inside_app_domain"] ? "inside" : "outside") + " the Applicability Domain derived from the chemical (PCA) space constructed using the training data." + "</a>";
var functionalGroupsImgSrc = "<img width='400' src='{% url 'depict' %}?smiles=" + encodeURIComponent(data['assessment']['smiles']) + "'>";
var reactivityCentersImgSrc = "<img width='400' src='{% url 'depict' %}?smiles=" + encodeURIComponent(data['assessment']['smiles']) + "'>";
tpl = `<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">
<a id="app-domain-assessment-functional-groups-link" data-toggle="collapse" data-parent="#app-domain-assessment" href="#app-domain-assessment-functional-groups">Functional Groups Covered by Model</a>
</h4>
</div>
<div id="app-domain-assessment-functional-groups" class="panel-collapse collapse">
<div class="panel-body list-group-item">
${inside_app_domain}
<p></p>
<div id="image-div" align="center">
${functionalGroupsImgSrc}
</div>
</div>
</div>
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">
<a id="app-domain-assessment-reactivity-centers-link" data-toggle="collapse" data-parent="#app-domain-assessment" href="#app-domain-assessment-reactivity-centers">Reactivity Centers</a>
</h4>
</div>
<div id="app-domain-assessment-reactivity-centers" class="panel-collapse collapse">
<div class="panel-body list-group-item">
<div id="image-div" align="center">
${reactivityCentersImgSrc}
</div>
</div>
</div>`
var transformations = '';
for (t in data['assessment']['transformations']) {
transObj = data['assessment']['transformations'][t];
var neighbors = '';
for (n in transObj['neighbors']) {
neighObj = transObj['neighbors'][n];
var neighImg = "<img width='100%' src='" + transObj['rule']['url'] + "?smiles=" + encodeURIComponent(neighObj['smiles']) + "'>";
var objLink = `<a class='list-group-item' href="${neighObj['url']}">${neighObj['name']}</a>`
var neighPredProb = "<a class='list-group-item'>Predicted probability: " + neighObj['probability'].toFixed(2) + "</a>";
var pwLinks = '';
for (pw in neighObj['related_pathways']) {
var pwObj = neighObj['related_pathways'][pw];
pwLinks += "<a class='list-group-item' href=" + pwObj['url'] + ">" + pwObj['name'] + "</a>";
}
var expPathways = `
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">
<a id="transformation-${t}-neighbor-${n}-exp-pathway-link" data-toggle="collapse" data-parent="#transformation-${t}-neighbor-${n}" href="#transformation-${t}-neighbor-${n}-exp-pathway">Experimental Pathways</a>
</h4>
</div>
<div id="transformation-${t}-neighbor-${n}-exp-pathway" class="panel-collapse collapse">
<div class="panel-body list-group-item">
${pwLinks}
</div>
</div>
`
if (pwLinks === '') {
expPathways = ''
}
neighbors += `
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">
<a id="transformation-${t}-neighbor-${n}-link" data-toggle="collapse" data-parent="#transformation-${t}" href="#transformation-${t}-neighbor-${n}">Analog Transformation on ${neighObj['name']}</a>
</h4>
</div>
<div id="transformation-${t}-neighbor-${n}" class="panel-collapse collapse">
<div class="panel-body list-group-item">
${objLink}
${neighPredProb}
${expPathways}
<p></p>
<div id="image-div" align="center">
${neighImg}
</div>
</div>
</div>
`
}
var panelName = null;
var objLink = null;
if (transObj['is_predicted']) {
panelName = `Predicted Transformation by ${transObj['rule']['name']}`;
objLink = `<a class='list-group-item' href="${transObj['edge']['url']}">${transObj['edge']['name']}</a>`
} else {
panelName = `Potential Transformation by applying ${transObj['rule']['name']}`;
objLink = `<a class='list-group-item' href="${transObj['rule']['url']}">${transObj['rule']['name']}</a>`
}
var predProb = "<a class='list-group-item'>Predicted probability: " + transObj['probability'].toFixed(2) + "</a>";
var timesTriggered = "<a class='list-group-item'>This rule has triggered " + transObj['times_triggered'] + " times in the training set</a>";
var reliability = "<a class='list-group-item'>Reliability: " + transObj['reliability'].toFixed(2) + " (" + (transObj['reliability'] > data['ad_params']['reliability_threshold'] ? "&gt" : "&lt") + " Reliability Threshold of " + data['ad_params']['reliability_threshold'] + ") </a>";
var localCompatibility = "<a class='list-group-item'>Local Compatibility: " + transObj['local_compatibility'].toFixed(2) + " (" + (transObj['local_compatibility'] > data['ad_params']['local_compatibilty_threshold'] ? "&gt" : "&lt") + " Local Compatibility Threshold of " + data['ad_params']['local_compatibilty_threshold'] + ")</a>";
var transImg = "<img width='100%' src='" + transObj['rule']['url'] + "?smiles=" + encodeURIComponent(data['assessment']['smiles']) + "'>";
var transformation = `
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">
<a id="transformation-${t}-link" data-toggle="collapse" data-parent="#transformation-${t}" href="#transformation-${t}">${panelName}</a>
</h4>
</div>
<div id="transformation-${t}" class="panel-collapse collapse">
<div class="panel-body list-group-item">
${objLink}
${predProb}
${timesTriggered}
${reliability}
${localCompatibility}
<p></p>
<div id="image-div" align="center">
${transImg}
</div>
<p></p>
${neighbors}
</div>
</div>
`
transformations += transformation;
}
res = tpl + transformations;
$("#appDomainAssessmentResultTable").append(res);
}
function handlePredictionResponse(data) {
res = "<table class='table table-striped'>"
res += "<thead>"
res += "<th scope='col'>#</th>"
@ -327,7 +462,7 @@
success: function (data, textStatus) {
try {
$("#predictLoading").empty();
handleResponse(data);
handlePredictionResponse(data);
} catch (error) {
console.log("Error");
$("#predictLoading").empty();
@ -363,6 +498,7 @@
success: function (data, textStatus) {
try {
$("#appDomainLoading").empty();
handleAssessmentResponse(data);
console.log(data);
} catch (error) {
console.log("Error");

View File

@ -109,6 +109,16 @@ class Dataset:
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f'trig_{rule_uuid}')
times_triggered = 0
for row in self.data:
if row[idx] == 1:
times_triggered += 1
return times_triggered
def struct_features(self) -> Tuple[int, int]:
return self._struct_features