[Feature] Rule Based Model (#92)

Fixes #89

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#92
This commit is contained in:
2025-09-09 19:32:12 +12:00
parent 1a6608287d
commit 5477b5b3d4
10 changed files with 560 additions and 185 deletions

View File

@ -4,6 +4,7 @@ import json
import logging
import os
import secrets
from abc import abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import Union, List, Optional, Dict, Tuple, Set
@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
logger = logging.getLogger(__name__)
@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
@property
def root_nodes(self):
return Node.objects.filter(pathway=self, depth=0)
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
# potentially prefetched node_set
return self.node_set.all().filter(pathway=self, depth=0)
@property
def nodes(self):
return Node.objects.filter(pathway=self)
# same as Node.objects.filter(pathway=self) but will utilize
# potentially prefetched node_set
return self.node_set.all()
def get_node(self, node_url):
for n in self.nodes:
@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
@property
def edges(self):
return Edge.objects.filter(pathway=self)
# same as Edge.objects.filter(pathway=self) but will utilize
# potentially prefetched edge_set
return self.edge_set.all()
def _url(self):
return '{}/pathway/{}'.format(self.package.url, self.uuid)
@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel):
return '{}/model/{}'.format(self.package.url, self.uuid)
class MLRelativeReasoning(EPModel):
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages")
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages")
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages")
class PackageBasedModel(EPModel):
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages",
related_name="%(app_label)s_%(class)s_rule_packages")
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages",
related_name="%(app_label)s_%(class)s_data_packages")
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages",
related_name="%(app_label)s_%(class)s_eval_packages")
threshold = models.FloatField(null=False, blank=False, default=0.5)
eval_results = JSONField(null=True, blank=True, default=dict)
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
default=None)
INITIAL = "INITIAL"
INITIALIZING = "INITIALIZING"
@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel):
}
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
eval_results = JSONField(null=True, blank=True, default=dict)
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
default=None)
def status(self):
return self.PROGRESS_STATUS_CHOICES[self.model_status]
def ready_for_prediction(self) -> bool:
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
@staticmethod
@transaction.atomic
def create(package: 'Package', rule_packages: List['Package'],
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
name: 'str' = None, description: str = None, build_app_domain: bool = False,
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
app_domain_local_compatibility_threshold: float = None):
mlrr = MLRelativeReasoning()
mlrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
mlrr.name = name
if description is not None and description.strip() != '':
mlrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
mlrr.threshold = threshold
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
mlrr.save()
for p in rule_packages:
mlrr.rule_packages.add(p)
if data_packages:
for p in data_packages:
mlrr.data_packages.add(p)
else:
for p in rule_packages:
mlrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mlrr.eval_packages.add(p)
if build_app_domain:
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
app_domain_local_compatibility_threshold)
mlrr.app_domain = ad
mlrr.save()
return mlrr
@cached_property
def applicable_rules(self) -> List['Rule']:
"""
@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return Dataset.load(ds_path)
def retrain(self):
self.build_dataset()
self.build_model()
def rebuild(self):
self.build_model()
@abstractmethod
def build_model(self):
pass
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
for rule, p, smis in zip(rules, probabilities, products):
res.append(PredictionResult(smis, p, rule))
return res
class Meta:
abstract = True
class RuleBasedRelativeReasoning(PackageBasedModel):
min_count = models.IntegerField(null=False, blank=False, default=10)
max_count = models.IntegerField(null=False, blank=False, default=0)
@staticmethod
@transaction.atomic
def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'],
eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0,
name: 'str' = None, description: str = None):
rbrr = RuleBasedRelativeReasoning()
rbrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
rbrr.name = name
if description is not None and description.strip() != '':
rbrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
rbrr.threshold = threshold
if min_count is None or min_count < 1:
raise ValueError("Minimum count must be an int greater than equal 1.")
rbrr.min_count = min_count
if max_count is None or max_count > min_count:
raise ValueError("Maximum count must be an int and must not be less than min_count.")
if max_count is None:
raise ValueError("Maximum count must be at least 0.")
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
rbrr.save()
for p in rule_packages:
rbrr.rule_packages.add(p)
if data_packages:
for p in data_packages:
rbrr.data_packages.add(p)
else:
for p in rule_packages:
rbrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
rbrr.eval_packages.add(p)
rbrr.save()
return rbrr
def build_model(self):
self.model_status = self.BUILDING
self.save()
ds = self.load_dataset()
labels = ds.y(na_replacement=None)
mod = RelativeReasoning(*ds.triggered())
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f)
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
@cached_property
def model(self) -> 'RelativeReasoning':
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
return mod
def predict(self, smiles) -> List['PredictionResult']:
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
mod = self.model
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
end = datetime.now()
logger.info(f"Full predict took {(end - start).total_seconds()}s")
return res
class MLRelativeReasoning(PackageBasedModel):
@staticmethod
@transaction.atomic
def create(package: 'Package', rule_packages: List['Package'],
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
name: 'str' = None, description: str = None, build_app_domain: bool = False,
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
app_domain_local_compatibility_threshold: float = None):
mlrr = MLRelativeReasoning()
mlrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
mlrr.name = name
if description is not None and description.strip() != '':
mlrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
mlrr.threshold = threshold
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
mlrr.save()
for p in rule_packages:
mlrr.rule_packages.add(p)
if data_packages:
for p in data_packages:
mlrr.data_packages.add(p)
else:
for p in rule_packages:
mlrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mlrr.eval_packages.add(p)
if build_app_domain:
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
app_domain_local_compatibility_threshold)
mlrr.app_domain = ad
mlrr.save()
return mlrr
def build_model(self):
self.model_status = self.BUILDING
self.save()
@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def retrain(self):
self.build_dataset()
self.build_model()
def rebuild(self):
self.build_model()
def evaluate_model(self):
if self.model_status != self.BUILT_NOT_EVALUATED:
@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel):
logger.info(f"Full predict took {(end - start).total_seconds()}s")
return res
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
for rule, p, smis in zip(rules, probabilities, products):
res.append(PredictionResult(smis, p, rule))
return res
@property
def pr_curve(self):
if self.model_status != self.FINISHED:
@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel):
return accuracy
class RuleBaseRelativeReasoning(EPModel):
pass
class EnviFormer(EPModel):
threshold = models.FloatField(null=False, blank=False, default=0.5)
@ -2406,6 +2518,12 @@ class EnviFormer(EPModel):
def applicable_rules(self):
return []
def status(self):
return "Model is built and can be used for predictions, Model is not evaluated yet."
def ready_for_prediction(self) -> bool:
return True
class PluginModel(EPModel):
pass

View File

@ -41,6 +41,12 @@ def evaluate_model(model_pk: int):
mod.evaluate_model()
@shared_task(queue='model')
def retrain(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.retrain()
@shared_task(queue='predict')
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)

View File

@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required
from utilities.misc import HTMLGenerator
from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser
from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \
EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
UserPackagePermission, Permission, License, User, Edge
logger = logging.getLogger(__name__)
@ -651,17 +651,26 @@ def package_models(request, package_uuid):
mod = EnviFormer.create(current_package, name, description, threshold)
elif model_type == 'ml-relative-reasoning':
elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning':
# Generic fields for ML and Rule Based
rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages')
data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages')
eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', [])
# Generic params
params = {
'package' : current_package,
'name' : name,
'description' : description,
'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages],
'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages],
'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages],
}
if model_type == 'ml-relative-reasoning':
# ML Specific
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
rule_packages = request.POST.getlist(f'{model_type}-rule-packages')
data_packages = request.POST.getlist(f'{model_type}-data-packages')
eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', [])
# get Package objects from urls
rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages]
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
# App Domain related parameters
build_ad = request.POST.get('build-app-domain', False) == 'on'
@ -669,28 +678,23 @@ def package_models(request, package_uuid):
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
params['threshold'] = threshold
# params['fingerprinter'] = fingerprinter
params['build_app_domain'] = build_ad
params['app_domain_num_neighbours'] = num_neighbors
params['app_domain_reliability_threshold'] = reliability_threshold
params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold
mod = MLRelativeReasoning.create(
package=current_package,
name=name,
description=description,
rule_packages=rule_package_objs,
data_packages=data_package_objs,
eval_packages=eval_packages_objs,
threshold=threshold,
# fingerprinter=fingerprinter,
build_app_domain=build_ad,
app_domain_num_neighbours=num_neighbors,
app_domain_reliability_threshold=reliability_threshold,
app_domain_local_compatibility_threshold=local_compatibility_threshold,
**params
)
else:
mod = RuleBasedRelativeReasoning.create(
**params
)
from .tasks import build_model
build_model.delay(mod.pk)
elif model_type == 'rule-base-relative-reasoning':
mod = RuleBaseRelativeReasoning()
mod.save()
else:
return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."')
return redirect(mod.url)
@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid):
else:
return HttpResponseBadRequest()
else:
name = request.POST.get('model-name', '').strip()
description = request.POST.get('model-description', '').strip()
if any([name, description]):
if name:
current_model.name = name
if description:
current_model.description = description
current_model.save()
return redirect(current_model.url)
return HttpResponseBadRequest()
else:

View File

@ -1,4 +1,16 @@
{% if meta.can_edit %}
<li>
<a role="button" data-toggle="modal" data-target="#edit_model_modal">
<i class="glyphicon glyphicon-edit"></i> Edit Model</a>
</li>
<li>
<a role="button" data-toggle="modal" data-target="#evaluate_model_modal">
<i class="glyphicon glyphicon-ok"></i> Evaluate Model</a>
</li>
<li>
<a role="button" data-toggle="modal" data-target="#retrain_model_modal">
<i class="glyphicon glyphicon-repeat"></i> Retrain Model</a>
</li>
<li>
<a class="button" data-toggle="modal" data-target="#generic_delete_modal">
<i class="glyphicon glyphicon-trash"></i> Delete Model</a>

View File

@ -32,11 +32,11 @@
<option value="{{ v }}">{{ k }}</option>
{% endfor %}
</select>
<!-- ML Based Form-->
<div id="ml-relative-reasoning-specific-form">
<!-- ML and Rule Based Based Form-->
<div id="package-based-relative-reasoning-specific-form">
<!-- Rule Packages -->
<label for="ml-relative-reasoning-rule-packages">Rule Packages</label>
<select id="ml-relative-reasoning-rule-packages" name="ml-relative-reasoning-rule-packages"
<label for="package-based-relative-reasoning-rule-packages">Rule Packages</label>
<select id="package-based-relative-reasoning-rule-packages" name="package-based-relative-reasoning-rule-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option>
{% for obj in meta.readable_packages %}
@ -53,8 +53,8 @@
{% endfor %}
</select>
<!-- Data Packages -->
<label for="ml-relative-reasoning-data-packages" >Data Packages</label>
<select id="ml-relative-reasoning-data-packages" name="ml-relative-reasoning-data-packages"
<label for="package-based-relative-reasoning-data-packages" >Data Packages</label>
<select id="package-based-relative-reasoning-data-packages" name="package-based-relative-reasoning-data-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option>
{% for obj in meta.readable_packages %}
@ -71,6 +71,7 @@
{% endfor %}
</select>
<div id="ml-relative-reasoning-specific-form">
<!-- Fingerprinter -->
<label for="ml-relative-reasoning-fingerprinter">Fingerprinter</label>
<select id="ml-relative-reasoning-fingerprinter" name="ml-relative-reasoning-fingerprinter"
@ -79,8 +80,10 @@
</select>
{% if meta.enabled_features.PLUGINS and additional_descriptors %}
<!-- Property Plugins go here -->
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter / Descriptors</label>
<select id="ml-relative-reasoning-additional-fingerprinter" name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter /
Descriptors</label>
<select id="ml-relative-reasoning-additional-fingerprinter"
name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
<option disabled selected>Select Additional Fingerprinter / Descriptor</option>
{% for k, v in additional_descriptors.items %}
<option value="{{ v }}">{{ k }}</option>
@ -92,33 +95,16 @@
<input type="number" min="0" max="1" step="0.05" value="0.5"
id="ml-relative-reasoning-threshold"
name="ml-relative-reasoning-threshold" class="form-control">
<!-- Evaluation -->
<label for="ml-relative-reasoning-evaluation-packages">Evaluation Packages</label>
<select id="ml-relative-reasoning-evaluation-packages" name="ml-relative-reasoning-evaluation-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option>
{% for obj in meta.readable_packages %}
{% if obj.reviewed %}
<option value="{{ obj.url }}">{{ obj.name }}</option>
{% endif %}
{% endfor %}
<option disabled>Unreviewed Packages</option>
{% for obj in meta.readable_packages %}
{% if not obj.reviewed %}
<option value="{{ obj.url }}">{{ obj.name }}</option>
{% endif %}
{% endfor %}
</select>
</div>
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
<!-- Build AD? -->
<div class="checkbox">
<label>
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an Applicability Domain?
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an
Applicability Domain?
</label>
</div>
<div id="ad-params" style="display:none">
<!-- Num Neighbors -->
<label for="num-neighbors">Number of Neighbors</label>
<input id="num-neighbors" name="num-neighbors" type="number" class="form-control" value="5"
@ -131,11 +117,8 @@
<label for="reliability-threshold">Reliability Threshold</label>
<input id="reliability-threshold" name="reliability-threshold" type="number"
class="form-control" value="0.5" step="0.01" min="0" max="1">
{% endif %}
</div>
<!-- Rule Based Based Form-->
<div id="rule-based-relative-reasoning-specific-form">
{% endif %}
</div>
<!-- EnviFormer-->
<div id="enviformer-specific-form">
@ -160,20 +143,38 @@ $(function() {
$(this).hide();
});
$("#ml-relative-reasoning-rule-packages").selectpicker();
$("#ml-relative-reasoning-data-packages").selectpicker();
$("#ml-relative-reasoning-evaluation-packages").selectpicker();
$('#model-type').selectpicker();
$("#ml-relative-reasoning-fingerprinter").selectpicker();
$("#package-based-relative-reasoning-rule-packages").selectpicker();
$("#package-based-relative-reasoning-data-packages").selectpicker();
$("#package-based-relative-reasoning-evaluation-packages").selectpicker();
if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) {
$("#ml-relative-reasoning-additional-fingerprinter").selectpicker();
}
$("#build-app-domain").change(function () {
if ($(this).is(":checked")) {
$('#ad-params').show();
} else {
$('#ad-params').hide();
}
});
// On change hide all and show only selected
$("#model-type").change(function() {
$("div[id$='-specific-form']").each( function() {
$(this).hide();
});
val = $('option:selected', this).val();
if (val === 'ml-relative-reasoning' || val === 'rule-based-relative-reasoning') {
$("#package-based-relative-reasoning-specific-form").show();
if (val === 'ml-relative-reasoning') {
$("#ml-relative-reasoning-specific-form").show();
}
} else {
$("#" + val + "-specific-form").show();
}
});
$('#new_model_modal_form_submit').on('click', function(e){

View File

@ -0,0 +1,44 @@
{% load static %}
<!-- Edit Model -->
<div id="edit_model_modal" class="modal" tabindex="-1">
<div class="modal-dialog">
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal" aria-label="Close">
<span aria-hidden="true">&times;</span>
</button>
<h3 class="modal-title">Update Model</h3>
</div>
<div class="modal-body">
<p>Alter Name and Description of the Model.</p>
<form id="edit-model-modal-form" accept-charset="UTF-8" action="" data-remote="true" method="post">
{% csrf_token %}
<p>
<label for="model-name">Name</label>
<input id="model-name" type="text" class="form-control" name="model-name"
value="{{ model.name }}">
</p>
<p>
<label for="model-description">Description</label>
<input id="model-description" type="text" class="form-control" name="model-description"
value="{{ model.description }}">
</p>
</form>
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-dismiss="modal">Close</button>
<button type="button" class="btn btn-primary" id="edit-model-modal-submit">Update</button>
</div>
</div>
</div>
</div>
<script>
$(function () {
$('#edit-model-modal-submit').click(function (e) {
e.preventDefault();
$('#edit-model-modal-form').submit();
});
})
</script>

View File

@ -0,0 +1,62 @@
<div class="modal fade" tabindex="-1" id="evaluate_model_modal" role="dialog" aria-labelledby="evaluate_model_modal"
aria-hidden="true">
<div class="modal-dialog modal-lg">
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal">
<span aria-hidden="true">&times;</span>
<span class="sr-only">Close</span>
</button>
<h4 class="modal-title">Evaluate Model</h4>
</div>
<div class="modal-body">
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
data-remote="true" method="post">
{% csrf_token %}
<div class="jumbotron">
For evaluation, you need to select the packages you want to use.
While the model is evaluating, you can use the model for predictions.
</div>
<!-- Evaluation -->
<label for="relative-reasoning-evaluation-packages">Evaluation Packages</label>
<select id="relative-reasoning-evaluation-packages" name=relative-reasoning-evaluation-packages"
data-actions-box='true' class="form-control" multiple data-width='100%'>
<option disabled>Reviewed Packages</option>
{% for obj in meta.readable_packages %}
{% if obj.reviewed %}
<option value="{{ obj.url }}">{{ obj.name }}</option>
{% endif %}
{% endfor %}
<option disabled>Unreviewed Packages</option>
{% for obj in meta.readable_packages %}
{% if not obj.reviewed %}
<option value="{{ obj.url }}">{{ obj.name }}</option>
{% endif %}
{% endfor %}
</select>
</form>
</div>
<div class="modal-footer">
<a id="evaluate_model_form_submit" class="btn btn-primary" href="#">Evaluate</a>
<button type="button" class="btn btn-default" data-dismiss="modal">Cancel</button>
</div>
</div>
</div>
</div>
<script>
$(function () {
$("#relative-reasoning-evaluation-packages").selectpicker();
$('#evaluate_model_form_submit').on('click', function (e) {
e.preventDefault();
$('#evaluate_model_form').submit();
});
});
</script>

View File

@ -0,0 +1,43 @@
<div class="modal fade" tabindex="-1" id="retrain_model_modal" role="dialog" aria-labelledby="retrain_model_modal"
aria-hidden="true">
<div class="modal-dialog modal-lg">
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal">
<span aria-hidden="true">&times;</span>
<span class="sr-only">Close</span>
</button>
<h4 class="modal-title">Retrain Model</h4>
</div>
<div class="modal-body">
<form id="retrain_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
data-remote="true" method="post">
<div class="jumbotron">
To reflect changes in the rule or data packages, you can use the "Retrain" button,
to let the model reflect the changes without creating a new model.
While the model is retraining, it will be unavailable for prediction.
</div>
{% csrf_token %}
<input type="hidden" name="action" value="retrain">
</form>
</div>
<div class="modal-footer">
<a id="retrain_model_form_submit" class="btn btn-primary" href="#">Retrain</a>
<button type="button" class="btn btn-default" data-dismiss="modal">Cancel</button>
</div>
</div>
</div>
</div>
<script>
$(function () {
$('#retrain_model_form_submit').on('click', function (e) {
e.preventDefault();
$('#retrain_model_form').submit();
});
});
</script>

View File

@ -4,6 +4,9 @@
{% block content %}
{% block action_modals %}
{% include "modals/objects/edit_model_modal.html" %}
{% include "modals/objects/evaluate_model_modal.html" %}
{% include "modals/objects/retrain_model_modal.html" %}
{% include "modals/objects/generic_delete_modal.html" %}
{% endblock action_modals %}
@ -32,7 +35,7 @@
<div class="panel-body">
<p> {{ model.description }} </p>
</div>
{% if model|classname == 'MLRelativeReasoning' %}
{% if model|classname == 'MLRelativeReasoning' or model|classname == 'RuleBasedRelativeReasoning'%}
<!-- Rule Packages -->
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
<h4 class="panel-title">

View File

@ -289,6 +289,12 @@ class Dataset:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
@ -324,7 +330,7 @@ class Dataset:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path'):
def load(path: 'Path') -> 'Dataset':
import pickle
return pickle.load(open(path, "rb"))
@ -553,6 +559,68 @@ class EnsembleClassifierChain:
return labels / self.num_chains
class RelativeReasoning:
def __init__(self, start_index: int, end_index: int):
self.start_index: int = start_index
self.end_index: int = end_index
self.winmap: Dict[int, List[int]] = defaultdict(list)
self.min_count: int = 5
self.max_count: int = 0
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = len(Y[0])
for i in range(n_attributes):
for j in range(n_attributes):
if i == j:
continue
countwin = 0
countloose = 0
countboth = 0
for k in range(n_instances):
vi = Y[k][i]
vj = Y[k][j]
if vi is None or vj is None:
continue
if vi < vj:
countwin += 1
elif vi > vj:
countloose += 1
elif vi == vj and vi == 1: # tie
countboth += 1
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count and
countwin > countloose and
(
countloose <= self.max_count or
self.max_count < 0
) and
countboth == 0
):
self.winmap[i].append(j)
def predict(self, X):
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
for inst_idx, inst in enumerate(X):
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
res[inst_idx][i] = t
if t:
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
res[inst_idx][i] = 0
return res
def predict_proba(self, X):
return self.predict(X)
class ApplicabilityDomainPCA(PCA):