forked from enviPath/enviPy
Experimental App Domain (#43)
Backend App Domain done, Frontend missing Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#43
This commit is contained in:
@ -261,7 +261,7 @@ CELERY_ACCEPT_CONTENT = ['json']
|
||||
CELERY_TASK_SERIALIZER = 'json'
|
||||
|
||||
MODEL_BUILDING_ENABLED = os.environ.get('MODEL_BUILDING_ENABLED', 'False') == 'True'
|
||||
|
||||
APPLICABILITY_DOMAIN_ENABLED = os.environ.get('APPLICABILITY_DOMAIN_ENABLED', 'False') == 'True'
|
||||
DEFAULT_RF_MODEL_PARAMS = {
|
||||
'base_clf': RandomForestClassifier(
|
||||
n_estimators=100,
|
||||
@ -275,14 +275,14 @@ DEFAULT_RF_MODEL_PARAMS = {
|
||||
'num_chains': 10,
|
||||
}
|
||||
|
||||
DEFAULT_DT_MODEL_PARAMS = {
|
||||
DEFAULT_MODEL_PARAMS = {
|
||||
'base_clf': DecisionTreeClassifier(
|
||||
criterion='entropy',
|
||||
max_depth=3,
|
||||
min_samples_split=5,
|
||||
min_samples_leaf=5,
|
||||
# min_samples_leaf=5,
|
||||
max_features='sqrt',
|
||||
class_weight='balanced',
|
||||
# class_weight='balanced',
|
||||
random_state=42
|
||||
),
|
||||
'num_chains': 10,
|
||||
@ -322,4 +322,5 @@ FLAGS = {
|
||||
'PLUGINS': PLUGINS_ENABLED,
|
||||
'SENTRY': SENTRY_ENABLED,
|
||||
'ENVIFORMER': ENVIFORMER_PRESENT,
|
||||
'APPLICABILITY_DOMAIN': APPLICABILITY_DOMAIN_ENABLED,
|
||||
}
|
||||
|
||||
@ -1,40 +1,105 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from .models import User, Group, UserPackagePermission, GroupPackagePermission, Setting, SimpleAmbitRule, Scenario
|
||||
from .models import (
|
||||
User,
|
||||
UserPackagePermission,
|
||||
Group,
|
||||
GroupPackagePermission,
|
||||
Package,
|
||||
MLRelativeReasoning,
|
||||
Compound,
|
||||
CompoundStructure,
|
||||
SimpleAmbitRule,
|
||||
ParallelRule,
|
||||
Reaction,
|
||||
Pathway,
|
||||
Node,
|
||||
Edge,
|
||||
Scenario,
|
||||
Setting
|
||||
)
|
||||
|
||||
|
||||
class UserAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class GroupAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class UserPackagePermissionAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class GroupAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class GroupPackagePermissionAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class SettingAdmin(admin.ModelAdmin):
|
||||
class EPAdmin(admin.ModelAdmin):
|
||||
search_fields = ['name', 'description']
|
||||
|
||||
|
||||
class PackageAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
class MLRelativeReasoningAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleAmbitRuleAdmin(admin.ModelAdmin):
|
||||
class CompoundAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class ScenarioAdmin(admin.ModelAdmin):
|
||||
class CompoundStructureAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleAmbitRuleAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRuleAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class ReactionAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class PathwayAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class NodeAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class EdgeAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class ScenarioAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class SettingAdmin(EPAdmin):
|
||||
pass
|
||||
|
||||
|
||||
admin.site.register(User, UserAdmin)
|
||||
admin.site.register(Group, GroupAdmin)
|
||||
admin.site.register(UserPackagePermission, UserPackagePermissionAdmin)
|
||||
admin.site.register(Group, GroupAdmin)
|
||||
admin.site.register(GroupPackagePermission, GroupPackagePermissionAdmin)
|
||||
admin.site.register(Setting, SettingAdmin)
|
||||
admin.site.register(Package, PackageAdmin)
|
||||
admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin)
|
||||
admin.site.register(Compound, CompoundAdmin)
|
||||
admin.site.register(CompoundStructure, CompoundStructureAdmin)
|
||||
admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin)
|
||||
admin.site.register(ParallelRule, ParallelRuleAdmin)
|
||||
admin.site.register(Reaction, ReactionAdmin)
|
||||
admin.site.register(Pathway, PathwayAdmin)
|
||||
admin.site.register(Node, NodeAdmin)
|
||||
admin.site.register(Edge, EdgeAdmin)
|
||||
admin.site.register(Setting, SettingAdmin)
|
||||
admin.site.register(Scenario, ScenarioAdmin)
|
||||
|
||||
@ -339,7 +339,7 @@ class PackageManager(object):
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def import_package(data: dict, owner: User, keep_ids=False):
|
||||
def import_package(data: dict, owner: User, keep_ids=False, add_import_timestamp=True):
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
@ -349,7 +349,12 @@ class PackageManager(object):
|
||||
|
||||
pack = Package()
|
||||
pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4()
|
||||
|
||||
if add_import_timestamp:
|
||||
pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M'))
|
||||
else:
|
||||
pack.name = data['name']
|
||||
|
||||
pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False
|
||||
pack.description = data['description']
|
||||
pack.save()
|
||||
|
||||
@ -58,7 +58,7 @@ class Command(BaseCommand):
|
||||
return anon, admin, g, jebus
|
||||
|
||||
def import_package(self, data, owner):
|
||||
return PackageManager.import_package(data, owner, keep_ids=True)
|
||||
return PackageManager.import_package(data, owner, keep_ids=True, add_import_timestamp=False)
|
||||
|
||||
def create_default_setting(self, owner, packages):
|
||||
s = SettingManager.create_setting(
|
||||
|
||||
484
epdb/models.py
484
epdb/models.py
@ -3,8 +3,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Union, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union, List, Optional, Dict, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import joblib
|
||||
@ -14,7 +14,7 @@ from django.contrib.auth.hashers import make_password, check_password
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import models, transaction
|
||||
from django.db.models import JSONField, Count, Q
|
||||
from django.db.models import JSONField, Count, Q, QuerySet
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
from model_utils.models import TimeStampedModel
|
||||
@ -23,7 +23,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import SparseLabelECC
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -172,6 +172,9 @@ class EnviPathModel(TimeStampedModel):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name} (pk={self.pk})"
|
||||
|
||||
|
||||
class AliasMixin(models.Model):
|
||||
aliases = ArrayField(
|
||||
@ -844,7 +847,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
# We shouldn't lose or make up nodes...
|
||||
assert len(nodes) == len(self.nodes)
|
||||
print(f"Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
|
||||
logger.debug(f"{self.name}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
|
||||
|
||||
links = [e.d3_json() for e in self.edges]
|
||||
|
||||
@ -1136,19 +1139,44 @@ class MLRelativeReasoning(EPModel):
|
||||
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
|
||||
def status(self):
|
||||
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package, name, description, rule_packages, data_packages, eval_packages, threshold):
|
||||
def create(package: 'Package', rule_packages: List['Package'],
|
||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
|
||||
mlrr = MLRelativeReasoning()
|
||||
mlrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mlrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mlrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mlrr.threshold = threshold
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
mlrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mlrr.rule_packages.add(p)
|
||||
|
||||
@ -1163,11 +1191,17 @@ class MLRelativeReasoning(EPModel):
|
||||
for p in eval_packages:
|
||||
mlrr.eval_packages.add(p)
|
||||
|
||||
if build_app_domain:
|
||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
app_domain_local_compatibility_threshold)
|
||||
mlrr.app_domain = ad
|
||||
|
||||
mlrr.save()
|
||||
|
||||
return mlrr
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self):
|
||||
def applicable_rules(self) -> List['Rule']:
|
||||
"""
|
||||
Returns a ordered set of rules where the following applies:
|
||||
1. All Composite will be added to result
|
||||
@ -1195,6 +1229,7 @@ class MLRelativeReasoning(EPModel):
|
||||
rules.append(r)
|
||||
|
||||
rules = sorted(rules, key=lambda x: x.url)
|
||||
|
||||
return rules
|
||||
|
||||
def _get_excludes(self):
|
||||
@ -1209,197 +1244,79 @@ class MLRelativeReasoning(EPModel):
|
||||
pathway_qs = pathway_qs.distinct()
|
||||
return pathway_qs
|
||||
|
||||
|
||||
def _get_reactions(self) -> QuerySet:
|
||||
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
def build_dataset(self):
|
||||
self.model_status = self.INITIALIZING
|
||||
self.save()
|
||||
from datetime import datetime
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
applicable_rules = self.applicable_rules
|
||||
print("got rules")
|
||||
|
||||
# if s.DEBUG:
|
||||
# pathways = self._get_pathways().order_by('-name')[:20]
|
||||
# else:
|
||||
pathways = self._get_pathways()
|
||||
|
||||
print("got pathways")
|
||||
excludes = self._get_excludes()
|
||||
|
||||
# Collect all compounds
|
||||
compounds = set()
|
||||
reactions = set()
|
||||
for i, p in enumerate(pathways):
|
||||
print(f"{i + 1}/{len(pathways)}...")
|
||||
for n in p.nodes:
|
||||
cs = n.default_node_label.compound.default_structure
|
||||
# TODO too many lookups
|
||||
if cs.smiles in excludes:
|
||||
continue
|
||||
|
||||
compounds.add(cs)
|
||||
|
||||
for e in p.edges:
|
||||
reactions.add(e.edge_label)
|
||||
|
||||
print(len(compounds))
|
||||
print(len(reactions))
|
||||
|
||||
triggered = set()
|
||||
observed = set()
|
||||
|
||||
# TODO naming
|
||||
|
||||
pw = defaultdict(lambda: defaultdict(set))
|
||||
|
||||
for i, c in enumerate(compounds):
|
||||
print(f"{i + 1}/{len(compounds)}...")
|
||||
for r in applicable_rules:
|
||||
# TODO check normalization
|
||||
product_sets = r.apply(c.smiles)
|
||||
|
||||
if len(product_sets) == 0:
|
||||
continue
|
||||
|
||||
triggered.add(f"{r.uuid} + {c.uuid}")
|
||||
|
||||
for ps in product_sets:
|
||||
for p in ps:
|
||||
pw[c][r].add(p)
|
||||
|
||||
for r in reactions:
|
||||
if r is None:
|
||||
print(r)
|
||||
continue
|
||||
if len(r.educts.all()) != 1:
|
||||
print(f"Skipping {r.url}")
|
||||
continue
|
||||
|
||||
# Loop will run only once
|
||||
for c in r.educts.all():
|
||||
if c not in pw:
|
||||
continue
|
||||
|
||||
for rule in pw[c].keys():
|
||||
# standardize...
|
||||
|
||||
if 0 != len(pw[c][rule]) and len(pw[c][rule]) == len(r.products.all()):
|
||||
print(f"potential match for {c.smiles} and {r.uuid} ({r.name})")
|
||||
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
pass
|
||||
|
||||
standardized_products.append(smi)
|
||||
|
||||
standardized_pred_products = []
|
||||
for smi in pw[c][rule]:
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
pass
|
||||
|
||||
standardized_pred_products.append(smi)
|
||||
|
||||
if sorted(list(set(standardized_products))) == sorted(list(set(standardized_pred_products))):
|
||||
observed.add(f"{rule.uuid} + {c.uuid}")
|
||||
print(f"Adding observed, current count {len(observed)}")
|
||||
|
||||
header = None
|
||||
X = []
|
||||
y = []
|
||||
for i, c in enumerate(compounds):
|
||||
print(f'{i + 1}/{len(compounds)}...')
|
||||
# Features
|
||||
feat = FormatConverter.maccs(c.smiles)
|
||||
trig = []
|
||||
obs = []
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {c.uuid}"
|
||||
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
else:
|
||||
obs.append(0)
|
||||
|
||||
if header is None:
|
||||
header = [f'feature_{i}' for i, _ in enumerate(feat)] \
|
||||
+ [f'trig_{r.uuid}' for r in applicable_rules] \
|
||||
+ [f'corr_{r.uuid}' for r in applicable_rules]
|
||||
X.append(feat + trig)
|
||||
y.append(obs)
|
||||
reactions = list(self._get_reactions())
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
|
||||
end = datetime.now()
|
||||
print(f"Duration {(end - start).total_seconds()}s")
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
|
||||
data = {
|
||||
'X': X,
|
||||
'y': y,
|
||||
'header': header
|
||||
}
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
|
||||
json.dump(data, open(f, 'w'))
|
||||
return X, y
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
|
||||
return json.load(open(ds_path, 'r'))
|
||||
def load_dataset(self) -> 'Dataset':
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
def build_model(self, X, y):
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
mod = SparseLabelECC(
|
||||
**s.DEFAULT_DT_MODEL_PARAMS
|
||||
)
|
||||
start = datetime.now()
|
||||
|
||||
ds = self.load_dataset()
|
||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||
|
||||
mod = EnsembleClassifierChain(
|
||||
**s.DEFAULT_MODEL_PARAMS
|
||||
)
|
||||
mod.fit(X, y)
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.pkl")
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
if self.app_domain is not None:
|
||||
logger.debug("Building applicability domain...")
|
||||
self.app_domain.build()
|
||||
logger.debug("Done building applicability domain.")
|
||||
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
def rebuild(self):
|
||||
data = self.load_dataset()
|
||||
self.build_model(data['X'], data['y'])
|
||||
self.build_model()
|
||||
|
||||
def evaluate_model(self):
|
||||
"""
|
||||
Performs Leave-One-Out cross-validation on a multi-label dataset.
|
||||
|
||||
Parameters:
|
||||
X (list of lists): Feature matrix.
|
||||
y (list of lists): Multi-label targets.
|
||||
classifier (sklearn estimator, optional): Base classifier. Defaults to RandomForest.
|
||||
|
||||
Returns:
|
||||
float: Average accuracy across all LOO splits.
|
||||
"""
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json")
|
||||
data = json.load(open(f))
|
||||
ds = self.load_dataset()
|
||||
|
||||
X = np.array(data['X'])
|
||||
y = np.array(data['y'])
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
|
||||
n_splits = 20
|
||||
|
||||
@ -1409,22 +1326,32 @@ class MLRelativeReasoning(EPModel):
|
||||
X_train, X_test = X[train_index], X[test_index]
|
||||
y_train, y_test = y[train_index], y[test_index]
|
||||
|
||||
model = SparseLabelECC(
|
||||
**s.DEFAULT_DT_MODEL_PARAMS
|
||||
model = EnsembleClassifierChain(
|
||||
**s.DEFAULT_MODEL_PARAMS
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict_proba(X_test)
|
||||
y_thresholded = (y_pred >= threshold).astype(int)
|
||||
|
||||
acc = jaccard_score(y_test, y_thresholded, average='samples', zero_division=0)
|
||||
# Flatten them to get rid of np.nan
|
||||
y_test = np.asarray(y_test).flatten()
|
||||
y_pred = np.asarray(y_pred).flatten()
|
||||
y_thresholded = np.asarray(y_thresholded).flatten()
|
||||
|
||||
mask = ~np.isnan(y_test)
|
||||
y_test_filtered = y_test[mask]
|
||||
y_pred_filtered = y_pred[mask]
|
||||
y_thresholded_filtered = y_thresholded[mask]
|
||||
|
||||
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
||||
|
||||
prec, rec = dict(), dict()
|
||||
|
||||
for t in np.arange(0, 1.05, 0.05):
|
||||
temp_thresholded = (y_pred >= t).astype(int)
|
||||
prec[f"{t:.2f}"] = precision_score(y_test, temp_thresholded, average='samples', zero_division=0)
|
||||
rec[f"{t:.2f}"] = recall_score(y_test, temp_thresholded, average='samples', zero_division=0)
|
||||
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
||||
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
|
||||
return acc, prec, rec
|
||||
|
||||
@ -1462,38 +1389,30 @@ class MLRelativeReasoning(EPModel):
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}.pkl'))
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
mod.base_clf.n_jobs = -1
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List['PredictionResult']:
|
||||
start = datetime.now()
|
||||
features = FormatConverter.maccs(smiles)
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
pred = self.model.predict_proba(classify_ds.X())
|
||||
|
||||
trig = []
|
||||
prods = []
|
||||
for rule in self.applicable_rules:
|
||||
products = rule.apply(smiles)
|
||||
|
||||
if len(products):
|
||||
trig.append(1)
|
||||
prods.append(products)
|
||||
else:
|
||||
trig.append(0)
|
||||
prods.append([])
|
||||
|
||||
end_ds_gen = datetime.now()
|
||||
logger.info(f"Gen predict dataset took {(end_ds_gen - start).total_seconds()}s")
|
||||
pred = self.model.predict_proba([features + trig])
|
||||
|
||||
res = []
|
||||
for rule, p, smis in zip(self.applicable_rules, pred[0], prods):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
res = MLRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||
|
||||
end = datetime.now()
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
for rule, p, smis in zip(rules, probabilities, products):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
return res
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
@ -1515,26 +1434,171 @@ class MLRelativeReasoning(EPModel):
|
||||
class ApplicabilityDomain(EnviPathModel):
|
||||
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
||||
|
||||
num_neighbours = models.FloatField(blank=False, null=False, default=5)
|
||||
num_neighbours = models.IntegerField(blank=False, null=False, default=5)
|
||||
reliability_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
||||
local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
||||
|
||||
def build_applicability_domain(self):
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(mlrr: MLRelativeReasoning, num_neighbours: int = 5, reliability_threshold: float = 0.5,
|
||||
local_compatibility_threshold: float = 0.5):
|
||||
ad = ApplicabilityDomain()
|
||||
ad.model = mlrr
|
||||
# ad.uuid = mlrr.uuid
|
||||
ad.name = f"AD for {mlrr.name}"
|
||||
ad.num_neighbours = num_neighbours
|
||||
ad.reliability_threshold = reliability_threshold
|
||||
ad.local_compatibilty_threshold = local_compatibility_threshold
|
||||
ad.save()
|
||||
return ad
|
||||
|
||||
@cached_property
|
||||
def pca(self) -> ApplicabilityDomainPCA:
|
||||
pca = joblib.load(os.path.join(s.MODEL_DIR, f'{self.model.uuid}_pca.pkl'))
|
||||
return pca
|
||||
|
||||
@cached_property
|
||||
def training_set_probs(self):
|
||||
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
||||
|
||||
def build(self):
|
||||
ds = self.model.load_dataset()
|
||||
X = ds['X']
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
pca = PCA(n_components=5) # choose number of components
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
start = datetime.now()
|
||||
|
||||
max_vals = np.max(X_pca, axis=0)
|
||||
min_vals = np.min(X_pca, axis=0)
|
||||
# Get Trainingset probs and dump them as they're required when using the app domain
|
||||
probs = self.model.model.predict_proba(ds.X())
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
||||
joblib.dump(probs, f)
|
||||
|
||||
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
||||
ad.build(ds)
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds")
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")
|
||||
joblib.dump(ad, f)
|
||||
|
||||
def assess(self, structure: Union[str, 'CompoundStructure']):
|
||||
ds = self.model.load_dataset()
|
||||
|
||||
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
|
||||
|
||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||
# {
|
||||
# assessment_structure_index: {
|
||||
# rule_index: [training_structure_indices_with_same_triggered_reaction]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# For each structure in the assessment dataset and each rule (represented by a trigger feature),
|
||||
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
||||
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
||||
# with a given assessment structure under a particular rule.
|
||||
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
||||
feature = ds.columns[feature_index]
|
||||
if feature.startswith('trig_'):
|
||||
# TODO unroll loop
|
||||
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
||||
if int(cx[feature_index]) == 1:
|
||||
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
||||
if int(tx[feature_index]) == 1:
|
||||
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
||||
|
||||
probs = self.training_set_probs
|
||||
# preds = self.model.model.predict_proba(assessment_ds.X())
|
||||
preds = self.model.combine_products_and_probs(self.model.applicable_rules,
|
||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
||||
assessment_prods[0])
|
||||
|
||||
res = list()
|
||||
|
||||
# loop through our assessment dataset
|
||||
for i, instance in enumerate(assessment_ds):
|
||||
|
||||
rule_reliabilities = dict()
|
||||
local_compatibilities = dict()
|
||||
neighbours_per_rule = dict()
|
||||
|
||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
||||
|
||||
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
||||
# train dataset
|
||||
train_instances = []
|
||||
for v in vals:
|
||||
train_instances.append((v, ds.at(v)))
|
||||
|
||||
# sf is a tuple with start/end index of the features
|
||||
sf = ds.struct_features()
|
||||
|
||||
# compute tanimoto distance for all neighbours
|
||||
# result ist a list of tuples with train index and computed distance
|
||||
dists = self._compute_distances(
|
||||
instance.X()[0][sf[0]:sf[1]],
|
||||
[ti[1].X()[0][sf[0]:sf[1]] for ti in train_instances]
|
||||
)
|
||||
|
||||
dists_with_index = list()
|
||||
for ti, dist in zip(train_instances, dists):
|
||||
dists_with_index.append((ti[0], dist[1]))
|
||||
|
||||
# sort them in a descending way and take at most `self.num_neighbours`
|
||||
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
||||
dists_with_index = dists_with_index[:self.num_neighbours]
|
||||
|
||||
# compute average distance
|
||||
rule_reliabilities[rule_idx] = sum([d[1] for d in dists_with_index]) / len(dists_with_index) if len(dists_with_index) > 0 else 0.0
|
||||
|
||||
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
||||
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
|
||||
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
|
||||
|
||||
# Assemble result for instance
|
||||
res.append({
|
||||
'in_ad': self.pca.is_applicable(instance)[0],
|
||||
'rule_reliabilities': rule_reliabilities,
|
||||
'local_compatibilities': local_compatibilities,
|
||||
'neighbours': neighbours_per_rule,
|
||||
'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]],
|
||||
'prob': preds
|
||||
})
|
||||
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||
from utilities.ml import tanimoto_distance
|
||||
distances = [(i, tanimoto_distance(classify_instance, train)) for i, train in
|
||||
enumerate(train_instances)]
|
||||
return distances
|
||||
|
||||
@staticmethod
|
||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, 'Dataset']]):
|
||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
||||
accuracy = 0.0
|
||||
|
||||
for n in neighbours:
|
||||
obs = n[1].y()[0][rule_idx]
|
||||
pred = preds[n[0]][rule_idx]
|
||||
if obs and pred:
|
||||
tp += 1
|
||||
elif not obs and pred:
|
||||
fp += 1
|
||||
elif obs and not pred:
|
||||
fn += 1
|
||||
else:
|
||||
tn += 1
|
||||
# Jaccard Index
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn);
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
class RuleBaseRelativeReasoning(EPModel):
|
||||
@ -1574,10 +1638,6 @@ class EnviFormer(EPModel):
|
||||
logger.info(f"Submitting {kek} to {hash(self.model)}")
|
||||
products = self.model.predict(kek)
|
||||
logger.info(f"Got results {products}")
|
||||
# from pprint import pprint
|
||||
#
|
||||
# print(smiles)
|
||||
# pprint(products)
|
||||
|
||||
res = []
|
||||
for smi, prob in products.items():
|
||||
@ -1715,9 +1775,7 @@ class Setting(EnviPathModel):
|
||||
|
||||
transformations = []
|
||||
if self.model is not None:
|
||||
print(self.model)
|
||||
pred_results = self.model.predict(current_node.smiles)
|
||||
print(pred_results)
|
||||
for pred_result in pred_results:
|
||||
if pred_result.probability >= self.model_threshold:
|
||||
transformations.append(pred_result)
|
||||
|
||||
@ -31,8 +31,8 @@ def send_registration_mail(user_pk: int):
|
||||
@shared_task(queue='model')
|
||||
def build_model(model_pk: int):
|
||||
mod = EPModel.objects.get(id=model_pk)
|
||||
X, y = mod.build_dataset()
|
||||
mod.build_model(X, y)
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
|
||||
|
||||
@shared_task(queue='model')
|
||||
|
||||
@ -103,7 +103,10 @@ def login(request):
|
||||
else:
|
||||
context['message'] = "Account has been created! You'll receive a mail to activate your account shortly."
|
||||
return render(request, 'login.html', context)
|
||||
|
||||
else:
|
||||
return HttpResponseBadRequest()
|
||||
else:
|
||||
return HttpResponseNotAllowed(['GET', 'POST'])
|
||||
|
||||
def logout(request):
|
||||
if request.method == 'POST':
|
||||
@ -136,7 +139,7 @@ def editable(request, user):
|
||||
f"{s.SERVER_URL}/group", f"{s.SERVER_URL}/search"]:
|
||||
return True
|
||||
else:
|
||||
print(f"Unknown url: {url}")
|
||||
logger.debug(f"Unknown url: {url}")
|
||||
return False
|
||||
|
||||
|
||||
@ -584,6 +587,9 @@ def package_models(request, package_uuid):
|
||||
return render(request, 'collections/objects_list.html', context)
|
||||
|
||||
elif request.method == 'POST':
|
||||
|
||||
log_post_params(request)
|
||||
|
||||
name = request.POST.get('model-name')
|
||||
description = request.POST.get('model-description')
|
||||
|
||||
@ -606,14 +612,25 @@ def package_models(request, package_uuid):
|
||||
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
|
||||
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
|
||||
|
||||
# App Domain related parameters
|
||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||
num_neighbors = request.POST.get('num-neighbors', 5)
|
||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
current_package,
|
||||
name,
|
||||
description,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold
|
||||
package=current_package,
|
||||
name=name,
|
||||
description=description,
|
||||
rule_packages=rule_package_objs,
|
||||
data_packages=data_package_objs,
|
||||
eval_packages=eval_packages_objs,
|
||||
threshold=threshold,
|
||||
# fingerprinter=fingerprinter,
|
||||
build_app_domain=build_ad,
|
||||
app_domain_num_neighbours=num_neighbors,
|
||||
app_domain_reliability_threshold=reliability_threshold,
|
||||
app_domain_local_compatibility_threshold=local_compatibility_threshold,
|
||||
)
|
||||
|
||||
from .tasks import build_model
|
||||
@ -649,7 +666,7 @@ def package_model(request, package_uuid, model_uuid):
|
||||
if len(pr) > 0:
|
||||
products = []
|
||||
for prod_set in pr.product_sets:
|
||||
print(f"Checking {prod_set}")
|
||||
logger.debug(f"Checking {prod_set}")
|
||||
products.append(tuple([x for x in prod_set]))
|
||||
|
||||
res.append({
|
||||
@ -660,6 +677,12 @@ def package_model(request, package_uuid, model_uuid):
|
||||
|
||||
return JsonResponse(res, safe=False)
|
||||
|
||||
elif request.GET.get('app-domain-assessment', False):
|
||||
smiles = request.GET['smiles']
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
|
||||
return JsonResponse(app_domain_assessment, safe=False)
|
||||
|
||||
context = get_base_context(request)
|
||||
context['title'] = f'enviPath - {current_package.name} - {current_model.name}'
|
||||
|
||||
@ -1717,8 +1740,6 @@ def user(request, user_uuid):
|
||||
}
|
||||
}
|
||||
|
||||
print(setting)
|
||||
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
else:
|
||||
@ -1781,9 +1802,7 @@ def group(request, group_uuid):
|
||||
|
||||
elif request.method == 'POST':
|
||||
|
||||
if s.DEBUG:
|
||||
for k, v in request.POST.items():
|
||||
print(k, v)
|
||||
log_post_params(request)
|
||||
|
||||
if hidden := request.POST.get('hidden', None):
|
||||
if hidden == 'delete-group':
|
||||
|
||||
@ -16,14 +16,14 @@
|
||||
<div class="jumbotron">Create a new Model to
|
||||
limit the number of degradation products in the
|
||||
prediction. You just need to set a name and the packages
|
||||
you want the object to be based on. If you want to use the
|
||||
default options suggested by us, simply click Submit,
|
||||
otherwise click Advanced Options.
|
||||
you want the object to be based on. There are multiple types of models available.
|
||||
For additional information have a look at our
|
||||
<a target="_blank" href="https://wiki.envipath.org/index.php/relative-reasoning" role="button">wiki >></a>
|
||||
</div>
|
||||
<label for="name">Name</label>
|
||||
<input id="name" name="model-name" class="form-control" placeholder="Name"/>
|
||||
<label for="description">Description</label>
|
||||
<input id="description" name="model-description" class="form-control"
|
||||
<label for="model-name">Name</label>
|
||||
<input id="model-name" name="model-name" class="form-control" placeholder="Name"/>
|
||||
<label for="model-description">Description</label>
|
||||
<input id="model-description" name="model-description" class="form-control"
|
||||
placeholder="Description"/>
|
||||
<label for="model-type">Model Type</label>
|
||||
<select id="model-type" name="model-type" class="form-control" data-width='100%'>
|
||||
@ -35,7 +35,7 @@
|
||||
<!-- ML Based Form-->
|
||||
<div id="ml-relative-reasoning-specific-form">
|
||||
<!-- Rule Packages -->
|
||||
<label>Rule Packages</label><br>
|
||||
<label for="ml-relative-reasoning-rule-packages">Rule Packages</label>
|
||||
<select id="ml-relative-reasoning-rule-packages" name="ml-relative-reasoning-rule-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
@ -53,7 +53,7 @@
|
||||
{% endfor %}
|
||||
</select>
|
||||
<!-- Data Packages -->
|
||||
<label>Data Packages</label><br>
|
||||
<label for="ml-relative-reasoning-data-packages" >Data Packages</label>
|
||||
<select id="ml-relative-reasoning-data-packages" name="ml-relative-reasoning-data-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
@ -77,22 +77,24 @@
|
||||
class="form-control">
|
||||
<option value="MACCS" selected>MACCS Fingerprinter</option>
|
||||
</select>
|
||||
{% if meta.enabled_features.PLUGINS %}
|
||||
{% if meta.enabled_features.PLUGINS and additional_descriptors %}
|
||||
<!-- Property Plugins go here -->
|
||||
<label for="ml-relative-reasoning-additional-fingerprinter">Fingerprinter</label>
|
||||
<select id="ml-relative-reasoning-additional-fingerprinter"
|
||||
name="ml-relative-reasoning-additional-fingerprinter"
|
||||
class="form-control">
|
||||
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter / Descriptors</label>
|
||||
<select id="ml-relative-reasoning-additional-fingerprinter" name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
|
||||
<option disabled selected>Select Additional Fingerprinter / Descriptor</option>
|
||||
{% for k, v in additional_descriptors.items %}
|
||||
<option value="{{ v }}">{{ k }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
{% endif %}
|
||||
|
||||
<label for="ml-relative-reasoning-threshold">Threshold</label>
|
||||
<input type="number" min="0" max="1" step="0.05" value="0.5"
|
||||
id="ml-relative-reasoning-threshold"
|
||||
name="ml-relative-reasoning-threshold" class="form-control">
|
||||
|
||||
<!-- Evaluation -->
|
||||
|
||||
<label>Evaluation Packages</label><br>
|
||||
<label for="ml-relative-reasoning-evaluation-packages">Evaluation Packages</label>
|
||||
<select id="ml-relative-reasoning-evaluation-packages" name="ml-relative-reasoning-evaluation-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
@ -110,6 +112,26 @@
|
||||
{% endfor %}
|
||||
</select>
|
||||
|
||||
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
|
||||
<!-- Build AD? -->
|
||||
<div class="checkbox">
|
||||
<label>
|
||||
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an Applicability Domain?
|
||||
</label>
|
||||
</div>
|
||||
<!-- Num Neighbors -->
|
||||
<label for="num-neighbors">Number of Neighbors</label>
|
||||
<input id="num-neighbors" name="num-neighbors" type="number" class="form-control" value="5"
|
||||
step="1" min="0" max="10">
|
||||
<!-- Local Compatibility -->
|
||||
<label for="local-compatibility-threshold">Local Compatibility Threshold</label>
|
||||
<input id="local-compatibility-threshold" name="local-compatibility-threshold" type="number"
|
||||
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
||||
<!-- Reliability -->
|
||||
<label for="reliability-threshold">Reliability Threshold</label>
|
||||
<input id="reliability-threshold" name="reliability-threshold" type="number"
|
||||
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
||||
{% endif %}
|
||||
</div>
|
||||
<!-- Rule Based Based Form-->
|
||||
<div id="rule-based-relative-reasoning-specific-form">
|
||||
@ -118,47 +140,9 @@
|
||||
<!-- EnviFormer-->
|
||||
<div id="enviformer-specific-form">
|
||||
<label for="enviformer-threshold">Threshold</label>
|
||||
<input type="number" min="0" , max="1" step="0.05" value="0.5" id="enviformer-threshold"
|
||||
<input type="number" min="0" max="1" step="0.05" value="0.5" id="enviformer-threshold"
|
||||
name="enviformer-threshold" class="form-control">
|
||||
</div>
|
||||
|
||||
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
|
||||
<div class="modal-body hide" data-step="3" data-title="Advanced Options II">
|
||||
<div class="jumbotron">Selection of parameter values for the Applicability Domain process.
|
||||
Number of Neighbours refers to a requirement on the minimum number of compounds from the
|
||||
training
|
||||
dataset that has at least one triggered transformation rule that is common with the compound
|
||||
being
|
||||
analyzed.
|
||||
Reliability Threshold is a requirement on the average tanimoto distance to the set number of
|
||||
"nearest neighbours" (Number of neighbours with the smallest tanimoto distances).
|
||||
Local Compatibility Threshold is a requirement on the average F1 score determined from the
|
||||
number of
|
||||
nearest neighbours, using their respective precision and recall values computed from the
|
||||
agreement
|
||||
between their observed and triggered rules.
|
||||
You can learn more about it in our wiki!
|
||||
</div>
|
||||
<!-- Use AD? -->
|
||||
<div class="checkbox">
|
||||
<label>
|
||||
<input type="checkbox" id="buildAD" name="buildAD">Also build an Applicability Domain?
|
||||
</label>
|
||||
</div>
|
||||
<!-- Num Neighbours -->
|
||||
<label for="adK">Number of Neighbours</label>
|
||||
<input id="adK" name="adK" type="number" class="form-control" value="5" step="1" min="0"
|
||||
max="10">
|
||||
<!-- F1 Threshold -->
|
||||
<label for="localCompatibilityThreshold">Local Compatibility Threshold</label>
|
||||
<input id="localCompatibilityThreshold" name="localCompatibilityThreshold" type="number"
|
||||
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
||||
<!-- Percentile Threshold -->
|
||||
<label for="reliabilityThreshold">Reliability Threshold</label>
|
||||
<input id="reliabilityThreshold" name="reliabilityThreshold" type="number" class="form-control"
|
||||
value="0.5" step="0.01" min="0" max="1">
|
||||
</div>
|
||||
{% endif %}
|
||||
</form>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
@ -179,6 +163,9 @@ $(function() {
|
||||
$("#ml-relative-reasoning-rule-packages").selectpicker();
|
||||
$("#ml-relative-reasoning-data-packages").selectpicker();
|
||||
$("#ml-relative-reasoning-evaluation-packages").selectpicker();
|
||||
if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) {
|
||||
$("#ml-relative-reasoning-additional-fingerprinter").selectpicker();
|
||||
}
|
||||
|
||||
// On change hide all and show only selected
|
||||
$("#model-type").change(function() {
|
||||
|
||||
@ -90,6 +90,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if model.ready_for_prediction %}
|
||||
<!-- Predict Panel -->
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
@ -106,11 +107,36 @@
|
||||
<button class="btn btn-default" type="submit" id="predict-button">Predict!</button>
|
||||
</span>
|
||||
</div>
|
||||
<div id="loading"></div>
|
||||
<div id="predictLoading"></div>
|
||||
<div id="predictResultTable"></div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- End Predict Panel -->
|
||||
{% endif %}
|
||||
|
||||
{% if model.app_domain %}
|
||||
<!-- App Domain -->
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
<a id="app-domain-assessment-link" data-toggle="collapse" data-parent="#model-detail"
|
||||
href="#app-domain-assessment">Applicability Domain Assessment</a>
|
||||
</h4>
|
||||
</div>
|
||||
<div id="app-domain-assessment" class="panel-collapse collapse in">
|
||||
<div class="panel-body list-group-item">
|
||||
<div class="input-group">
|
||||
<input id="smiles-to-assess" type="text" class="form-control" placeholder="CCN(CC)C(=O)C1=CC(=CC=C1)C">
|
||||
<span class="input-group-btn">
|
||||
<button class="btn btn-default" type="submit" id="assess-button">Assess!</button>
|
||||
</span>
|
||||
</div>
|
||||
<div id="appDomainLoading"></div>
|
||||
<div id="appDomainAssessmentResultTable"></div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- End App Domain -->
|
||||
{% endif %}
|
||||
|
||||
{% if model.model_status == 'FINISHED' %}
|
||||
<!-- Single Gen Curve Panel -->
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
@ -277,9 +303,9 @@
|
||||
$("#predictResultTable").append(res);
|
||||
}
|
||||
|
||||
function clear() {
|
||||
$("#predictResultTable").removeClass("alert alert-danger");
|
||||
$("#predictResultTable").empty();
|
||||
function clear(divid) {
|
||||
$("#" + divid).removeClass("alert alert-danger");
|
||||
$("#" + divid).empty();
|
||||
}
|
||||
|
||||
if ($('#predict-button').length > 0) {
|
||||
@ -291,32 +317,69 @@
|
||||
"classify": "ILikeCats!"
|
||||
}
|
||||
|
||||
clear();
|
||||
clear("predictResultTable");
|
||||
|
||||
makeLoadingGif("#loading", "{% static '/images/wait.gif' %}");
|
||||
makeLoadingGif("#predictLoading", "{% static '/images/wait.gif' %}");
|
||||
$.ajax({
|
||||
type: 'get',
|
||||
data: data,
|
||||
url: '',
|
||||
success: function (data, textStatus) {
|
||||
try {
|
||||
$("#loading").empty();
|
||||
$("#predictLoading").empty();
|
||||
handleResponse(data);
|
||||
} catch (error) {
|
||||
console.log("Error");
|
||||
$("#loading").empty();
|
||||
$("#predictLoading").empty();
|
||||
$("#predictResultTable").addClass("alert alert-danger");
|
||||
$("#predictResultTable").append("Error while processing request :/");
|
||||
}
|
||||
},
|
||||
error: function (jqXHR, textStatus, errorThrown) {
|
||||
$("#loading").empty();
|
||||
$("#predictLoading").empty();
|
||||
$("#predictResultTable").addClass("alert alert-danger");
|
||||
$("#predictResultTable").append("Error while processing request :/");
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if ($('#assess-button').length > 0) {
|
||||
$("#assess-button").on("click", function (e) {
|
||||
e.preventDefault();
|
||||
|
||||
data = {
|
||||
"smiles": $("#smiles-to-assess").val(),
|
||||
"app-domain-assessment": "ILikeCats!"
|
||||
}
|
||||
|
||||
clear("appDomainAssessmentResultTable");
|
||||
|
||||
makeLoadingGif("#appDomainLoading", "{% static '/images/wait.gif' %}");
|
||||
$.ajax({
|
||||
type: 'get',
|
||||
data: data,
|
||||
url: '',
|
||||
success: function (data, textStatus) {
|
||||
try {
|
||||
$("#appDomainLoading").empty();
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.log("Error");
|
||||
$("#appDomainLoading").empty();
|
||||
$("#appDomainAssessmentResultTable").addClass("alert alert-danger");
|
||||
$("#appDomainAssessmentResultTable").append("Error while processing request :/");
|
||||
}
|
||||
},
|
||||
error: function (jqXHR, textStatus, errorThrown) {
|
||||
$("#appDomainLoading").empty();
|
||||
$("#appDomainAssessmentResultTable").addClass("alert alert-danger");
|
||||
$("#appDomainAssessmentResultTable").append("Error while processing request :/");
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
{% endblock content %}
|
||||
|
||||
52
tests/test_dataset.py
Normal file
52
tests/test_dataset.py
Normal file
@ -0,0 +1,52 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import Reaction, Compound, User, Rule
|
||||
from utilities.ml import Dataset
|
||||
|
||||
|
||||
class DatasetTest(TestCase):
|
||||
fixtures = ["test_fixture.cleaned.json"]
|
||||
|
||||
def setUp(self):
|
||||
self.cs1 = Compound.create(
|
||||
self.package,
|
||||
name='2,6-Dibromohydroquinone',
|
||||
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b',
|
||||
smiles='C1=C(C(=C(C=C1O)Br)O)Br',
|
||||
).default_structure
|
||||
|
||||
self.cs2 = Compound.create(
|
||||
self.package,
|
||||
smiles='O=C(O)CC(=O)/C=C(/Br)C(=O)O',
|
||||
).default_structure
|
||||
|
||||
self.rule1 = Rule.create(
|
||||
rule_type='SimpleAmbitRule',
|
||||
package=self.package,
|
||||
smirks='[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\[#6:3]=[#6:2](\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]',
|
||||
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6'
|
||||
)
|
||||
|
||||
self.reaction1 = Reaction.create(
|
||||
package=self.package,
|
||||
educts=[self.cs1],
|
||||
products=[self.cs2],
|
||||
rules=[self.rule1],
|
||||
multi_step=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(DatasetGeneratorTest, cls).setUpClass()
|
||||
cls.user = User.objects.get(username='anonymous')
|
||||
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
|
||||
|
||||
def test_smoke(self):
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
||||
applicable_rules = [self.rule1]
|
||||
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules)
|
||||
|
||||
self.assertEqual(len(ds.y()), 1)
|
||||
self.assertEqual(sum(ds.y()[0]), 1)
|
||||
@ -1,111 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.models import ParallelRule
|
||||
from utilities.ml import Compound, Reaction, DatasetGenerator
|
||||
|
||||
|
||||
class CompoundTest(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C", uuid='c1')
|
||||
self.c2 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C", uuid='c2')
|
||||
|
||||
def test_compound_eq_ignores_uuid(self):
|
||||
self.assertEqual(self.c1, self.c2)
|
||||
|
||||
|
||||
class ReactionTest(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||
self.c2 = Compound(smiles="CCN(CCO)C(=O)C1=CC(C)=CC=C1")
|
||||
# self.r1 = Rule(uuid="bt0334")
|
||||
# c1 --r1--> c2
|
||||
self.c3_1 = Compound(smiles="CCNC(=O)C1=CC(C)=CC=C1")
|
||||
self.c3_2 = Compound(smiles="CC=O")
|
||||
# self.r2 = Rule(uuid="bt0243")
|
||||
# c1 --r2--> c3_1, c3_2
|
||||
|
||||
def test_reaction_equality_ignores_uuid(self):
|
||||
r1 = Reaction([self.c1], [self.c2], self.r1, uuid="abc")
|
||||
r2 = Reaction([self.c1], [self.c2], self.r1, uuid="xyz")
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def test_reaction_inequality_on_data_change(self):
|
||||
r1 = Reaction([self.c1], [self.c2], self.r1)
|
||||
r2 = Reaction([self.c1], [self.c3_1], self.r1)
|
||||
self.assertNotEqual(r1, r2)
|
||||
|
||||
def test_reaction_is_hashable(self):
|
||||
r = Reaction([self.c1], [self.c2], self.r1)
|
||||
reactions = {r}
|
||||
self.assertIn(Reaction([self.c1], [self.c2], self.r1), reactions)
|
||||
|
||||
def test_rule_is_optional(self):
|
||||
r = Reaction([self.c1], [self.c2])
|
||||
self.assertIsNone(r.rule)
|
||||
|
||||
def test_uuid_is_optional(self):
|
||||
r = Reaction([self.c1], [self.c2], self.r1)
|
||||
self.assertIsNone(r.uuid)
|
||||
|
||||
def test_repr_includes_uuid(self):
|
||||
r = Reaction([self.c1], [self.c2], self.r1, uuid="abc")
|
||||
self.assertIn("abc", repr(r))
|
||||
|
||||
def test_reaction_equality_with_multiple_compounds_different_ordering(self):
|
||||
r1 = Reaction([self.c1], [self.c3_1, self.c3_2], self.r2)
|
||||
r2 = Reaction([self.c1], [self.c3_2, self.c3_1], self.r2)
|
||||
|
||||
self.assertEqual(r1, r2, "Reaction equality should not rely on list order")
|
||||
|
||||
|
||||
class RuleTest(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
# self.r1 = Rule(uuid="bt0334")
|
||||
# self.r2 = Rule(uuid="bt0243")
|
||||
|
||||
|
||||
class DatasetGeneratorTest(TestCase):
|
||||
fixtures = ['bootstrap.json']
|
||||
|
||||
def setUp(self):
|
||||
self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||
self.c2 = Compound(smiles="CCN(CCO)C(=O)C1=CC(C)=CC=C1")
|
||||
self.c3_1 = Compound(smiles="CCNC(=O)C1=CC(C)=CC=C1")
|
||||
self.c3_2 = Compound(smiles="CC=O")
|
||||
|
||||
# self.r1 = Rule(uuid="bt0334") # trig
|
||||
# self.r2 = Rule(uuid="bt0243") # trig
|
||||
# self.r3 = Rule(uuid="bt0003") # no trig
|
||||
|
||||
self.reaction1 = Reaction([self.c1], [self.c2], self.r3)
|
||||
self.reaction2 = Reaction([self.c1], [self.c3_1, self.c3_2], self.r2)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_test(self):
|
||||
compounds = [
|
||||
self.c1,
|
||||
self.c2,
|
||||
self.c3_1,
|
||||
self.c3_2,
|
||||
]
|
||||
|
||||
reactions = [
|
||||
self.reaction1,
|
||||
self.reaction2,
|
||||
]
|
||||
|
||||
applicable_rules = [
|
||||
# Rule('bt0334', ParallelRule.objects.get(name='bt0334')),
|
||||
# Rule('bt0243', ParallelRule.objects.get(name='bt0243')),
|
||||
# Rule('bt0003', ParallelRule.objects.get(name='bt0003')),
|
||||
]
|
||||
|
||||
ds = DatasetGenerator.generate_dataset(compounds, reactions, applicable_rules)
|
||||
|
||||
self.assertIsNotNone(ds)
|
||||
55
tests/test_model.py
Normal file
55
tests/test_model.py
Normal file
@ -0,0 +1,55 @@
|
||||
import json
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import Compound, User, CompoundStructure, Reaction, Rule, MLRelativeReasoning
|
||||
|
||||
|
||||
class ModelTest(TestCase):
|
||||
fixtures = ["test_fixture.cleaned.json"]
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ModelTest, cls).setUpClass()
|
||||
cls.user = User.objects.get(username='anonymous')
|
||||
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
|
||||
bbd_data = json.load(open('fixtures/packages/2025-07-18/EAWAG-BBD.json'))
|
||||
cls.BBD = PackageManager.import_package(bbd_data, cls.user)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_smoke(self):
|
||||
threshold = float(0.5)
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [self.BBD]
|
||||
data_package_objs = [self.BBD]
|
||||
eval_packages_objs = []
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
'ECC - BBD - 0.5',
|
||||
'Created MLRelativeReasoning in Testcase',
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold
|
||||
)
|
||||
ds = mod.load_dataset()
|
||||
|
||||
mod.build_model()
|
||||
print("Model built!")
|
||||
mod.evaluate_model()
|
||||
print("Model Evaluated")
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
print(results)
|
||||
@ -131,6 +131,24 @@ class FormatConverter(object):
|
||||
# TODO call to AMBIT Service
|
||||
return smiles
|
||||
|
||||
@staticmethod
|
||||
def ep_standardize(smiles):
|
||||
change = True
|
||||
while change:
|
||||
change = False
|
||||
for standardizer in MATCH_STANDARDIZER:
|
||||
tmp_smiles = standardizer.standardize(smiles)
|
||||
|
||||
if tmp_smiles != smiles:
|
||||
print(f"change {smiles} to {tmp_smiles}")
|
||||
change = True
|
||||
smiles = tmp_smiles
|
||||
|
||||
if change is False:
|
||||
print(f"nothing changed")
|
||||
|
||||
return smiles
|
||||
|
||||
@staticmethod
|
||||
def standardize(smiles):
|
||||
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
|
||||
@ -180,54 +198,6 @@ class FormatConverter(object):
|
||||
atom.UpdatePropertyCache()
|
||||
return mol
|
||||
|
||||
# @staticmethod
|
||||
# def apply(smiles, smirks, preprocess_smiles=True, bracketize=False, standardize=True):
|
||||
# logger.debug(f'Applying {smirks} on {smiles}')
|
||||
#
|
||||
# if bracketize:
|
||||
# smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")"
|
||||
#
|
||||
# res = set()
|
||||
# try:
|
||||
# rxn = rdChemReactions.ReactionFromSmarts(smirks)
|
||||
# mol = Chem.MolFromSmiles(smiles)
|
||||
#
|
||||
# # Inplace
|
||||
# if preprocess_smiles:
|
||||
# Chem.SanitizeMol(mol)
|
||||
# mol = Chem.AddHs(mol)
|
||||
#
|
||||
# # apply!
|
||||
# reacts = rxn.RunReactants((mol,))
|
||||
# if len(reacts):
|
||||
# # Sanitize mols
|
||||
# for product_set in reacts:
|
||||
# prod_set = list()
|
||||
# for product in product_set:
|
||||
# # Fixes
|
||||
# # [2025-01-30 23:00:50] ERROR chem - Sanitizing and converting failed:
|
||||
# # non-ring atom 3 marked aromatic
|
||||
# # But does not improve overall performance
|
||||
# #
|
||||
# # for a in product.GetAtoms():
|
||||
# # if (not a.IsInRing()) and a.GetIsAromatic():
|
||||
# # a.SetIsAromatic(False)
|
||||
# # for b in product.GetBonds():
|
||||
# # if (not b.IsInRing()) and b.GetIsAromatic():
|
||||
# # b.SetIsAromatic(False)
|
||||
#
|
||||
# try:
|
||||
# Chem.SanitizeMol(product)
|
||||
# prod_set.append(FormatConverter.standardize(Chem.MolToSmiles(product)))
|
||||
# except ValueError as e:
|
||||
# logger.error(f'Sanitizing and converting failed:\n{e}')
|
||||
# continue
|
||||
# res.add(tuple(list(set(prod_set))))
|
||||
# except Exception as e:
|
||||
# logger.error(f'Applying {smirks} on {smiles} failed:\n{e}')
|
||||
#
|
||||
# return list(res)
|
||||
|
||||
@staticmethod
|
||||
def is_valid_smirks(smirks: str) -> bool:
|
||||
try:
|
||||
|
||||
514
utilities/ml.py
514
utilities/ml.py
@ -1,46 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.multioutput import ClassifierChain
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
# @dataclasses.dataclass
|
||||
# class Feature:
|
||||
# name: str
|
||||
# value: float
|
||||
#
|
||||
#
|
||||
#
|
||||
# class Row:
|
||||
# def __init__(self, compound_uuid: str, compound_smiles: str, descriptors: List[int]):
|
||||
# self.data = {}
|
||||
#
|
||||
#
|
||||
#
|
||||
# class DataSet(object):
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.rows: List[Row] = []
|
||||
#
|
||||
# def add_row(self, row: Row):
|
||||
# pass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from utilities.chem import FormatConverter
|
||||
from utilities.chem import FormatConverter, PredictionResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class Compound:
|
||||
class SCompound:
|
||||
smiles: str
|
||||
uuid: str = field(default=None, compare=False, hash=False)
|
||||
|
||||
@ -53,10 +36,10 @@ class Compound:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reaction:
|
||||
educts: List[Compound]
|
||||
products: List[Compound]
|
||||
rule_uuid: str = field(default=None, compare=False, hash=False)
|
||||
class SReaction:
|
||||
educts: List[SCompound]
|
||||
products: List[SCompound]
|
||||
rule_uuid: SRule = field(default=None, compare=False, hash=False)
|
||||
reaction_uuid: str = field(default=None, compare=False, hash=False)
|
||||
|
||||
def __hash__(self):
|
||||
@ -68,7 +51,7 @@ class Reaction:
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Reaction):
|
||||
if not isinstance(other, SReaction):
|
||||
return NotImplemented
|
||||
return (
|
||||
sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and
|
||||
@ -76,69 +59,286 @@ class Reaction:
|
||||
)
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
@dataclass
|
||||
class SRule(ABC):
|
||||
|
||||
def __init__(self, headers=List['str'], data=List[List[str|int|float]]):
|
||||
self.headers = headers
|
||||
@abstractmethod
|
||||
def apply(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSimpleRule:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SParallelRule:
|
||||
pass
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
||||
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
|
||||
self.columns: List[str] = columns
|
||||
self.num_labels: int = num_labels
|
||||
|
||||
if data is None:
|
||||
self.data: List[List[str | int | float]] = list()
|
||||
else:
|
||||
self.data = data
|
||||
|
||||
self.num_features: int = len(columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
|
||||
self._triggered: Tuple[int, int] = self._block_indices('trig_')
|
||||
self._observed: Tuple[int, int] = self._block_indices('obs_')
|
||||
|
||||
def features(self):
|
||||
pass
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
indices: List[int] = []
|
||||
for i, feature in enumerate(self.columns):
|
||||
if feature.startswith(prefix):
|
||||
indices.append(i)
|
||||
|
||||
def labels(self):
|
||||
pass
|
||||
return min(indices), max(indices)
|
||||
|
||||
def to_json(self):
|
||||
pass
|
||||
def structure_id(self):
|
||||
return self.data[0][0]
|
||||
|
||||
def to_csv(self):
|
||||
pass
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
if len(self.columns) != len(row):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
|
||||
self.data.append(row)
|
||||
|
||||
def to_arff(self):
|
||||
pass
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
return self._struct_features
|
||||
|
||||
def triggered(self) -> Tuple[int, int]:
|
||||
return self._triggered
|
||||
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
return self._observed
|
||||
|
||||
def at(self, position: int) -> Dataset:
|
||||
return Dataset(self.columns, self.num_labels, [self.data[position]])
|
||||
|
||||
def limit(self, limit: int) -> Dataset:
|
||||
return Dataset(self.columns, self.num_labels, self.data[:limit])
|
||||
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
|
||||
|
||||
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
for struct in structures:
|
||||
|
||||
class DatasetGenerator(object):
|
||||
if isinstance(struct, str):
|
||||
struct_id = None
|
||||
struct_smiles = struct
|
||||
else:
|
||||
struct_id = str(struct.uuid)
|
||||
struct_smiles = struct.smiles
|
||||
|
||||
features = FormatConverter.maccs(struct_smiles)
|
||||
|
||||
trig = []
|
||||
prods = []
|
||||
for rule in applicable_rules:
|
||||
products = rule.apply(struct_smiles)
|
||||
|
||||
if len(products):
|
||||
trig.append(1)
|
||||
prods.append(products)
|
||||
else:
|
||||
trig.append(0)
|
||||
prods.append([])
|
||||
|
||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||
classify_products.append(prods)
|
||||
|
||||
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(compounds: List[Compound], reactions: List[Reaction], applicable_rules: 'Rule',
|
||||
compounds_to_exclude: Optional[Compound] = None, educts_only: bool = False) -> Dataset:
|
||||
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
|
||||
_structures = set()
|
||||
|
||||
rows = []
|
||||
|
||||
if educts_only:
|
||||
compounds = set()
|
||||
for r in reactions:
|
||||
for e in r.educts:
|
||||
compounds.add(e)
|
||||
compounds = list(compounds)
|
||||
for e in r.educts.all():
|
||||
_structures.add(e)
|
||||
|
||||
total = len(compounds)
|
||||
for i, c in enumerate(compounds):
|
||||
row = []
|
||||
print(f"{i + 1}/{total} - {c.smiles}")
|
||||
for r in applicable_rules:
|
||||
product_sets = r.rule.apply(c.smiles)
|
||||
if not educts_only:
|
||||
for e in r.products:
|
||||
_structures.add(e)
|
||||
|
||||
compounds = sorted(_structures, key=lambda x: x.url)
|
||||
|
||||
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||
observed: Set[str] = set()
|
||||
|
||||
# Apply rules on collected compounds and store tps
|
||||
for i, comp in enumerate(compounds):
|
||||
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||
|
||||
for rule in applicable_rules:
|
||||
product_sets = rule.apply(comp.smiles)
|
||||
|
||||
if len(product_sets) == 0:
|
||||
row.append([])
|
||||
continue
|
||||
|
||||
#triggered.add(f"{r.uuid} + {c.uuid}")
|
||||
reacts = set()
|
||||
for ps in product_sets:
|
||||
products = []
|
||||
for p in ps:
|
||||
products.append(Compound(FormatConverter.standardize(p)))
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
reacts.add(Reaction([c], products, r))
|
||||
row.append(list(reacts))
|
||||
if key in triggered:
|
||||
logger.info(f"{key} already present. Duplicate reaction?")
|
||||
|
||||
rows.append(row)
|
||||
for prod_set in product_sets:
|
||||
for smi in prod_set:
|
||||
|
||||
return rows
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
except Exception:
|
||||
# :shrug:
|
||||
logger.debug(f'Standardizing SMILES failed for {smi}')
|
||||
pass
|
||||
|
||||
triggered[key].add(smi)
|
||||
|
||||
for i, r in enumerate(reactions):
|
||||
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||
|
||||
if len(r.educts.all()) != 1:
|
||||
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||
continue
|
||||
|
||||
for comp in r.educts.all():
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
if key not in triggered:
|
||||
continue
|
||||
|
||||
# standardize products from reactions for comparison
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
logger.debug(f'Standardizing SMILES failed for {smi}')
|
||||
pass
|
||||
|
||||
standardized_products.append(smi)
|
||||
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
observed.add(key)
|
||||
else:
|
||||
pass
|
||||
|
||||
ds = None
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feat = FormatConverter.maccs(comp.smiles)
|
||||
trig = []
|
||||
obs = []
|
||||
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
elif key not in triggered:
|
||||
obs.append(None)
|
||||
else:
|
||||
obs.append(0)
|
||||
|
||||
if ds is None:
|
||||
header = ['structure_id'] + \
|
||||
[f'feature_{i}' for i, _ in enumerate(feat)] \
|
||||
+ [f'trig_{r.uuid}' for r in applicable_rules] \
|
||||
+ [f'obs_{r.uuid}' for r in applicable_rules]
|
||||
ds = Dataset(header, len(applicable_rules))
|
||||
|
||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
|
||||
|
||||
row_key, col_key = key
|
||||
|
||||
# Normalize rows
|
||||
if isinstance(row_key, int):
|
||||
rows = [self.data[row_key]]
|
||||
else:
|
||||
rows = self.data[row_key]
|
||||
|
||||
# Normalize columns
|
||||
if isinstance(col_key, int):
|
||||
res = [row[col_key] for row in rows]
|
||||
else:
|
||||
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
|
||||
else [row[i] for i in col_key] for row in rows]
|
||||
|
||||
return res
|
||||
|
||||
def save(self, path: 'Path'):
|
||||
import pickle
|
||||
with open(path, "wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: 'Path'):
|
||||
import pickle
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
def to_arff(self, path: 'Path'):
|
||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||
arff += "\n"
|
||||
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]:
|
||||
if c == 'structure_id':
|
||||
arff += f"@attribute {c} string\n"
|
||||
else:
|
||||
arff += f"@attribute {c} {{0,1}}\n"
|
||||
|
||||
arff += f"\n@data\n"
|
||||
for d in self.data:
|
||||
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
|
||||
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
|
||||
arff += f'{ys},{xs}\n'
|
||||
|
||||
with open(path, "w") as fh:
|
||||
fh.write(arff)
|
||||
fh.flush()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
|
||||
|
||||
|
||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
@ -166,8 +366,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
self.keep_columns_.append(col)
|
||||
|
||||
y_reduced = y[:, self.keep_columns_]
|
||||
self.chains_ = [ClassifierChain(self.base_clf, order='random', random_state=i)
|
||||
for i in range(self.num_chains)]
|
||||
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
|
||||
|
||||
for i, chain in enumerate(self.chains_):
|
||||
print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
||||
@ -208,26 +407,169 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
|
||||
|
||||
|
||||
class ApplicabilityDomain(PCA):
|
||||
|
||||
def __init__(self, n_components=5):
|
||||
super().__init__(n_components=n_components)
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
|
||||
class BinaryRelevance:
|
||||
def __init__(self, baseline_clf):
|
||||
self.clf = baseline_clf
|
||||
self.classifiers = None
|
||||
|
||||
def fit(self, X, Y):
|
||||
if self.classifiers is None:
|
||||
self.classifiers = []
|
||||
|
||||
for l in range(len(Y[0])):
|
||||
X_l = X[~np.isnan(Y[:, l])]
|
||||
Y_l = (Y[~np.isnan(Y[:, l]), l])
|
||||
if len(X_l) == 0: # all labels are nan -> predict 0
|
||||
clf = DummyClassifier(strategy='constant', constant=0)
|
||||
clf.fit([X[0]], [0])
|
||||
self.classifiers.append(clf)
|
||||
continue
|
||||
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
|
||||
clf = DummyClassifier(strategy='most_frequent')
|
||||
else:
|
||||
clf = copy.deepcopy(self.clf)
|
||||
clf.fit(X_l, Y_l)
|
||||
self.classifiers.append(clf)
|
||||
|
||||
def predict(self, X):
|
||||
labels = []
|
||||
for clf in self.classifiers:
|
||||
labels.append(clf.predict(X))
|
||||
return np.column_stack(labels)
|
||||
|
||||
def predict_proba(self, X):
|
||||
labels = np.empty((len(X), 0))
|
||||
for clf in self.classifiers:
|
||||
pred = clf.predict_proba(X)
|
||||
if pred.shape[1] > 1:
|
||||
pred = pred[:, 1]
|
||||
else:
|
||||
pred = pred * clf.predict([X[0]])[0]
|
||||
labels = np.column_stack((labels, pred))
|
||||
return labels
|
||||
|
||||
|
||||
class MissingValuesClassifierChain:
|
||||
def __init__(self, base_clf):
|
||||
self.base_clf = base_clf
|
||||
self.permutation = None
|
||||
self.classifiers = None
|
||||
|
||||
def fit(self, X, Y):
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if self.permutation is None:
|
||||
self.permutation = np.random.permutation(len(Y[0]))
|
||||
|
||||
Y = Y[:, self.permutation]
|
||||
|
||||
if self.classifiers is None:
|
||||
self.classifiers = []
|
||||
|
||||
for p in range(len(self.permutation)):
|
||||
X_p = X[~np.isnan(Y[:, p])]
|
||||
Y_p = Y[~np.isnan(Y[:, p]), p]
|
||||
if len(X_p) == 0: # all labels are nan -> predict 0
|
||||
clf = DummyClassifier(strategy='constant', constant=0)
|
||||
self.classifiers.append(clf.fit([X[0]], [0]))
|
||||
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
|
||||
clf = DummyClassifier(strategy='most_frequent')
|
||||
self.classifiers.append(clf.fit(X_p, Y_p))
|
||||
else:
|
||||
clf = copy.deepcopy(self.base_clf)
|
||||
self.classifiers.append(clf.fit(X_p, Y_p))
|
||||
newcol = Y[:, p]
|
||||
pred = clf.predict(X)
|
||||
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
|
||||
X = np.column_stack((X, newcol))
|
||||
|
||||
def predict(self, X):
|
||||
labels = np.empty((len(X), 0))
|
||||
for clf in self.classifiers:
|
||||
pred = clf.predict(np.column_stack((X, labels)))
|
||||
labels = np.column_stack((labels, pred))
|
||||
return labels[:, np.argsort(self.permutation)]
|
||||
|
||||
def predict_proba(self, X):
|
||||
labels = np.empty((len(X), 0))
|
||||
for clf in self.classifiers:
|
||||
pred = clf.predict_proba(np.column_stack((X, np.round(labels))))
|
||||
if pred.shape[1] > 1:
|
||||
pred = pred[:, 1]
|
||||
else:
|
||||
pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0]
|
||||
labels = np.column_stack((labels, pred))
|
||||
return labels[:, np.argsort(self.permutation)]
|
||||
|
||||
|
||||
class EnsembleClassifierChain:
|
||||
def __init__(self, base_clf, num_chains=10):
|
||||
self.base_clf = base_clf
|
||||
self.num_chains = num_chains
|
||||
self.num_labels = None
|
||||
self.classifiers = None
|
||||
|
||||
def fit(self, X, Y):
|
||||
if self.classifiers is None:
|
||||
self.classifiers = []
|
||||
|
||||
if self.num_labels is None:
|
||||
self.num_labels = len(Y[0])
|
||||
|
||||
for p in range(self.num_chains):
|
||||
print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
clf = MissingValuesClassifierChain(self.base_clf)
|
||||
clf.fit(X, Y)
|
||||
self.classifiers.append(clf)
|
||||
|
||||
def predict(self, X):
|
||||
labels = np.zeros((len(X), self.num_labels))
|
||||
for clf in self.classifiers:
|
||||
labels += clf.predict(X)
|
||||
return np.round(labels / self.num_chains)
|
||||
|
||||
def predict_proba(self, X):
|
||||
labels = np.zeros((len(X), self.num_labels))
|
||||
for clf in self.classifiers:
|
||||
labels += clf.predict_proba(X)
|
||||
return labels / self.num_chains
|
||||
|
||||
|
||||
|
||||
|
||||
class ApplicabilityDomainPCA(PCA):
|
||||
|
||||
def __init__(self, num_neighbours: int = 5):
|
||||
super().__init__(n_components=num_neighbours)
|
||||
self.scaler = StandardScaler()
|
||||
self.num_neighbours = num_neighbours
|
||||
self.min_vals = None
|
||||
self.max_vals = None
|
||||
|
||||
def build(self, X):
|
||||
def build(self, train_dataset: 'Dataset'):
|
||||
# transform
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
||||
# fit pca
|
||||
X_pca = self.fit_transform(X_scaled)
|
||||
|
||||
self.max_vals = np.max(X_pca, axis=0)
|
||||
self.min_vals = np.min(X_pca, axis=0)
|
||||
|
||||
def is_applicable(self, instances):
|
||||
def __transform(self, instances):
|
||||
instances_scaled = self.scaler.transform(instances)
|
||||
instances_pca = self.transform(instances_scaled)
|
||||
return instances_pca
|
||||
|
||||
def is_applicable(self, classify_instances: 'Dataset'):
|
||||
instances_pca = self.__transform(classify_instances.X())
|
||||
|
||||
is_applicable = []
|
||||
for i, instance in enumerate(instances_pca):
|
||||
@ -237,3 +579,17 @@ class ApplicabilityDomain(PCA):
|
||||
is_applicable[i] = False
|
||||
|
||||
return is_applicable
|
||||
|
||||
|
||||
def tanimoto_distance(a: List[int], b: List[int]):
|
||||
if len(a) != len(b):
|
||||
raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}")
|
||||
|
||||
sum_a = sum(a)
|
||||
sum_b = sum(b)
|
||||
sum_c = sum(v1 and v2 for v1, v2 in zip(a, b))
|
||||
|
||||
if sum_a + sum_b - sum_c == 0:
|
||||
return 0.0
|
||||
|
||||
return 1 - (sum_c / (sum_a + sum_b - sum_c))
|
||||
|
||||
Reference in New Issue
Block a user