forked from enviPath/enviPy
[Feature] Rule Based Model (#92)
Fixes #89 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#92
This commit is contained in:
282
epdb/models.py
282
epdb/models.py
@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
||||
@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
@property
|
||||
def root_nodes(self):
|
||||
return Node.objects.filter(pathway=self, depth=0)
|
||||
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
|
||||
# potentially prefetched node_set
|
||||
return self.node_set.all().filter(pathway=self, depth=0)
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
return Node.objects.filter(pathway=self)
|
||||
# same as Node.objects.filter(pathway=self) but will utilize
|
||||
# potentially prefetched node_set
|
||||
return self.node_set.all()
|
||||
|
||||
def get_node(self, node_url):
|
||||
for n in self.nodes:
|
||||
@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
return Edge.objects.filter(pathway=self)
|
||||
# same as Edge.objects.filter(pathway=self) but will utilize
|
||||
# potentially prefetched edge_set
|
||||
return self.edge_set.all()
|
||||
|
||||
def _url(self):
|
||||
return '{}/pathway/{}'.format(self.package.url, self.uuid)
|
||||
@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel):
|
||||
return '{}/model/{}'.format(self.package.url, self.uuid)
|
||||
|
||||
|
||||
class MLRelativeReasoning(EPModel):
|
||||
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages")
|
||||
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages")
|
||||
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages")
|
||||
class PackageBasedModel(EPModel):
|
||||
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages",
|
||||
related_name="%(app_label)s_%(class)s_rule_packages")
|
||||
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages")
|
||||
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages",
|
||||
related_name="%(app_label)s_%(class)s_eval_packages")
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
|
||||
INITIAL = "INITIAL"
|
||||
INITIALIZING = "INITIALIZING"
|
||||
@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel):
|
||||
}
|
||||
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
|
||||
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
|
||||
def status(self):
|
||||
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package: 'Package', rule_packages: List['Package'],
|
||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
|
||||
mlrr = MLRelativeReasoning()
|
||||
mlrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mlrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mlrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mlrr.threshold = threshold
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
mlrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mlrr.rule_packages.add(p)
|
||||
|
||||
if data_packages:
|
||||
for p in data_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
else:
|
||||
for p in rule_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
mlrr.eval_packages.add(p)
|
||||
|
||||
if build_app_domain:
|
||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
app_domain_local_compatibility_threshold)
|
||||
mlrr.app_domain = ad
|
||||
|
||||
mlrr.save()
|
||||
|
||||
return mlrr
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self) -> List['Rule']:
|
||||
"""
|
||||
@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
def rebuild(self):
|
||||
self.build_model()
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
for rule, p, smis in zip(rules, probabilities, products):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
return res
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
min_count = models.IntegerField(null=False, blank=False, default=10)
|
||||
max_count = models.IntegerField(null=False, blank=False, default=0)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'],
|
||||
eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0,
|
||||
name: 'str' = None, description: str = None):
|
||||
|
||||
rbrr = RuleBasedRelativeReasoning()
|
||||
rbrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
rbrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
rbrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
rbrr.threshold = threshold
|
||||
|
||||
if min_count is None or min_count < 1:
|
||||
raise ValueError("Minimum count must be an int greater than equal 1.")
|
||||
|
||||
rbrr.min_count = min_count
|
||||
|
||||
if max_count is None or max_count > min_count:
|
||||
raise ValueError("Maximum count must be an int and must not be less than min_count.")
|
||||
|
||||
if max_count is None:
|
||||
raise ValueError("Maximum count must be at least 0.")
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
rbrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
rbrr.rule_packages.add(p)
|
||||
|
||||
if data_packages:
|
||||
for p in data_packages:
|
||||
rbrr.data_packages.add(p)
|
||||
else:
|
||||
for p in rule_packages:
|
||||
rbrr.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
rbrr.eval_packages.add(p)
|
||||
|
||||
rbrr.save()
|
||||
|
||||
return rbrr
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
ds = self.load_dataset()
|
||||
labels = ds.y(na_replacement=None)
|
||||
|
||||
mod = RelativeReasoning(*ds.triggered())
|
||||
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
@cached_property
|
||||
def model(self) -> 'RelativeReasoning':
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List['PredictionResult']:
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
|
||||
mod = self.model
|
||||
|
||||
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
||||
|
||||
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||
|
||||
end = datetime.now()
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
|
||||
class MLRelativeReasoning(PackageBasedModel):
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package: 'Package', rule_packages: List['Package'],
|
||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
|
||||
mlrr = MLRelativeReasoning()
|
||||
mlrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mlrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mlrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mlrr.threshold = threshold
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
mlrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mlrr.rule_packages.add(p)
|
||||
|
||||
if data_packages:
|
||||
for p in data_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
else:
|
||||
for p in rule_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
mlrr.eval_packages.add(p)
|
||||
|
||||
if build_app_domain:
|
||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
app_domain_local_compatibility_threshold)
|
||||
mlrr.app_domain = ad
|
||||
|
||||
mlrr.save()
|
||||
|
||||
return mlrr
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel):
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
def rebuild(self):
|
||||
self.build_model()
|
||||
|
||||
def evaluate_model(self):
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel):
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
for rule, p, smis in zip(rules, probabilities, products):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
return res
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
return accuracy
|
||||
|
||||
|
||||
class RuleBaseRelativeReasoning(EPModel):
|
||||
pass
|
||||
|
||||
|
||||
class EnviFormer(EPModel):
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
@ -2406,6 +2518,12 @@ class EnviFormer(EPModel):
|
||||
def applicable_rules(self):
|
||||
return []
|
||||
|
||||
def status(self):
|
||||
return "Model is built and can be used for predictions, Model is not evaluated yet."
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class PluginModel(EPModel):
|
||||
pass
|
||||
|
||||
@ -41,6 +41,12 @@ def evaluate_model(model_pk: int):
|
||||
mod.evaluate_model()
|
||||
|
||||
|
||||
@shared_task(queue='model')
|
||||
def retrain(model_pk: int):
|
||||
mod = EPModel.objects.get(id=model_pk)
|
||||
mod.retrain()
|
||||
|
||||
|
||||
@shared_task(queue='predict')
|
||||
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
|
||||
pw = Pathway.objects.get(id=pw_pk)
|
||||
|
||||
@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required
|
||||
from utilities.misc import HTMLGenerator
|
||||
from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser
|
||||
from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \
|
||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||
UserPackagePermission, Permission, License, User, Edge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -651,17 +651,26 @@ def package_models(request, package_uuid):
|
||||
|
||||
mod = EnviFormer.create(current_package, name, description, threshold)
|
||||
|
||||
elif model_type == 'ml-relative-reasoning':
|
||||
elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning':
|
||||
# Generic fields for ML and Rule Based
|
||||
rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages')
|
||||
data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages')
|
||||
eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', [])
|
||||
|
||||
# Generic params
|
||||
params = {
|
||||
'package' : current_package,
|
||||
'name' : name,
|
||||
'description' : description,
|
||||
'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages],
|
||||
'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages],
|
||||
'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages],
|
||||
}
|
||||
|
||||
if model_type == 'ml-relative-reasoning':
|
||||
# ML Specific
|
||||
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
||||
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
||||
rule_packages = request.POST.getlist(f'{model_type}-rule-packages')
|
||||
data_packages = request.POST.getlist(f'{model_type}-data-packages')
|
||||
eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', [])
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages]
|
||||
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
|
||||
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
|
||||
|
||||
# App Domain related parameters
|
||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||
@ -669,28 +678,23 @@ def package_models(request, package_uuid):
|
||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||
|
||||
params['threshold'] = threshold
|
||||
# params['fingerprinter'] = fingerprinter
|
||||
params['build_app_domain'] = build_ad
|
||||
params['app_domain_num_neighbours'] = num_neighbors
|
||||
params['app_domain_reliability_threshold'] = reliability_threshold
|
||||
params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
package=current_package,
|
||||
name=name,
|
||||
description=description,
|
||||
rule_packages=rule_package_objs,
|
||||
data_packages=data_package_objs,
|
||||
eval_packages=eval_packages_objs,
|
||||
threshold=threshold,
|
||||
# fingerprinter=fingerprinter,
|
||||
build_app_domain=build_ad,
|
||||
app_domain_num_neighbours=num_neighbors,
|
||||
app_domain_reliability_threshold=reliability_threshold,
|
||||
app_domain_local_compatibility_threshold=local_compatibility_threshold,
|
||||
**params
|
||||
)
|
||||
else:
|
||||
mod = RuleBasedRelativeReasoning.create(
|
||||
**params
|
||||
)
|
||||
|
||||
from .tasks import build_model
|
||||
build_model.delay(mod.pk)
|
||||
|
||||
elif model_type == 'rule-base-relative-reasoning':
|
||||
mod = RuleBaseRelativeReasoning()
|
||||
|
||||
mod.save()
|
||||
else:
|
||||
return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."')
|
||||
return redirect(mod.url)
|
||||
@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid):
|
||||
else:
|
||||
return HttpResponseBadRequest()
|
||||
else:
|
||||
|
||||
name = request.POST.get('model-name', '').strip()
|
||||
description = request.POST.get('model-description', '').strip()
|
||||
|
||||
if any([name, description]):
|
||||
if name:
|
||||
current_model.name = name
|
||||
|
||||
if description:
|
||||
current_model.description = description
|
||||
|
||||
current_model.save()
|
||||
return redirect(current_model.url)
|
||||
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
else:
|
||||
|
||||
@ -1,4 +1,16 @@
|
||||
{% if meta.can_edit %}
|
||||
<li>
|
||||
<a role="button" data-toggle="modal" data-target="#edit_model_modal">
|
||||
<i class="glyphicon glyphicon-edit"></i> Edit Model</a>
|
||||
</li>
|
||||
<li>
|
||||
<a role="button" data-toggle="modal" data-target="#evaluate_model_modal">
|
||||
<i class="glyphicon glyphicon-ok"></i> Evaluate Model</a>
|
||||
</li>
|
||||
<li>
|
||||
<a role="button" data-toggle="modal" data-target="#retrain_model_modal">
|
||||
<i class="glyphicon glyphicon-repeat"></i> Retrain Model</a>
|
||||
</li>
|
||||
<li>
|
||||
<a class="button" data-toggle="modal" data-target="#generic_delete_modal">
|
||||
<i class="glyphicon glyphicon-trash"></i> Delete Model</a>
|
||||
|
||||
@ -32,11 +32,11 @@
|
||||
<option value="{{ v }}">{{ k }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
<!-- ML Based Form-->
|
||||
<div id="ml-relative-reasoning-specific-form">
|
||||
<!-- ML and Rule Based Based Form-->
|
||||
<div id="package-based-relative-reasoning-specific-form">
|
||||
<!-- Rule Packages -->
|
||||
<label for="ml-relative-reasoning-rule-packages">Rule Packages</label>
|
||||
<select id="ml-relative-reasoning-rule-packages" name="ml-relative-reasoning-rule-packages"
|
||||
<label for="package-based-relative-reasoning-rule-packages">Rule Packages</label>
|
||||
<select id="package-based-relative-reasoning-rule-packages" name="package-based-relative-reasoning-rule-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
{% for obj in meta.readable_packages %}
|
||||
@ -53,8 +53,8 @@
|
||||
{% endfor %}
|
||||
</select>
|
||||
<!-- Data Packages -->
|
||||
<label for="ml-relative-reasoning-data-packages" >Data Packages</label>
|
||||
<select id="ml-relative-reasoning-data-packages" name="ml-relative-reasoning-data-packages"
|
||||
<label for="package-based-relative-reasoning-data-packages" >Data Packages</label>
|
||||
<select id="package-based-relative-reasoning-data-packages" name="package-based-relative-reasoning-data-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
{% for obj in meta.readable_packages %}
|
||||
@ -71,6 +71,7 @@
|
||||
{% endfor %}
|
||||
</select>
|
||||
|
||||
<div id="ml-relative-reasoning-specific-form">
|
||||
<!-- Fingerprinter -->
|
||||
<label for="ml-relative-reasoning-fingerprinter">Fingerprinter</label>
|
||||
<select id="ml-relative-reasoning-fingerprinter" name="ml-relative-reasoning-fingerprinter"
|
||||
@ -79,8 +80,10 @@
|
||||
</select>
|
||||
{% if meta.enabled_features.PLUGINS and additional_descriptors %}
|
||||
<!-- Property Plugins go here -->
|
||||
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter / Descriptors</label>
|
||||
<select id="ml-relative-reasoning-additional-fingerprinter" name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
|
||||
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter /
|
||||
Descriptors</label>
|
||||
<select id="ml-relative-reasoning-additional-fingerprinter"
|
||||
name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
|
||||
<option disabled selected>Select Additional Fingerprinter / Descriptor</option>
|
||||
{% for k, v in additional_descriptors.items %}
|
||||
<option value="{{ v }}">{{ k }}</option>
|
||||
@ -92,33 +95,16 @@
|
||||
<input type="number" min="0" max="1" step="0.05" value="0.5"
|
||||
id="ml-relative-reasoning-threshold"
|
||||
name="ml-relative-reasoning-threshold" class="form-control">
|
||||
|
||||
<!-- Evaluation -->
|
||||
<label for="ml-relative-reasoning-evaluation-packages">Evaluation Packages</label>
|
||||
<select id="ml-relative-reasoning-evaluation-packages" name="ml-relative-reasoning-evaluation-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
{% for obj in meta.readable_packages %}
|
||||
{% if obj.reviewed %}
|
||||
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
<option disabled>Unreviewed Packages</option>
|
||||
{% for obj in meta.readable_packages %}
|
||||
{% if not obj.reviewed %}
|
||||
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</select>
|
||||
|
||||
</div>
|
||||
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
|
||||
<!-- Build AD? -->
|
||||
<div class="checkbox">
|
||||
<label>
|
||||
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an Applicability Domain?
|
||||
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an
|
||||
Applicability Domain?
|
||||
</label>
|
||||
</div>
|
||||
<div id="ad-params" style="display:none">
|
||||
<!-- Num Neighbors -->
|
||||
<label for="num-neighbors">Number of Neighbors</label>
|
||||
<input id="num-neighbors" name="num-neighbors" type="number" class="form-control" value="5"
|
||||
@ -131,11 +117,8 @@
|
||||
<label for="reliability-threshold">Reliability Threshold</label>
|
||||
<input id="reliability-threshold" name="reliability-threshold" type="number"
|
||||
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
||||
{% endif %}
|
||||
</div>
|
||||
<!-- Rule Based Based Form-->
|
||||
<div id="rule-based-relative-reasoning-specific-form">
|
||||
|
||||
{% endif %}
|
||||
</div>
|
||||
<!-- EnviFormer-->
|
||||
<div id="enviformer-specific-form">
|
||||
@ -160,20 +143,38 @@ $(function() {
|
||||
$(this).hide();
|
||||
});
|
||||
|
||||
$("#ml-relative-reasoning-rule-packages").selectpicker();
|
||||
$("#ml-relative-reasoning-data-packages").selectpicker();
|
||||
$("#ml-relative-reasoning-evaluation-packages").selectpicker();
|
||||
$('#model-type').selectpicker();
|
||||
$("#ml-relative-reasoning-fingerprinter").selectpicker();
|
||||
$("#package-based-relative-reasoning-rule-packages").selectpicker();
|
||||
$("#package-based-relative-reasoning-data-packages").selectpicker();
|
||||
$("#package-based-relative-reasoning-evaluation-packages").selectpicker();
|
||||
if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) {
|
||||
$("#ml-relative-reasoning-additional-fingerprinter").selectpicker();
|
||||
}
|
||||
|
||||
$("#build-app-domain").change(function () {
|
||||
if ($(this).is(":checked")) {
|
||||
$('#ad-params').show();
|
||||
} else {
|
||||
$('#ad-params').hide();
|
||||
}
|
||||
});
|
||||
|
||||
// On change hide all and show only selected
|
||||
$("#model-type").change(function() {
|
||||
$("div[id$='-specific-form']").each( function() {
|
||||
$(this).hide();
|
||||
});
|
||||
val = $('option:selected', this).val();
|
||||
|
||||
if (val === 'ml-relative-reasoning' || val === 'rule-based-relative-reasoning') {
|
||||
$("#package-based-relative-reasoning-specific-form").show();
|
||||
if (val === 'ml-relative-reasoning') {
|
||||
$("#ml-relative-reasoning-specific-form").show();
|
||||
}
|
||||
} else {
|
||||
$("#" + val + "-specific-form").show();
|
||||
}
|
||||
});
|
||||
|
||||
$('#new_model_modal_form_submit').on('click', function(e){
|
||||
|
||||
44
templates/modals/objects/edit_model_modal.html
Normal file
44
templates/modals/objects/edit_model_modal.html
Normal file
@ -0,0 +1,44 @@
|
||||
{% load static %}
|
||||
<!-- Edit Model -->
|
||||
<div id="edit_model_modal" class="modal" tabindex="-1">
|
||||
<div class="modal-dialog">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<button type="button" class="close" data-dismiss="modal" aria-label="Close">
|
||||
<span aria-hidden="true">×</span>
|
||||
</button>
|
||||
<h3 class="modal-title">Update Model</h3>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<p>Alter Name and Description of the Model.</p>
|
||||
<form id="edit-model-modal-form" accept-charset="UTF-8" action="" data-remote="true" method="post">
|
||||
{% csrf_token %}
|
||||
<p>
|
||||
<label for="model-name">Name</label>
|
||||
<input id="model-name" type="text" class="form-control" name="model-name"
|
||||
value="{{ model.name }}">
|
||||
</p>
|
||||
<p>
|
||||
<label for="model-description">Description</label>
|
||||
<input id="model-description" type="text" class="form-control" name="model-description"
|
||||
value="{{ model.description }}">
|
||||
</p>
|
||||
</form>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<button type="button" class="btn btn-secondary" data-dismiss="modal">Close</button>
|
||||
<button type="button" class="btn btn-primary" id="edit-model-modal-submit">Update</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
$(function () {
|
||||
|
||||
$('#edit-model-modal-submit').click(function (e) {
|
||||
e.preventDefault();
|
||||
$('#edit-model-modal-form').submit();
|
||||
});
|
||||
|
||||
})
|
||||
</script>
|
||||
62
templates/modals/objects/evaluate_model_modal.html
Normal file
62
templates/modals/objects/evaluate_model_modal.html
Normal file
@ -0,0 +1,62 @@
|
||||
<div class="modal fade" tabindex="-1" id="evaluate_model_modal" role="dialog" aria-labelledby="evaluate_model_modal"
|
||||
aria-hidden="true">
|
||||
<div class="modal-dialog modal-lg">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<button type="button" class="close" data-dismiss="modal">
|
||||
<span aria-hidden="true">×</span>
|
||||
<span class="sr-only">Close</span>
|
||||
</button>
|
||||
<h4 class="modal-title">Evaluate Model</h4>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
|
||||
data-remote="true" method="post">
|
||||
{% csrf_token %}
|
||||
<div class="jumbotron">
|
||||
For evaluation, you need to select the packages you want to use.
|
||||
While the model is evaluating, you can use the model for predictions.
|
||||
</div>
|
||||
<!-- Evaluation -->
|
||||
<label for="relative-reasoning-evaluation-packages">Evaluation Packages</label>
|
||||
<select id="relative-reasoning-evaluation-packages" name=relative-reasoning-evaluation-packages"
|
||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||
<option disabled>Reviewed Packages</option>
|
||||
{% for obj in meta.readable_packages %}
|
||||
{% if obj.reviewed %}
|
||||
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
<option disabled>Unreviewed Packages</option>
|
||||
{% for obj in meta.readable_packages %}
|
||||
{% if not obj.reviewed %}
|
||||
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</select>
|
||||
</form>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<a id="evaluate_model_form_submit" class="btn btn-primary" href="#">Evaluate</a>
|
||||
<button type="button" class="btn btn-default" data-dismiss="modal">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
$(function () {
|
||||
|
||||
$("#relative-reasoning-evaluation-packages").selectpicker();
|
||||
|
||||
$('#evaluate_model_form_submit').on('click', function (e) {
|
||||
e.preventDefault();
|
||||
$('#evaluate_model_form').submit();
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
|
||||
</script>
|
||||
43
templates/modals/objects/retrain_model_modal.html
Normal file
43
templates/modals/objects/retrain_model_modal.html
Normal file
@ -0,0 +1,43 @@
|
||||
<div class="modal fade" tabindex="-1" id="retrain_model_modal" role="dialog" aria-labelledby="retrain_model_modal"
|
||||
aria-hidden="true">
|
||||
<div class="modal-dialog modal-lg">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<button type="button" class="close" data-dismiss="modal">
|
||||
<span aria-hidden="true">×</span>
|
||||
<span class="sr-only">Close</span>
|
||||
</button>
|
||||
<h4 class="modal-title">Retrain Model</h4>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<form id="retrain_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
|
||||
data-remote="true" method="post">
|
||||
<div class="jumbotron">
|
||||
To reflect changes in the rule or data packages, you can use the "Retrain" button,
|
||||
to let the model reflect the changes without creating a new model.
|
||||
While the model is retraining, it will be unavailable for prediction.
|
||||
</div>
|
||||
{% csrf_token %}
|
||||
<input type="hidden" name="action" value="retrain">
|
||||
</form>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<a id="retrain_model_form_submit" class="btn btn-primary" href="#">Retrain</a>
|
||||
<button type="button" class="btn btn-default" data-dismiss="modal">Cancel</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
$(function () {
|
||||
|
||||
$('#retrain_model_form_submit').on('click', function (e) {
|
||||
e.preventDefault();
|
||||
$('#retrain_model_form').submit();
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
</script>
|
||||
@ -4,6 +4,9 @@
|
||||
{% block content %}
|
||||
|
||||
{% block action_modals %}
|
||||
{% include "modals/objects/edit_model_modal.html" %}
|
||||
{% include "modals/objects/evaluate_model_modal.html" %}
|
||||
{% include "modals/objects/retrain_model_modal.html" %}
|
||||
{% include "modals/objects/generic_delete_modal.html" %}
|
||||
{% endblock action_modals %}
|
||||
|
||||
@ -32,7 +35,7 @@
|
||||
<div class="panel-body">
|
||||
<p> {{ model.description }} </p>
|
||||
</div>
|
||||
{% if model|classname == 'MLRelativeReasoning' %}
|
||||
{% if model|classname == 'MLRelativeReasoning' or model|classname == 'RuleBasedRelativeReasoning'%}
|
||||
<!-- Rule Packages -->
|
||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||
<h4 class="panel-title">
|
||||
|
||||
@ -289,6 +289,12 @@ class Dataset:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||
@ -324,7 +330,7 @@ class Dataset:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: 'Path'):
|
||||
def load(path: 'Path') -> 'Dataset':
|
||||
import pickle
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
@ -553,6 +559,68 @@ class EnsembleClassifierChain:
|
||||
return labels / self.num_chains
|
||||
|
||||
|
||||
class RelativeReasoning:
|
||||
def __init__(self, start_index: int, end_index: int):
|
||||
self.start_index: int = start_index
|
||||
self.end_index: int = end_index
|
||||
self.winmap: Dict[int, List[int]] = defaultdict(list)
|
||||
self.min_count: int = 5
|
||||
self.max_count: int = 0
|
||||
|
||||
def fit(self, X, Y):
|
||||
n_instances = len(Y)
|
||||
n_attributes = len(Y[0])
|
||||
|
||||
for i in range(n_attributes):
|
||||
for j in range(n_attributes):
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
countwin = 0
|
||||
countloose = 0
|
||||
countboth = 0
|
||||
|
||||
for k in range(n_instances):
|
||||
vi = Y[k][i]
|
||||
vj = Y[k][j]
|
||||
|
||||
if vi is None or vj is None:
|
||||
continue
|
||||
|
||||
if vi < vj:
|
||||
countwin += 1
|
||||
elif vi > vj:
|
||||
countloose += 1
|
||||
elif vi == vj and vi == 1: # tie
|
||||
countboth += 1
|
||||
|
||||
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
|
||||
if (
|
||||
countwin >= self.min_count and
|
||||
countwin > countloose and
|
||||
(
|
||||
countloose <= self.max_count or
|
||||
self.max_count < 0
|
||||
) and
|
||||
countboth == 0
|
||||
):
|
||||
self.winmap[i].append(j)
|
||||
|
||||
def predict(self, X):
|
||||
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
||||
|
||||
for inst_idx, inst in enumerate(X):
|
||||
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
res[inst_idx][i] = t
|
||||
if t:
|
||||
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
|
||||
res[inst_idx][i] = 0
|
||||
|
||||
return res
|
||||
|
||||
def predict_proba(self, X):
|
||||
return self.predict(X)
|
||||
|
||||
|
||||
class ApplicabilityDomainPCA(PCA):
|
||||
|
||||
Reference in New Issue
Block a user