forked from enviPath/enviPy
[Feature] Rule Based Model (#92)
Fixes #89 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#92
This commit is contained in:
282
epdb/models.py
282
epdb/models.py
@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
||||
@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
@property
|
||||
def root_nodes(self):
|
||||
return Node.objects.filter(pathway=self, depth=0)
|
||||
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
|
||||
# potentially prefetched node_set
|
||||
return self.node_set.all().filter(pathway=self, depth=0)
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
return Node.objects.filter(pathway=self)
|
||||
# same as Node.objects.filter(pathway=self) but will utilize
|
||||
# potentially prefetched node_set
|
||||
return self.node_set.all()
|
||||
|
||||
def get_node(self, node_url):
|
||||
for n in self.nodes:
|
||||
@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
return Edge.objects.filter(pathway=self)
|
||||
# same as Edge.objects.filter(pathway=self) but will utilize
|
||||
# potentially prefetched edge_set
|
||||
return self.edge_set.all()
|
||||
|
||||
def _url(self):
|
||||
return '{}/pathway/{}'.format(self.package.url, self.uuid)
|
||||
@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel):
|
||||
return '{}/model/{}'.format(self.package.url, self.uuid)
|
||||
|
||||
|
||||
class MLRelativeReasoning(EPModel):
|
||||
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages")
|
||||
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages")
|
||||
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages")
|
||||
class PackageBasedModel(EPModel):
|
||||
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages",
|
||||
related_name="%(app_label)s_%(class)s_rule_packages")
|
||||
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages")
|
||||
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages",
|
||||
related_name="%(app_label)s_%(class)s_eval_packages")
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
|
||||
INITIAL = "INITIAL"
|
||||
INITIALIZING = "INITIALIZING"
|
||||
@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel):
|
||||
}
|
||||
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
|
||||
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
|
||||
def status(self):
|
||||
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package: 'Package', rule_packages: List['Package'],
|
||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
|
||||
mlrr = MLRelativeReasoning()
|
||||
mlrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mlrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mlrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mlrr.threshold = threshold
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
mlrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mlrr.rule_packages.add(p)
|
||||
|
||||
if data_packages:
|
||||
for p in data_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
else:
|
||||
for p in rule_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
mlrr.eval_packages.add(p)
|
||||
|
||||
if build_app_domain:
|
||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
app_domain_local_compatibility_threshold)
|
||||
mlrr.app_domain = ad
|
||||
|
||||
mlrr.save()
|
||||
|
||||
return mlrr
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self) -> List['Rule']:
|
||||
"""
|
||||
@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
def rebuild(self):
|
||||
self.build_model()
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
for rule, p, smis in zip(rules, probabilities, products):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
return res
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
min_count = models.IntegerField(null=False, blank=False, default=10)
|
||||
max_count = models.IntegerField(null=False, blank=False, default=0)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'],
|
||||
eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0,
|
||||
name: 'str' = None, description: str = None):
|
||||
|
||||
rbrr = RuleBasedRelativeReasoning()
|
||||
rbrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
rbrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
rbrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
rbrr.threshold = threshold
|
||||
|
||||
if min_count is None or min_count < 1:
|
||||
raise ValueError("Minimum count must be an int greater than equal 1.")
|
||||
|
||||
rbrr.min_count = min_count
|
||||
|
||||
if max_count is None or max_count > min_count:
|
||||
raise ValueError("Maximum count must be an int and must not be less than min_count.")
|
||||
|
||||
if max_count is None:
|
||||
raise ValueError("Maximum count must be at least 0.")
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
rbrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
rbrr.rule_packages.add(p)
|
||||
|
||||
if data_packages:
|
||||
for p in data_packages:
|
||||
rbrr.data_packages.add(p)
|
||||
else:
|
||||
for p in rule_packages:
|
||||
rbrr.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
rbrr.eval_packages.add(p)
|
||||
|
||||
rbrr.save()
|
||||
|
||||
return rbrr
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
ds = self.load_dataset()
|
||||
labels = ds.y(na_replacement=None)
|
||||
|
||||
mod = RelativeReasoning(*ds.triggered())
|
||||
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
@cached_property
|
||||
def model(self) -> 'RelativeReasoning':
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List['PredictionResult']:
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
|
||||
mod = self.model
|
||||
|
||||
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
||||
|
||||
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||
|
||||
end = datetime.now()
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
|
||||
class MLRelativeReasoning(PackageBasedModel):
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package: 'Package', rule_packages: List['Package'],
|
||||
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
||||
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
|
||||
mlrr = MLRelativeReasoning()
|
||||
mlrr.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mlrr.name = name
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mlrr.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mlrr.threshold = threshold
|
||||
|
||||
if len(rule_packages) == 0:
|
||||
raise ValueError("At least one rule package must be provided.")
|
||||
|
||||
mlrr.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mlrr.rule_packages.add(p)
|
||||
|
||||
if data_packages:
|
||||
for p in data_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
else:
|
||||
for p in rule_packages:
|
||||
mlrr.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
mlrr.eval_packages.add(p)
|
||||
|
||||
if build_app_domain:
|
||||
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
app_domain_local_compatibility_threshold)
|
||||
mlrr.app_domain = ad
|
||||
|
||||
mlrr.save()
|
||||
|
||||
return mlrr
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel):
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
def rebuild(self):
|
||||
self.build_model()
|
||||
|
||||
def evaluate_model(self):
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel):
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
for rule, p, smis in zip(rules, probabilities, products):
|
||||
res.append(PredictionResult(smis, p, rule))
|
||||
return res
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
return accuracy
|
||||
|
||||
|
||||
class RuleBaseRelativeReasoning(EPModel):
|
||||
pass
|
||||
|
||||
|
||||
class EnviFormer(EPModel):
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
@ -2406,6 +2518,12 @@ class EnviFormer(EPModel):
|
||||
def applicable_rules(self):
|
||||
return []
|
||||
|
||||
def status(self):
|
||||
return "Model is built and can be used for predictions, Model is not evaluated yet."
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class PluginModel(EPModel):
|
||||
pass
|
||||
|
||||
@ -41,6 +41,12 @@ def evaluate_model(model_pk: int):
|
||||
mod.evaluate_model()
|
||||
|
||||
|
||||
@shared_task(queue='model')
|
||||
def retrain(model_pk: int):
|
||||
mod = EPModel.objects.get(id=model_pk)
|
||||
mod.retrain()
|
||||
|
||||
|
||||
@shared_task(queue='predict')
|
||||
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
|
||||
pw = Pathway.objects.get(id=pw_pk)
|
||||
|
||||
@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required
|
||||
from utilities.misc import HTMLGenerator
|
||||
from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser
|
||||
from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \
|
||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||
UserPackagePermission, Permission, License, User, Edge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -651,46 +651,50 @@ def package_models(request, package_uuid):
|
||||
|
||||
mod = EnviFormer.create(current_package, name, description, threshold)
|
||||
|
||||
elif model_type == 'ml-relative-reasoning':
|
||||
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
||||
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
||||
rule_packages = request.POST.getlist(f'{model_type}-rule-packages')
|
||||
data_packages = request.POST.getlist(f'{model_type}-data-packages')
|
||||
eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', [])
|
||||
elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning':
|
||||
# Generic fields for ML and Rule Based
|
||||
rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages')
|
||||
data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages')
|
||||
eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', [])
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages]
|
||||
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
|
||||
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
|
||||
# Generic params
|
||||
params = {
|
||||
'package' : current_package,
|
||||
'name' : name,
|
||||
'description' : description,
|
||||
'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages],
|
||||
'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages],
|
||||
'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages],
|
||||
}
|
||||
|
||||
# App Domain related parameters
|
||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||
num_neighbors = request.POST.get('num-neighbors', 5)
|
||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||
if model_type == 'ml-relative-reasoning':
|
||||
# ML Specific
|
||||
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
||||
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
package=current_package,
|
||||
name=name,
|
||||
description=description,
|
||||
rule_packages=rule_package_objs,
|
||||
data_packages=data_package_objs,
|
||||
eval_packages=eval_packages_objs,
|
||||
threshold=threshold,
|
||||
# fingerprinter=fingerprinter,
|
||||
build_app_domain=build_ad,
|
||||
app_domain_num_neighbours=num_neighbors,
|
||||
app_domain_reliability_threshold=reliability_threshold,
|
||||
app_domain_local_compatibility_threshold=local_compatibility_threshold,
|
||||
)
|
||||
# App Domain related parameters
|
||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||
num_neighbors = request.POST.get('num-neighbors', 5)
|
||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||
|
||||
params['threshold'] = threshold
|
||||
# params['fingerprinter'] = fingerprinter
|
||||
params['build_app_domain'] = build_ad
|
||||
params['app_domain_num_neighbours'] = num_neighbors
|
||||
params['app_domain_reliability_threshold'] = reliability_threshold
|
||||
params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
**params
|
||||
)
|
||||
else:
|
||||
mod = RuleBasedRelativeReasoning.create(
|
||||
**params
|
||||
)
|
||||
|
||||
from .tasks import build_model
|
||||
build_model.delay(mod.pk)
|
||||
|
||||
elif model_type == 'rule-base-relative-reasoning':
|
||||
mod = RuleBaseRelativeReasoning()
|
||||
|
||||
mod.save()
|
||||
else:
|
||||
return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."')
|
||||
return redirect(mod.url)
|
||||
@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid):
|
||||
else:
|
||||
return HttpResponseBadRequest()
|
||||
else:
|
||||
|
||||
name = request.POST.get('model-name', '').strip()
|
||||
description = request.POST.get('model-description', '').strip()
|
||||
|
||||
if any([name, description]):
|
||||
if name:
|
||||
current_model.name = name
|
||||
|
||||
if description:
|
||||
current_model.description = description
|
||||
|
||||
current_model.save()
|
||||
return redirect(current_model.url)
|
||||
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user