forked from enviPath/enviPy
[Feature] Rule Based Model (#92)
Fixes #89 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#92
This commit is contained in:
282
epdb/models.py
282
epdb/models.py
@ -4,6 +4,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
from typing import Union, List, Optional, Dict, Tuple, Set
|
||||||
@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|||||||
from sklearn.model_selection import ShuffleSplit
|
from sklearn.model_selection import ShuffleSplit
|
||||||
|
|
||||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
|
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def root_nodes(self):
|
def root_nodes(self):
|
||||||
return Node.objects.filter(pathway=self, depth=0)
|
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
|
||||||
|
# potentially prefetched node_set
|
||||||
|
return self.node_set.all().filter(pathway=self, depth=0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self):
|
def nodes(self):
|
||||||
return Node.objects.filter(pathway=self)
|
# same as Node.objects.filter(pathway=self) but will utilize
|
||||||
|
# potentially prefetched node_set
|
||||||
|
return self.node_set.all()
|
||||||
|
|
||||||
def get_node(self, node_url):
|
def get_node(self, node_url):
|
||||||
for n in self.nodes:
|
for n in self.nodes:
|
||||||
@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def edges(self):
|
def edges(self):
|
||||||
return Edge.objects.filter(pathway=self)
|
# same as Edge.objects.filter(pathway=self) but will utilize
|
||||||
|
# potentially prefetched edge_set
|
||||||
|
return self.edge_set.all()
|
||||||
|
|
||||||
def _url(self):
|
def _url(self):
|
||||||
return '{}/pathway/{}'.format(self.package.url, self.uuid)
|
return '{}/pathway/{}'.format(self.package.url, self.uuid)
|
||||||
@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel):
|
|||||||
return '{}/model/{}'.format(self.package.url, self.uuid)
|
return '{}/model/{}'.format(self.package.url, self.uuid)
|
||||||
|
|
||||||
|
|
||||||
class MLRelativeReasoning(EPModel):
|
class PackageBasedModel(EPModel):
|
||||||
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages")
|
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages",
|
||||||
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages")
|
related_name="%(app_label)s_%(class)s_rule_packages")
|
||||||
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages")
|
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages",
|
||||||
|
related_name="%(app_label)s_%(class)s_data_packages")
|
||||||
|
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages",
|
||||||
|
related_name="%(app_label)s_%(class)s_eval_packages")
|
||||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||||
|
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||||
|
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||||
|
default=None)
|
||||||
|
|
||||||
INITIAL = "INITIAL"
|
INITIAL = "INITIAL"
|
||||||
INITIALIZING = "INITIALIZING"
|
INITIALIZING = "INITIALIZING"
|
||||||
@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel):
|
|||||||
}
|
}
|
||||||
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
|
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
|
||||||
|
|
||||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
|
||||||
|
|
||||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
|
||||||
default=None)
|
|
||||||
|
|
||||||
def status(self):
|
def status(self):
|
||||||
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
||||||
|
|
||||||
def ready_for_prediction(self) -> bool:
|
def ready_for_prediction(self) -> bool:
|
||||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@transaction.atomic
|
|
||||||
def create(package: 'Package', rule_packages: List['Package'],
|
|
||||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
|
||||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
|
||||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
|
||||||
app_domain_local_compatibility_threshold: float = None):
|
|
||||||
|
|
||||||
mlrr = MLRelativeReasoning()
|
|
||||||
mlrr.package = package
|
|
||||||
|
|
||||||
if name is None or name.strip() == '':
|
|
||||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
|
||||||
|
|
||||||
mlrr.name = name
|
|
||||||
|
|
||||||
if description is not None and description.strip() != '':
|
|
||||||
mlrr.description = description
|
|
||||||
|
|
||||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
||||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
||||||
|
|
||||||
mlrr.threshold = threshold
|
|
||||||
|
|
||||||
if len(rule_packages) == 0:
|
|
||||||
raise ValueError("At least one rule package must be provided.")
|
|
||||||
|
|
||||||
mlrr.save()
|
|
||||||
|
|
||||||
for p in rule_packages:
|
|
||||||
mlrr.rule_packages.add(p)
|
|
||||||
|
|
||||||
if data_packages:
|
|
||||||
for p in data_packages:
|
|
||||||
mlrr.data_packages.add(p)
|
|
||||||
else:
|
|
||||||
for p in rule_packages:
|
|
||||||
mlrr.data_packages.add(p)
|
|
||||||
|
|
||||||
if eval_packages:
|
|
||||||
for p in eval_packages:
|
|
||||||
mlrr.eval_packages.add(p)
|
|
||||||
|
|
||||||
if build_app_domain:
|
|
||||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
|
||||||
app_domain_local_compatibility_threshold)
|
|
||||||
mlrr.app_domain = ad
|
|
||||||
|
|
||||||
mlrr.save()
|
|
||||||
|
|
||||||
return mlrr
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def applicable_rules(self) -> List['Rule']:
|
def applicable_rules(self) -> List['Rule']:
|
||||||
"""
|
"""
|
||||||
@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel):
|
|||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||||
return Dataset.load(ds_path)
|
return Dataset.load(ds_path)
|
||||||
|
|
||||||
|
def retrain(self):
|
||||||
|
self.build_dataset()
|
||||||
|
self.build_model()
|
||||||
|
|
||||||
|
def rebuild(self):
|
||||||
|
self.build_model()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_model(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||||
|
res = []
|
||||||
|
for rule, p, smis in zip(rules, probabilities, products):
|
||||||
|
res.append(PredictionResult(smis, p, rule))
|
||||||
|
return res
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
|
||||||
|
class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||||
|
min_count = models.IntegerField(null=False, blank=False, default=10)
|
||||||
|
max_count = models.IntegerField(null=False, blank=False, default=0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@transaction.atomic
|
||||||
|
def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'],
|
||||||
|
eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0,
|
||||||
|
name: 'str' = None, description: str = None):
|
||||||
|
|
||||||
|
rbrr = RuleBasedRelativeReasoning()
|
||||||
|
rbrr.package = package
|
||||||
|
|
||||||
|
if name is None or name.strip() == '':
|
||||||
|
name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||||
|
|
||||||
|
rbrr.name = name
|
||||||
|
|
||||||
|
if description is not None and description.strip() != '':
|
||||||
|
rbrr.description = description
|
||||||
|
|
||||||
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||||
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||||
|
|
||||||
|
rbrr.threshold = threshold
|
||||||
|
|
||||||
|
if min_count is None or min_count < 1:
|
||||||
|
raise ValueError("Minimum count must be an int greater than equal 1.")
|
||||||
|
|
||||||
|
rbrr.min_count = min_count
|
||||||
|
|
||||||
|
if max_count is None or max_count > min_count:
|
||||||
|
raise ValueError("Maximum count must be an int and must not be less than min_count.")
|
||||||
|
|
||||||
|
if max_count is None:
|
||||||
|
raise ValueError("Maximum count must be at least 0.")
|
||||||
|
|
||||||
|
if len(rule_packages) == 0:
|
||||||
|
raise ValueError("At least one rule package must be provided.")
|
||||||
|
|
||||||
|
rbrr.save()
|
||||||
|
|
||||||
|
for p in rule_packages:
|
||||||
|
rbrr.rule_packages.add(p)
|
||||||
|
|
||||||
|
if data_packages:
|
||||||
|
for p in data_packages:
|
||||||
|
rbrr.data_packages.add(p)
|
||||||
|
else:
|
||||||
|
for p in rule_packages:
|
||||||
|
rbrr.data_packages.add(p)
|
||||||
|
|
||||||
|
if eval_packages:
|
||||||
|
for p in eval_packages:
|
||||||
|
rbrr.eval_packages.add(p)
|
||||||
|
|
||||||
|
rbrr.save()
|
||||||
|
|
||||||
|
return rbrr
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
self.model_status = self.BUILDING
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
ds = self.load_dataset()
|
||||||
|
labels = ds.y(na_replacement=None)
|
||||||
|
|
||||||
|
mod = RelativeReasoning(*ds.triggered())
|
||||||
|
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
|
||||||
|
|
||||||
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||||
|
joblib.dump(mod, f)
|
||||||
|
|
||||||
|
self.model_status = self.BUILT_NOT_EVALUATED
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self) -> 'RelativeReasoning':
|
||||||
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||||
|
return mod
|
||||||
|
|
||||||
|
def predict(self, smiles) -> List['PredictionResult']:
|
||||||
|
start = datetime.now()
|
||||||
|
ds = self.load_dataset()
|
||||||
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||||
|
|
||||||
|
mod = self.model
|
||||||
|
|
||||||
|
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
|
|
||||||
|
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||||
|
|
||||||
|
end = datetime.now()
|
||||||
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class MLRelativeReasoning(PackageBasedModel):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@transaction.atomic
|
||||||
|
def create(package: 'Package', rule_packages: List['Package'],
|
||||||
|
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||||
|
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||||
|
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||||
|
app_domain_local_compatibility_threshold: float = None):
|
||||||
|
|
||||||
|
mlrr = MLRelativeReasoning()
|
||||||
|
mlrr.package = package
|
||||||
|
|
||||||
|
if name is None or name.strip() == '':
|
||||||
|
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||||
|
|
||||||
|
mlrr.name = name
|
||||||
|
|
||||||
|
if description is not None and description.strip() != '':
|
||||||
|
mlrr.description = description
|
||||||
|
|
||||||
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||||
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||||
|
|
||||||
|
mlrr.threshold = threshold
|
||||||
|
|
||||||
|
if len(rule_packages) == 0:
|
||||||
|
raise ValueError("At least one rule package must be provided.")
|
||||||
|
|
||||||
|
mlrr.save()
|
||||||
|
|
||||||
|
for p in rule_packages:
|
||||||
|
mlrr.rule_packages.add(p)
|
||||||
|
|
||||||
|
if data_packages:
|
||||||
|
for p in data_packages:
|
||||||
|
mlrr.data_packages.add(p)
|
||||||
|
else:
|
||||||
|
for p in rule_packages:
|
||||||
|
mlrr.data_packages.add(p)
|
||||||
|
|
||||||
|
if eval_packages:
|
||||||
|
for p in eval_packages:
|
||||||
|
mlrr.eval_packages.add(p)
|
||||||
|
|
||||||
|
if build_app_domain:
|
||||||
|
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||||
|
app_domain_local_compatibility_threshold)
|
||||||
|
mlrr.app_domain = ad
|
||||||
|
|
||||||
|
mlrr.save()
|
||||||
|
|
||||||
|
return mlrr
|
||||||
|
|
||||||
def build_model(self):
|
def build_model(self):
|
||||||
self.model_status = self.BUILDING
|
self.model_status = self.BUILDING
|
||||||
self.save()
|
self.save()
|
||||||
@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel):
|
|||||||
self.model_status = self.BUILT_NOT_EVALUATED
|
self.model_status = self.BUILT_NOT_EVALUATED
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def retrain(self):
|
|
||||||
self.build_dataset()
|
|
||||||
self.build_model()
|
|
||||||
|
|
||||||
def rebuild(self):
|
|
||||||
self.build_model()
|
|
||||||
|
|
||||||
def evaluate_model(self):
|
def evaluate_model(self):
|
||||||
|
|
||||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||||
@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel):
|
|||||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
|
||||||
res = []
|
|
||||||
for rule, p, smis in zip(rules, probabilities, products):
|
|
||||||
res.append(PredictionResult(smis, p, rule))
|
|
||||||
return res
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pr_curve(self):
|
def pr_curve(self):
|
||||||
if self.model_status != self.FINISHED:
|
if self.model_status != self.FINISHED:
|
||||||
@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
class RuleBaseRelativeReasoning(EPModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EnviFormer(EPModel):
|
class EnviFormer(EPModel):
|
||||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||||
@ -2406,6 +2518,12 @@ class EnviFormer(EPModel):
|
|||||||
def applicable_rules(self):
|
def applicable_rules(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def status(self):
|
||||||
|
return "Model is built and can be used for predictions, Model is not evaluated yet."
|
||||||
|
|
||||||
|
def ready_for_prediction(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class PluginModel(EPModel):
|
class PluginModel(EPModel):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -41,6 +41,12 @@ def evaluate_model(model_pk: int):
|
|||||||
mod.evaluate_model()
|
mod.evaluate_model()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue='model')
|
||||||
|
def retrain(model_pk: int):
|
||||||
|
mod = EPModel.objects.get(id=model_pk)
|
||||||
|
mod.retrain()
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue='predict')
|
@shared_task(queue='predict')
|
||||||
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
|
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
|
||||||
pw = Pathway.objects.get(id=pw_pk)
|
pw = Pathway.objects.get(id=pw_pk)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required
|
|||||||
from utilities.misc import HTMLGenerator
|
from utilities.misc import HTMLGenerator
|
||||||
from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser
|
from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser
|
||||||
from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \
|
from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \
|
||||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||||
UserPackagePermission, Permission, License, User, Edge
|
UserPackagePermission, Permission, License, User, Edge
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -651,46 +651,50 @@ def package_models(request, package_uuid):
|
|||||||
|
|
||||||
mod = EnviFormer.create(current_package, name, description, threshold)
|
mod = EnviFormer.create(current_package, name, description, threshold)
|
||||||
|
|
||||||
elif model_type == 'ml-relative-reasoning':
|
elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning':
|
||||||
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
# Generic fields for ML and Rule Based
|
||||||
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages')
|
||||||
rule_packages = request.POST.getlist(f'{model_type}-rule-packages')
|
data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages')
|
||||||
data_packages = request.POST.getlist(f'{model_type}-data-packages')
|
eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', [])
|
||||||
eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', [])
|
|
||||||
|
|
||||||
# get Package objects from urls
|
# Generic params
|
||||||
rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages]
|
params = {
|
||||||
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
|
'package' : current_package,
|
||||||
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
|
'name' : name,
|
||||||
|
'description' : description,
|
||||||
|
'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages],
|
||||||
|
'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages],
|
||||||
|
'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages],
|
||||||
|
}
|
||||||
|
|
||||||
# App Domain related parameters
|
if model_type == 'ml-relative-reasoning':
|
||||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
# ML Specific
|
||||||
num_neighbors = request.POST.get('num-neighbors', 5)
|
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
||||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
||||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
|
||||||
|
|
||||||
mod = MLRelativeReasoning.create(
|
# App Domain related parameters
|
||||||
package=current_package,
|
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||||
name=name,
|
num_neighbors = request.POST.get('num-neighbors', 5)
|
||||||
description=description,
|
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||||
rule_packages=rule_package_objs,
|
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||||
data_packages=data_package_objs,
|
|
||||||
eval_packages=eval_packages_objs,
|
params['threshold'] = threshold
|
||||||
threshold=threshold,
|
# params['fingerprinter'] = fingerprinter
|
||||||
# fingerprinter=fingerprinter,
|
params['build_app_domain'] = build_ad
|
||||||
build_app_domain=build_ad,
|
params['app_domain_num_neighbours'] = num_neighbors
|
||||||
app_domain_num_neighbours=num_neighbors,
|
params['app_domain_reliability_threshold'] = reliability_threshold
|
||||||
app_domain_reliability_threshold=reliability_threshold,
|
params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold
|
||||||
app_domain_local_compatibility_threshold=local_compatibility_threshold,
|
|
||||||
)
|
mod = MLRelativeReasoning.create(
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mod = RuleBasedRelativeReasoning.create(
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
from .tasks import build_model
|
from .tasks import build_model
|
||||||
build_model.delay(mod.pk)
|
build_model.delay(mod.pk)
|
||||||
|
|
||||||
elif model_type == 'rule-base-relative-reasoning':
|
|
||||||
mod = RuleBaseRelativeReasoning()
|
|
||||||
|
|
||||||
mod.save()
|
|
||||||
else:
|
else:
|
||||||
return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."')
|
return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."')
|
||||||
return redirect(mod.url)
|
return redirect(mod.url)
|
||||||
@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid):
|
|||||||
else:
|
else:
|
||||||
return HttpResponseBadRequest()
|
return HttpResponseBadRequest()
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
name = request.POST.get('model-name', '').strip()
|
||||||
|
description = request.POST.get('model-description', '').strip()
|
||||||
|
|
||||||
|
if any([name, description]):
|
||||||
|
if name:
|
||||||
|
current_model.name = name
|
||||||
|
|
||||||
|
if description:
|
||||||
|
current_model.description = description
|
||||||
|
|
||||||
|
current_model.save()
|
||||||
|
return redirect(current_model.url)
|
||||||
|
|
||||||
return HttpResponseBadRequest()
|
return HttpResponseBadRequest()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,4 +1,16 @@
|
|||||||
{% if meta.can_edit %}
|
{% if meta.can_edit %}
|
||||||
|
<li>
|
||||||
|
<a role="button" data-toggle="modal" data-target="#edit_model_modal">
|
||||||
|
<i class="glyphicon glyphicon-edit"></i> Edit Model</a>
|
||||||
|
</li>
|
||||||
|
<li>
|
||||||
|
<a role="button" data-toggle="modal" data-target="#evaluate_model_modal">
|
||||||
|
<i class="glyphicon glyphicon-ok"></i> Evaluate Model</a>
|
||||||
|
</li>
|
||||||
|
<li>
|
||||||
|
<a role="button" data-toggle="modal" data-target="#retrain_model_modal">
|
||||||
|
<i class="glyphicon glyphicon-repeat"></i> Retrain Model</a>
|
||||||
|
</li>
|
||||||
<li>
|
<li>
|
||||||
<a class="button" data-toggle="modal" data-target="#generic_delete_modal">
|
<a class="button" data-toggle="modal" data-target="#generic_delete_modal">
|
||||||
<i class="glyphicon glyphicon-trash"></i> Delete Model</a>
|
<i class="glyphicon glyphicon-trash"></i> Delete Model</a>
|
||||||
|
|||||||
@ -32,11 +32,11 @@
|
|||||||
<option value="{{ v }}">{{ k }}</option>
|
<option value="{{ v }}">{{ k }}</option>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</select>
|
</select>
|
||||||
<!-- ML Based Form-->
|
<!-- ML and Rule Based Based Form-->
|
||||||
<div id="ml-relative-reasoning-specific-form">
|
<div id="package-based-relative-reasoning-specific-form">
|
||||||
<!-- Rule Packages -->
|
<!-- Rule Packages -->
|
||||||
<label for="ml-relative-reasoning-rule-packages">Rule Packages</label>
|
<label for="package-based-relative-reasoning-rule-packages">Rule Packages</label>
|
||||||
<select id="ml-relative-reasoning-rule-packages" name="ml-relative-reasoning-rule-packages"
|
<select id="package-based-relative-reasoning-rule-packages" name="package-based-relative-reasoning-rule-packages"
|
||||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||||
<option disabled>Reviewed Packages</option>
|
<option disabled>Reviewed Packages</option>
|
||||||
{% for obj in meta.readable_packages %}
|
{% for obj in meta.readable_packages %}
|
||||||
@ -53,8 +53,8 @@
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
</select>
|
</select>
|
||||||
<!-- Data Packages -->
|
<!-- Data Packages -->
|
||||||
<label for="ml-relative-reasoning-data-packages" >Data Packages</label>
|
<label for="package-based-relative-reasoning-data-packages" >Data Packages</label>
|
||||||
<select id="ml-relative-reasoning-data-packages" name="ml-relative-reasoning-data-packages"
|
<select id="package-based-relative-reasoning-data-packages" name="package-based-relative-reasoning-data-packages"
|
||||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||||
<option disabled>Reviewed Packages</option>
|
<option disabled>Reviewed Packages</option>
|
||||||
{% for obj in meta.readable_packages %}
|
{% for obj in meta.readable_packages %}
|
||||||
@ -71,71 +71,54 @@
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
</select>
|
</select>
|
||||||
|
|
||||||
<!-- Fingerprinter -->
|
<div id="ml-relative-reasoning-specific-form">
|
||||||
<label for="ml-relative-reasoning-fingerprinter">Fingerprinter</label>
|
<!-- Fingerprinter -->
|
||||||
<select id="ml-relative-reasoning-fingerprinter" name="ml-relative-reasoning-fingerprinter"
|
<label for="ml-relative-reasoning-fingerprinter">Fingerprinter</label>
|
||||||
class="form-control">
|
<select id="ml-relative-reasoning-fingerprinter" name="ml-relative-reasoning-fingerprinter"
|
||||||
<option value="MACCS" selected>MACCS Fingerprinter</option>
|
class="form-control">
|
||||||
</select>
|
<option value="MACCS" selected>MACCS Fingerprinter</option>
|
||||||
{% if meta.enabled_features.PLUGINS and additional_descriptors %}
|
</select>
|
||||||
<!-- Property Plugins go here -->
|
{% if meta.enabled_features.PLUGINS and additional_descriptors %}
|
||||||
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter / Descriptors</label>
|
<!-- Property Plugins go here -->
|
||||||
<select id="ml-relative-reasoning-additional-fingerprinter" name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
|
<label for="ml-relative-reasoning-additional-fingerprinter">Additional Fingerprinter /
|
||||||
<option disabled selected>Select Additional Fingerprinter / Descriptor</option>
|
Descriptors</label>
|
||||||
{% for k, v in additional_descriptors.items %}
|
<select id="ml-relative-reasoning-additional-fingerprinter"
|
||||||
<option value="{{ v }}">{{ k }}</option>
|
name="ml-relative-reasoning-additional-fingerprinter" class="form-control">
|
||||||
{% endfor %}
|
<option disabled selected>Select Additional Fingerprinter / Descriptor</option>
|
||||||
</select>
|
{% for k, v in additional_descriptors.items %}
|
||||||
{% endif %}
|
<option value="{{ v }}">{{ k }}</option>
|
||||||
|
{% endfor %}
|
||||||
<label for="ml-relative-reasoning-threshold">Threshold</label>
|
</select>
|
||||||
<input type="number" min="0" max="1" step="0.05" value="0.5"
|
|
||||||
id="ml-relative-reasoning-threshold"
|
|
||||||
name="ml-relative-reasoning-threshold" class="form-control">
|
|
||||||
|
|
||||||
<!-- Evaluation -->
|
|
||||||
<label for="ml-relative-reasoning-evaluation-packages">Evaluation Packages</label>
|
|
||||||
<select id="ml-relative-reasoning-evaluation-packages" name="ml-relative-reasoning-evaluation-packages"
|
|
||||||
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
|
||||||
<option disabled>Reviewed Packages</option>
|
|
||||||
{% for obj in meta.readable_packages %}
|
|
||||||
{% if obj.reviewed %}
|
|
||||||
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
<option disabled>Unreviewed Packages</option>
|
|
||||||
{% for obj in meta.readable_packages %}
|
|
||||||
{% if not obj.reviewed %}
|
|
||||||
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
|
||||||
{% endif %}
|
|
||||||
{% endfor %}
|
|
||||||
</select>
|
|
||||||
|
|
||||||
|
<label for="ml-relative-reasoning-threshold">Threshold</label>
|
||||||
|
<input type="number" min="0" max="1" step="0.05" value="0.5"
|
||||||
|
id="ml-relative-reasoning-threshold"
|
||||||
|
name="ml-relative-reasoning-threshold" class="form-control">
|
||||||
|
</div>
|
||||||
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
|
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
|
||||||
<!-- Build AD? -->
|
<!-- Build AD? -->
|
||||||
<div class="checkbox">
|
<div class="checkbox">
|
||||||
<label>
|
<label>
|
||||||
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an Applicability Domain?
|
<input type="checkbox" id="build-app-domain" name="build-app-domain">Also build an
|
||||||
|
Applicability Domain?
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
<!-- Num Neighbors -->
|
<div id="ad-params" style="display:none">
|
||||||
<label for="num-neighbors">Number of Neighbors</label>
|
<!-- Num Neighbors -->
|
||||||
<input id="num-neighbors" name="num-neighbors" type="number" class="form-control" value="5"
|
<label for="num-neighbors">Number of Neighbors</label>
|
||||||
step="1" min="0" max="10">
|
<input id="num-neighbors" name="num-neighbors" type="number" class="form-control" value="5"
|
||||||
<!-- Local Compatibility -->
|
step="1" min="0" max="10">
|
||||||
<label for="local-compatibility-threshold">Local Compatibility Threshold</label>
|
<!-- Local Compatibility -->
|
||||||
<input id="local-compatibility-threshold" name="local-compatibility-threshold" type="number"
|
<label for="local-compatibility-threshold">Local Compatibility Threshold</label>
|
||||||
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
<input id="local-compatibility-threshold" name="local-compatibility-threshold" type="number"
|
||||||
<!-- Reliability -->
|
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
||||||
<label for="reliability-threshold">Reliability Threshold</label>
|
<!-- Reliability -->
|
||||||
<input id="reliability-threshold" name="reliability-threshold" type="number"
|
<label for="reliability-threshold">Reliability Threshold</label>
|
||||||
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
<input id="reliability-threshold" name="reliability-threshold" type="number"
|
||||||
|
class="form-control" value="0.5" step="0.01" min="0" max="1">
|
||||||
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</div>
|
|
||||||
<!-- Rule Based Based Form-->
|
|
||||||
<div id="rule-based-relative-reasoning-specific-form">
|
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
<!-- EnviFormer-->
|
<!-- EnviFormer-->
|
||||||
<div id="enviformer-specific-form">
|
<div id="enviformer-specific-form">
|
||||||
@ -160,20 +143,38 @@ $(function() {
|
|||||||
$(this).hide();
|
$(this).hide();
|
||||||
});
|
});
|
||||||
|
|
||||||
$("#ml-relative-reasoning-rule-packages").selectpicker();
|
$('#model-type').selectpicker();
|
||||||
$("#ml-relative-reasoning-data-packages").selectpicker();
|
$("#ml-relative-reasoning-fingerprinter").selectpicker();
|
||||||
$("#ml-relative-reasoning-evaluation-packages").selectpicker();
|
$("#package-based-relative-reasoning-rule-packages").selectpicker();
|
||||||
|
$("#package-based-relative-reasoning-data-packages").selectpicker();
|
||||||
|
$("#package-based-relative-reasoning-evaluation-packages").selectpicker();
|
||||||
if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) {
|
if ($('#ml-relative-reasoning-additional-fingerprinter').length > 0) {
|
||||||
$("#ml-relative-reasoning-additional-fingerprinter").selectpicker();
|
$("#ml-relative-reasoning-additional-fingerprinter").selectpicker();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
$("#build-app-domain").change(function () {
|
||||||
|
if ($(this).is(":checked")) {
|
||||||
|
$('#ad-params').show();
|
||||||
|
} else {
|
||||||
|
$('#ad-params').hide();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// On change hide all and show only selected
|
// On change hide all and show only selected
|
||||||
$("#model-type").change(function() {
|
$("#model-type").change(function() {
|
||||||
$("div[id$='-specific-form']").each( function() {
|
$("div[id$='-specific-form']").each( function() {
|
||||||
$(this).hide();
|
$(this).hide();
|
||||||
});
|
});
|
||||||
val = $('option:selected', this).val();
|
val = $('option:selected', this).val();
|
||||||
$("#" + val + "-specific-form").show();
|
|
||||||
|
if (val === 'ml-relative-reasoning' || val === 'rule-based-relative-reasoning') {
|
||||||
|
$("#package-based-relative-reasoning-specific-form").show();
|
||||||
|
if (val === 'ml-relative-reasoning') {
|
||||||
|
$("#ml-relative-reasoning-specific-form").show();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
$("#" + val + "-specific-form").show();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
$('#new_model_modal_form_submit').on('click', function(e){
|
$('#new_model_modal_form_submit').on('click', function(e){
|
||||||
|
|||||||
44
templates/modals/objects/edit_model_modal.html
Normal file
44
templates/modals/objects/edit_model_modal.html
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
{% load static %}
|
||||||
|
<!-- Edit Model -->
|
||||||
|
<div id="edit_model_modal" class="modal" tabindex="-1">
|
||||||
|
<div class="modal-dialog">
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<button type="button" class="close" data-dismiss="modal" aria-label="Close">
|
||||||
|
<span aria-hidden="true">×</span>
|
||||||
|
</button>
|
||||||
|
<h3 class="modal-title">Update Model</h3>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
<p>Alter Name and Description of the Model.</p>
|
||||||
|
<form id="edit-model-modal-form" accept-charset="UTF-8" action="" data-remote="true" method="post">
|
||||||
|
{% csrf_token %}
|
||||||
|
<p>
|
||||||
|
<label for="model-name">Name</label>
|
||||||
|
<input id="model-name" type="text" class="form-control" name="model-name"
|
||||||
|
value="{{ model.name }}">
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<label for="model-description">Description</label>
|
||||||
|
<input id="model-description" type="text" class="form-control" name="model-description"
|
||||||
|
value="{{ model.description }}">
|
||||||
|
</p>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
<div class="modal-footer">
|
||||||
|
<button type="button" class="btn btn-secondary" data-dismiss="modal">Close</button>
|
||||||
|
<button type="button" class="btn btn-primary" id="edit-model-modal-submit">Update</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script>
|
||||||
|
$(function () {
|
||||||
|
|
||||||
|
$('#edit-model-modal-submit').click(function (e) {
|
||||||
|
e.preventDefault();
|
||||||
|
$('#edit-model-modal-form').submit();
|
||||||
|
});
|
||||||
|
|
||||||
|
})
|
||||||
|
</script>
|
||||||
62
templates/modals/objects/evaluate_model_modal.html
Normal file
62
templates/modals/objects/evaluate_model_modal.html
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
<div class="modal fade" tabindex="-1" id="evaluate_model_modal" role="dialog" aria-labelledby="evaluate_model_modal"
|
||||||
|
aria-hidden="true">
|
||||||
|
<div class="modal-dialog modal-lg">
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<button type="button" class="close" data-dismiss="modal">
|
||||||
|
<span aria-hidden="true">×</span>
|
||||||
|
<span class="sr-only">Close</span>
|
||||||
|
</button>
|
||||||
|
<h4 class="modal-title">Evaluate Model</h4>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
|
||||||
|
data-remote="true" method="post">
|
||||||
|
{% csrf_token %}
|
||||||
|
<div class="jumbotron">
|
||||||
|
For evaluation, you need to select the packages you want to use.
|
||||||
|
While the model is evaluating, you can use the model for predictions.
|
||||||
|
</div>
|
||||||
|
<!-- Evaluation -->
|
||||||
|
<label for="relative-reasoning-evaluation-packages">Evaluation Packages</label>
|
||||||
|
<select id="relative-reasoning-evaluation-packages" name=relative-reasoning-evaluation-packages"
|
||||||
|
data-actions-box='true' class="form-control" multiple data-width='100%'>
|
||||||
|
<option disabled>Reviewed Packages</option>
|
||||||
|
{% for obj in meta.readable_packages %}
|
||||||
|
{% if obj.reviewed %}
|
||||||
|
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
<option disabled>Unreviewed Packages</option>
|
||||||
|
{% for obj in meta.readable_packages %}
|
||||||
|
{% if not obj.reviewed %}
|
||||||
|
<option value="{{ obj.url }}">{{ obj.name }}</option>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
<div class="modal-footer">
|
||||||
|
<a id="evaluate_model_form_submit" class="btn btn-primary" href="#">Evaluate</a>
|
||||||
|
<button type="button" class="btn btn-default" data-dismiss="modal">Cancel</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
|
||||||
|
$(function () {
|
||||||
|
|
||||||
|
$("#relative-reasoning-evaluation-packages").selectpicker();
|
||||||
|
|
||||||
|
$('#evaluate_model_form_submit').on('click', function (e) {
|
||||||
|
e.preventDefault();
|
||||||
|
$('#evaluate_model_form').submit();
|
||||||
|
});
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
</script>
|
||||||
43
templates/modals/objects/retrain_model_modal.html
Normal file
43
templates/modals/objects/retrain_model_modal.html
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
<div class="modal fade" tabindex="-1" id="retrain_model_modal" role="dialog" aria-labelledby="retrain_model_modal"
|
||||||
|
aria-hidden="true">
|
||||||
|
<div class="modal-dialog modal-lg">
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<button type="button" class="close" data-dismiss="modal">
|
||||||
|
<span aria-hidden="true">×</span>
|
||||||
|
<span class="sr-only">Close</span>
|
||||||
|
</button>
|
||||||
|
<h4 class="modal-title">Retrain Model</h4>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
<form id="retrain_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
|
||||||
|
data-remote="true" method="post">
|
||||||
|
<div class="jumbotron">
|
||||||
|
To reflect changes in the rule or data packages, you can use the "Retrain" button,
|
||||||
|
to let the model reflect the changes without creating a new model.
|
||||||
|
While the model is retraining, it will be unavailable for prediction.
|
||||||
|
</div>
|
||||||
|
{% csrf_token %}
|
||||||
|
<input type="hidden" name="action" value="retrain">
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
<div class="modal-footer">
|
||||||
|
<a id="retrain_model_form_submit" class="btn btn-primary" href="#">Retrain</a>
|
||||||
|
<button type="button" class="btn btn-default" data-dismiss="modal">Cancel</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
|
||||||
|
$(function () {
|
||||||
|
|
||||||
|
$('#retrain_model_form_submit').on('click', function (e) {
|
||||||
|
e.preventDefault();
|
||||||
|
$('#retrain_model_form').submit();
|
||||||
|
});
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
</script>
|
||||||
@ -4,6 +4,9 @@
|
|||||||
{% block content %}
|
{% block content %}
|
||||||
|
|
||||||
{% block action_modals %}
|
{% block action_modals %}
|
||||||
|
{% include "modals/objects/edit_model_modal.html" %}
|
||||||
|
{% include "modals/objects/evaluate_model_modal.html" %}
|
||||||
|
{% include "modals/objects/retrain_model_modal.html" %}
|
||||||
{% include "modals/objects/generic_delete_modal.html" %}
|
{% include "modals/objects/generic_delete_modal.html" %}
|
||||||
{% endblock action_modals %}
|
{% endblock action_modals %}
|
||||||
|
|
||||||
@ -32,7 +35,7 @@
|
|||||||
<div class="panel-body">
|
<div class="panel-body">
|
||||||
<p> {{ model.description }} </p>
|
<p> {{ model.description }} </p>
|
||||||
</div>
|
</div>
|
||||||
{% if model|classname == 'MLRelativeReasoning' %}
|
{% if model|classname == 'MLRelativeReasoning' or model|classname == 'RuleBasedRelativeReasoning'%}
|
||||||
<!-- Rule Packages -->
|
<!-- Rule Packages -->
|
||||||
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
<div class="panel panel-default panel-heading list-group-item" style="background-color:silver">
|
||||||
<h4 class="panel-title">
|
<h4 class="panel-title">
|
||||||
|
|||||||
@ -289,6 +289,12 @@ class Dataset:
|
|||||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def trig(self, na_replacement=0):
|
||||||
|
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
|
||||||
|
if na_replacement is not None:
|
||||||
|
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def y(self, na_replacement=0):
|
def y(self, na_replacement=0):
|
||||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||||
@ -324,7 +330,7 @@ class Dataset:
|
|||||||
pickle.dump(self, fh)
|
pickle.dump(self, fh)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(path: 'Path'):
|
def load(path: 'Path') -> 'Dataset':
|
||||||
import pickle
|
import pickle
|
||||||
return pickle.load(open(path, "rb"))
|
return pickle.load(open(path, "rb"))
|
||||||
|
|
||||||
@ -553,6 +559,68 @@ class EnsembleClassifierChain:
|
|||||||
return labels / self.num_chains
|
return labels / self.num_chains
|
||||||
|
|
||||||
|
|
||||||
|
class RelativeReasoning:
|
||||||
|
def __init__(self, start_index: int, end_index: int):
|
||||||
|
self.start_index: int = start_index
|
||||||
|
self.end_index: int = end_index
|
||||||
|
self.winmap: Dict[int, List[int]] = defaultdict(list)
|
||||||
|
self.min_count: int = 5
|
||||||
|
self.max_count: int = 0
|
||||||
|
|
||||||
|
def fit(self, X, Y):
|
||||||
|
n_instances = len(Y)
|
||||||
|
n_attributes = len(Y[0])
|
||||||
|
|
||||||
|
for i in range(n_attributes):
|
||||||
|
for j in range(n_attributes):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
|
||||||
|
countwin = 0
|
||||||
|
countloose = 0
|
||||||
|
countboth = 0
|
||||||
|
|
||||||
|
for k in range(n_instances):
|
||||||
|
vi = Y[k][i]
|
||||||
|
vj = Y[k][j]
|
||||||
|
|
||||||
|
if vi is None or vj is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if vi < vj:
|
||||||
|
countwin += 1
|
||||||
|
elif vi > vj:
|
||||||
|
countloose += 1
|
||||||
|
elif vi == vj and vi == 1: # tie
|
||||||
|
countboth += 1
|
||||||
|
|
||||||
|
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
|
||||||
|
if (
|
||||||
|
countwin >= self.min_count and
|
||||||
|
countwin > countloose and
|
||||||
|
(
|
||||||
|
countloose <= self.max_count or
|
||||||
|
self.max_count < 0
|
||||||
|
) and
|
||||||
|
countboth == 0
|
||||||
|
):
|
||||||
|
self.winmap[i].append(j)
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
||||||
|
|
||||||
|
for inst_idx, inst in enumerate(X):
|
||||||
|
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||||
|
res[inst_idx][i] = t
|
||||||
|
if t:
|
||||||
|
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||||
|
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
|
||||||
|
res[inst_idx][i] = 0
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def predict_proba(self, X):
|
||||||
|
return self.predict(X)
|
||||||
|
|
||||||
|
|
||||||
class ApplicabilityDomainPCA(PCA):
|
class ApplicabilityDomainPCA(PCA):
|
||||||
|
|||||||
Reference in New Issue
Block a user