Experimental App Domain (#43)

Backend App Domain done, Frontend missing

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#43
This commit is contained in:
2025-08-08 20:52:21 +12:00
parent 280ddc7205
commit 579cd519d0
14 changed files with 1094 additions and 574 deletions

View File

@ -261,7 +261,7 @@ CELERY_ACCEPT_CONTENT = ['json']
CELERY_TASK_SERIALIZER = 'json' CELERY_TASK_SERIALIZER = 'json'
MODEL_BUILDING_ENABLED = os.environ.get('MODEL_BUILDING_ENABLED', 'False') == 'True' MODEL_BUILDING_ENABLED = os.environ.get('MODEL_BUILDING_ENABLED', 'False') == 'True'
APPLICABILITY_DOMAIN_ENABLED = os.environ.get('APPLICABILITY_DOMAIN_ENABLED', 'False') == 'True'
DEFAULT_RF_MODEL_PARAMS = { DEFAULT_RF_MODEL_PARAMS = {
'base_clf': RandomForestClassifier( 'base_clf': RandomForestClassifier(
n_estimators=100, n_estimators=100,
@ -275,14 +275,14 @@ DEFAULT_RF_MODEL_PARAMS = {
'num_chains': 10, 'num_chains': 10,
} }
DEFAULT_DT_MODEL_PARAMS = { DEFAULT_MODEL_PARAMS = {
'base_clf': DecisionTreeClassifier( 'base_clf': DecisionTreeClassifier(
criterion='entropy', criterion='entropy',
max_depth=3, max_depth=3,
min_samples_split=5, min_samples_split=5,
min_samples_leaf=5, # min_samples_leaf=5,
max_features='sqrt', max_features='sqrt',
class_weight='balanced', # class_weight='balanced',
random_state=42 random_state=42
), ),
'num_chains': 10, 'num_chains': 10,
@ -322,4 +322,5 @@ FLAGS = {
'PLUGINS': PLUGINS_ENABLED, 'PLUGINS': PLUGINS_ENABLED,
'SENTRY': SENTRY_ENABLED, 'SENTRY': SENTRY_ENABLED,
'ENVIFORMER': ENVIFORMER_PRESENT, 'ENVIFORMER': ENVIFORMER_PRESENT,
'APPLICABILITY_DOMAIN': APPLICABILITY_DOMAIN_ENABLED,
} }

View File

@ -1,40 +1,105 @@
from django.contrib import admin from django.contrib import admin
from .models import User, Group, UserPackagePermission, GroupPackagePermission, Setting, SimpleAmbitRule, Scenario from .models import (
User,
UserPackagePermission,
Group,
GroupPackagePermission,
Package,
MLRelativeReasoning,
Compound,
CompoundStructure,
SimpleAmbitRule,
ParallelRule,
Reaction,
Pathway,
Node,
Edge,
Scenario,
Setting
)
class UserAdmin(admin.ModelAdmin): class UserAdmin(admin.ModelAdmin):
pass pass
class GroupAdmin(admin.ModelAdmin):
pass
class UserPackagePermissionAdmin(admin.ModelAdmin): class UserPackagePermissionAdmin(admin.ModelAdmin):
pass pass
class GroupAdmin(admin.ModelAdmin):
pass
class GroupPackagePermissionAdmin(admin.ModelAdmin): class GroupPackagePermissionAdmin(admin.ModelAdmin):
pass pass
class SettingAdmin(admin.ModelAdmin): class EPAdmin(admin.ModelAdmin):
search_fields = ['name', 'description']
class PackageAdmin(EPAdmin):
pass
class MLRelativeReasoningAdmin(EPAdmin):
pass pass
class SimpleAmbitRuleAdmin(admin.ModelAdmin): class CompoundAdmin(EPAdmin):
pass pass
class ScenarioAdmin(admin.ModelAdmin): class CompoundStructureAdmin(EPAdmin):
pass
class SimpleAmbitRuleAdmin(EPAdmin):
pass
class ParallelRuleAdmin(EPAdmin):
pass
class ReactionAdmin(EPAdmin):
pass
class PathwayAdmin(EPAdmin):
pass
class NodeAdmin(EPAdmin):
pass
class EdgeAdmin(EPAdmin):
pass
class ScenarioAdmin(EPAdmin):
pass
class SettingAdmin(EPAdmin):
pass pass
admin.site.register(User, UserAdmin) admin.site.register(User, UserAdmin)
admin.site.register(Group, GroupAdmin)
admin.site.register(UserPackagePermission, UserPackagePermissionAdmin) admin.site.register(UserPackagePermission, UserPackagePermissionAdmin)
admin.site.register(Group, GroupAdmin)
admin.site.register(GroupPackagePermission, GroupPackagePermissionAdmin) admin.site.register(GroupPackagePermission, GroupPackagePermissionAdmin)
admin.site.register(Setting, SettingAdmin) admin.site.register(Package, PackageAdmin)
admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin)
admin.site.register(Compound, CompoundAdmin)
admin.site.register(CompoundStructure, CompoundStructureAdmin)
admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin) admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin)
admin.site.register(ParallelRule, ParallelRuleAdmin)
admin.site.register(Reaction, ReactionAdmin)
admin.site.register(Pathway, PathwayAdmin)
admin.site.register(Node, NodeAdmin)
admin.site.register(Edge, EdgeAdmin)
admin.site.register(Setting, SettingAdmin)
admin.site.register(Scenario, ScenarioAdmin) admin.site.register(Scenario, ScenarioAdmin)

View File

@ -339,7 +339,7 @@ class PackageManager(object):
@staticmethod @staticmethod
@transaction.atomic @transaction.atomic
def import_package(data: dict, owner: User, keep_ids=False): def import_package(data: dict, owner: User, keep_ids=False, add_import_timestamp=True):
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
@ -349,7 +349,12 @@ class PackageManager(object):
pack = Package() pack = Package()
pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4() pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4()
pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M'))
if add_import_timestamp:
pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M'))
else:
pack.name = data['name']
pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False
pack.description = data['description'] pack.description = data['description']
pack.save() pack.save()

View File

@ -58,7 +58,7 @@ class Command(BaseCommand):
return anon, admin, g, jebus return anon, admin, g, jebus
def import_package(self, data, owner): def import_package(self, data, owner):
return PackageManager.import_package(data, owner, keep_ids=True) return PackageManager.import_package(data, owner, keep_ids=True, add_import_timestamp=False)
def create_default_setting(self, owner, packages): def create_default_setting(self, owner, packages):
s = SettingManager.create_setting( s = SettingManager.create_setting(

View File

@ -3,8 +3,8 @@ import json
import logging import logging
import os import os
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta, date from datetime import datetime, timedelta
from typing import Union, List, Optional from typing import Union, List, Optional, Dict, Tuple
from uuid import uuid4 from uuid import uuid4
import joblib import joblib
@ -14,7 +14,7 @@ from django.contrib.auth.hashers import make_password, check_password
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.db import models, transaction from django.db import models, transaction
from django.db.models import JSONField, Count, Q from django.db.models import JSONField, Count, Q, QuerySet
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
from model_utils.models import TimeStampedModel from model_utils.models import TimeStampedModel
@ -23,7 +23,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import SparseLabelECC from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -172,6 +172,9 @@ class EnviPathModel(TimeStampedModel):
class Meta: class Meta:
abstract = True abstract = True
def __str__(self):
return f"{self.name} (pk={self.pk})"
class AliasMixin(models.Model): class AliasMixin(models.Model):
aliases = ArrayField( aliases = ArrayField(
@ -844,7 +847,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# We shouldn't lose or make up nodes... # We shouldn't lose or make up nodes...
assert len(nodes) == len(self.nodes) assert len(nodes) == len(self.nodes)
print(f"Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}") logger.debug(f"{self.name}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
links = [e.d3_json() for e in self.edges] links = [e.d3_json() for e in self.edges]
@ -1136,19 +1139,44 @@ class MLRelativeReasoning(EPModel):
eval_results = JSONField(null=True, blank=True, default=dict) eval_results = JSONField(null=True, blank=True, default=dict)
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
default=None)
def status(self): def status(self):
return self.PROGRESS_STATUS_CHOICES[self.model_status] return self.PROGRESS_STATUS_CHOICES[self.model_status]
def ready_for_prediction(self) -> bool:
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
@staticmethod @staticmethod
@transaction.atomic @transaction.atomic
def create(package, name, description, rule_packages, data_packages, eval_packages, threshold): def create(package: 'Package', rule_packages: List['Package'],
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
name: 'str' = None, description: str = None, build_app_domain: bool = False,
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
app_domain_local_compatibility_threshold: float = None):
mlrr = MLRelativeReasoning() mlrr = MLRelativeReasoning()
mlrr.package = package mlrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
mlrr.name = name mlrr.name = name
mlrr.description = description
if description is not None and description.strip() != '':
mlrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
mlrr.threshold = threshold mlrr.threshold = threshold
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
mlrr.save() mlrr.save()
for p in rule_packages: for p in rule_packages:
mlrr.rule_packages.add(p) mlrr.rule_packages.add(p)
@ -1163,11 +1191,17 @@ class MLRelativeReasoning(EPModel):
for p in eval_packages: for p in eval_packages:
mlrr.eval_packages.add(p) mlrr.eval_packages.add(p)
if build_app_domain:
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
app_domain_local_compatibility_threshold)
mlrr.app_domain = ad
mlrr.save() mlrr.save()
return mlrr return mlrr
@cached_property @cached_property
def applicable_rules(self): def applicable_rules(self) -> List['Rule']:
""" """
Returns a ordered set of rules where the following applies: Returns a ordered set of rules where the following applies:
1. All Composite will be added to result 1. All Composite will be added to result
@ -1195,6 +1229,7 @@ class MLRelativeReasoning(EPModel):
rules.append(r) rules.append(r)
rules = sorted(rules, key=lambda x: x.url) rules = sorted(rules, key=lambda x: x.url)
return rules return rules
def _get_excludes(self): def _get_excludes(self):
@ -1209,197 +1244,79 @@ class MLRelativeReasoning(EPModel):
pathway_qs = pathway_qs.distinct() pathway_qs = pathway_qs.distinct()
return pathway_qs return pathway_qs
def _get_reactions(self) -> QuerySet:
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
def build_dataset(self): def build_dataset(self):
self.model_status = self.INITIALIZING self.model_status = self.INITIALIZING
self.save() self.save()
from datetime import datetime
start = datetime.now() start = datetime.now()
applicable_rules = self.applicable_rules applicable_rules = self.applicable_rules
print("got rules") reactions = list(self._get_reactions())
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
# if s.DEBUG:
# pathways = self._get_pathways().order_by('-name')[:20]
# else:
pathways = self._get_pathways()
print("got pathways")
excludes = self._get_excludes()
# Collect all compounds
compounds = set()
reactions = set()
for i, p in enumerate(pathways):
print(f"{i + 1}/{len(pathways)}...")
for n in p.nodes:
cs = n.default_node_label.compound.default_structure
# TODO too many lookups
if cs.smiles in excludes:
continue
compounds.add(cs)
for e in p.edges:
reactions.add(e.edge_label)
print(len(compounds))
print(len(reactions))
triggered = set()
observed = set()
# TODO naming
pw = defaultdict(lambda: defaultdict(set))
for i, c in enumerate(compounds):
print(f"{i + 1}/{len(compounds)}...")
for r in applicable_rules:
# TODO check normalization
product_sets = r.apply(c.smiles)
if len(product_sets) == 0:
continue
triggered.add(f"{r.uuid} + {c.uuid}")
for ps in product_sets:
for p in ps:
pw[c][r].add(p)
for r in reactions:
if r is None:
print(r)
continue
if len(r.educts.all()) != 1:
print(f"Skipping {r.url}")
continue
# Loop will run only once
for c in r.educts.all():
if c not in pw:
continue
for rule in pw[c].keys():
# standardize...
if 0 != len(pw[c][rule]) and len(pw[c][rule]) == len(r.products.all()):
print(f"potential match for {c.smiles} and {r.uuid} ({r.name})")
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi)
except Exception as e:
# :shrug:
pass
standardized_products.append(smi)
standardized_pred_products = []
for smi in pw[c][rule]:
try:
smi = FormatConverter.standardize(smi)
except Exception as e:
# :shrug:
pass
standardized_pred_products.append(smi)
if sorted(list(set(standardized_products))) == sorted(list(set(standardized_pred_products))):
observed.add(f"{rule.uuid} + {c.uuid}")
print(f"Adding observed, current count {len(observed)}")
header = None
X = []
y = []
for i, c in enumerate(compounds):
print(f'{i + 1}/{len(compounds)}...')
# Features
feat = FormatConverter.maccs(c.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {c.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
else:
obs.append(0)
if header is None:
header = [f'feature_{i}' for i, _ in enumerate(feat)] \
+ [f'trig_{r.uuid}' for r in applicable_rules] \
+ [f'corr_{r.uuid}' for r in applicable_rules]
X.append(feat + trig)
y.append(obs)
end = datetime.now() end = datetime.now()
print(f"Duration {(end - start).total_seconds()}s") logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
data = { f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
'X': X, ds.save(f)
'y': y, return ds
'header': header
}
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
json.dump(data, open(f, 'w'))
return X, y
def load_dataset(self): def load_dataset(self) -> 'Dataset':
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}.json") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return json.load(open(ds_path, 'r')) return Dataset.load(ds_path)
def build_model(self, X, y): def build_model(self):
self.model_status = self.BUILDING self.model_status = self.BUILDING
self.save() self.save()
mod = SparseLabelECC( start = datetime.now()
**s.DEFAULT_DT_MODEL_PARAMS
)
ds = self.load_dataset()
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
mod = EnsembleClassifierChain(
**s.DEFAULT_MODEL_PARAMS
)
mod.fit(X, y) mod.fit(X, y)
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.pkl")
end = datetime.now()
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f) joblib.dump(mod, f)
if self.app_domain is not None:
logger.debug("Building applicability domain...")
self.app_domain.build()
logger.debug("Done building applicability domain.")
self.model_status = self.BUILT_NOT_EVALUATED self.model_status = self.BUILT_NOT_EVALUATED
self.save() self.save()
def retrain(self):
self.build_dataset()
self.build_model()
def rebuild(self): def rebuild(self):
data = self.load_dataset() self.build_model()
self.build_model(data['X'], data['y'])
def evaluate_model(self): def evaluate_model(self):
"""
Performs Leave-One-Out cross-validation on a multi-label dataset.
Parameters:
X (list of lists): Feature matrix.
y (list of lists): Multi-label targets.
classifier (sklearn estimator, optional): Base classifier. Defaults to RandomForest.
Returns:
float: Average accuracy across all LOO splits.
"""
if self.model_status != self.BUILT_NOT_EVALUATED: if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!") raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING self.model_status = self.EVALUATING
self.save() self.save()
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json") ds = self.load_dataset()
data = json.load(open(f))
X = np.array(data['X']) X = np.array(ds.X(na_replacement=np.nan))
y = np.array(data['y']) y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20 n_splits = 20
@ -1409,22 +1326,32 @@ class MLRelativeReasoning(EPModel):
X_train, X_test = X[train_index], X[test_index] X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index] y_train, y_test = y[train_index], y[test_index]
model = SparseLabelECC( model = EnsembleClassifierChain(
**s.DEFAULT_DT_MODEL_PARAMS **s.DEFAULT_MODEL_PARAMS
) )
model.fit(X_train, y_train) model.fit(X_train, y_train)
y_pred = model.predict_proba(X_test) y_pred = model.predict_proba(X_test)
y_thresholded = (y_pred >= threshold).astype(int) y_thresholded = (y_pred >= threshold).astype(int)
acc = jaccard_score(y_test, y_thresholded, average='samples', zero_division=0) # Flatten them to get rid of np.nan
y_test = np.asarray(y_test).flatten()
y_pred = np.asarray(y_pred).flatten()
y_thresholded = np.asarray(y_thresholded).flatten()
mask = ~np.isnan(y_test)
y_test_filtered = y_test[mask]
y_pred_filtered = y_pred[mask]
y_thresholded_filtered = y_thresholded[mask]
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
prec, rec = dict(), dict() prec, rec = dict(), dict()
for t in np.arange(0, 1.05, 0.05): for t in np.arange(0, 1.05, 0.05):
temp_thresholded = (y_pred >= t).astype(int) temp_thresholded = (y_pred_filtered >= t).astype(int)
prec[f"{t:.2f}"] = precision_score(y_test, temp_thresholded, average='samples', zero_division=0) prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
rec[f"{t:.2f}"] = recall_score(y_test, temp_thresholded, average='samples', zero_division=0) rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
return acc, prec, rec return acc, prec, rec
@ -1462,38 +1389,30 @@ class MLRelativeReasoning(EPModel):
@cached_property @cached_property
def model(self): def model(self):
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}.pkl')) mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
mod.base_clf.n_jobs = -1 mod.base_clf.n_jobs = -1
return mod return mod
def predict(self, smiles) -> List['PredictionResult']: def predict(self, smiles) -> List['PredictionResult']:
start = datetime.now() start = datetime.now()
features = FormatConverter.maccs(smiles) ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X())
trig = [] res = MLRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
prods = []
for rule in self.applicable_rules:
products = rule.apply(smiles)
if len(products):
trig.append(1)
prods.append(products)
else:
trig.append(0)
prods.append([])
end_ds_gen = datetime.now()
logger.info(f"Gen predict dataset took {(end_ds_gen - start).total_seconds()}s")
pred = self.model.predict_proba([features + trig])
res = []
for rule, p, smis in zip(self.applicable_rules, pred[0], prods):
res.append(PredictionResult(smis, p, rule))
end = datetime.now() end = datetime.now()
logger.info(f"Full predict took {(end - start).total_seconds()}s") logger.info(f"Full predict took {(end - start).total_seconds()}s")
return res return res
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
for rule, p, smis in zip(rules, probabilities, products):
res.append(PredictionResult(smis, p, rule))
return res
@property @property
def pr_curve(self): def pr_curve(self):
if self.model_status != self.FINISHED: if self.model_status != self.FINISHED:
@ -1515,26 +1434,171 @@ class MLRelativeReasoning(EPModel):
class ApplicabilityDomain(EnviPathModel): class ApplicabilityDomain(EnviPathModel):
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE) model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
num_neighbours = models.FloatField(blank=False, null=False, default=5) num_neighbours = models.IntegerField(blank=False, null=False, default=5)
reliability_threshold = models.FloatField(blank=False, null=False, default=0.5) reliability_threshold = models.FloatField(blank=False, null=False, default=0.5)
local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5) local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5)
def build_applicability_domain(self): @staticmethod
@transaction.atomic
def create(mlrr: MLRelativeReasoning, num_neighbours: int = 5, reliability_threshold: float = 0.5,
local_compatibility_threshold: float = 0.5):
ad = ApplicabilityDomain()
ad.model = mlrr
# ad.uuid = mlrr.uuid
ad.name = f"AD for {mlrr.name}"
ad.num_neighbours = num_neighbours
ad.reliability_threshold = reliability_threshold
ad.local_compatibilty_threshold = local_compatibility_threshold
ad.save()
return ad
@cached_property
def pca(self) -> ApplicabilityDomainPCA:
pca = joblib.load(os.path.join(s.MODEL_DIR, f'{self.model.uuid}_pca.pkl'))
return pca
@cached_property
def training_set_probs(self):
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
def build(self):
ds = self.model.load_dataset() ds = self.model.load_dataset()
X = ds['X']
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler() start = datetime.now()
X_scaled = scaler.fit_transform(X)
pca = PCA(n_components=5) # choose number of components
X_pca = pca.fit_transform(X_scaled)
max_vals = np.max(X_pca, axis=0) # Get Trainingset probs and dump them as they're required when using the app domain
min_vals = np.min(X_pca, axis=0) probs = self.model.model.predict_proba(ds.X())
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
joblib.dump(probs, f)
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
ad.build(ds)
end = datetime.now()
logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")
joblib.dump(ad, f)
def assess(self, structure: Union[str, 'CompoundStructure']):
ds = self.model.load_dataset()
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
# qualified_neighbours_per_rule is a nested dictionary structured as:
# {
# assessment_structure_index: {
# rule_index: [training_structure_indices_with_same_triggered_reaction]
# }
# }
#
# For each structure in the assessment dataset and each rule (represented by a trigger feature),
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
# with a given assessment structure under a particular rule.
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(lambda: defaultdict(list))
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
feature = ds.columns[feature_index]
if feature.startswith('trig_'):
# TODO unroll loop
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
if int(cx[feature_index]) == 1:
for j, tx in enumerate(ds.X(exclude_id_col=False)):
if int(tx[feature_index]) == 1:
qualified_neighbours_per_rule[i][rule_idx].append(j)
probs = self.training_set_probs
# preds = self.model.model.predict_proba(assessment_ds.X())
preds = self.model.combine_products_and_probs(self.model.applicable_rules,
self.model.model.predict_proba(assessment_ds.X())[0],
assessment_prods[0])
res = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds):
rule_reliabilities = dict()
local_compatibilities = dict()
neighbours_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
# train dataset
train_instances = []
for v in vals:
train_instances.append((v, ds.at(v)))
# sf is a tuple with start/end index of the features
sf = ds.struct_features()
# compute tanimoto distance for all neighbours
# result ist a list of tuples with train index and computed distance
dists = self._compute_distances(
instance.X()[0][sf[0]:sf[1]],
[ti[1].X()[0][sf[0]:sf[1]] for ti in train_instances]
)
dists_with_index = list()
for ti, dist in zip(train_instances, dists):
dists_with_index.append((ti[0], dist[1]))
# sort them in a descending way and take at most `self.num_neighbours`
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
dists_with_index = dists_with_index[:self.num_neighbours]
# compute average distance
rule_reliabilities[rule_idx] = sum([d[1] for d in dists_with_index]) / len(dists_with_index) if len(dists_with_index) > 0 else 0.0
# for local_compatibility we'll need the datasets for the indices having the highest similarity
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
# Assemble result for instance
res.append({
'in_ad': self.pca.is_applicable(instance)[0],
'rule_reliabilities': rule_reliabilities,
'local_compatibilities': local_compatibilities,
'neighbours': neighbours_per_rule,
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
'prob': preds
})
return res
@staticmethod
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
from utilities.ml import tanimoto_distance
distances = [(i, tanimoto_distance(classify_instance, train)) for i, train in
enumerate(train_instances)]
return distances
@staticmethod
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, 'Dataset']]):
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
accuracy = 0.0
for n in neighbours:
obs = n[1].y()[0][rule_idx]
pred = preds[n[0]][rule_idx]
if obs and pred:
tp += 1
elif not obs and pred:
fp += 1
elif obs and not pred:
fn += 1
else:
tn += 1
# Jaccard Index
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn);
return accuracy
class RuleBaseRelativeReasoning(EPModel): class RuleBaseRelativeReasoning(EPModel):
@ -1574,10 +1638,6 @@ class EnviFormer(EPModel):
logger.info(f"Submitting {kek} to {hash(self.model)}") logger.info(f"Submitting {kek} to {hash(self.model)}")
products = self.model.predict(kek) products = self.model.predict(kek)
logger.info(f"Got results {products}") logger.info(f"Got results {products}")
# from pprint import pprint
#
# print(smiles)
# pprint(products)
res = [] res = []
for smi, prob in products.items(): for smi, prob in products.items():
@ -1715,9 +1775,7 @@ class Setting(EnviPathModel):
transformations = [] transformations = []
if self.model is not None: if self.model is not None:
print(self.model)
pred_results = self.model.predict(current_node.smiles) pred_results = self.model.predict(current_node.smiles)
print(pred_results)
for pred_result in pred_results: for pred_result in pred_results:
if pred_result.probability >= self.model_threshold: if pred_result.probability >= self.model_threshold:
transformations.append(pred_result) transformations.append(pred_result)

