forked from enviPath/enviPy
Enable App Domain Assessment on Model Page (#45)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#45
This commit is contained in:
@ -1066,7 +1066,7 @@ class SPathway(object):
|
||||
if from_depth is not None:
|
||||
substrates = self._get_nodes_for_depth(from_depth)
|
||||
elif from_node is not None:
|
||||
for k,v in self.snode_persist_lookup.items():
|
||||
for k, v in self.snode_persist_lookup.items():
|
||||
if from_node == v:
|
||||
substrates = [k]
|
||||
break
|
||||
|
||||
110
epdb/models.py
110
epdb/models.py
@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union, List, Optional, Dict, Tuple
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
||||
from uuid import uuid4
|
||||
|
||||
import joblib
|
||||
@ -164,6 +164,18 @@ class EnviPathModel(TimeStampedModel):
|
||||
def url(self):
|
||||
pass
|
||||
|
||||
def simple_json(self, include_description=False):
|
||||
res = {
|
||||
'url': self.url,
|
||||
'uuid': str(self.uuid),
|
||||
'name': self.name,
|
||||
}
|
||||
|
||||
if include_description:
|
||||
res['description'] = self.description
|
||||
|
||||
return res
|
||||
|
||||
def get_v(self, k, default=None):
|
||||
if self.kv:
|
||||
return self.kv.get(k, default)
|
||||
@ -618,7 +630,7 @@ class ParallelRule(Rule):
|
||||
return '{}/parallel-rule/{}'.format(self.package.url, self.uuid)
|
||||
|
||||
@property
|
||||
def srs(self):
|
||||
def srs(self) -> QuerySet:
|
||||
return self.simple_rules.all()
|
||||
|
||||
def apply(self, structure):
|
||||
@ -628,6 +640,26 @@ class ParallelRule(Rule):
|
||||
|
||||
return list(set(res))
|
||||
|
||||
@property
|
||||
def reactants_smarts(self) -> Set[str]:
|
||||
res = set()
|
||||
|
||||
for sr in self.srs:
|
||||
for part in sr.reactants_smarts.split('.'):
|
||||
res.add(part)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def products_smarts(self) -> Set[str]:
|
||||
res = set()
|
||||
|
||||
for sr in self.srs:
|
||||
for part in sr.products_smarts.split('.'):
|
||||
res.add(part)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class SequentialRule(Rule):
|
||||
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules',
|
||||
@ -1494,6 +1526,11 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
def assess(self, structure: Union[str, 'CompoundStructure']):
|
||||
ds = self.model.load_dataset()
|
||||
|
||||
if isinstance(structure, CompoundStructure):
|
||||
smiles = structure.smiles
|
||||
else:
|
||||
smiles = structure
|
||||
|
||||
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
|
||||
|
||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||
@ -1525,7 +1562,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
||||
assessment_prods[0])
|
||||
|
||||
res = list()
|
||||
assessments = list()
|
||||
|
||||
# loop through our assessment dataset
|
||||
for i, instance in enumerate(assessment_ds):
|
||||
@ -1533,6 +1570,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
rule_reliabilities = dict()
|
||||
local_compatibilities = dict()
|
||||
neighbours_per_rule = dict()
|
||||
neighbor_probs_per_rule = dict()
|
||||
|
||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
||||
@ -1568,19 +1606,63 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
||||
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
|
||||
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
|
||||
neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index]
|
||||
|
||||
# Assemble result for instance
|
||||
res.append({
|
||||
'in_ad': self.pca.is_applicable(instance)[0],
|
||||
'rule_reliabilities': rule_reliabilities,
|
||||
'local_compatibilities': local_compatibilities,
|
||||
'neighbours': neighbours_per_rule,
|
||||
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
|
||||
'prob': preds
|
||||
})
|
||||
ad_res = {
|
||||
'ad_params': {
|
||||
'uuid': str(self.uuid),
|
||||
'model': self.model.simple_json(),
|
||||
'num_neighbours': self.num_neighbours,
|
||||
'reliability_threshold': self.reliability_threshold,
|
||||
'local_compatibilty_threshold': self.local_compatibilty_threshold,
|
||||
},
|
||||
'assessment': {
|
||||
'smiles': smiles,
|
||||
'inside_app_domain': self.pca.is_applicable(instance)[0],
|
||||
}
|
||||
}
|
||||
|
||||
transformations = list()
|
||||
for rule_idx in rule_reliabilities.keys():
|
||||
rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', ''))
|
||||
|
||||
return res
|
||||
rule_data = rule.simple_json()
|
||||
rule_data['image'] = f"{rule.url}?image=svg"
|
||||
|
||||
neighbors = []
|
||||
for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]):
|
||||
neighbor = n.simple_json()
|
||||
neighbor['image'] = f"{n.url}?image=svg"
|
||||
neighbor['smiles'] = n.smiles
|
||||
neighbor['related_pathways'] = [
|
||||
pw.simple_json() for pw in Pathway.objects.filter(
|
||||
node__default_node_label=n,
|
||||
package__in=self.model.data_packages.all()
|
||||
).distinct()
|
||||
]
|
||||
neighbor['probability'] = n_prob
|
||||
|
||||
neighbors.append(neighbor)
|
||||
|
||||
transformation = {
|
||||
'rule': rule_data,
|
||||
'reliability': rule_reliabilities[rule_idx],
|
||||
# TODO
|
||||
'is_predicted': False,
|
||||
'local_compatibility': local_compatibilities[rule_idx],
|
||||
'probability': preds[rule_idx].probability,
|
||||
'transformation_products': [x.product_set for x in preds[rule_idx].product_sets],
|
||||
'times_triggered': ds.times_triggered(str(rule.uuid)),
|
||||
'neighbors': neighbors,
|
||||
}
|
||||
|
||||
transformations.append(transformation)
|
||||
|
||||
ad_res['assessment']['transformations'] = transformations
|
||||
|
||||
assessments.append(ad_res)
|
||||
|
||||
return assessments
|
||||
|
||||
@staticmethod
|
||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||
@ -1607,7 +1689,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
tn += 1
|
||||
# Jaccard Index
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn);
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
@ -680,7 +680,7 @@ def package_model(request, package_uuid, model_uuid):
|
||||
elif request.GET.get('app-domain-assessment', False):
|
||||
smiles = request.GET['smiles']
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
|
||||
return JsonResponse(app_domain_assessment, safe=False)
|
||||
|
||||
context = get_base_context(request)
|
||||
@ -1048,6 +1048,24 @@ def package_rule(request, package_uuid, rule_uuid):
|
||||
|
||||
if request.method == 'GET':
|
||||
context = get_base_context(request)
|
||||
|
||||
if smiles := request.GET.get('smiles', False):
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
res = current_rule.apply(stand_smiles)
|
||||
if len(res) > 1:
|
||||
logger.info(f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one.")
|
||||
|
||||
smirks = f"{stand_smiles}>>{'.'.join(sorted(res[0]))}"
|
||||
# Usually the functional groups are a mapping of fg -> count
|
||||
# As we are doing it on the fly here fake a high count to ensure that its properly highlighted
|
||||
educt_functional_groups = {x: 1000 for x in current_rule.reactants_smarts}
|
||||
product_functional_groups = {x: 1000 for x in current_rule.products_smarts}
|
||||
return HttpResponse(
|
||||
IndigoUtils.smirks_to_svg(smirks, False, 0, 0,
|
||||
educt_functional_groups=educt_functional_groups,
|
||||
product_functional_groups=product_functional_groups),
|
||||
content_type='image/svg+xml')
|
||||
|
||||
context['title'] = f'enviPath - {current_package.name} - {current_rule.name}'
|
||||
|
||||
context['meta']['current_package'] = current_package
|
||||
|
||||
@ -269,7 +269,142 @@
|
||||
|
||||
<script>
|
||||
|
||||
function handleResponse(data) {
|
||||
function handleAssessmentResponse(data) {
|
||||
var inside_app_domain = "<a class='list-group-item'>This compound is " + (data["assessment"]["inside_app_domain"] ? "inside" : "outside") + " the Applicability Domain derived from the chemical (PCA) space constructed using the training data." + "</a>";
|
||||
var functionalGroupsImgSrc = "<img width='400' src='{% url 'depict' %}?smiles=" + encodeURIComponent(data['assessment']['smiles']) + "'>";
|
||||
var reactivityCentersImgSrc = "<img width='400' src='{% url 'depict' %}?smiles=" + encodeURIComponent(data['assessment']['smiles']) + "'>";
|
||||
|
||||
tpl = `<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
<a id="app-domain-assessment-functional-groups-link" data-toggle="collapse" data-parent="#app-domain-assessment" href="#app-domain-assessment-functional-groups">Functional Groups Covered by Model</a>
|
||||
</h4>
|
||||
</div>
|
||||
<div id="app-domain-assessment-functional-groups" class="panel-collapse collapse">
|
||||
<div class="panel-body list-group-item">
|
||||
${inside_app_domain}
|
||||
<p></p>
|
||||
<div id="image-div" align="center">
|
||||
${functionalGroupsImgSrc}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
<a id="app-domain-assessment-reactivity-centers-link" data-toggle="collapse" data-parent="#app-domain-assessment" href="#app-domain-assessment-reactivity-centers">Reactivity Centers</a>
|
||||
</h4>
|
||||
</div>
|
||||
<div id="app-domain-assessment-reactivity-centers" class="panel-collapse collapse">
|
||||
<div class="panel-body list-group-item">
|
||||
<div id="image-div" align="center">
|
||||
${reactivityCentersImgSrc}
|
||||
</div>
|
||||
</div>
|
||||
</div>`
|
||||
|
||||
var transformations = '';
|
||||
|
||||
for (t in data['assessment']['transformations']) {
|
||||
transObj = data['assessment']['transformations'][t];
|
||||
var neighbors = '';
|
||||
for (n in transObj['neighbors']) {
|
||||
neighObj = transObj['neighbors'][n];
|
||||
var neighImg = "<img width='100%' src='" + transObj['rule']['url'] + "?smiles=" + encodeURIComponent(neighObj['smiles']) + "'>";
|
||||
var objLink = `<a class='list-group-item' href="${neighObj['url']}">${neighObj['name']}</a>`
|
||||
var neighPredProb = "<a class='list-group-item'>Predicted probability: " + neighObj['probability'].toFixed(2) + "</a>";
|
||||
|
||||
var pwLinks = '';
|
||||
for (pw in neighObj['related_pathways']) {
|
||||
var pwObj = neighObj['related_pathways'][pw];
|
||||
pwLinks += "<a class='list-group-item' href=" + pwObj['url'] + ">" + pwObj['name'] + "</a>";
|
||||
}
|
||||
|
||||
var expPathways = `
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
<a id="transformation-${t}-neighbor-${n}-exp-pathway-link" data-toggle="collapse" data-parent="#transformation-${t}-neighbor-${n}" href="#transformation-${t}-neighbor-${n}-exp-pathway">Experimental Pathways</a>
|
||||
</h4>
|
||||
</div>
|
||||
<div id="transformation-${t}-neighbor-${n}-exp-pathway" class="panel-collapse collapse">
|
||||
<div class="panel-body list-group-item">
|
||||
${pwLinks}
|
||||
</div>
|
||||
</div>
|
||||
`
|
||||
|
||||
if (pwLinks === '') {
|
||||
expPathways = ''
|
||||
}
|
||||
|
||||
neighbors += `
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
<a id="transformation-${t}-neighbor-${n}-link" data-toggle="collapse" data-parent="#transformation-${t}" href="#transformation-${t}-neighbor-${n}">Analog Transformation on ${neighObj['name']}</a>
|
||||
</h4>
|
||||
</div>
|
||||
<div id="transformation-${t}-neighbor-${n}" class="panel-collapse collapse">
|
||||
<div class="panel-body list-group-item">
|
||||
${objLink}
|
||||
${neighPredProb}
|
||||
${expPathways}
|
||||
<p></p>
|
||||
<div id="image-div" align="center">
|
||||
${neighImg}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`
|
||||
}
|
||||
|
||||
var panelName = null;
|
||||
var objLink = null;
|
||||
if (transObj['is_predicted']) {
|
||||
panelName = `Predicted Transformation by ${transObj['rule']['name']}`;
|
||||
objLink = `<a class='list-group-item' href="${transObj['edge']['url']}">${transObj['edge']['name']}</a>`
|
||||
} else {
|
||||
panelName = `Potential Transformation by applying ${transObj['rule']['name']}`;
|
||||
objLink = `<a class='list-group-item' href="${transObj['rule']['url']}">${transObj['rule']['name']}</a>`
|
||||
}
|
||||
|
||||
var predProb = "<a class='list-group-item'>Predicted probability: " + transObj['probability'].toFixed(2) + "</a>";
|
||||
var timesTriggered = "<a class='list-group-item'>This rule has triggered " + transObj['times_triggered'] + " times in the training set</a>";
|
||||
var reliability = "<a class='list-group-item'>Reliability: " + transObj['reliability'].toFixed(2) + " (" + (transObj['reliability'] > data['ad_params']['reliability_threshold'] ? ">" : "<") + " Reliability Threshold of " + data['ad_params']['reliability_threshold'] + ") </a>";
|
||||
var localCompatibility = "<a class='list-group-item'>Local Compatibility: " + transObj['local_compatibility'].toFixed(2) + " (" + (transObj['local_compatibility'] > data['ad_params']['local_compatibilty_threshold'] ? ">" : "<") + " Local Compatibility Threshold of " + data['ad_params']['local_compatibilty_threshold'] + ")</a>";
|
||||
|
||||
var transImg = "<img width='100%' src='" + transObj['rule']['url'] + "?smiles=" + encodeURIComponent(data['assessment']['smiles']) + "'>";
|
||||
|
||||
var transformation = `
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
<a id="transformation-${t}-link" data-toggle="collapse" data-parent="#transformation-${t}" href="#transformation-${t}">${panelName}</a>
|
||||
</h4>
|
||||
</div>
|
||||
<div id="transformation-${t}" class="panel-collapse collapse">
|
||||
<div class="panel-body list-group-item">
|
||||
${objLink}
|
||||
${predProb}
|
||||
${timesTriggered}
|
||||
${reliability}
|
||||
${localCompatibility}
|
||||
<p></p>
|
||||
<div id="image-div" align="center">
|
||||
${transImg}
|
||||
</div>
|
||||
<p></p>
|
||||
${neighbors}
|
||||
</div>
|
||||
</div>
|
||||
`
|
||||
transformations += transformation;
|
||||
}
|
||||
|
||||
res = tpl + transformations;
|
||||
|
||||
$("#appDomainAssessmentResultTable").append(res);
|
||||
|
||||
}
|
||||
|
||||
function handlePredictionResponse(data) {
|
||||
res = "<table class='table table-striped'>"
|
||||
res += "<thead>"
|
||||
res += "<th scope='col'>#</th>"
|
||||
@ -327,7 +462,7 @@
|
||||
success: function (data, textStatus) {
|
||||
try {
|
||||
$("#predictLoading").empty();
|
||||
handleResponse(data);
|
||||
handlePredictionResponse(data);
|
||||
} catch (error) {
|
||||
console.log("Error");
|
||||
$("#predictLoading").empty();
|
||||
@ -363,6 +498,7 @@
|
||||
success: function (data, textStatus) {
|
||||
try {
|
||||
$("#appDomainLoading").empty();
|
||||
handleAssessmentResponse(data);
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.log("Error");
|
||||
|
||||
@ -109,6 +109,16 @@ class Dataset:
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
|
||||
self.data.append(row)
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
idx = self.columns.index(f'trig_{rule_uuid}')
|
||||
|
||||
times_triggered = 0
|
||||
for row in self.data:
|
||||
if row[idx] == 1:
|
||||
times_triggered += 1
|
||||
|
||||
return times_triggered
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
return self._struct_features
|
||||
|
||||
|
||||
Reference in New Issue
Block a user