diff --git a/envipath/settings.py b/envipath/settings.py index 29fcca60..18ffe150 100644 --- a/envipath/settings.py +++ b/envipath/settings.py @@ -261,7 +261,7 @@ CELERY_ACCEPT_CONTENT = ['json'] CELERY_TASK_SERIALIZER = 'json' MODEL_BUILDING_ENABLED = os.environ.get('MODEL_BUILDING_ENABLED', 'False') == 'True' - +APPLICABILITY_DOMAIN_ENABLED = os.environ.get('APPLICABILITY_DOMAIN_ENABLED', 'False') == 'True' DEFAULT_RF_MODEL_PARAMS = { 'base_clf': RandomForestClassifier( n_estimators=100, @@ -275,14 +275,14 @@ DEFAULT_RF_MODEL_PARAMS = { 'num_chains': 10, } -DEFAULT_DT_MODEL_PARAMS = { +DEFAULT_MODEL_PARAMS = { 'base_clf': DecisionTreeClassifier( criterion='entropy', max_depth=3, min_samples_split=5, - min_samples_leaf=5, + # min_samples_leaf=5, max_features='sqrt', - class_weight='balanced', + # class_weight='balanced', random_state=42 ), 'num_chains': 10, @@ -322,4 +322,5 @@ FLAGS = { 'PLUGINS': PLUGINS_ENABLED, 'SENTRY': SENTRY_ENABLED, 'ENVIFORMER': ENVIFORMER_PRESENT, + 'APPLICABILITY_DOMAIN': APPLICABILITY_DOMAIN_ENABLED, } diff --git a/epdb/admin.py b/epdb/admin.py index bd3a3dee..d21b67b4 100644 --- a/epdb/admin.py +++ b/epdb/admin.py @@ -1,40 +1,105 @@ from django.contrib import admin -from .models import User, Group, UserPackagePermission, GroupPackagePermission, Setting, SimpleAmbitRule, Scenario +from .models import ( + User, + UserPackagePermission, + Group, + GroupPackagePermission, + Package, + MLRelativeReasoning, + Compound, + CompoundStructure, + SimpleAmbitRule, + ParallelRule, + Reaction, + Pathway, + Node, + Edge, + Scenario, + Setting +) class UserAdmin(admin.ModelAdmin): pass -class GroupAdmin(admin.ModelAdmin): - pass - - class UserPackagePermissionAdmin(admin.ModelAdmin): pass +class GroupAdmin(admin.ModelAdmin): + pass + + class GroupPackagePermissionAdmin(admin.ModelAdmin): pass -class SettingAdmin(admin.ModelAdmin): +class EPAdmin(admin.ModelAdmin): + search_fields = ['name', 'description'] + + +class PackageAdmin(EPAdmin): + pass + +class MLRelativeReasoningAdmin(EPAdmin): pass -class SimpleAmbitRuleAdmin(admin.ModelAdmin): +class CompoundAdmin(EPAdmin): pass -class ScenarioAdmin(admin.ModelAdmin): +class CompoundStructureAdmin(EPAdmin): + pass + + +class SimpleAmbitRuleAdmin(EPAdmin): + pass + + +class ParallelRuleAdmin(EPAdmin): + pass + + +class ReactionAdmin(EPAdmin): + pass + + +class PathwayAdmin(EPAdmin): + pass + + +class NodeAdmin(EPAdmin): + pass + + +class EdgeAdmin(EPAdmin): + pass + + +class ScenarioAdmin(EPAdmin): + pass + + +class SettingAdmin(EPAdmin): pass admin.site.register(User, UserAdmin) -admin.site.register(Group, GroupAdmin) admin.site.register(UserPackagePermission, UserPackagePermissionAdmin) +admin.site.register(Group, GroupAdmin) admin.site.register(GroupPackagePermission, GroupPackagePermissionAdmin) -admin.site.register(Setting, SettingAdmin) +admin.site.register(Package, PackageAdmin) +admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin) +admin.site.register(Compound, CompoundAdmin) +admin.site.register(CompoundStructure, CompoundStructureAdmin) admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin) +admin.site.register(ParallelRule, ParallelRuleAdmin) +admin.site.register(Reaction, ReactionAdmin) +admin.site.register(Pathway, PathwayAdmin) +admin.site.register(Node, NodeAdmin) +admin.site.register(Edge, EdgeAdmin) +admin.site.register(Setting, SettingAdmin) admin.site.register(Scenario, ScenarioAdmin) diff --git a/epdb/logic.py b/epdb/logic.py index 6fa01c67..8c6ce056 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -339,7 +339,7 @@ class PackageManager(object): @staticmethod @transaction.atomic - def import_package(data: dict, owner: User, keep_ids=False): + def import_package(data: dict, owner: User, keep_ids=False, add_import_timestamp=True): from uuid import UUID, uuid4 from datetime import datetime from collections import defaultdict @@ -349,7 +349,12 @@ class PackageManager(object): pack = Package() pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4() - pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M')) + + if add_import_timestamp: + pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M')) + else: + pack.name = data['name'] + pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False pack.description = data['description'] pack.save() diff --git a/epdb/management/commands/bootstrap.py b/epdb/management/commands/bootstrap.py index 6ca3e18e..c937ab78 100644 --- a/epdb/management/commands/bootstrap.py +++ b/epdb/management/commands/bootstrap.py @@ -58,7 +58,7 @@ class Command(BaseCommand): return anon, admin, g, jebus def import_package(self, data, owner): - return PackageManager.import_package(data, owner, keep_ids=True) + return PackageManager.import_package(data, owner, keep_ids=True, add_import_timestamp=False) def create_default_setting(self, owner, packages): s = SettingManager.create_setting( diff --git a/epdb/models.py b/epdb/models.py index 5c1770b2..ae85a3d5 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -3,8 +3,8 @@ import json import logging import os from collections import defaultdict -from datetime import datetime, timedelta, date -from typing import Union, List, Optional +from datetime import datetime, timedelta +from typing import Union, List, Optional, Dict, Tuple from uuid import uuid4 import joblib @@ -14,7 +14,7 @@ from django.contrib.auth.hashers import make_password, check_password from django.contrib.auth.models import AbstractUser from django.contrib.postgres.fields import ArrayField from django.db import models, transaction -from django.db.models import JSONField, Count, Q +from django.db.models import JSONField, Count, Q, QuerySet from django.utils import timezone from django.utils.functional import cached_property from model_utils.models import TimeStampedModel @@ -23,7 +23,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import SparseLabelECC +from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain logger = logging.getLogger(__name__) @@ -172,6 +172,9 @@ class EnviPathModel(TimeStampedModel): class Meta: abstract = True + def __str__(self): + return f"{self.name} (pk={self.pk})" + class AliasMixin(models.Model): aliases = ArrayField( @@ -844,7 +847,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): # We shouldn't lose or make up nodes... assert len(nodes) == len(self.nodes) - print(f"Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}") + logger.debug(f"{self.name}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}") links = [e.d3_json() for e in self.edges] @@ -1136,19 +1139,44 @@ class MLRelativeReasoning(EPModel): eval_results = JSONField(null=True, blank=True, default=dict) + app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True, + default=None) + def status(self): return self.PROGRESS_STATUS_CHOICES[self.model_status] + def ready_for_prediction(self) -> bool: + return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED] + @staticmethod @transaction.atomic - def create(package, name, description, rule_packages, data_packages, eval_packages, threshold): + def create(package: 'Package', rule_packages: List['Package'], + data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5, + name: 'str' = None, description: str = None, build_app_domain: bool = False, + app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, + app_domain_local_compatibility_threshold: float = None): + mlrr = MLRelativeReasoning() mlrr.package = package + + if name is None or name.strip() == '': + name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}" + mlrr.name = name - mlrr.description = description + + if description is not None and description.strip() != '': + mlrr.description = description + + if threshold is None or (threshold <= 0 or 1 <= threshold): + raise ValueError("Threshold must be a float between 0 and 1.") + mlrr.threshold = threshold + if len(rule_packages) == 0: + raise ValueError("At least one rule package must be provided.") + mlrr.save() + for p in rule_packages: mlrr.rule_packages.add(p) @@ -1163,11 +1191,17 @@ class MLRelativeReasoning(EPModel): for p in eval_packages: mlrr.eval_packages.add(p) + if build_app_domain: + ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold, + app_domain_local_compatibility_threshold) + mlrr.app_domain = ad + mlrr.save() + return mlrr @cached_property - def applicable_rules(self): + def applicable_rules(self) -> List['Rule']: """ Returns a ordered set of rules where the following applies: 1. All Composite will be added to result @@ -1195,6 +1229,7 @@ class MLRelativeReasoning(EPModel): rules.append(r) rules = sorted(rules, key=lambda x: x.url) + return rules def _get_excludes(self): @@ -1209,197 +1244,79 @@ class MLRelativeReasoning(EPModel): pathway_qs = pathway_qs.distinct() return pathway_qs + + def _get_reactions(self) -> QuerySet: + return Reaction.objects.filter(package__in=self.data_packages.all()).distinct() + def build_dataset(self): self.model_status = self.INITIALIZING self.save() - from datetime import datetime + start = datetime.now() + applicable_rules = self.applicable_rules - print("got rules") - - # if s.DEBUG: - # pathways = self._get_pathways().order_by('-name')[:20] - # else: - pathways = self._get_pathways() - - print("got pathways") - excludes = self._get_excludes() - - # Collect all compounds - compounds = set() - reactions = set() - for i, p in enumerate(pathways): - print(f"{i + 1}/{len(pathways)}...") - for n in p.nodes: - cs = n.default_node_label.compound.default_structure - # TODO too many lookups - if cs.smiles in excludes: - continue - - compounds.add(cs) - - for e in p.edges: - reactions.add(e.edge_label) - - print(len(compounds)) - print(len(reactions)) - - triggered = set() - observed = set() - - # TODO naming - - pw = defaultdict(lambda: defaultdict(set)) - - for i, c in enumerate(compounds): - print(f"{i + 1}/{len(compounds)}...") - for r in applicable_rules: - # TODO check normalization - product_sets = r.apply(c.smiles) - - if len(product_sets) == 0: - continue - - triggered.add(f"{r.uuid} + {c.uuid}") - - for ps in product_sets: - for p in ps: - pw[c][r].add(p) - - for r in reactions: - if r is None: - print(r) - continue - if len(r.educts.all()) != 1: - print(f"Skipping {r.url}") - continue - - # Loop will run only once - for c in r.educts.all(): - if c not in pw: - continue - - for rule in pw[c].keys(): - # standardize... - - if 0 != len(pw[c][rule]) and len(pw[c][rule]) == len(r.products.all()): - print(f"potential match for {c.smiles} and {r.uuid} ({r.name})") - - standardized_products = [] - for cs in r.products.all(): - smi = cs.smiles - - try: - smi = FormatConverter.standardize(smi) - except Exception as e: - # :shrug: - pass - - standardized_products.append(smi) - - standardized_pred_products = [] - for smi in pw[c][rule]: - - try: - smi = FormatConverter.standardize(smi) - except Exception as e: - # :shrug: - pass - - standardized_pred_products.append(smi) - - if sorted(list(set(standardized_products))) == sorted(list(set(standardized_pred_products))): - observed.add(f"{rule.uuid} + {c.uuid}") - print(f"Adding observed, current count {len(observed)}") - - header = None - X = [] - y = [] - for i, c in enumerate(compounds): - print(f'{i + 1}/{len(compounds)}...') - # Features - feat = FormatConverter.maccs(c.smiles) - trig = [] - obs = [] - for rule in applicable_rules: - key = f"{rule.uuid} + {c.uuid}" - - # Check triggered - if key in triggered: - trig.append(1) - else: - trig.append(0) - - # Check obs - if key in observed: - obs.append(1) - else: - obs.append(0) - - if header is None: - header = [f'feature_{i}' for i, _ in enumerate(feat)] \ - + [f'trig_{r.uuid}' for r in applicable_rules] \ - + [f'corr_{r.uuid}' for r in applicable_rules] - X.append(feat + trig) - y.append(obs) + reactions = list(self._get_reactions()) + ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True) end = datetime.now() - print(f"Duration {(end - start).total_seconds()}s") + logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") - data = { - 'X': X, - 'y': y, - 'header': header - } - f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json") - json.dump(data, open(f, 'w')) - return X, y + f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") + ds.save(f) + return ds - def load_dataset(self): - ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}.json") - return json.load(open(ds_path, 'r')) + def load_dataset(self) -> 'Dataset': + ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") + return Dataset.load(ds_path) - def build_model(self, X, y): + def build_model(self): self.model_status = self.BUILDING self.save() - mod = SparseLabelECC( - **s.DEFAULT_DT_MODEL_PARAMS - ) + start = datetime.now() + ds = self.load_dataset() + X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) + + mod = EnsembleClassifierChain( + **s.DEFAULT_MODEL_PARAMS + ) mod.fit(X, y) - f = os.path.join(s.MODEL_DIR, f"{self.uuid}.pkl") + + end = datetime.now() + logger.debug(f"fitting model took {(end - start).total_seconds()} seconds") + + f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl") joblib.dump(mod, f) + + if self.app_domain is not None: + logger.debug("Building applicability domain...") + self.app_domain.build() + logger.debug("Done building applicability domain.") + + self.model_status = self.BUILT_NOT_EVALUATED self.save() + def retrain(self): + self.build_dataset() + self.build_model() + def rebuild(self): - data = self.load_dataset() - self.build_model(data['X'], data['y']) + self.build_model() def evaluate_model(self): - """ - Performs Leave-One-Out cross-validation on a multi-label dataset. - Parameters: - X (list of lists): Feature matrix. - y (list of lists): Multi-label targets. - classifier (sklearn estimator, optional): Base classifier. Defaults to RandomForest. - - Returns: - float: Average accuracy across all LOO splits. - """ if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") self.model_status = self.EVALUATING self.save() - f = os.path.join(s.MODEL_DIR, f"{self.uuid}.json") - data = json.load(open(f)) + ds = self.load_dataset() - X = np.array(data['X']) - y = np.array(data['y']) + X = np.array(ds.X(na_replacement=np.nan)) + y = np.array(ds.y(na_replacement=np.nan)) n_splits = 20 @@ -1409,22 +1326,32 @@ class MLRelativeReasoning(EPModel): X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index] - model = SparseLabelECC( - **s.DEFAULT_DT_MODEL_PARAMS + model = EnsembleClassifierChain( + **s.DEFAULT_MODEL_PARAMS ) model.fit(X_train, y_train) y_pred = model.predict_proba(X_test) y_thresholded = (y_pred >= threshold).astype(int) - acc = jaccard_score(y_test, y_thresholded, average='samples', zero_division=0) + # Flatten them to get rid of np.nan + y_test = np.asarray(y_test).flatten() + y_pred = np.asarray(y_pred).flatten() + y_thresholded = np.asarray(y_thresholded).flatten() + + mask = ~np.isnan(y_test) + y_test_filtered = y_test[mask] + y_pred_filtered = y_pred[mask] + y_thresholded_filtered = y_thresholded[mask] + + acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0) prec, rec = dict(), dict() for t in np.arange(0, 1.05, 0.05): - temp_thresholded = (y_pred >= t).astype(int) - prec[f"{t:.2f}"] = precision_score(y_test, temp_thresholded, average='samples', zero_division=0) - rec[f"{t:.2f}"] = recall_score(y_test, temp_thresholded, average='samples', zero_division=0) + temp_thresholded = (y_pred_filtered >= t).astype(int) + prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0) + rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0) return acc, prec, rec @@ -1462,38 +1389,30 @@ class MLRelativeReasoning(EPModel): @cached_property def model(self): - mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}.pkl')) + mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl')) mod.base_clf.n_jobs = -1 return mod def predict(self, smiles) -> List['PredictionResult']: start = datetime.now() - features = FormatConverter.maccs(smiles) + ds = self.load_dataset() + classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) + pred = self.model.predict_proba(classify_ds.X()) - trig = [] - prods = [] - for rule in self.applicable_rules: - products = rule.apply(smiles) - - if len(products): - trig.append(1) - prods.append(products) - else: - trig.append(0) - prods.append([]) - - end_ds_gen = datetime.now() - logger.info(f"Gen predict dataset took {(end_ds_gen - start).total_seconds()}s") - pred = self.model.predict_proba([features + trig]) - - res = [] - for rule, p, smis in zip(self.applicable_rules, pred[0], prods): - res.append(PredictionResult(smis, p, rule)) + res = MLRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0]) end = datetime.now() logger.info(f"Full predict took {(end - start).total_seconds()}s") return res + + @staticmethod + def combine_products_and_probs(rules: List['Rule'], probabilities, products): + res = [] + for rule, p, smis in zip(rules, probabilities, products): + res.append(PredictionResult(smis, p, rule)) + return res + @property def pr_curve(self): if self.model_status != self.FINISHED: @@ -1515,26 +1434,171 @@ class MLRelativeReasoning(EPModel): class ApplicabilityDomain(EnviPathModel): model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE) - num_neighbours = models.FloatField(blank=False, null=False, default=5) + num_neighbours = models.IntegerField(blank=False, null=False, default=5) reliability_threshold = models.FloatField(blank=False, null=False, default=0.5) local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5) - def build_applicability_domain(self): + @staticmethod + @transaction.atomic + def create(mlrr: MLRelativeReasoning, num_neighbours: int = 5, reliability_threshold: float = 0.5, + local_compatibility_threshold: float = 0.5): + ad = ApplicabilityDomain() + ad.model = mlrr + # ad.uuid = mlrr.uuid + ad.name = f"AD for {mlrr.name}" + ad.num_neighbours = num_neighbours + ad.reliability_threshold = reliability_threshold + ad.local_compatibilty_threshold = local_compatibility_threshold + ad.save() + return ad + + @cached_property + def pca(self) -> ApplicabilityDomainPCA: + pca = joblib.load(os.path.join(s.MODEL_DIR, f'{self.model.uuid}_pca.pkl')) + return pca + + @cached_property + def training_set_probs(self): + return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")) + + def build(self): ds = self.model.load_dataset() - X = ds['X'] - import numpy as np - from sklearn.decomposition import PCA - from sklearn.preprocessing import StandardScaler - scaler = StandardScaler() - X_scaled = scaler.fit_transform(X) - pca = PCA(n_components=5) # choose number of components - X_pca = pca.fit_transform(X_scaled) + start = datetime.now() - max_vals = np.max(X_pca, axis=0) - min_vals = np.min(X_pca, axis=0) + # Get Trainingset probs and dump them as they're required when using the app domain + probs = self.model.model.predict_proba(ds.X()) + f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl") + joblib.dump(probs, f) + + ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours) + ad.build(ds) + + end = datetime.now() + logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds") + + f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl") + joblib.dump(ad, f) + + def assess(self, structure: Union[str, 'CompoundStructure']): + ds = self.model.load_dataset() + + assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules) + + # qualified_neighbours_per_rule is a nested dictionary structured as: + # { + # assessment_structure_index: { + # rule_index: [training_structure_indices_with_same_triggered_reaction] + # } + # } + # + # For each structure in the assessment dataset and each rule (represented by a trigger feature), + # it identifies all training structures that have the same trigger reaction activated (i.e., value 1). + # This is used to find "qualified neighbours" — training examples that share the same triggered feature + # with a given assessment structure under a particular rule. + qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(lambda: defaultdict(list)) + + for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())): + feature = ds.columns[feature_index] + if feature.startswith('trig_'): + # TODO unroll loop + for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)): + if int(cx[feature_index]) == 1: + for j, tx in enumerate(ds.X(exclude_id_col=False)): + if int(tx[feature_index]) == 1: + qualified_neighbours_per_rule[i][rule_idx].append(j) + + probs = self.training_set_probs + # preds = self.model.model.predict_proba(assessment_ds.X()) + preds = self.model.combine_products_and_probs(self.model.applicable_rules, + self.model.model.predict_proba(assessment_ds.X())[0], + assessment_prods[0]) + + res = list() + + # loop through our assessment dataset + for i, instance in enumerate(assessment_ds): + + rule_reliabilities = dict() + local_compatibilities = dict() + neighbours_per_rule = dict() + + # loop through rule indices together with the collected neighbours indices from train dataset + for rule_idx, vals in qualified_neighbours_per_rule[i].items(): + + # collect the train dataset instances and store it along with the index (a.k.a. row number) of the + # train dataset + train_instances = [] + for v in vals: + train_instances.append((v, ds.at(v))) + + # sf is a tuple with start/end index of the features + sf = ds.struct_features() + + # compute tanimoto distance for all neighbours + # result ist a list of tuples with train index and computed distance + dists = self._compute_distances( + instance.X()[0][sf[0]:sf[1]], + [ti[1].X()[0][sf[0]:sf[1]] for ti in train_instances] + ) + + dists_with_index = list() + for ti, dist in zip(train_instances, dists): + dists_with_index.append((ti[0], dist[1])) + + # sort them in a descending way and take at most `self.num_neighbours` + dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True) + dists_with_index = dists_with_index[:self.num_neighbours] + + # compute average distance + rule_reliabilities[rule_idx] = sum([d[1] for d in dists_with_index]) / len(dists_with_index) if len(dists_with_index) > 0 else 0.0 + + # for local_compatibility we'll need the datasets for the indices having the highest similarity + neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index] + local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets) + neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets] + + # Assemble result for instance + res.append({ + 'in_ad': self.pca.is_applicable(instance)[0], + 'rule_reliabilities': rule_reliabilities, + 'local_compatibilities': local_compatibilities, + 'neighbours': neighbours_per_rule, + 'rule_lookup': [Rule.objects.get(uuid=r.replace('obs_', '')) for r in instance.columns[instance.observed()[0]: instance.observed()[1]]], + 'prob': preds + }) + return res + + @staticmethod + def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]): + from utilities.ml import tanimoto_distance + distances = [(i, tanimoto_distance(classify_instance, train)) for i, train in + enumerate(train_instances)] + return distances + + @staticmethod + def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, 'Dataset']]): + tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0 + accuracy = 0.0 + + for n in neighbours: + obs = n[1].y()[0][rule_idx] + pred = preds[n[0]][rule_idx] + if obs and pred: + tp += 1 + elif not obs and pred: + fp += 1 + elif obs and not pred: + fn += 1 + else: + tn += 1 + # Jaccard Index + if tp + tn > 0.0: + accuracy = (tp + tn) / (tp + tn + fp + fn); + + return accuracy class RuleBaseRelativeReasoning(EPModel): @@ -1574,10 +1638,6 @@ class EnviFormer(EPModel): logger.info(f"Submitting {kek} to {hash(self.model)}") products = self.model.predict(kek) logger.info(f"Got results {products}") - # from pprint import pprint - # - # print(smiles) - # pprint(products) res = [] for smi, prob in products.items(): @@ -1715,9 +1775,7 @@ class Setting(EnviPathModel): transformations = [] if self.model is not None: - print(self.model) pred_results = self.model.predict(current_node.smiles) - print(pred_results) for pred_result in pred_results: if pred_result.probability >= self.model_threshold: transformations.append(pred_result) diff --git a/epdb/tasks.py b/epdb/tasks.py index 2dcb5b65..3664caf5 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -31,8 +31,8 @@ def send_registration_mail(user_pk: int): @shared_task(queue='model') def build_model(model_pk: int): mod = EPModel.objects.get(id=model_pk) - X, y = mod.build_dataset() - mod.build_model(X, y) + mod.build_dataset() + mod.build_model() @shared_task(queue='model') diff --git a/epdb/views.py b/epdb/views.py index a178cfc3..74d835e8 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -103,7 +103,10 @@ def login(request): else: context['message'] = "Account has been created! You'll receive a mail to activate your account shortly." return render(request, 'login.html', context) - + else: + return HttpResponseBadRequest() + else: + return HttpResponseNotAllowed(['GET', 'POST']) def logout(request): if request.method == 'POST': @@ -136,7 +139,7 @@ def editable(request, user): f"{s.SERVER_URL}/group", f"{s.SERVER_URL}/search"]: return True else: - print(f"Unknown url: {url}") + logger.debug(f"Unknown url: {url}") return False @@ -584,6 +587,9 @@ def package_models(request, package_uuid): return render(request, 'collections/objects_list.html', context) elif request.method == 'POST': + + log_post_params(request) + name = request.POST.get('model-name') description = request.POST.get('model-description') @@ -606,14 +612,25 @@ def package_models(request, package_uuid): data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages] eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages] + # App Domain related parameters + build_ad = request.POST.get('build-app-domain', False) == 'on' + num_neighbors = request.POST.get('num-neighbors', 5) + reliability_threshold = request.POST.get('reliability-threshold', 0.5) + local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5) + mod = MLRelativeReasoning.create( - current_package, - name, - description, - rule_package_objs, - data_package_objs, - eval_packages_objs, - threshold + package=current_package, + name=name, + description=description, + rule_packages=rule_package_objs, + data_packages=data_package_objs, + eval_packages=eval_packages_objs, + threshold=threshold, + # fingerprinter=fingerprinter, + build_app_domain=build_ad, + app_domain_num_neighbours=num_neighbors, + app_domain_reliability_threshold=reliability_threshold, + app_domain_local_compatibility_threshold=local_compatibility_threshold, ) from .tasks import build_model @@ -649,7 +666,7 @@ def package_model(request, package_uuid, model_uuid): if len(pr) > 0: products = [] for prod_set in pr.product_sets: - print(f"Checking {prod_set}") + logger.debug(f"Checking {prod_set}") products.append(tuple([x for x in prod_set])) res.append({ @@ -660,6 +677,12 @@ def package_model(request, package_uuid, model_uuid): return JsonResponse(res, safe=False) + elif request.GET.get('app-domain-assessment', False): + smiles = request.GET['smiles'] + stand_smiles = FormatConverter.standardize(smiles) + app_domain_assessment = current_model.app_domain.assess(stand_smiles) + return JsonResponse(app_domain_assessment, safe=False) + context = get_base_context(request) context['title'] = f'enviPath - {current_package.name} - {current_model.name}' @@ -1717,8 +1740,6 @@ def user(request, user_uuid): } } - print(setting) - return HttpResponseBadRequest() else: @@ -1781,9 +1802,7 @@ def group(request, group_uuid): elif request.method == 'POST': - if s.DEBUG: - for k, v in request.POST.items(): - print(k, v) + log_post_params(request) if hidden := request.POST.get('hidden', None): if hidden == 'delete-group': diff --git a/templates/modals/collections/new_model_modal.html b/templates/modals/collections/new_model_modal.html index 1d23d58e..5283275a 100644 --- a/templates/modals/collections/new_model_modal.html +++ b/templates/modals/collections/new_model_modal.html @@ -16,14 +16,14 @@
Create a new Model to limit the number of degradation products in the prediction. You just need to set a name and the packages - you want the object to be based on. If you want to use the - default options suggested by us, simply click Submit, - otherwise click Advanced Options. + you want the object to be based on. There are multiple types of models available. + For additional information have a look at our + wiki >>
- - - - Name + + + @@ -53,7 +53,7 @@ {% endfor %} -
+ - {% if meta.enabled_features.PLUGINS %} + {% if meta.enabled_features.PLUGINS and additional_descriptors %} - - + + {% for k, v in additional_descriptors.items %} + + {% endfor %} {% endif %} + - -
+ + {% if meta.enabled_features.APPLICABILITY_DOMAIN %} + +
+ +
+ + + + + + + + + + {% endif %}
@@ -118,47 +140,9 @@
-
- - {% if meta.enabled_features.APPLICABILITY_DOMAIN %} - - {% endif %}
{% endif %} - -
-

- Predict -

-
-
-
-
- - + {% if model.ready_for_prediction %} + +
+

+ Predict +

+
+
+
+
+ + +
+
+
-
-
-
- + + {% endif %} + + {% if model.app_domain %} + + +
+
+
+ + + + +
+
+
+
+
+ + {% endif %} + {% if model.model_status == 'FINISHED' %}
@@ -277,9 +303,9 @@ $("#predictResultTable").append(res); } - function clear() { - $("#predictResultTable").removeClass("alert alert-danger"); - $("#predictResultTable").empty(); + function clear(divid) { + $("#" + divid).removeClass("alert alert-danger"); + $("#" + divid).empty(); } if ($('#predict-button').length > 0) { @@ -291,32 +317,69 @@ "classify": "ILikeCats!" } - clear(); + clear("predictResultTable"); - makeLoadingGif("#loading", "{% static '/images/wait.gif' %}"); + makeLoadingGif("#predictLoading", "{% static '/images/wait.gif' %}"); $.ajax({ type: 'get', data: data, url: '', success: function (data, textStatus) { try { - $("#loading").empty(); + $("#predictLoading").empty(); handleResponse(data); } catch (error) { console.log("Error"); - $("#loading").empty(); + $("#predictLoading").empty(); $("#predictResultTable").addClass("alert alert-danger"); $("#predictResultTable").append("Error while processing request :/"); } }, error: function (jqXHR, textStatus, errorThrown) { - $("#loading").empty(); + $("#predictLoading").empty(); $("#predictResultTable").addClass("alert alert-danger"); $("#predictResultTable").append("Error while processing request :/"); } }); }); } + + if ($('#assess-button').length > 0) { + $("#assess-button").on("click", function (e) { + e.preventDefault(); + + data = { + "smiles": $("#smiles-to-assess").val(), + "app-domain-assessment": "ILikeCats!" + } + + clear("appDomainAssessmentResultTable"); + + makeLoadingGif("#appDomainLoading", "{% static '/images/wait.gif' %}"); + $.ajax({ + type: 'get', + data: data, + url: '', + success: function (data, textStatus) { + try { + $("#appDomainLoading").empty(); + console.log(data); + } catch (error) { + console.log("Error"); + $("#appDomainLoading").empty(); + $("#appDomainAssessmentResultTable").addClass("alert alert-danger"); + $("#appDomainAssessmentResultTable").append("Error while processing request :/"); + } + }, + error: function (jqXHR, textStatus, errorThrown) { + $("#appDomainLoading").empty(); + $("#appDomainAssessmentResultTable").addClass("alert alert-danger"); + $("#appDomainAssessmentResultTable").append("Error while processing request :/"); + } + }); + }); + } + {% endblock content %} diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..1841d663 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,52 @@ +from django.test import TestCase + +from epdb.logic import PackageManager +from epdb.models import Reaction, Compound, User, Rule +from utilities.ml import Dataset + + +class DatasetTest(TestCase): + fixtures = ["test_fixture.cleaned.json"] + + def setUp(self): + self.cs1 = Compound.create( + self.package, + name='2,6-Dibromohydroquinone', + description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b', + smiles='C1=C(C(=C(C=C1O)Br)O)Br', + ).default_structure + + self.cs2 = Compound.create( + self.package, + smiles='O=C(O)CC(=O)/C=C(/Br)C(=O)O', + ).default_structure + + self.rule1 = Rule.create( + rule_type='SimpleAmbitRule', + package=self.package, + smirks='[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\[#6:3]=[#6:2](\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]', + description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6' + ) + + self.reaction1 = Reaction.create( + package=self.package, + educts=[self.cs1], + products=[self.cs2], + rules=[self.rule1], + multi_step=False + ) + + @classmethod + def setUpClass(cls): + super(DatasetGeneratorTest, cls).setUpClass() + cls.user = User.objects.get(username='anonymous') + cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') + + def test_smoke(self): + reactions = [r for r in Reaction.objects.filter(package=self.package)] + applicable_rules = [self.rule1] + + ds = Dataset.generate_dataset(reactions, applicable_rules) + + self.assertEqual(len(ds.y()), 1) + self.assertEqual(sum(ds.y()[0]), 1) diff --git a/tests/test_datasetgenerator.py b/tests/test_datasetgenerator.py deleted file mode 100644 index 6c03a075..00000000 --- a/tests/test_datasetgenerator.py +++ /dev/null @@ -1,111 +0,0 @@ -from django.test import TestCase - -from epdb.models import ParallelRule -from utilities.ml import Compound, Reaction, DatasetGenerator - - -class CompoundTest(TestCase): - - def setUp(self): - self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C", uuid='c1') - self.c2 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C", uuid='c2') - - def test_compound_eq_ignores_uuid(self): - self.assertEqual(self.c1, self.c2) - - -class ReactionTest(TestCase): - - def setUp(self): - self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C") - self.c2 = Compound(smiles="CCN(CCO)C(=O)C1=CC(C)=CC=C1") - # self.r1 = Rule(uuid="bt0334") - # c1 --r1--> c2 - self.c3_1 = Compound(smiles="CCNC(=O)C1=CC(C)=CC=C1") - self.c3_2 = Compound(smiles="CC=O") - # self.r2 = Rule(uuid="bt0243") - # c1 --r2--> c3_1, c3_2 - - def test_reaction_equality_ignores_uuid(self): - r1 = Reaction([self.c1], [self.c2], self.r1, uuid="abc") - r2 = Reaction([self.c1], [self.c2], self.r1, uuid="xyz") - self.assertEqual(r1, r2) - - def test_reaction_inequality_on_data_change(self): - r1 = Reaction([self.c1], [self.c2], self.r1) - r2 = Reaction([self.c1], [self.c3_1], self.r1) - self.assertNotEqual(r1, r2) - - def test_reaction_is_hashable(self): - r = Reaction([self.c1], [self.c2], self.r1) - reactions = {r} - self.assertIn(Reaction([self.c1], [self.c2], self.r1), reactions) - - def test_rule_is_optional(self): - r = Reaction([self.c1], [self.c2]) - self.assertIsNone(r.rule) - - def test_uuid_is_optional(self): - r = Reaction([self.c1], [self.c2], self.r1) - self.assertIsNone(r.uuid) - - def test_repr_includes_uuid(self): - r = Reaction([self.c1], [self.c2], self.r1, uuid="abc") - self.assertIn("abc", repr(r)) - - def test_reaction_equality_with_multiple_compounds_different_ordering(self): - r1 = Reaction([self.c1], [self.c3_1, self.c3_2], self.r2) - r2 = Reaction([self.c1], [self.c3_2, self.c3_1], self.r2) - - self.assertEqual(r1, r2, "Reaction equality should not rely on list order") - - -class RuleTest(TestCase): - - def setUp(self): - pass - # self.r1 = Rule(uuid="bt0334") - # self.r2 = Rule(uuid="bt0243") - - -class DatasetGeneratorTest(TestCase): - fixtures = ['bootstrap.json'] - - def setUp(self): - self.c1 = Compound(smiles="CCN(CC)C(=O)C1=CC(=CC=C1)C") - self.c2 = Compound(smiles="CCN(CCO)C(=O)C1=CC(C)=CC=C1") - self.c3_1 = Compound(smiles="CCNC(=O)C1=CC(C)=CC=C1") - self.c3_2 = Compound(smiles="CC=O") - - # self.r1 = Rule(uuid="bt0334") # trig - # self.r2 = Rule(uuid="bt0243") # trig - # self.r3 = Rule(uuid="bt0003") # no trig - - self.reaction1 = Reaction([self.c1], [self.c2], self.r3) - self.reaction2 = Reaction([self.c1], [self.c3_1, self.c3_2], self.r2) - - - - - def test_test(self): - compounds = [ - self.c1, - self.c2, - self.c3_1, - self.c3_2, - ] - - reactions = [ - self.reaction1, - self.reaction2, - ] - - applicable_rules = [ - # Rule('bt0334', ParallelRule.objects.get(name='bt0334')), - # Rule('bt0243', ParallelRule.objects.get(name='bt0243')), - # Rule('bt0003', ParallelRule.objects.get(name='bt0003')), - ] - - ds = DatasetGenerator.generate_dataset(compounds, reactions, applicable_rules) - - self.assertIsNotNone(ds) \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 00000000..f1652b03 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,55 @@ +import json + +from django.test import TestCase + +from epdb.logic import PackageManager +from epdb.models import Compound, User, CompoundStructure, Reaction, Rule, MLRelativeReasoning + + +class ModelTest(TestCase): + fixtures = ["test_fixture.cleaned.json"] + + def setUp(self): + pass + + @classmethod + def setUpClass(cls): + super(ModelTest, cls).setUpClass() + cls.user = User.objects.get(username='anonymous') + cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') + bbd_data = json.load(open('fixtures/packages/2025-07-18/EAWAG-BBD.json')) + cls.BBD = PackageManager.import_package(bbd_data, cls.user) + + @classmethod + def tearDownClass(cls): + pass + + def tearDown(self): + pass + + def test_smoke(self): + threshold = float(0.5) + + # get Package objects from urls + rule_package_objs = [self.BBD] + data_package_objs = [self.BBD] + eval_packages_objs = [] + + mod = MLRelativeReasoning.create( + self.package, + 'ECC - BBD - 0.5', + 'Created MLRelativeReasoning in Testcase', + rule_package_objs, + data_package_objs, + eval_packages_objs, + threshold + ) + ds = mod.load_dataset() + + mod.build_model() + print("Model built!") + mod.evaluate_model() + print("Model Evaluated") + + results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C') + print(results) diff --git a/utilities/chem.py b/utilities/chem.py index 244aa099..77d61b3a 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -131,6 +131,24 @@ class FormatConverter(object): # TODO call to AMBIT Service return smiles + @staticmethod + def ep_standardize(smiles): + change = True + while change: + change = False + for standardizer in MATCH_STANDARDIZER: + tmp_smiles = standardizer.standardize(smiles) + + if tmp_smiles != smiles: + print(f"change {smiles} to {tmp_smiles}") + change = True + smiles = tmp_smiles + + if change is False: + print(f"nothing changed") + + return smiles + @staticmethod def standardize(smiles): # Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/ @@ -180,54 +198,6 @@ class FormatConverter(object): atom.UpdatePropertyCache() return mol - # @staticmethod - # def apply(smiles, smirks, preprocess_smiles=True, bracketize=False, standardize=True): - # logger.debug(f'Applying {smirks} on {smiles}') - # - # if bracketize: - # smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")" - # - # res = set() - # try: - # rxn = rdChemReactions.ReactionFromSmarts(smirks) - # mol = Chem.MolFromSmiles(smiles) - # - # # Inplace - # if preprocess_smiles: - # Chem.SanitizeMol(mol) - # mol = Chem.AddHs(mol) - # - # # apply! - # reacts = rxn.RunReactants((mol,)) - # if len(reacts): - # # Sanitize mols - # for product_set in reacts: - # prod_set = list() - # for product in product_set: - # # Fixes - # # [2025-01-30 23:00:50] ERROR chem - Sanitizing and converting failed: - # # non-ring atom 3 marked aromatic - # # But does not improve overall performance - # # - # # for a in product.GetAtoms(): - # # if (not a.IsInRing()) and a.GetIsAromatic(): - # # a.SetIsAromatic(False) - # # for b in product.GetBonds(): - # # if (not b.IsInRing()) and b.GetIsAromatic(): - # # b.SetIsAromatic(False) - # - # try: - # Chem.SanitizeMol(product) - # prod_set.append(FormatConverter.standardize(Chem.MolToSmiles(product))) - # except ValueError as e: - # logger.error(f'Sanitizing and converting failed:\n{e}') - # continue - # res.add(tuple(list(set(prod_set)))) - # except Exception as e: - # logger.error(f'Applying {smirks} on {smiles} failed:\n{e}') - # - # return list(res) - @staticmethod def is_valid_smirks(smirks: str) -> bool: try: diff --git a/utilities/ml.py b/utilities/ml.py index 85cacd84..18ffeb97 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -1,46 +1,29 @@ from __future__ import annotations -import dataclasses +import logging +from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime -from typing import List, Optional +from typing import List, Dict, Set, Tuple + import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA +from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score from sklearn.multioutput import ClassifierChain from sklearn.preprocessing import StandardScaler -from sklearn.tree import DecisionTreeClassifier -from sklearn.ensemble import RandomForestClassifier -# @dataclasses.dataclass -# class Feature: -# name: str -# value: float -# -# -# -# class Row: -# def __init__(self, compound_uuid: str, compound_smiles: str, descriptors: List[int]): -# self.data = {} -# -# -# -# class DataSet(object): -# -# def __init__(self): -# self.rows: List[Row] = [] -# -# def add_row(self, row: Row): -# pass +logger = logging.getLogger(__name__) + from dataclasses import dataclass, field -from utilities.chem import FormatConverter +from utilities.chem import FormatConverter, PredictionResult @dataclass -class Compound: +class SCompound: smiles: str uuid: str = field(default=None, compare=False, hash=False) @@ -53,10 +36,10 @@ class Compound: @dataclass -class Reaction: - educts: List[Compound] - products: List[Compound] - rule_uuid: str = field(default=None, compare=False, hash=False) +class SReaction: + educts: List[SCompound] + products: List[SCompound] + rule_uuid: SRule = field(default=None, compare=False, hash=False) reaction_uuid: str = field(default=None, compare=False, hash=False) def __hash__(self): @@ -68,77 +51,294 @@ class Reaction: return self._hash def __eq__(self, other): - if not isinstance(other, Reaction): + if not isinstance(other, SReaction): return NotImplemented return ( - sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and - sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles) + sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and + sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles) ) -class Dataset(object): +@dataclass +class SRule(ABC): - def __init__(self, headers=List['str'], data=List[List[str|int|float]]): - self.headers = headers - self.data = data - - - def features(self): - pass - - def labels(self): - pass - - def to_json(self): - pass - - def to_csv(self): - pass - - def to_arff(self): + @abstractmethod + def apply(self): pass +@dataclass +class SSimpleRule: + pass -class DatasetGenerator(object): + +@dataclass +class SParallelRule: + pass + + +class Dataset: + + def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None): + self.columns: List[str] = columns + self.num_labels: int = num_labels + + if data is None: + self.data: List[List[str | int | float]] = list() + else: + self.data = data + + self.num_features: int = len(columns) - self.num_labels + self._struct_features: Tuple[int, int] = self._block_indices('feature_') + self._triggered: Tuple[int, int] = self._block_indices('trig_') + self._observed: Tuple[int, int] = self._block_indices('obs_') + + def _block_indices(self, prefix) -> Tuple[int, int]: + indices: List[int] = [] + for i, feature in enumerate(self.columns): + if feature.startswith(prefix): + indices.append(i) + + return min(indices), max(indices) + + def structure_id(self): + return self.data[0][0] + + def add_row(self, row: List[str | int | float]): + if len(self.columns) != len(row): + raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") + self.data.append(row) + + def struct_features(self) -> Tuple[int, int]: + return self._struct_features + + def triggered(self) -> Tuple[int, int]: + return self._triggered + + def observed(self) -> Tuple[int, int]: + return self._observed + + def at(self, position: int) -> Dataset: + return Dataset(self.columns, self.num_labels, [self.data[position]]) + + def limit(self, limit: int) -> Dataset: + return Dataset(self.columns, self.num_labels, self.data[:limit]) + + def __iter__(self): + return (self.at(i) for i, _ in enumerate(self.data)) + + + def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]: + classify_data = [] + classify_products = [] + for struct in structures: + + if isinstance(struct, str): + struct_id = None + struct_smiles = struct + else: + struct_id = str(struct.uuid) + struct_smiles = struct.smiles + + features = FormatConverter.maccs(struct_smiles) + + trig = [] + prods = [] + for rule in applicable_rules: + products = rule.apply(struct_smiles) + + if len(products): + trig.append(1) + prods.append(products) + else: + trig.append(0) + prods.append([]) + + classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) + classify_products.append(prods) + + return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products @staticmethod - def generate_dataset(compounds: List[Compound], reactions: List[Reaction], applicable_rules: 'Rule', - compounds_to_exclude: Optional[Compound] = None, educts_only: bool = False) -> Dataset: + def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset: + _structures = set() - rows = [] + for r in reactions: + for e in r.educts.all(): + _structures.add(e) - if educts_only: - compounds = set() - for r in reactions: - for e in r.educts: - compounds.add(e) - compounds = list(compounds) + if not educts_only: + for e in r.products: + _structures.add(e) - total = len(compounds) - for i, c in enumerate(compounds): - row = [] - print(f"{i + 1}/{total} - {c.smiles}") - for r in applicable_rules: - product_sets = r.rule.apply(c.smiles) + compounds = sorted(_structures, key=lambda x: x.url) + + triggered: Dict[str, Set[str]] = defaultdict(set) + observed: Set[str] = set() + + # Apply rules on collected compounds and store tps + for i, comp in enumerate(compounds): + logger.debug(f"{i + 1}/{len(compounds)}...") + + for rule in applicable_rules: + product_sets = rule.apply(comp.smiles) if len(product_sets) == 0: - row.append([]) continue - #triggered.add(f"{r.uuid} + {c.uuid}") - reacts = set() - for ps in product_sets: - products = [] - for p in ps: - products.append(Compound(FormatConverter.standardize(p))) + key = f"{rule.uuid} + {comp.uuid}" - reacts.add(Reaction([c], products, r)) - row.append(list(reacts)) + if key in triggered: + logger.info(f"{key} already present. Duplicate reaction?") - rows.append(row) + for prod_set in product_sets: + for smi in prod_set: - return rows + try: + smi = FormatConverter.standardize(smi) + except Exception: + # :shrug: + logger.debug(f'Standardizing SMILES failed for {smi}') + pass + + triggered[key].add(smi) + + for i, r in enumerate(reactions): + logger.debug(f"{i + 1}/{len(reactions)}...") + + if len(r.educts.all()) != 1: + logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") + continue + + for comp in r.educts.all(): + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + + if key not in triggered: + continue + + # standardize products from reactions for comparison + standardized_products = [] + for cs in r.products.all(): + smi = cs.smiles + + try: + smi = FormatConverter.standardize(smi) + except Exception as e: + # :shrug: + logger.debug(f'Standardizing SMILES failed for {smi}') + pass + + standardized_products.append(smi) + + if len(set(standardized_products).difference(triggered[key])) == 0: + observed.add(key) + else: + pass + + ds = None + + for i, comp in enumerate(compounds): + # Features + feat = FormatConverter.maccs(comp.smiles) + trig = [] + obs = [] + + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + + # Check triggered + if key in triggered: + trig.append(1) + else: + trig.append(0) + + # Check obs + if key in observed: + obs.append(1) + elif key not in triggered: + obs.append(None) + else: + obs.append(0) + + if ds is None: + header = ['structure_id'] + \ + [f'feature_{i}' for i, _ in enumerate(feat)] \ + + [f'trig_{r.uuid}' for r in applicable_rules] \ + + [f'obs_{r.uuid}' for r in applicable_rules] + ds = Dataset(header, len(applicable_rules)) + + ds.add_row([str(comp.uuid)] + feat + trig + obs) + + return ds + + + def X(self, exclude_id_col=True, na_replacement=0): + res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))) + if na_replacement is not None: + res = [[x if x is not None else na_replacement for x in row] for row in res] + return res + + + def y(self, na_replacement=0): + res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) + if na_replacement is not None: + res = [[x if x is not None else na_replacement for x in row] for row in res] + return res + + + def __getitem__(self, key): + if not isinstance(key, tuple): + raise TypeError("Dataset must be indexed with dataset[rows, columns]") + + row_key, col_key = key + + # Normalize rows + if isinstance(row_key, int): + rows = [self.data[row_key]] + else: + rows = self.data[row_key] + + # Normalize columns + if isinstance(col_key, int): + res = [row[col_key] for row in rows] + else: + res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice) + else [row[i] for i in col_key] for row in rows] + + return res + + def save(self, path: 'Path'): + import pickle + with open(path, "wb") as fh: + pickle.dump(self, fh) + + @staticmethod + def load(path: 'Path'): + import pickle + return pickle.load(open(path, "rb")) + + def to_arff(self, path: 'Path'): + arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" + arff += "\n" + for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]: + if c == 'structure_id': + arff += f"@attribute {c} string\n" + else: + arff += f"@attribute {c} {{0,1}}\n" + + arff += f"\n@data\n" + for d in self.data: + ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]]) + xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]]) + arff += f'{ys},{xs}\n' + + with open(path, "w") as fh: + fh.write(arff) + fh.flush() + + def __repr__(self): + return f"" class SparseLabelECC(BaseEstimator, ClassifierMixin): @@ -166,8 +366,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin): self.keep_columns_.append(col) y_reduced = y[:, self.keep_columns_] - self.chains_ = [ClassifierChain(self.base_clf, order='random', random_state=i) - for i in range(self.num_chains)] + self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)] for i, chain in enumerate(self.chains_): print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}") @@ -208,26 +407,169 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin): return accuracy_score(y_true, y_pred, sample_weight=sample_weight) -class ApplicabilityDomain(PCA): - def __init__(self, n_components=5): - super().__init__(n_components=n_components) +import copy + +import numpy as np +from sklearn.dummy import DummyClassifier +from sklearn.tree import DecisionTreeClassifier + + +class BinaryRelevance: + def __init__(self, baseline_clf): + self.clf = baseline_clf + self.classifiers = None + + def fit(self, X, Y): + if self.classifiers is None: + self.classifiers = [] + + for l in range(len(Y[0])): + X_l = X[~np.isnan(Y[:, l])] + Y_l = (Y[~np.isnan(Y[:, l]), l]) + if len(X_l) == 0: # all labels are nan -> predict 0 + clf = DummyClassifier(strategy='constant', constant=0) + clf.fit([X[0]], [0]) + self.classifiers.append(clf) + continue + elif len(np.unique(Y_l)) == 1: # only one class -> predict that class + clf = DummyClassifier(strategy='most_frequent') + else: + clf = copy.deepcopy(self.clf) + clf.fit(X_l, Y_l) + self.classifiers.append(clf) + + def predict(self, X): + labels = [] + for clf in self.classifiers: + labels.append(clf.predict(X)) + return np.column_stack(labels) + + def predict_proba(self, X): + labels = np.empty((len(X), 0)) + for clf in self.classifiers: + pred = clf.predict_proba(X) + if pred.shape[1] > 1: + pred = pred[:, 1] + else: + pred = pred * clf.predict([X[0]])[0] + labels = np.column_stack((labels, pred)) + return labels + + +class MissingValuesClassifierChain: + def __init__(self, base_clf): + self.base_clf = base_clf + self.permutation = None + self.classifiers = None + + def fit(self, X, Y): + X = np.array(X) + Y = np.array(Y) + if self.permutation is None: + self.permutation = np.random.permutation(len(Y[0])) + + Y = Y[:, self.permutation] + + if self.classifiers is None: + self.classifiers = [] + + for p in range(len(self.permutation)): + X_p = X[~np.isnan(Y[:, p])] + Y_p = Y[~np.isnan(Y[:, p]), p] + if len(X_p) == 0: # all labels are nan -> predict 0 + clf = DummyClassifier(strategy='constant', constant=0) + self.classifiers.append(clf.fit([X[0]], [0])) + elif len(np.unique(Y_p)) == 1: # only one class -> predict that class + clf = DummyClassifier(strategy='most_frequent') + self.classifiers.append(clf.fit(X_p, Y_p)) + else: + clf = copy.deepcopy(self.base_clf) + self.classifiers.append(clf.fit(X_p, Y_p)) + newcol = Y[:, p] + pred = clf.predict(X) + newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions + X = np.column_stack((X, newcol)) + + def predict(self, X): + labels = np.empty((len(X), 0)) + for clf in self.classifiers: + pred = clf.predict(np.column_stack((X, labels))) + labels = np.column_stack((labels, pred)) + return labels[:, np.argsort(self.permutation)] + + def predict_proba(self, X): + labels = np.empty((len(X), 0)) + for clf in self.classifiers: + pred = clf.predict_proba(np.column_stack((X, np.round(labels)))) + if pred.shape[1] > 1: + pred = pred[:, 1] + else: + pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0] + labels = np.column_stack((labels, pred)) + return labels[:, np.argsort(self.permutation)] + + +class EnsembleClassifierChain: + def __init__(self, base_clf, num_chains=10): + self.base_clf = base_clf + self.num_chains = num_chains + self.num_labels = None + self.classifiers = None + + def fit(self, X, Y): + if self.classifiers is None: + self.classifiers = [] + + if self.num_labels is None: + self.num_labels = len(Y[0]) + + for p in range(self.num_chains): + print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") + clf = MissingValuesClassifierChain(self.base_clf) + clf.fit(X, Y) + self.classifiers.append(clf) + + def predict(self, X): + labels = np.zeros((len(X), self.num_labels)) + for clf in self.classifiers: + labels += clf.predict(X) + return np.round(labels / self.num_chains) + + def predict_proba(self, X): + labels = np.zeros((len(X), self.num_labels)) + for clf in self.classifiers: + labels += clf.predict_proba(X) + return labels / self.num_chains + + + + +class ApplicabilityDomainPCA(PCA): + + def __init__(self, num_neighbours: int = 5): + super().__init__(n_components=num_neighbours) self.scaler = StandardScaler() + self.num_neighbours = num_neighbours self.min_vals = None self.max_vals = None - def build(self, X): + def build(self, train_dataset: 'Dataset'): # transform - X_scaled = self.scaler.fit_transform(X) + X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca X_pca = self.fit_transform(X_scaled) self.max_vals = np.max(X_pca, axis=0) self.min_vals = np.min(X_pca, axis=0) - def is_applicable(self, instances): + def __transform(self, instances): instances_scaled = self.scaler.transform(instances) instances_pca = self.transform(instances_scaled) + return instances_pca + + def is_applicable(self, classify_instances: 'Dataset'): + instances_pca = self.__transform(classify_instances.X()) is_applicable = [] for i, instance in enumerate(instances_pca): @@ -237,3 +579,17 @@ class ApplicabilityDomain(PCA): is_applicable[i] = False return is_applicable + + +def tanimoto_distance(a: List[int], b: List[int]): + if len(a) != len(b): + raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}") + + sum_a = sum(a) + sum_b = sum(b) + sum_c = sum(v1 and v2 for v1, v2 in zip(a, b)) + + if sum_a + sum_b - sum_c == 0: + return 0.0 + + return 1 - (sum_c / (sum_a + sum_b - sum_c))