View File

@ -31,8 +31,8 @@ def send_registration_mail(user_pk: int):
@shared_task(queue='model') @shared_task(queue='model')
def build_model(model_pk: int): def build_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk) mod = EPModel.objects.get(id=model_pk)
X, y = mod.build_dataset() mod.build_dataset()
mod.build_model(X, y) mod.build_model()
@shared_task(queue='model') @shared_task(queue='model')

View File

@ -103,7 +103,10 @@ def login(request):
else: else:
context['message'] = "Account has been created! You'll receive a mail to activate your account shortly." context['message'] = "Account has been created! You'll receive a mail to activate your account shortly."
return render(request, 'login.html', context) return render(request, 'login.html', context)
else:
return HttpResponseBadRequest()
else:
return HttpResponseNotAllowed(['GET', 'POST'])
def logout(request): def logout(request):
if request.method == 'POST': if request.method == 'POST':
@ -136,7 +139,7 @@ def editable(request, user):
f"{s.SERVER_URL}/group", f"{s.SERVER_URL}/search"]: f"{s.SERVER_URL}/group", f"{s.SERVER_URL}/search"]:
return True return True
else: else:
print(f"Unknown url: {url}") logger.debug(f"Unknown url: {url}")
return False return False
@ -584,6 +587,9 @@ def package_models(request, package_uuid):
return render(request, 'collections/objects_list.html', context) return render(request, 'collections/objects_list.html', context)
elif request.method == 'POST': elif request.method == 'POST':
log_post_params(request)
name = request.POST.get('model-name') name = request.POST.get('model-name')
description = request.POST.get('model-description') description = request.POST.get('model-description')
@ -606,14 +612,25 @@ def package_models(request, package_uuid):
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages] data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages] eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
# App Domain related parameters
build_ad = request.POST.get('build-app-domain', False) == 'on'
num_neighbors = request.POST.get('num-neighbors', 5)
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
mod = MLRelativeReasoning.create( mod = MLRelativeReasoning.create(
current_package, package=current_package,
name, name=name,
description, description=description,
rule_package_objs, rule_packages=rule_package_objs,
data_package_objs, data_packages=data_package_objs,
eval_packages_objs, eval_packages=eval_packages_objs,
threshold threshold=threshold,
# fingerprinter=fingerprinter,
build_app_domain=build_ad,
app_domain_num_neighbours=num_neighbors,
app_domain_reliability_threshold=reliability_threshold,
app_domain_local_compatibility_threshold=local_compatibility_threshold,
) )
from .tasks import build_model from .tasks import build_model
@ -649,7 +666,7 @@ def package_model(request, package_uuid, model_uuid):
if len(pr) > 0: if len(pr) > 0:
products = [] products = []
for prod_set in pr.product_sets: for prod_set in pr.product_sets:
print(f"Checking {prod_set}") logger.debug(f"Checking {prod_set}")
products.append(tuple([x for x in prod_set])) products.append(tuple([x for x in prod_set]))
res.append({ res.append({
@ -660,6 +677,12 @@ def package_model(request, package_uuid, model_uuid):
return JsonResponse(res, safe=False) return JsonResponse(res, safe=False)
elif request.GET.get('app-domain-assessment', False):
smiles = request.GET['smiles']
stand_smiles = FormatConverter.standardize(smiles)
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request) context = get_base_context(request)
context['title'] = f'enviPath - {current_package.name} - {current_model.name}' context['title'] = f'enviPath - {current_package.name} - {current_model.name}'
@ -1717,8 +1740,6 @@ def user(request, user_uuid):
} }
} }
print(setting)
return HttpResponseBadRequest() return HttpResponseBadRequest()
else: else:
@ -1781,9 +1802,7 @@ def group(request, group_uuid):
elif request.method == 'POST': elif request.method == 'POST':
if s.DEBUG: log_post_params(request)
for k, v in request.POST.items():
print(k, v)
if hidden := request.POST.get('hidden', None): if hidden := request.POST.get('hidden', None):
if hidden == 'delete-group': if hidden == 'delete-group':

