forked from enviPath/enviPy
Experimental App Domain (#43)
Backend App Domain done, Frontend missing Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#43
This commit is contained in:
486
epdb/models.py
486
epdb/models.py
@ -3,8 +3,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Union, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union, List, Optional, Dict, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import joblib
|
||||
@ -14,7 +14,7 @@ from django.contrib.auth.hashers import make_password, check_password
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import models, transaction
|
||||
from django.db.models import JSONField, Count, Q
|
||||
from django.db.models import JSONField, Count, Q, QuerySet
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
from model_utils.models import TimeStampedModel
|
||||
@ -23,7 +23,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import SparseLabelECC
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -172,6 +172,9 @@ class EnviPathModel(TimeStampedModel):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name} (pk={self.pk})"
|
||||
|
||||
|
||||
class AliasMixin(models.Model):
|
||||
aliases = ArrayField(
|
||||
@ -844,7 +847,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
# We shouldn't lose or make up nodes...
|
||||
assert len(nodes) == len(self.nodes)
|
||||
print(f"Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
|
||||
logger.debug(f"{self.name}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
|
||||
|
||||
links = [e.d3_json() for e in self.edges]
|
||||
|
||||
@ -1136,19 +1139,44 @@ class MLRelativeReasoning(EPModel):
|
||||
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
|
||||
def status(self):
|
||||
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package, name, description, rule_packages, data_packages, eval_packages, threshold):
|
||||
def create(package: 'Package', rule_packages: List['Package'],
|
||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
|
||||
mlrr = MLRelativeReasoning()
|
||||
mlrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mlrr.name = name
|
||||
mlrr.description = description
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mlrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mlrr.threshold = threshold
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
mlrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mlrr.rule_packages.add(p)
|
||||
|
||||
@ -1163,11 +1191,17 @@ class MLRelativeReasoning(EPModel):
|
||||
for p in eval_packages:
|
||||
mlrr.eval_packages.add(p)
|
||||
|
||||
if build_app_domain:
|
||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
app_domain_local_compatibility_threshold)
|
||||
mlrr.app_domain = ad
|
||||
|
||||
mlrr.save()
|
||||
|
||||
return mlrr
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self):
|
||||
def applicable_rules(self) -> List['Rule']:
|
||||
"""
|
||||
Returns a ordered set of rules where the following applies:
|
||||
1. All Composite will be added to result
|
||||
@ -1195,6 +1229,7 @@ class MLRelativeReasoning(EPModel):
|
||||
rules.append(r)
|
||||
|
||||
rules = sorted(rules, key=lambda x: x.url)
|
||||
|
||||
return rules
|
||||
|
||||
def _get_excludes(self):
|
||||
@ -1209,197 +1244,79 @@ class MLRelativeReasoning(EPModel):
|
||||
pathway_qs = pathway_qs.distinct()
|
||||
return pathway_qs
|
||||
|
||||
|
||||
def _get_reactions(self) -> QuerySet:
|
||||
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
def build_dataset(self):
|
||||
self.model_status = self.INITIALIZING
|
||||
self.save()
|
||||
from datetime import datetime
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
applicable_rules = self.applicable_rules
|
||||
print("got rules")
|
||||
|
||||
# if s.DEBUG:
|
||||
# pathways = self._get_pathways().order_by('-name')[:20]
|
||||
# else:
|
||||
pathways = self._get_pathways()
|
||||
|
||||
print("got pathways")
|
||||
excludes = self._get_excludes()
|
||||
|
||||
# Collect all compounds
|
||||
compounds = set()
|
||||
reactions = set()
|
||||
for i, p in enumerate(pathways):
|
||||
print(f"{i + 1}/{len(pathways)}...")
|
||||
for n in p.nodes:
|
||||
cs = n.default_node_label.compound.default_structure
|
||||
# TODO too many lookups
|
||||
if cs.smiles in excludes:
|
||||
continue
|
||||
|
||||
compounds.add(cs)
|
||||
|
||||
for e in p.edges:
|
||||
reactions.add(e.edge_label)
|
||||
|
||||
print(len(compounds))
|
||||
print(len(reactions))
|
||||
|
||||
triggered = set()
|
||||
observed = set()
|
||||
|
||||
# TODO naming
|
||||
|
||||
pw = defaultdict(lambda: defaultdict(set))
|
||||
|
||||
for i, c in enumerate(compounds):
|
||||
print(f"{i + 1}/{len(compounds)}...")
|
||||
for r in applicable_rules:
|
||||
# TODO check normalization
|
||||
product_sets = r.apply(c.smiles)
|
||||
|
||||
if len(product_sets) == 0:
|
||||
continue
|
||||
|
||||
triggered.add(f"{r.uuid} + {c.uuid}")
|
||||
|
||||
for ps in product_sets:
|
||||
for p in ps:
|
||||
pw[c][r].add(p)
|
||||
|
||||
for r in reactions:
|
||||
if r is None:
|
||||
print(r)
|
||||
continue
|
||||
if len(r.educts.all()) != 1:
|
||||
print(f"Skipping {r.url}")
|
||||
continue
|
||||
|
||||
# Loop will run only once
|
||||
for c in r.educts.all():
|
||||
if c not in pw:
|
||||
continue
|
||||
|
||||
for rule in pw[c].keys():
|
||||
# standardize...
|
||||
|
||||
if 0 != len(pw[c][rule]) and len(pw[c][rule]) == len(r.products.all()):
|
||||
print(f"potential match for {c.smiles} and {r.uuid} ({r.name})")
|
||||
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
pass
|
||||
|
||||
standardized_products.append(smi)
|
||||
|
||||
standardized_pred_products = []
|
||||
for smi in pw[c][rule]:
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
pass
|
||||
|
||||
standardized_pred_products.append(smi)
|
||||
|
||||
if sorted(list(set(standardized_products))) == sorted(list(set(standardized_pred_products))):
|
||||
observed.add(f"{rule.uuid} + {c.uuid}")
|
||||
print(f"Adding observed, current count {len(observed)}")
|
||||
|
||||
header = None
|
||||
X = []
|
||||
y = []
|
||||
for i, c in enumerate(compounds):
|
||||
print(f'{i + 1}/{len(compounds)}...')
|
||||
# Features
|
||||
feat = FormatConverter.maccs(c.smiles)
|
||||
trig = []
|
||||
obs = []
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {c.uuid}"
|
||||
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
else:
|
||||
obs.append(0)
|
||||
|
||||
if header is None:
|
||||
header = [f'feature_{i}' for i, _ in enumerate(feat)] \
|
||||
+ [f'trig_{r.uuid}' for r in applicable_rules] \
|
||||
+ [f'corr_{r.uuid}' for r in applicable_rules]
|
||||
X.append(feat + trig)
|
||||
y.append(obs)
|
||||
reactions = list(self._get_reactions())
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
|
||||
end = datetime.now()
|
||||
print(f"Duration {(end - start).total_seconds()}s")
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
|
||||
data = {
|
||||
'X': X,
|
||||
'y': y,
|
||||
'header': header
|
||||
}
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
|
||||
json.dump(data, open(f, 'w'))
|
||||
return X, y
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
|
||||
return json.load(open(ds_path, 'r'))
|
||||
def load_dataset(self) -> 'Dataset':
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
def build_model(self, X, y):
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
mod = SparseLabelECC(
|
||||
**s.DEFAULT_DT_MODEL_PARAMS
|
||||
)
|
||||
start = datetime.now()
|
||||
|
||||
ds = self.load_dataset()
|
||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||
|
||||
mod = EnsembleClassifierChain(
|
||||
**s.DEFAULT_MODEL_PARAMS
|
||||
)
|
||||
mod.fit(X, y)
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.pkl")
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
if self.app_domain is not None:
|
||||
logger.debug("Building applicability domain...")
|
||||
self.app_domain.build()
|
||||
logger.debug("Done building applicability domain.")
|
||||
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
def rebuild(self):
|
||||
data = self.load_dataset()
|
||||
self.build_model(data['X'], data['y'])
|
||||
self.build_model()
|
||||
|
||||
def evaluate_model(self):
|
||||
"""
|
||||
Performs Leave-One-Out cross-validation on a multi-label dataset.
|
||||
|
||||
Parameters:
|
||||
X (list of lists): Feature matrix.
|
||||
y (list of lists): Multi-label targets.
|
||||
classifier (sklearn estimator, optional): Base classifier. Defaults to RandomForest.
|
||||
|
||||
Returns:
|
||||
float: Average accuracy across all LOO splits.
|
||||
"""
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
|
||||
data = json.load(open(f))
|
||||
ds = self.load_dataset()
|
||||
|
||||
X = np.array(data['X'])
|
||||
y = np.array(data['y'])
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
|
||||
n_splits = 20
|
||||
|
||||
@ -1409,22 +1326,32 @@ class MLRelativeReasoning(EPModel):
|
||||
X_train, X_test = X[train_index], X[test_index]
|
||||
y_train, y_test = y[train_index], y[test_index]
|
||||
|
||||
model = SparseLabelECC(
|
||||
**s.DEFAULT_DT_MODEL_PARAMS
|
||||
model = EnsembleClassifierChain(
|
||||
**s.DEFAULT_MODEL_PARAMS
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict_proba(X_test)
|
||||
y_thresholded = (y_pred >= threshold).astype(int)
|
||||
|
||||
acc = jaccard_score(y_test, y_thresholded, average='samples', zero_division=0)
|
||||
# Flatten them to get rid of np.nan
|
||||
y_test = np.asarray(y_test).flatten()
|
||||
y_pred = np.asarray(y_pred).flatten()
|
||||
y_thresholded = np.asarray(y_thresholded).flatten()
|
||||
|
||||
mask = ~np.isnan(y_test)
|
||||
y_test_filtered = y_test[mask]
|
||||
y_pred_filtered = y_pred[mask]
|
||||
y_thresholded_filtered = y_thresholded[mask]
|
||||
|
||||
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
||||
|
||||
prec, rec = dict(), dict()
|
||||
|
||||
for t in np.arange(0, 1.05, 0.05):
|
||||
temp_thresholded = (y_pred >= t).astype(int)
|
||||
prec[f"{t:.2f}"] = precision_score(y_test, temp_thresholded, average='samples', zero_division=0)
|
||||
rec[f"{t:.2f}"] = recall_score(y_test, temp_thresholded, average='samples', zero_division=0)
|
||||
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
||||
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
|
||||
return acc, prec, rec
|
||||
|
||||
@ -1462,38 +1389,30 @@ class MLRelativeReasoning(EPModel):
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}.pkl'))
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
mod.base_clf.n_jobs = -1
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List['PredictionResult']:
|
||||
start = datetime.now()
|
||||
features = FormatConverter.maccs(smiles)
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
pred = self.model.predict_proba(classify_ds.X())
|
||||
|
||||
trig = []
|
||||
prods = []
|
||||
for rule in self.applicable_rules:
|
||||
products = rule.apply(smiles)
|
||||
|
||||
if len(products):
|
||||
trig.append(1)
|
||||
prods.append(products)
|
||||
else:
|
||||
trig.append(0)
|
||||
prods.append([])
|
||||
|
||||
end_ds_gen = datetime.now()
|
||||
logger.info(f"Gen predict dataset took {(end_ds_gen - start).total_seconds()}s")
|
||||
pred = self.model.predict_proba([features + trig])
|
||||
|
||||
res = []
|
||||
for rule, p, smis in zip(self.applicable_rules, pred[0], prods):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
res = MLRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||
|
||||
end = datetime.now()
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
for rule, p, smis in zip(rules, probabilities, products):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
return res
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
@ -1515,26 +1434,171 @@ class MLRelativeReasoning(EPModel):
|
||||
class ApplicabilityDomain(EnviPathModel):
|
||||
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
||||
|
||||
num_neighbours = models.FloatField(blank=False, null=False, default=5)
|
||||
num_neighbours = models.IntegerField(blank=False, null=False, default=5)
|
||||
reliability_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
||||
local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
||||
|
||||
def build_applicability_domain(self):
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(mlrr: MLRelativeReasoning, num_neighbours: int = 5, reliability_threshold: float = 0.5,
|
||||
local_compatibility_threshold: float = 0.5):
|
||||
ad = ApplicabilityDomain()
|
||||
ad.model = mlrr
|
||||
# ad.uuid = mlrr.uuid
|
||||
ad.name = f"AD for {mlrr.name}"
|
||||
ad.num_neighbours = num_neighbours
|
||||
ad.reliability_threshold = reliability_threshold
|
||||
ad.local_compatibilty_threshold = local_compatibility_threshold
|
||||
ad.save()
|
||||
return ad
|
||||
|
||||
@cached_property
|
||||
def pca(self) -> ApplicabilityDomainPCA:
|
||||
pca = joblib.load(os.path.join(s.MODEL_DIR, f'{self.model.uuid}_pca.pkl'))
|
||||
return pca
|
||||
|
||||
@cached_property
|
||||
def training_set_probs(self):
|
||||
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
||||
|
||||
def build(self):
|
||||
ds = self.model.load_dataset()
|
||||
X = ds['X']
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
pca = PCA(n_components=5) # choose number of components
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
start = datetime.now()
|
||||
|
||||
max_vals = np.max(X_pca, axis=0)
|
||||
min_vals = np.min(X_pca, axis=0)
|
||||
# Get Trainingset probs and dump them as they're required when using the app domain
|
||||
probs = self.model.model.predict_proba(ds.X())
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
||||
joblib.dump(probs, f)
|
||||
|
||||
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
||||
ad.build(ds)
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds")
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")
|
||||
joblib.dump(ad, f)
|
||||
|
||||
def assess(self, structure: Union[str, 'CompoundStructure']):
|
||||
ds = self.model.load_dataset()
|
||||
|
||||
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
|
||||
|
||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||
# {
|
||||
# assessment_structure_index: {
|
||||
# rule_index: [training_structure_indices_with_same_triggered_reaction]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# For each structure in the assessment dataset and each rule (represented by a trigger feature),
|
||||
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
||||
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
||||
# with a given assessment structure under a particular rule.
|
||||
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
||||
feature = ds.columns[feature_index]
|
||||
if feature.startswith('trig_'):
|
||||
# TODO unroll loop
|
||||
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
||||
if int(cx[feature_index]) == 1:
|
||||
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
||||
if int(tx[feature_index]) == 1:
|
||||
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
||||
|
||||
probs = self.training_set_probs
|
||||
# preds = self.model.model.predict_proba(assessment_ds.X())
|
||||
preds = self.model.combine_products_and_probs(self.model.applicable_rules,
|
||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
||||
assessment_prods[0])
|
||||
|
||||
res = list()
|
||||
|
||||
# loop through our assessment dataset
|
||||
for i, instance in enumerate(assessment_ds):
|
||||
|
||||
rule_reliabilities = dict()
|
||||
local_compatibilities = dict()
|
||||
neighbours_per_rule = dict()
|
||||
|
||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
||||
|
||||
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
||||
# train dataset
|
||||
train_instances = []
|
||||
for v in vals:
|
||||
train_instances.append((v, ds.at(v)))
|
||||
|
||||
# sf is a tuple with start/end index of the features
|
||||
sf = ds.struct_features()
|
||||
|
||||
# compute tanimoto distance for all neighbours
|
||||
# result ist a list of tuples with train index and computed distance
|
||||
dists = self._compute_distances(
|
||||
instance.X()[0][sf[0]:sf[1]],
|
||||
[ti[1].X()[0][sf[0]:sf[1]] for ti in train_instances]
|
||||
)
|
||||
|
||||
dists_with_index = list()
|
||||
for ti, dist in zip(train_instances, dists):
|
||||
dists_with_index.append((ti[0], dist[1]))
|
||||
|
||||
# sort them in a descending way and take at most `self.num_neighbours`
|
||||
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
||||
dists_with_index = dists_with_index[:self.num_neighbours]
|
||||
|
||||
# compute average distance
|
||||
rule_reliabilities[rule_idx] = sum([d[1] for d in dists_with_index]) / len(dists_with_index) if len(dists_with_index) > 0 else 0.0
|
||||
|
||||
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
||||
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
|
||||
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
|
||||
|
||||
# Assemble result for instance
|
||||
res.append({
|
||||
'in_ad': self.pca.is_applicable(instance)[0],
|
||||
'rule_reliabilities': rule_reliabilities,
|
||||
'local_compatibilities': local_compatibilities,
|
||||
'neighbours': neighbours_per_rule,
|
||||
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
|
||||
'prob': preds
|
||||
})
|
||||
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||
from utilities.ml import tanimoto_distance
|
||||
distances = [(i, tanimoto_distance(classify_instance, train)) for i, train in
|
||||
enumerate(train_instances)]
|
||||
return distances
|
||||
|
||||
@staticmethod
|
||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, 'Dataset']]):
|
||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
||||
accuracy = 0.0
|
||||
|
||||
for n in neighbours:
|
||||
obs = n[1].y()[0][rule_idx]
|
||||
pred = preds[n[0]][rule_idx]
|
||||
if obs and pred:
|
||||
tp += 1
|
||||
elif not obs and pred:
|
||||
fp += 1
|
||||
elif obs and not pred:
|
||||
fn += 1
|
||||
else:
|
||||
tn += 1
|
||||
# Jaccard Index
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn);
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
class RuleBaseRelativeReasoning(EPModel):
|
||||
@ -1574,10 +1638,6 @@ class EnviFormer(EPModel):
|
||||
logger.info(f"Submitting {kek} to {hash(self.model)}")
|
||||
products = self.model.predict(kek)
|
||||
logger.info(f"Got results {products}")
|
||||
# from pprint import pprint
|
||||
#
|
||||
# print(smiles)
|
||||
# pprint(products)
|
||||
|
||||
res = []
|
||||
for smi, prob in products.items():
|
||||
@ -1715,9 +1775,7 @@ class Setting(EnviPathModel):
|
||||
|
||||
transformations = []
|
||||
if self.model is not None:
|
||||
print(self.model)
|
||||
pred_results = self.model.predict(current_node.smiles)
|
||||
print(pred_results)
|
||||
for pred_result in pred_results:
|
||||
if pred_result.probability >= self.model_threshold:
|
||||
transformations.append(pred_result)
|
||||
|
||||
Reference in New Issue
Block a user