forked from enviPath/enviPy
Enable App Domain Assessment on Model Page (#45)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#45
This commit is contained in:
@ -1066,7 +1066,7 @@ class SPathway(object):
|
||||
if from_depth is not None:
|
||||
substrates = self._get_nodes_for_depth(from_depth)
|
||||
elif from_node is not None:
|
||||
for k,v in self.snode_persist_lookup.items():
|
||||
for k, v in self.snode_persist_lookup.items():
|
||||
if from_node == v:
|
||||
substrates = [k]
|
||||
break
|
||||
|
||||
110
epdb/models.py
110
epdb/models.py
@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union, List, Optional, Dict, Tuple
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
||||
from uuid import uuid4
|
||||
|
||||
import joblib
|
||||
@ -164,6 +164,18 @@ class EnviPathModel(TimeStampedModel):
|
||||
def url(self):
|
||||
pass
|
||||
|
||||
def simple_json(self, include_description=False):
|
||||
res = {
|
||||
'url': self.url,
|
||||
'uuid': str(self.uuid),
|
||||
'name': self.name,
|
||||
}
|
||||
|
||||
if include_description:
|
||||
res['description'] = self.description
|
||||
|
||||
return res
|
||||
|
||||
def get_v(self, k, default=None):
|
||||
if self.kv:
|
||||
return self.kv.get(k, default)
|
||||
@ -618,7 +630,7 @@ class ParallelRule(Rule):
|
||||
return '{}/parallel-rule/{}'.format(self.package.url, self.uuid)
|
||||
|
||||
@property
|
||||
def srs(self):
|
||||
def srs(self) -> QuerySet:
|
||||
return self.simple_rules.all()
|
||||
|
||||
def apply(self, structure):
|
||||
@ -628,6 +640,26 @@ class ParallelRule(Rule):
|
||||
|
||||
return list(set(res))
|
||||
|
||||
@property
|
||||
def reactants_smarts(self) -> Set[str]:
|
||||
res = set()
|
||||
|
||||
for sr in self.srs:
|
||||
for part in sr.reactants_smarts.split('.'):
|
||||
res.add(part)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def products_smarts(self) -> Set[str]:
|
||||
res = set()
|
||||
|
||||
for sr in self.srs:
|
||||
for part in sr.products_smarts.split('.'):
|
||||
res.add(part)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class SequentialRule(Rule):
|
||||
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules',
|
||||
@ -1494,6 +1526,11 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
def assess(self, structure: Union[str, 'CompoundStructure']):
|
||||
ds = self.model.load_dataset()
|
||||
|
||||
if isinstance(structure, CompoundStructure):
|
||||
smiles = structure.smiles
|
||||
else:
|
||||
smiles = structure
|
||||
|
||||
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
|
||||
|
||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||
@ -1525,7 +1562,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
||||
assessment_prods[0])
|
||||
|
||||
res = list()
|
||||
assessments = list()
|
||||
|
||||
# loop through our assessment dataset
|
||||
for i, instance in enumerate(assessment_ds):
|
||||
@ -1533,6 +1570,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
rule_reliabilities = dict()
|
||||
local_compatibilities = dict()
|
||||
neighbours_per_rule = dict()
|
||||
neighbor_probs_per_rule = dict()
|
||||
|
||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
||||
@ -1568,19 +1606,63 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
||||
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
|
||||
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
|
||||
neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index]
|
||||
|
||||
# Assemble result for instance
|
||||
res.append({
|
||||
'in_ad': self.pca.is_applicable(instance)[0],
|
||||
'rule_reliabilities': rule_reliabilities,
|
||||
'local_compatibilities': local_compatibilities,
|
||||
'neighbours': neighbours_per_rule,
|
||||
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
|
||||
'prob': preds
|
||||
})
|
||||
ad_res = {
|
||||
'ad_params': {
|
||||
'uuid': str(self.uuid),
|
||||
'model': self.model.simple_json(),
|
||||
'num_neighbours': self.num_neighbours,
|
||||
'reliability_threshold': self.reliability_threshold,
|
||||
'local_compatibilty_threshold': self.local_compatibilty_threshold,
|
||||
},
|
||||
'assessment': {
|
||||
'smiles': smiles,
|
||||
'inside_app_domain': self.pca.is_applicable(instance)[0],
|
||||
}
|
||||
}
|
||||
|
||||
transformations = list()
|
||||
for rule_idx in rule_reliabilities.keys():
|
||||
rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', ''))
|
||||
|
||||
return res
|
||||
rule_data = rule.simple_json()
|
||||
rule_data['image'] = f"{rule.url}?image=svg"
|
||||
|
||||
neighbors = []
|
||||
for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]):
|
||||
neighbor = n.simple_json()
|
||||
neighbor['image'] = f"{n.url}?image=svg"
|
||||
neighbor['smiles'] = n.smiles
|
||||
neighbor['related_pathways'] = [
|
||||
pw.simple_json() for pw in Pathway.objects.filter(
|
||||
node__default_node_label=n,
|
||||
package__in=self.model.data_packages.all()
|
||||
).distinct()
|
||||
]
|
||||
neighbor['probability'] = n_prob
|
||||
|
||||
neighbors.append(neighbor)
|
||||
|
||||
transformation = {
|
||||
'rule': rule_data,
|
||||
'reliability': rule_reliabilities[rule_idx],
|
||||
# TODO
|
||||
'is_predicted': False,
|
||||
'local_compatibility': local_compatibilities[rule_idx],
|
||||
'probability': preds[rule_idx].probability,
|
||||
'transformation_products': [x.product_set for x in preds[rule_idx].product_sets],
|
||||
'times_triggered': ds.times_triggered(str(rule.uuid)),
|
||||
'neighbors': neighbors,
|
||||
}
|
||||
|
||||
transformations.append(transformation)
|
||||
|
||||
ad_res['assessment']['transformations'] = transformations
|
||||
|
||||
assessments.append(ad_res)
|
||||
|
||||
return assessments
|
||||
|
||||
@staticmethod
|
||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||
@ -1607,7 +1689,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
tn += 1
|
||||
# Jaccard Index
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn);
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
@ -680,7 +680,7 @@ def package_model(request, package_uuid, model_uuid):
|
||||
elif request.GET.get('app-domain-assessment', False):
|
||||
smiles = request.GET['smiles']
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
|
||||
return JsonResponse(app_domain_assessment, safe=False)
|
||||
|
||||
context = get_base_context(request)
|
||||
@ -1048,6 +1048,24 @@ def package_rule(request, package_uuid, rule_uuid):
|
||||
|
||||
if request.method == 'GET':
|
||||
context = get_base_context(request)
|
||||
|
||||
if smiles := request.GET.get('smiles', False):
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
res = current_rule.apply(stand_smiles)
|
||||
if len(res) > 1:
|
||||
logger.info(f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one.")
|
||||
|
||||
smirks = f"{stand_smiles}>>{'.'.join(sorted(res[0]))}"
|
||||
# Usually the functional groups are a mapping of fg -> count
|
||||
# As we are doing it on the fly here fake a high count to ensure that its properly highlighted
|
||||
educt_functional_groups = {x: 1000 for x in current_rule.reactants_smarts}
|
||||
product_functional_groups = {x: 1000 for x in current_rule.products_smarts}
|
||||
return HttpResponse(
|
||||
IndigoUtils.smirks_to_svg(smirks, False, 0, 0,
|
||||
educt_functional_groups=educt_functional_groups,
|
||||
product_functional_groups=product_functional_groups),
|
||||
content_type='image/svg+xml')
|
||||
|
||||
context['title'] = f'enviPath - {current_package.name} - {current_rule.name}'
|
||||
|
||||
context['meta']['current_package'] = current_package
|
||||
|
||||
Reference in New Issue
Block a user