View File

@ -16,14 +16,14 @@
<div class="jumbotron">Create a new Model to <div class="jumbotron">Create a new Model to
limit the number of degradation products in the limit the number of degradation products in the
prediction. You just need to set a name and the packages prediction. You just need to set a name and the packages
you want the object to be based on. If you want to use the you want the object to be based on. There are multiple types of models available.
default options suggested by us, simply click Submit, For additional information have a look at our
otherwise click Advanced Options. <a target="_blank" href="https://wiki.envipath.org/index.php/relative-reasoning" role="button">wiki &gt;&gt;</a>
</div> </div>
<label for="name">Name</label> <label for="model-name">Name</label>
<input id="name" name="model-name" class="form-control" placeholder="Name"/> <input id="model-name" name="model-name" class="form-control" placeholder="Name"/>
<label for="description">Description</label> <label for="model-description">Description</label>
<input id="description" name="model-description" class="form-control" <input id="model-description" name="model-description" class="form-control"
placeholder="Description"/> placeholder="Description"/>
<label for="model-type">Model Type</label> <label for="model-type">Model Type</label>
<select id="model-type" name="model-type" class="form-control" data-width='100%'> <select id="model-type" name="model-type" class="form-control" data-width='100%'>
@ -35,7 +35,7 @@
<!-- ML Based Form--> <!-- ML Based Form-->
<div id="ml-relative-reasoning-specific-form"> <div id="ml-relative-reasoning-specific-form">
<!-- Rule Packages --> <!-- Rule Packages -->
<label>Rule Packages</label><br> <label for="ml-relative-reasoning-rule-packages">Rule Packages</label>
<select id="ml-relative-reasoning-rule-packages" name="ml-relative-reasoning-rule-packages" <select id="ml-relative-reasoning-rule-packages" name="ml-relative-reasoning-rule-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'> data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option> <option disabled>Reviewed Packages</option>
@ -53,7 +53,7 @@
{% endfor %} {% endfor %}
</select> </select>
<!-- Data Packages --> <!-- Data Packages -->
<label>Data Packages</label><br> <label for="ml-relative-reasoning-data-packages" >Data Packages</label>
<select id="ml-relative-reasoning-data-packages" name="ml-relative-reasoning-data-packages" <select id="ml-relative-reasoning-data-packages" name="ml-relative-reasoning-data-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'> data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option> <option disabled>Reviewed Packages</option>
@ -77,22 +77,24 @@
class="form-control"> class="form-control">
<option value="MACCS" selected>MACCS Fingerprinter</option> <option value="MACCS" selected>MACCS Fingerprinter</option>
</select> </select>
{% if meta.enabled_features.PLUGINS %} {% if meta.enabled_features.PLUGINS and additional_descriptors %}
<!-- Property Plugins go here --> <!-- Property Plugins go here -->
<label for="ml-relative-reasoning-additional-fingerprinter">Fingerprinter</label> <label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter / Descriptors</label>
<select id="ml-relative-reasoning-additional-fingerprinter" <select id="ml-relative-reasoning-additional-fingerprinter" name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
name="ml-relative-reasoning-additional-fingerprinter" <option disabled selected>Select Additional Fingerprinter / Descriptor</option>
class="form-control"> {% for k, v in additional_descriptors.items %}
<option value="{{ v }}">{{ k }}</option>
{% endfor %}
</select> </select>
{% endif %} {% endif %}
<label for="ml-relative-reasoning-threshold">Threshold</label> <label for="ml-relative-reasoning-threshold">Threshold</label>
<input type="number" min="0" max="1" step="0.05" value="0.5" <input type="number" min="0" max="1" step="0.05" value="0.5"
id="ml-relative-reasoning-threshold" id="ml-relative-reasoning-threshold"
name="ml-relative-reasoning-threshold" class="form-control"> name="ml-relative-reasoning-threshold" class="form-control">
<!-- Evaluation --> <!-- Evaluation -->
<label for="ml-relative-reasoning-evaluation-packages">Evaluation Packages</label>
<label>Evaluation Packages</label><br>
<select id="ml-relative-reasoning-evaluation-packages" name="ml-relative-reasoning-evaluation-packages" <select id="ml-relative-reasoning-evaluation-packages" name="ml-relative-reasoning-evaluation-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'> data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option> <option disabled>Reviewed Packages</option>
@ -110,6 +112,26 @@
{% endfor %} {% endfor %}
</select> </select>
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
<!-- Build AD? -->
<div class="checkbox">
<label>
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an Applicability Domain?
</label>
</div>
<!-- Num Neighbors -->
<label for="num-neighbors">Number of Neighbors</label>
<input id="num-neighbors" name="num-neighbors" type="number" class="form-control" value="5"
step="1" min="0" max="10">
<!-- Local Compatibility -->
<label for="local-compatibility-threshold">Local Compatibility Threshold</label>
<input id="local-compatibility-threshold" name="local-compatibility-threshold" type="number"
class="form-control" value="0.5" step="0.01" min="0" max="1">
<!-- Reliability -->
<label for="reliability-threshold">Reliability Threshold</label>
<input id="reliability-threshold" name="reliability-threshold" type="number"
class="form-control" value="0.5" step="0.01" min="0" max="1">
{% endif %}
</div> </div>
<!-- Rule Based Based Form--> <!-- Rule Based Based Form-->
<div id="rule-based-relative-reasoning-specific-form"> <div id="rule-based-relative-reasoning-specific-form">
@ -118,47 +140,9 @@
<!-- EnviFormer--> <!-- EnviFormer-->
<div id="enviformer-specific-form"> <div id="enviformer-specific-form">
<label for="enviformer-threshold">Threshold</label> <label for="enviformer-threshold">Threshold</label>
<input type="number" min="0" , max="1" step="0.05" value="0.5" id="enviformer-threshold" <input type="number" min="0" max="1" step="0.05" value="0.5" id="enviformer-threshold"
name="enviformer-threshold" class="form-control"> name="enviformer-threshold" class="form-control">
</div> </div>
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
<div class="modal-body hide" data-step="3" data-title="Advanced Options II">
<div class="jumbotron">Selection of parameter values for the Applicability Domain process.
Number of Neighbours refers to a requirement on the minimum number of compounds from the
training
dataset that has at least one triggered transformation rule that is common with the compound
being
analyzed.
Reliability Threshold is a requirement on the average tanimoto distance to the set number of
"nearest neighbours" (Number of neighbours with the smallest tanimoto distances).
Local Compatibility Threshold is a requirement on the average F1 score determined from the
number of
nearest neighbours, using their respective precision and recall values computed from the
agreement
between their observed and triggered rules.
You can learn more about it in our wiki!
</div>
<!-- Use AD? -->
<div class="checkbox">
<label>
<input type="checkbox" id="buildAD" name="buildAD">Also build an Applicability Domain?
</label>
</div>
<!-- Num Neighbours -->
<label for="adK">Number of Neighbours</label>
<input id="adK" name="adK" type="number" class="form-control" value="5" step="1" min="0"
max="10">
<!-- F1 Threshold -->
<label for="localCompatibilityThreshold">Local Compatibility Threshold</label>
<input id="localCompatibilityThreshold" name="localCompatibilityThreshold" type="number"
class="form-control" value="0.5" step="0.01" min="0" max="1">
<!-- Percentile Threshold -->
<label for="reliabilityThreshold">Reliability Threshold</label>
<input id="reliabilityThreshold" name="reliabilityThreshold" type="number" class="form-control"
value="0.5" step="0.01" min="0" max="1">
</div>
{% endif %}
</form> </form>
</div> </div>
<div class="modal-footer"> <div class="modal-footer">
@ -179,6 +163,9 @@ $(function() {
$("#ml-relative-reasoning-rule-packages").selectpicker(); $("#ml-relative-reasoning-rule-packages").selectpicker();
$("#ml-relative-reasoning-data-packages").selectpicker(); $("#ml-relative-reasoning-data-packages").selectpicker();
$("#ml-relative-reasoning-evaluation-packages").selectpicker(); $("#ml-relative-reasoning-evaluation-packages").selectpicker();
if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) {
$("#ml-relative-reasoning-additional-fingerprinter").selectpicker();
}
// On change hide all and show only selected // On change hide all and show only selected
$("#model-type").change(function() { $("#model-type").change(function() {

View File

@ -90,27 +90,53 @@
</div> </div>
</div> </div>
{% endif %} {% endif %}
<!-- Predict Panel --> {% if model.ready_for_prediction %}
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver"> <!-- Predict Panel -->
<h4 class="panel-title"> <div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<a id="predict-smiles-link" data-toggle="collapse" data-parent="#model-detail" <h4 class="panel-title">
href="#predict-smiles">Predict</a> <a id="predict-smiles-link" data-toggle="collapse" data-parent="#model-detail"
</h4> href="#predict-smiles">Predict</a>
</div> </h4>
<div id="predict-smiles" class="panel-collapse collapse in"> </div>
<div class="panel-body list-group-item"> <div id="predict-smiles" class="panel-collapse collapse in">
<div class="input-group"> <div class="panel-body list-group-item">
<input id="smiles-to-predict" type="text" class="form-control" <div class="input-group">
placeholder="CCN(CC)C(=O)C1=CC(=CC=C1)C"> <input id="smiles-to-predict" type="text" class="form-control"
<span class="input-group-btn"> placeholder="CCN(CC)C(=O)C1=CC(=CC=C1)C">
<span class="input-group-btn">
<button class="btn btn-default" type="submit" id="predict-button">Predict!</button> <button class="btn btn-default" type="submit" id="predict-button">Predict!</button>
</span> </span>
</div>
<div id="predictLoading"></div>
<div id="predictResultTable"></div>
</div> </div>
<div id="loading"></div>
<div id="predictResultTable"></div>
</div> </div>
</div> <!-- End Predict Panel -->
<!-- End Predict Panel --> {% endif %}
{% if model.app_domain %}
<!-- App Domain -->
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">
<a id="app-domain-assessment-link" data-toggle="collapse" data-parent="#model-detail"
href="#app-domain-assessment">Applicability Domain Assessment</a>
</h4>
</div>
<div id="app-domain-assessment" class="panel-collapse collapse in">
<div class="panel-body list-group-item">
<div class="input-group">
<input id="smiles-to-assess" type="text" class="form-control" placeholder="CCN(CC)C(=O)C1=CC(=CC=C1)C">
<span class="input-group-btn">
<button class="btn btn-default" type="submit" id="assess-button">Assess!</button>
</span>
</div>
<div id="appDomainLoading"></div>
<div id="appDomainAssessmentResultTable"></div>
</div>
</div>
<!-- End App Domain -->
{% endif %}
{% if model.model_status == 'FINISHED' %} {% if model.model_status == 'FINISHED' %}
<!-- Single Gen Curve Panel --> <!-- Single Gen Curve Panel -->
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver"> <div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
@ -277,9 +303,9 @@
$("#predictResultTable").append(res); $("#predictResultTable").append(res);
} }
function clear() { function clear(divid) {
$("#predictResultTable").removeClass("alert alert-danger"); $("#" + divid).removeClass("alert alert-danger");
$("#predictResultTable").empty(); $("#" + divid).empty();
} }
if ($('#predict-button').length > 0) { if ($('#predict-button').length > 0) {
@ -291,32 +317,69 @@
"classify": "ILikeCats!" "classify": "ILikeCats!"
} }
clear(); clear("predictResultTable");
makeLoadingGif("#loading", "{% static '/images/wait.gif' %}"); makeLoadingGif("#predictLoading", "{% static '/images/wait.gif' %}");
$.ajax({ $.ajax({
type: 'get', type: 'get',
data: data, data: data,
url: '', url: '',
success: function (data, textStatus) { success: function (data, textStatus) {
try { try {
$("#loading").empty(); $("#predictLoading").empty();
handleResponse(data); handleResponse(data);
} catch (error) { } catch (error) {
console.log("Error"); console.log("Error");
$("#loading").empty(); $("#predictLoading").empty();
$("#predictResultTable").addClass("alert alert-danger"); $("#predictResultTable").addClass("alert alert-danger");
$("#predictResultTable").append("Error while processing request :/"); $("#predictResultTable").append("Error while processing request :/");
} }
}, },
error: function (jqXHR, textStatus, errorThrown) { error: function (jqXHR, textStatus, errorThrown) {
$("#loading").empty(); $("#predictLoading").empty();
$("#predictResultTable").addClass("alert alert-danger"); $("#predictResultTable").addClass("alert alert-danger");
$("#predictResultTable").append("Error while processing request :/"); $("#predictResultTable").append("Error while processing request :/");
} }
}); });
}); });
} }
if ($('#assess-button').length > 0) {
$("#assess-button").on("click", function (e) {
e.preventDefault();
data = {
"smiles": $("#smiles-to-assess").val(),
"app-domain-assessment": "ILikeCats!"
}
clear("appDomainAssessmentResultTable");
makeLoadingGif("#appDomainLoading", "{% static '/images/wait.gif' %}");
$.ajax({
type: 'get',
data: data,
url: '',
success: function (data, textStatus) {
try {
$("#appDomainLoading").empty();
console.log(data);
} catch (error) {
console.log("Error");
$("#appDomainLoading").empty();
$("#appDomainAssessmentResultTable").addClass("alert alert-danger");
$("#appDomainAssessmentResultTable").append("Error while processing request :/");
}
},
error: function (jqXHR, textStatus, errorThrown) {
$("#appDomainLoading").empty();
$("#appDomainAssessmentResultTable").addClass("alert alert-danger");
$("#appDomainAssessmentResultTable").append("Error while processing request :/");
}
});
});
}
</script> </script>
{% endblock content %} {% endblock content %}

