From 5477b5b3d4b22e8f4a4d1994ee66d5a554d9c685 Mon Sep 17 00:00:00 2001 From: jebus Date: Tue, 9 Sep 2025 19:32:12 +1200 Subject: [PATCH] [Feature] Rule Based Model (#92) Fixes #89 Co-authored-by: Tim Lorsbach Reviewed-on: https://git.envipath.com/enviPath/enviPy/pulls/92 --- epdb/models.py | 282 +++++++++++++----- epdb/tasks.py | 6 + epdb/views.py | 88 +++--- templates/actions/objects/model.html | 12 + .../modals/collections/new_model_modal.html | 133 +++++---- .../modals/objects/edit_model_modal.html | 44 +++ .../modals/objects/evaluate_model_modal.html | 62 ++++ .../modals/objects/retrain_model_modal.html | 43 +++ templates/objects/model.html | 5 +- utilities/ml.py | 70 ++++- 10 files changed, 560 insertions(+), 185 deletions(-) create mode 100644 templates/modals/objects/edit_model_modal.html create mode 100644 templates/modals/objects/evaluate_model_modal.html create mode 100644 templates/modals/objects/retrain_model_modal.html diff --git a/epdb/models.py b/epdb/models.py index fb4baf97..e82e05af 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -4,6 +4,7 @@ import json import logging import os import secrets +from abc import abstractmethod from collections import defaultdict from datetime import datetime from typing import Union, List, Optional, Dict, Tuple, Set @@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain +from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning logger = logging.getLogger(__name__) @@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): @property def root_nodes(self): - return Node.objects.filter(pathway=self, depth=0) + # sames as return Node.objects.filter(pathway=self, depth=0) but will utilize + # potentially prefetched node_set + return self.node_set.all().filter(pathway=self, depth=0) @property def nodes(self): - return Node.objects.filter(pathway=self) + # same as Node.objects.filter(pathway=self) but will utilize + # potentially prefetched node_set + return self.node_set.all() def get_node(self, node_url): for n in self.nodes: @@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): @property def edges(self): - return Edge.objects.filter(pathway=self) + # same as Edge.objects.filter(pathway=self) but will utilize + # potentially prefetched edge_set + return self.edge_set.all() def _url(self): return '{}/pathway/{}'.format(self.package.url, self.uuid) @@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel): return '{}/model/{}'.format(self.package.url, self.uuid) -class MLRelativeReasoning(EPModel): - rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages") - data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages") - eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages") +class PackageBasedModel(EPModel): + rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", + related_name="%(app_label)s_%(class)s_rule_packages") + data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", + related_name="%(app_label)s_%(class)s_data_packages") + eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", + related_name="%(app_label)s_%(class)s_eval_packages") threshold = models.FloatField(null=False, blank=False, default=0.5) + eval_results = JSONField(null=True, blank=True, default=dict) + app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True, + default=None) INITIAL = "INITIAL" INITIALIZING = "INITIALIZING" @@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel): } model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL) - eval_results = JSONField(null=True, blank=True, default=dict) - - app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True, - default=None) - def status(self): return self.PROGRESS_STATUS_CHOICES[self.model_status] def ready_for_prediction(self) -> bool: return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED] - @staticmethod - @transaction.atomic - def create(package: 'Package', rule_packages: List['Package'], - data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5, - name: 'str' = None, description: str = None, build_app_domain: bool = False, - app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, - app_domain_local_compatibility_threshold: float = None): - - mlrr = MLRelativeReasoning() - mlrr.package = package - - if name is None or name.strip() == '': - name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}" - - mlrr.name = name - - if description is not None and description.strip() != '': - mlrr.description = description - - if threshold is None or (threshold <= 0 or 1 <= threshold): - raise ValueError("Threshold must be a float between 0 and 1.") - - mlrr.threshold = threshold - - if len(rule_packages) == 0: - raise ValueError("At least one rule package must be provided.") - - mlrr.save() - - for p in rule_packages: - mlrr.rule_packages.add(p) - - if data_packages: - for p in data_packages: - mlrr.data_packages.add(p) - else: - for p in rule_packages: - mlrr.data_packages.add(p) - - if eval_packages: - for p in eval_packages: - mlrr.eval_packages.add(p) - - if build_app_domain: - ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold, - app_domain_local_compatibility_threshold) - mlrr.app_domain = ad - - mlrr.save() - - return mlrr - @cached_property def applicable_rules(self) -> List['Rule']: """ @@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel): ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") return Dataset.load(ds_path) + def retrain(self): + self.build_dataset() + self.build_model() + + def rebuild(self): + self.build_model() + + @abstractmethod + def build_model(self): + pass + + @staticmethod + def combine_products_and_probs(rules: List['Rule'], probabilities, products): + res = [] + for rule, p, smis in zip(rules, probabilities, products): + res.append(PredictionResult(smis, p, rule)) + return res + + class Meta: + abstract = True + + +class RuleBasedRelativeReasoning(PackageBasedModel): + min_count = models.IntegerField(null=False, blank=False, default=10) + max_count = models.IntegerField(null=False, blank=False, default=0) + + @staticmethod + @transaction.atomic + def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'], + eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0, + name: 'str' = None, description: str = None): + + rbrr = RuleBasedRelativeReasoning() + rbrr.package = package + + if name is None or name.strip() == '': + name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}" + + rbrr.name = name + + if description is not None and description.strip() != '': + rbrr.description = description + + if threshold is None or (threshold <= 0 or 1 <= threshold): + raise ValueError("Threshold must be a float between 0 and 1.") + + rbrr.threshold = threshold + + if min_count is None or min_count < 1: + raise ValueError("Minimum count must be an int greater than equal 1.") + + rbrr.min_count = min_count + + if max_count is None or max_count > min_count: + raise ValueError("Maximum count must be an int and must not be less than min_count.") + + if max_count is None: + raise ValueError("Maximum count must be at least 0.") + + if len(rule_packages) == 0: + raise ValueError("At least one rule package must be provided.") + + rbrr.save() + + for p in rule_packages: + rbrr.rule_packages.add(p) + + if data_packages: + for p in data_packages: + rbrr.data_packages.add(p) + else: + for p in rule_packages: + rbrr.data_packages.add(p) + + if eval_packages: + for p in eval_packages: + rbrr.eval_packages.add(p) + + rbrr.save() + + return rbrr + + def build_model(self): + self.model_status = self.BUILDING + self.save() + + ds = self.load_dataset() + labels = ds.y(na_replacement=None) + + mod = RelativeReasoning(*ds.triggered()) + mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)) + + f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl") + joblib.dump(mod, f) + + self.model_status = self.BUILT_NOT_EVALUATED + self.save() + + @cached_property + def model(self) -> 'RelativeReasoning': + mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl')) + return mod + + def predict(self, smiles) -> List['PredictionResult']: + start = datetime.now() + ds = self.load_dataset() + classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) + + mod = self.model + + pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None)) + + res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0]) + + end = datetime.now() + logger.info(f"Full predict took {(end - start).total_seconds()}s") + return res + + +class MLRelativeReasoning(PackageBasedModel): + + @staticmethod + @transaction.atomic + def create(package: 'Package', rule_packages: List['Package'], + data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5, + name: 'str' = None, description: str = None, build_app_domain: bool = False, + app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, + app_domain_local_compatibility_threshold: float = None): + + mlrr = MLRelativeReasoning() + mlrr.package = package + + if name is None or name.strip() == '': + name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}" + + mlrr.name = name + + if description is not None and description.strip() != '': + mlrr.description = description + + if threshold is None or (threshold <= 0 or 1 <= threshold): + raise ValueError("Threshold must be a float between 0 and 1.") + + mlrr.threshold = threshold + + if len(rule_packages) == 0: + raise ValueError("At least one rule package must be provided.") + + mlrr.save() + + for p in rule_packages: + mlrr.rule_packages.add(p) + + if data_packages: + for p in data_packages: + mlrr.data_packages.add(p) + else: + for p in rule_packages: + mlrr.data_packages.add(p) + + if eval_packages: + for p in eval_packages: + mlrr.eval_packages.add(p) + + if build_app_domain: + ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold, + app_domain_local_compatibility_threshold) + mlrr.app_domain = ad + + mlrr.save() + + return mlrr + def build_model(self): self.model_status = self.BUILDING self.save() @@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel): self.model_status = self.BUILT_NOT_EVALUATED self.save() - def retrain(self): - self.build_dataset() - self.build_model() - - def rebuild(self): - self.build_model() - def evaluate_model(self): if self.model_status != self.BUILT_NOT_EVALUATED: @@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel): logger.info(f"Full predict took {(end - start).total_seconds()}s") return res - @staticmethod - def combine_products_and_probs(rules: List['Rule'], probabilities, products): - res = [] - for rule, p, smis in zip(rules, probabilities, products): - res.append(PredictionResult(smis, p, rule)) - return res - @property def pr_curve(self): if self.model_status != self.FINISHED: @@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel): return accuracy -class RuleBaseRelativeReasoning(EPModel): - pass - class EnviFormer(EPModel): threshold = models.FloatField(null=False, blank=False, default=0.5) @@ -2406,6 +2518,12 @@ class EnviFormer(EPModel): def applicable_rules(self): return [] + def status(self): + return "Model is built and can be used for predictions, Model is not evaluated yet." + + def ready_for_prediction(self) -> bool: + return True + class PluginModel(EPModel): pass diff --git a/epdb/tasks.py b/epdb/tasks.py index 4ca4d183..3cbd9386 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -41,6 +41,12 @@ def evaluate_model(model_pk: int): mod.evaluate_model() +@shared_task(queue='model') +def retrain(model_pk: int): + mod = EPModel.objects.get(id=model_pk) + mod.retrain() + + @shared_task(queue='predict') def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway: pw = Pathway.objects.get(id=pw_pk) diff --git a/epdb/views.py b/epdb/views.py index da98412f..2ce006d2 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required from utilities.misc import HTMLGenerator from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \ - EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \ + EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \ UserPackagePermission, Permission, License, User, Edge logger = logging.getLogger(__name__) @@ -651,46 +651,50 @@ def package_models(request, package_uuid): mod = EnviFormer.create(current_package, name, description, threshold) - elif model_type == 'ml-relative-reasoning': - threshold = float(request.POST.get(f'{model_type}-threshold', 0.5)) - fingerprinter = request.POST.get(f'{model_type}-fingerprinter') - rule_packages = request.POST.getlist(f'{model_type}-rule-packages') - data_packages = request.POST.getlist(f'{model_type}-data-packages') - eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', []) + elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning': + # Generic fields for ML and Rule Based + rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages') + data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages') + eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', []) - # get Package objects from urls - rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages] - data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages] - eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages] + # Generic params + params = { + 'package' : current_package, + 'name' : name, + 'description' : description, + 'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages], + 'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages], + 'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages], + } - # App Domain related parameters - build_ad = request.POST.get('build-app-domain', False) == 'on' - num_neighbors = request.POST.get('num-neighbors', 5) - reliability_threshold = request.POST.get('reliability-threshold', 0.5) - local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5) + if model_type == 'ml-relative-reasoning': + # ML Specific + threshold = float(request.POST.get(f'{model_type}-threshold', 0.5)) + fingerprinter = request.POST.get(f'{model_type}-fingerprinter') - mod = MLRelativeReasoning.create( - package=current_package, - name=name, - description=description, - rule_packages=rule_package_objs, - data_packages=data_package_objs, - eval_packages=eval_packages_objs, - threshold=threshold, - # fingerprinter=fingerprinter, - build_app_domain=build_ad, - app_domain_num_neighbours=num_neighbors, - app_domain_reliability_threshold=reliability_threshold, - app_domain_local_compatibility_threshold=local_compatibility_threshold, - ) + # App Domain related parameters + build_ad = request.POST.get('build-app-domain', False) == 'on' + num_neighbors = request.POST.get('num-neighbors', 5) + reliability_threshold = request.POST.get('reliability-threshold', 0.5) + local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5) + + params['threshold'] = threshold + # params['fingerprinter'] = fingerprinter + params['build_app_domain'] = build_ad + params['app_domain_num_neighbours'] = num_neighbors + params['app_domain_reliability_threshold'] = reliability_threshold + params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold + + mod = MLRelativeReasoning.create( + **params + ) + else: + mod = RuleBasedRelativeReasoning.create( + **params + ) from .tasks import build_model build_model.delay(mod.pk) - - elif model_type == 'rule-base-relative-reasoning': - mod = RuleBaseRelativeReasoning() - - mod.save() else: return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."') return redirect(mod.url) @@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid): else: return HttpResponseBadRequest() else: + + name = request.POST.get('model-name', '').strip() + description = request.POST.get('model-description', '').strip() + + if any([name, description]): + if name: + current_model.name = name + + if description: + current_model.description = description + + current_model.save() + return redirect(current_model.url) + return HttpResponseBadRequest() else: diff --git a/templates/actions/objects/model.html b/templates/actions/objects/model.html index 466bcc66..acff9f23 100644 --- a/templates/actions/objects/model.html +++ b/templates/actions/objects/model.html @@ -1,4 +1,16 @@ {% if meta.can_edit %} +
  • + + Edit Model +
  • +
  • + + Evaluate Model +
  • +
  • + + Retrain Model +
  • Delete Model diff --git a/templates/modals/collections/new_model_modal.html b/templates/modals/collections/new_model_modal.html index 5283275a..b58a65ed 100644 --- a/templates/modals/collections/new_model_modal.html +++ b/templates/modals/collections/new_model_modal.html @@ -32,11 +32,11 @@ {% endfor %} - -
    + +
    - - {% for obj in meta.readable_packages %} @@ -53,8 +53,8 @@ {% endfor %} - - {% for obj in meta.readable_packages %} @@ -71,71 +71,54 @@ {% endfor %} - - - - {% if meta.enabled_features.PLUGINS and additional_descriptors %} - - - - {% endif %} - - - - - - - + + + {% if meta.enabled_features.PLUGINS and additional_descriptors %} + + + {% endif %} - {% endfor %} - - - {% for obj in meta.readable_packages %} - {% if not obj.reviewed %} - - {% endif %} - {% endfor %} - + + +
    {% if meta.enabled_features.APPLICABILITY_DOMAIN %}
    - - - - - - - - - + {% endif %} -
    - -
    -
    @@ -160,20 +143,38 @@ $(function() { $(this).hide(); }); - $("#ml-relative-reasoning-rule-packages").selectpicker(); - $("#ml-relative-reasoning-data-packages").selectpicker(); - $("#ml-relative-reasoning-evaluation-packages").selectpicker(); + $('#model-type').selectpicker(); + $("#ml-relative-reasoning-fingerprinter").selectpicker(); + $("#package-based-relative-reasoning-rule-packages").selectpicker(); + $("#package-based-relative-reasoning-data-packages").selectpicker(); + $("#package-based-relative-reasoning-evaluation-packages").selectpicker(); if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) { $("#ml-relative-reasoning-additional-fingerprinter").selectpicker(); } + $("#build-app-domain").change(function () { + if ($(this).is(":checked")) { + $('#ad-params').show(); + } else { + $('#ad-params').hide(); + } + }); + // On change hide all and show only selected $("#model-type").change(function() { $("div[id$='-specific-form']").each( function() { $(this).hide(); }); val = $('option:selected', this).val(); - $("#" + val + "-specific-form").show(); + + if (val === 'ml-relative-reasoning' || val === 'rule-based-relative-reasoning') { + $("#package-based-relative-reasoning-specific-form").show(); + if (val === 'ml-relative-reasoning') { + $("#ml-relative-reasoning-specific-form").show(); + } + } else { + $("#" + val + "-specific-form").show(); + } }); $('#new_model_modal_form_submit').on('click', function(e){ diff --git a/templates/modals/objects/edit_model_modal.html b/templates/modals/objects/edit_model_modal.html new file mode 100644 index 00000000..59c6705a --- /dev/null +++ b/templates/modals/objects/edit_model_modal.html @@ -0,0 +1,44 @@ +{% load static %} + + + diff --git a/templates/modals/objects/evaluate_model_modal.html b/templates/modals/objects/evaluate_model_modal.html new file mode 100644 index 00000000..42af6586 --- /dev/null +++ b/templates/modals/objects/evaluate_model_modal.html @@ -0,0 +1,62 @@ + + + diff --git a/templates/modals/objects/retrain_model_modal.html b/templates/modals/objects/retrain_model_modal.html new file mode 100644 index 00000000..9ed7745f --- /dev/null +++ b/templates/modals/objects/retrain_model_modal.html @@ -0,0 +1,43 @@ + + + diff --git a/templates/objects/model.html b/templates/objects/model.html index c23f477e..2c859780 100644 --- a/templates/objects/model.html +++ b/templates/objects/model.html @@ -4,6 +4,9 @@ {% block content %} {% block action_modals %} + {% include "modals/objects/edit_model_modal.html" %} + {% include "modals/objects/evaluate_model_modal.html" %} + {% include "modals/objects/retrain_model_modal.html" %} {% include "modals/objects/generic_delete_modal.html" %} {% endblock action_modals %} @@ -32,7 +35,7 @@

    {{ model.description }}

    - {% if model|classname == 'MLRelativeReasoning' %} + {% if model|classname == 'MLRelativeReasoning' or model|classname == 'RuleBasedRelativeReasoning'%}

    diff --git a/utilities/ml.py b/utilities/ml.py index dc810951..936dffba 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -289,6 +289,12 @@ class Dataset: res = [[x if x is not None else na_replacement for x in row] for row in res] return res + def trig(self, na_replacement=0): + res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1]))) + if na_replacement is not None: + res = [[x if x is not None else na_replacement for x in row] for row in res] + return res + def y(self, na_replacement=0): res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) @@ -324,7 +330,7 @@ class Dataset: pickle.dump(self, fh) @staticmethod - def load(path: 'Path'): + def load(path: 'Path') -> 'Dataset': import pickle return pickle.load(open(path, "rb")) @@ -553,6 +559,68 @@ class EnsembleClassifierChain: return labels / self.num_chains +class RelativeReasoning: + def __init__(self, start_index: int, end_index: int): + self.start_index: int = start_index + self.end_index: int = end_index + self.winmap: Dict[int, List[int]] = defaultdict(list) + self.min_count: int = 5 + self.max_count: int = 0 + + def fit(self, X, Y): + n_instances = len(Y) + n_attributes = len(Y[0]) + + for i in range(n_attributes): + for j in range(n_attributes): + if i == j: + continue + + countwin = 0 + countloose = 0 + countboth = 0 + + for k in range(n_instances): + vi = Y[k][i] + vj = Y[k][j] + + if vi is None or vj is None: + continue + + if vi < vj: + countwin += 1 + elif vi > vj: + countloose += 1 + elif vi == vj and vi == 1: # tie + countboth += 1 + + # We've seen more than self.min_count wins, more wins than loosing, no looses and no ties + if ( + countwin >= self.min_count and + countwin > countloose and + ( + countloose <= self.max_count or + self.max_count < 0 + ) and + countboth == 0 + ): + self.winmap[i].append(j) + + def predict(self, X): + res = np.zeros((len(X), (self.end_index + 1 - self.start_index))) + + for inst_idx, inst in enumerate(X): + for i, t in enumerate(inst[self.start_index: self.end_index + 1]): + res[inst_idx][i] = t + if t: + for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]): + if i != i2 and i2 in self.winmap.get(i, []) and X[t2]: + res[inst_idx][i] = 0 + + return res + + def predict_proba(self, X): + return self.predict(X) class ApplicabilityDomainPCA(PCA):