From 1267ca8ace2141d9a77468ecbf4ad9be968b804e Mon Sep 17 00:00:00 2001 From: jebus Date: Tue, 12 Aug 2025 09:02:11 +1200 Subject: [PATCH] Enable App Domain Assessment on Model Page (#45) Co-authored-by: Tim Lorsbach Reviewed-on: https://git.envipath.com/enviPath/enviPy/pulls/45 --- epdb/logic.py | 2 +- epdb/models.py | 110 +++++++++++++++++++++++---- epdb/views.py | 20 ++++- templates/objects/model.html | 140 ++++++++++++++++++++++++++++++++++- utilities/ml.py | 10 +++ 5 files changed, 264 insertions(+), 18 deletions(-) diff --git a/epdb/logic.py b/epdb/logic.py index 8c6ce056..d7d8d8cd 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -1066,7 +1066,7 @@ class SPathway(object): if from_depth is not None: substrates = self._get_nodes_for_depth(from_depth) elif from_node is not None: - for k,v in self.snode_persist_lookup.items(): + for k, v in self.snode_persist_lookup.items(): if from_node == v: substrates = [k] break diff --git a/epdb/models.py b/epdb/models.py index 25f7bb07..61ef704c 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -4,7 +4,7 @@ import logging import os from collections import defaultdict from datetime import datetime, timedelta -from typing import Union, List, Optional, Dict, Tuple +from typing import Union, List, Optional, Dict, Tuple, Set from uuid import uuid4 import joblib @@ -164,6 +164,18 @@ class EnviPathModel(TimeStampedModel): def url(self): pass + def simple_json(self, include_description=False): + res = { + 'url': self.url, + 'uuid': str(self.uuid), + 'name': self.name, + } + + if include_description: + res['description'] = self.description + + return res + def get_v(self, k, default=None): if self.kv: return self.kv.get(k, default) @@ -618,7 +630,7 @@ class ParallelRule(Rule): return '{}/parallel-rule/{}'.format(self.package.url, self.uuid) @property - def srs(self): + def srs(self) -> QuerySet: return self.simple_rules.all() def apply(self, structure): @@ -628,6 +640,26 @@ class ParallelRule(Rule): return list(set(res)) + @property + def reactants_smarts(self) -> Set[str]: + res = set() + + for sr in self.srs: + for part in sr.reactants_smarts.split('.'): + res.add(part) + + return res + + @property + def products_smarts(self) -> Set[str]: + res = set() + + for sr in self.srs: + for part in sr.products_smarts.split('.'): + res.add(part) + + return res + class SequentialRule(Rule): simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules', @@ -1494,6 +1526,11 @@ class ApplicabilityDomain(EnviPathModel): def assess(self, structure: Union[str, 'CompoundStructure']): ds = self.model.load_dataset() + if isinstance(structure, CompoundStructure): + smiles = structure.smiles + else: + smiles = structure + assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules) # qualified_neighbours_per_rule is a nested dictionary structured as: @@ -1525,7 +1562,7 @@ class ApplicabilityDomain(EnviPathModel): self.model.model.predict_proba(assessment_ds.X())[0], assessment_prods[0]) - res = list() + assessments = list() # loop through our assessment dataset for i, instance in enumerate(assessment_ds): @@ -1533,6 +1570,7 @@ class ApplicabilityDomain(EnviPathModel): rule_reliabilities = dict() local_compatibilities = dict() neighbours_per_rule = dict() + neighbor_probs_per_rule = dict() # loop through rule indices together with the collected neighbours indices from train dataset for rule_idx, vals in qualified_neighbours_per_rule[i].items(): @@ -1568,19 +1606,63 @@ class ApplicabilityDomain(EnviPathModel): neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index] local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets) neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets] + neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index] - # Assemble result for instance - res.append({ - 'in_ad': self.pca.is_applicable(instance)[0], - 'rule_reliabilities': rule_reliabilities, - 'local_compatibilities': local_compatibilities, - 'neighbours': neighbours_per_rule, - 'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]], - 'prob': preds - }) + ad_res = { + 'ad_params': { + 'uuid': str(self.uuid), + 'model': self.model.simple_json(), + 'num_neighbours': self.num_neighbours, + 'reliability_threshold': self.reliability_threshold, + 'local_compatibilty_threshold': self.local_compatibilty_threshold, + }, + 'assessment': { + 'smiles': smiles, + 'inside_app_domain': self.pca.is_applicable(instance)[0], + } + } + transformations = list() + for rule_idx in rule_reliabilities.keys(): + rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', '')) - return res + rule_data = rule.simple_json() + rule_data['image'] = f"{rule.url}?image=svg" + + neighbors = [] + for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]): + neighbor = n.simple_json() + neighbor['image'] = f"{n.url}?image=svg" + neighbor['smiles'] = n.smiles + neighbor['related_pathways'] = [ + pw.simple_json() for pw in Pathway.objects.filter( + node__default_node_label=n, + package__in=self.model.data_packages.all() + ).distinct() + ] + neighbor['probability'] = n_prob + + neighbors.append(neighbor) + + transformation = { + 'rule': rule_data, + 'reliability': rule_reliabilities[rule_idx], + # TODO + 'is_predicted': False, + 'local_compatibility': local_compatibilities[rule_idx], + 'probability': preds[rule_idx].probability, + 'transformation_products': [x.product_set for x in preds[rule_idx].product_sets], + 'times_triggered': ds.times_triggered(str(rule.uuid)), + 'neighbors': neighbors, + } + + transformations.append(transformation) + + ad_res['assessment']['transformations'] = transformations + + assessments.append(ad_res) + + return assessments @staticmethod def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]): @@ -1607,7 +1689,7 @@ class ApplicabilityDomain(EnviPathModel): tn += 1 # Jaccard Index if tp + tn > 0.0: - accuracy = (tp + tn) / (tp + tn + fp + fn); + accuracy = (tp + tn) / (tp + tn + fp + fn) return accuracy diff --git a/epdb/views.py b/epdb/views.py index 74d835e8..37a3ae72 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -680,7 +680,7 @@ def package_model(request, package_uuid, model_uuid): elif request.GET.get('app-domain-assessment', False): smiles = request.GET['smiles'] stand_smiles = FormatConverter.standardize(smiles) - app_domain_assessment = current_model.app_domain.assess(stand_smiles) + app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0] return JsonResponse(app_domain_assessment, safe=False) context = get_base_context(request) @@ -1048,6 +1048,24 @@ def package_rule(request, package_uuid, rule_uuid): if request.method == 'GET': context = get_base_context(request) + + if smiles := request.GET.get('smiles', False): + stand_smiles = FormatConverter.standardize(smiles) + res = current_rule.apply(stand_smiles) + if len(res) > 1: + logger.info(f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one.") + + smirks = f"{stand_smiles}>>{'.'.join(sorted(res[0]))}" + # Usually the functional groups are a mapping of fg -> count + # As we are doing it on the fly here fake a high count to ensure that its properly highlighted + educt_functional_groups = {x: 1000 for x in current_rule.reactants_smarts} + product_functional_groups = {x: 1000 for x in current_rule.products_smarts} + return HttpResponse( + IndigoUtils.smirks_to_svg(smirks, False, 0, 0, + educt_functional_groups=educt_functional_groups, + product_functional_groups=product_functional_groups), + content_type='image/svg+xml') + context['title'] = f'enviPath - {current_package.name} - {current_rule.name}' context['meta']['current_package'] = current_package diff --git a/templates/objects/model.html b/templates/objects/model.html index 4c2686b5..c58359b8 100644 --- a/templates/objects/model.html +++ b/templates/objects/model.html @@ -269,7 +269,142 @@