52
tests/test_dataset.py Normal file
View File

@ -0,0 +1,52 @@
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule
from utilities.ml import Dataset
class DatasetTest(TestCase):
fixtures = ["test_fixture.cleaned.json"]
def setUp(self):
self.cs1 = Compound.create(
self.package,
name='2,6-Dibromohydroquinone',
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b',
smiles='C1=C(C(=C(C=C1O)Br)O)Br',
).default_structure
self.cs2 = Compound.create(
self.package,
smiles='O=C(O)CC(=O)/C=C(/Br)C(=O)O',
).default_structure
self.rule1 = Rule.create(
rule_type='SimpleAmbitRule',
package=self.package,
smirks='[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\[#6:3]=[#6:2](\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]',
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6'
)
self.reaction1 = Reaction.create(
package=self.package,
educts=[self.cs1],
products=[self.cs2],
rules=[self.rule1],
multi_step=False
)
@classmethod
def setUpClass(cls):
super(DatasetGeneratorTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
def test_smoke(self):
reactions = [r for r in Reaction.objects.filter(package=self.package)]
applicable_rules = [self.rule1]
ds = Dataset.generate_dataset(reactions, applicable_rules)
self.assertEqual(len(ds.y()), 1)
self.assertEqual(sum(ds.y()[0]), 1)

View File

@ -1,111 +0,0 @@
from django.test import TestCase
from epdb.models import ParallelRule
from utilities.ml import Compound, Reaction, DatasetGenerator
class CompoundTest(TestCase):
def setUp(self):
self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C", uuid='c1')
self.c2 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C", uuid='c2')
def test_compound_eq_ignores_uuid(self):
self.assertEqual(self.c1, self.c2)
class ReactionTest(TestCase):
def setUp(self):
self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C")
self.c2 = Compound(smiles="CCN(CCO)C(=O)C1=CC(C)=CC=C1")
# self.r1 = Rule(uuid="bt0334")
# c1 --r1--> c2
self.c3_1 = Compound(smiles="CCNC(=O)C1=CC(C)=CC=C1")
self.c3_2 = Compound(smiles="CC=O")
# self.r2 = Rule(uuid="bt0243")
# c1 --r2--> c3_1, c3_2
def test_reaction_equality_ignores_uuid(self):
r1 = Reaction([self.c1], [self.c2], self.r1, uuid="abc")
r2 = Reaction([self.c1], [self.c2], self.r1, uuid="xyz")
self.assertEqual(r1, r2)
def test_reaction_inequality_on_data_change(self):
r1 = Reaction([self.c1], [self.c2], self.r1)
r2 = Reaction([self.c1], [self.c3_1], self.r1)
self.assertNotEqual(r1, r2)
def test_reaction_is_hashable(self):
r = Reaction([self.c1], [self.c2], self.r1)
reactions = {r}
self.assertIn(Reaction([self.c1], [self.c2], self.r1), reactions)
def test_rule_is_optional(self):
r = Reaction([self.c1], [self.c2])
self.assertIsNone(r.rule)
def test_uuid_is_optional(self):
r = Reaction([self.c1], [self.c2], self.r1)
self.assertIsNone(r.uuid)
def test_repr_includes_uuid(self):
r = Reaction([self.c1], [self.c2], self.r1, uuid="abc")
self.assertIn("abc", repr(r))
def test_reaction_equality_with_multiple_compounds_different_ordering(self):
r1 = Reaction([self.c1], [self.c3_1, self.c3_2], self.r2)
r2 = Reaction([self.c1], [self.c3_2, self.c3_1], self.r2)
self.assertEqual(r1, r2, "Reaction equality should not rely on list order")
class RuleTest(TestCase):
def setUp(self):
pass
# self.r1 = Rule(uuid="bt0334")
# self.r2 = Rule(uuid="bt0243")
class DatasetGeneratorTest(TestCase):
fixtures = ['bootstrap.json']
def setUp(self):
self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C")
self.c2 = Compound(smiles="CCN(CCO)C(=O)C1=CC(C)=CC=C1")
self.c3_1 = Compound(smiles="CCNC(=O)C1=CC(C)=CC=C1")
self.c3_2 = Compound(smiles="CC=O")
# self.r1 = Rule(uuid="bt0334") # trig
# self.r2 = Rule(uuid="bt0243") # trig
# self.r3 = Rule(uuid="bt0003") # no trig
self.reaction1 = Reaction([self.c1], [self.c2], self.r3)
self.reaction2 = Reaction([self.c1], [self.c3_1, self.c3_2], self.r2)
def test_test(self):
compounds = [
self.c1,
self.c2,
self.c3_1,
self.c3_2,
]
reactions = [
self.reaction1,
self.reaction2,
]
applicable_rules = [
# Rule('bt0334', ParallelRule.objects.get(name='bt0334')),
# Rule('bt0243', ParallelRule.objects.get(name='bt0243')),
# Rule('bt0003', ParallelRule.objects.get(name='bt0003')),
]
ds = DatasetGenerator.generate_dataset(compounds, reactions, applicable_rules)
self.assertIsNotNone(ds)

55
tests/test_model.py Normal file
View File

@ -0,0 +1,55 @@
import json
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Compound, User, CompoundStructure, Reaction, Rule, MLRelativeReasoning
class ModelTest(TestCase):
fixtures = ["test_fixture.cleaned.json"]
def setUp(self):
pass
@classmethod
def setUpClass(cls):
super(ModelTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
bbd_data = json.load(open('fixtures/packages/2025-07-18/EAWAG-BBD.json'))
cls.BBD = PackageManager.import_package(bbd_data, cls.user)
@classmethod
def tearDownClass(cls):
pass
def tearDown(self):
pass
def test_smoke(self):
threshold = float(0.5)
# get Package objects from urls
rule_package_objs = [self.BBD]
data_package_objs = [self.BBD]
eval_packages_objs = []
mod = MLRelativeReasoning.create(
self.package,
'ECC - BBD - 0.5',
'Created MLRelativeReasoning in Testcase',
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold
)
ds = mod.load_dataset()
mod.build_model()
print("Model built!")
mod.evaluate_model()
print("Model Evaluated")
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
print(results)

View File

@ -131,6 +131,24 @@ class FormatConverter(object):
# TODO call to AMBIT Service # TODO call to AMBIT Service
return smiles return smiles
@staticmethod
def ep_standardize(smiles):
change = True
while change:
change = False
for standardizer in MATCH_STANDARDIZER:
tmp_smiles = standardizer.standardize(smiles)
if tmp_smiles != smiles:
print(f"change {smiles} to {tmp_smiles}")
change = True
smiles = tmp_smiles
if change is False:
print(f"nothing changed")
return smiles
@staticmethod @staticmethod
def standardize(smiles): def standardize(smiles):
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/ # Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
@ -180,54 +198,6 @@ class FormatConverter(object):
atom.UpdatePropertyCache() atom.UpdatePropertyCache()
return mol return mol
# @staticmethod
# def apply(smiles, smirks, preprocess_smiles=True, bracketize=False, standardize=True):
# logger.debug(f'Applying {smirks} on {smiles}')
#
# if bracketize:
# smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")"
#
# res = set()
# try:
# rxn = rdChemReactions.ReactionFromSmarts(smirks)
# mol = Chem.MolFromSmiles(smiles)
#
# # Inplace
# if preprocess_smiles:
# Chem.SanitizeMol(mol)
# mol = Chem.AddHs(mol)
#
# # apply!
# reacts = rxn.RunReactants((mol,))
# if len(reacts):
# # Sanitize mols
# for product_set in reacts:
# prod_set = list()
# for product in product_set:
# # Fixes
# # [2025-01-30 23:00:50] ERROR chem - Sanitizing and converting failed:
# # non-ring atom 3 marked aromatic
# # But does not improve overall performance
# #
# # for a in product.GetAtoms():
# # if (not a.IsInRing()) and a.GetIsAromatic():
# # a.SetIsAromatic(False)
# # for b in product.GetBonds():
# # if (not b.IsInRing()) and b.GetIsAromatic():
# # b.SetIsAromatic(False)
#
# try:
# Chem.SanitizeMol(product)
# prod_set.append(FormatConverter.standardize(Chem.MolToSmiles(product)))
# except ValueError as e:
# logger.error(f'Sanitizing and converting failed:\n{e}')
# continue
# res.add(tuple(list(set(prod_set))))
# except Exception as e:
# logger.error(f'Applying {smirks} on {smiles} failed:\n{e}')
#
# return list(res)
@staticmethod @staticmethod
def is_valid_smirks(smirks: str) -> bool: def is_valid_smirks(smirks: str) -> bool:
try: try:

View File

@ -1,46 +1,29 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import logging
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Dict, Set, Tuple
import numpy as np import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
# @dataclasses.dataclass logger = logging.getLogger(__name__)
# class Feature:
# name: str
# value: float
#
#
#
# class Row:
# def __init__(self, compound_uuid: str, compound_smiles: str, descriptors: List[int]):
# self.data = {}
#
#
#
# class DataSet(object):
#
# def __init__(self):
# self.rows: List[Row] = []
#
# def add_row(self, row: Row):
# pass
from dataclasses import dataclass, field from dataclasses import dataclass, field
from utilities.chem import FormatConverter from utilities.chem import FormatConverter, PredictionResult
@dataclass @dataclass
class Compound: class SCompound:
smiles: str smiles: str
uuid: str = field(default=None, compare=False, hash=False) uuid: str = field(default=None, compare=False, hash=False)
@ -53,10 +36,10 @@ class Compound:
@dataclass @dataclass
class Reaction: class SReaction:
educts: List[Compound] educts: List[SCompound]
products: List[Compound] products: List[SCompound]
rule_uuid: str = field(default=None, compare=False, hash=False) rule_uuid: SRule = field(default=None, compare=False, hash=False)
reaction_uuid: str = field(default=None, compare=False, hash=False) reaction_uuid: str = field(default=None, compare=False, hash=False)
def __hash__(self): def __hash__(self):
@ -68,77 +51,294 @@ class Reaction:
return self._hash return self._hash
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Reaction): if not isinstance(other, SReaction):
return NotImplemented return NotImplemented
return ( return (
sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and
sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles) sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles)
) )
class Dataset(object): @dataclass
class SRule(ABC):
def __init__(self, headers=List['str'], data=List[List[str|int|float]]): @abstractmethod
self.headers = headers def apply(self):
self.data = data
def features(self):
pass
def labels(self):
pass
def to_json(self):
pass
def to_csv(self):
pass
def to_arff(self):
pass pass
@dataclass
class SSimpleRule:
pass
class DatasetGenerator(object):
@dataclass
class SParallelRule:
pass
class Dataset:
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
self.columns: List[str] = columns
self.num_labels: int = num_labels
if data is None:
self.data: List[List[str | int | float]] = list()
else:
self.data = data
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
self._triggered: Tuple[int, int] = self._block_indices('trig_')
self._observed: Tuple[int, int] = self._block_indices('obs_')
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
def structure_id(self):
return self.data[0][0]
def add_row(self, row: List[str | int | float]):
if len(self.columns) != len(row):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
self.data.append(row)
def struct_features(self) -> Tuple[int, int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
return self._observed
def at(self, position: int) -> Dataset:
return Dataset(self.columns, self.num_labels, [self.data[position]])
def limit(self, limit: int) -> Dataset:
return Dataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
if isinstance(struct, str):
struct_id = None
struct_smiles = struct
else:
struct_id = str(struct.uuid)
struct_smiles = struct.smiles
features = FormatConverter.maccs(struct_smiles)
trig = []
prods = []
for rule in applicable_rules:
products = rule.apply(struct_smiles)
if len(products):
trig.append(1)
prods.append(products)
else:
trig.append(0)
prods.append([])
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
@staticmethod @staticmethod
def generate_dataset(compounds: List[Compound], reactions: List[Reaction], applicable_rules: 'Rule', def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
compounds_to_exclude: Optional[Compound] = None, educts_only: bool = False) -> Dataset: _structures = set()
rows = [] for r in reactions:
for e in r.educts.all():
_structures.add(e)
if educts_only: if not educts_only:
compounds = set() for e in r.products:
for r in reactions: _structures.add(e)
for e in r.educts:
compounds.add(e)
compounds = list(compounds)
total = len(compounds) compounds = sorted(_structures, key=lambda x: x.url)
for i, c in enumerate(compounds):
row = [] triggered: Dict[str, Set[str]] = defaultdict(set)
print(f"{i + 1}/{total} - {c.smiles}") observed: Set[str] = set()
for r in applicable_rules:
product_sets = r.rule.apply(c.smiles) # Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0: if len(product_sets) == 0:
row.append([])
continue continue
#triggered.add(f"{r.uuid} + {c.uuid}") key = f"{rule.uuid} + {comp.uuid}"
reacts = set()
for ps in product_sets:
products = []
for p in ps:
products.append(Compound(FormatConverter.standardize(p)))
reacts.add(Reaction([c], products, r)) if key in triggered:
row.append(list(reacts)) logger.info(f"{key} already present. Duplicate reaction?")
rows.append(row) for prod_set in product_sets:
for smi in prod_set:
return rows try:
smi = FormatConverter.standardize(smi)
except Exception:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
pass
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi)
except Exception as e:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
pass
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
else:
pass
ds = None
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
if ds is None:
header = ['structure_id'] + \
[f'feature_{i}' for i, _ in enumerate(feat)] \
+ [f'trig_{r.uuid}' for r in applicable_rules] \
+ [f'obs_{r.uuid}' for r in applicable_rules]
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
row_key, col_key = key
# Normalize rows
if isinstance(row_key, int):
rows = [self.data[row_key]]
else:
rows = self.data[row_key]
# Normalize columns
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
else [row[i] for i in col_key] for row in rows]
return res
def save(self, path: 'Path'):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path'):
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: 'Path'):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]:
if c == 'structure_id':
arff += f"@attribute {c} string\n"
else:
arff += f"@attribute {c} {{0,1}}\n"
arff += f"\n@data\n"
for d in self.data:
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
arff += f'{ys},{xs}\n'
with open(path, "w") as fh:
fh.write(arff)
fh.flush()
def __repr__(self):
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
class SparseLabelECC(BaseEstimator, ClassifierMixin): class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -166,8 +366,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
self.keep_columns_.append(col) self.keep_columns_.append(col)
y_reduced = y[:, self.keep_columns_] y_reduced = y[:, self.keep_columns_]
self.chains_ = [ClassifierChain(self.base_clf, order='random', random_state=i) self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
for i in range(self.num_chains)]
for i, chain in enumerate(self.chains_): for i, chain in enumerate(self.chains_):
print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}") print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
@ -208,26 +407,169 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
return accuracy_score(y_true, y_pred, sample_weight=sample_weight) return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
class ApplicabilityDomain(PCA):
def __init__(self, n_components=5): import copy
super().__init__(n_components=n_components)
import numpy as np
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
class BinaryRelevance:
def __init__(self, baseline_clf):
self.clf = baseline_clf
self.classifiers = None
def fit(self, X, Y):
if self.classifiers is None:
self.classifiers = []
for l in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, l])]
Y_l = (Y[~np.isnan(Y[:, l]), l])
if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf.fit([X[0]], [0])
self.classifiers.append(clf)
continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
else:
clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l)
self.classifiers.append(clf)
def predict(self, X):
labels = []
for clf in self.classifiers:
labels.append(clf.predict(X))
return np.column_stack(labels)
def predict_proba(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict_proba(X)
if pred.shape[1] > 1:
pred = pred[:, 1]
else:
pred = pred * clf.predict([X[0]])[0]
labels = np.column_stack((labels, pred))
return labels
class MissingValuesClassifierChain:
def __init__(self, base_clf):
self.base_clf = base_clf
self.permutation = None
self.classifiers = None
def fit(self, X, Y):
X = np.array(X)
Y = np.array(Y)
if self.permutation is None:
self.permutation = np.random.permutation(len(Y[0]))
Y = Y[:, self.permutation]
if self.classifiers is None:
self.classifiers = []
for p in range(len(self.permutation)):
X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
self.classifiers.append(clf.fit(X_p, Y_p))
else:
clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p]
pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
X = np.column_stack((X, newcol))
def predict(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict(np.column_stack((X, labels)))
labels = np.column_stack((labels, pred))
return labels[:, np.argsort(self.permutation)]
def predict_proba(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict_proba(np.column_stack((X, np.round(labels))))
if pred.shape[1] > 1:
pred = pred[:, 1]
else:
pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0]
labels = np.column_stack((labels, pred))
return labels[:, np.argsort(self.permutation)]
class EnsembleClassifierChain:
def __init__(self, base_clf, num_chains=10):
self.base_clf = base_clf
self.num_chains = num_chains
self.num_labels = None
self.classifiers = None
def fit(self, X, Y):
if self.classifiers is None:
self.classifiers = []
if self.num_labels is None:
self.num_labels = len(Y[0])
for p in range(self.num_chains):
print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
clf = MissingValuesClassifierChain(self.base_clf)
clf.fit(X, Y)
self.classifiers.append(clf)
def predict(self, X):
labels = np.zeros((len(X), self.num_labels))
for clf in self.classifiers:
labels += clf.predict(X)
return np.round(labels / self.num_chains)
def predict_proba(self, X):
labels = np.zeros((len(X), self.num_labels))
for clf in self.classifiers:
labels += clf.predict_proba(X)
return labels / self.num_chains
class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler() self.scaler = StandardScaler()
self.num_neighbours = num_neighbours
self.min_vals = None self.min_vals = None
self.max_vals = None self.max_vals = None
def build(self, X): def build(self, train_dataset: 'Dataset'):
# transform # transform
X_scaled = self.scaler.fit_transform(X) X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca # fit pca
X_pca = self.fit_transform(X_scaled) X_pca = self.fit_transform(X_scaled)
self.max_vals = np.max(X_pca, axis=0) self.max_vals = np.max(X_pca, axis=0)
self.min_vals = np.min(X_pca, axis=0) self.min_vals = np.min(X_pca, axis=0)
def is_applicable(self, instances): def __transform(self, instances):
instances_scaled = self.scaler.transform(instances) instances_scaled = self.scaler.transform(instances)
instances_pca = self.transform(instances_scaled) instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: 'Dataset'):
instances_pca = self.__transform(classify_instances.X())
is_applicable = [] is_applicable = []
for i, instance in enumerate(instances_pca): for i, instance in enumerate(instances_pca):
@ -237,3 +579,17 @@ class ApplicabilityDomain(PCA):
is_applicable[i] = False is_applicable[i] = False
return is_applicable return is_applicable
def tanimoto_distance(a: List[int], b: List[int]):
if len(a) != len(b):
raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}")
sum_a = sum(a)
sum_b = sum(b)
sum_c = sum(v1 and v2 for v1, v2 in zip(a, b))
if sum_a + sum_b - sum_c == 0:
return 0.0
return 1 - (sum_c / (sum_a + sum_b - sum_c))