diff --git a/epdb/models.py b/epdb/models.py index fb4baf97..e82e05af 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -4,6 +4,7 @@ import json import logging import os import secrets +from abc import abstractmethod from collections import defaultdict from datetime import datetime from typing import Union, List, Optional, Dict, Tuple, Set @@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain +from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning logger = logging.getLogger(__name__) @@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): @property def root_nodes(self): - return Node.objects.filter(pathway=self, depth=0) + # sames as return Node.objects.filter(pathway=self, depth=0) but will utilize + # potentially prefetched node_set + return self.node_set.all().filter(pathway=self, depth=0) @property def nodes(self): - return Node.objects.filter(pathway=self) + # same as Node.objects.filter(pathway=self) but will utilize + # potentially prefetched node_set + return self.node_set.all() def get_node(self, node_url): for n in self.nodes: @@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): @property def edges(self): - return Edge.objects.filter(pathway=self) + # same as Edge.objects.filter(pathway=self) but will utilize + # potentially prefetched edge_set + return self.edge_set.all() def _url(self): return '{}/pathway/{}'.format(self.package.url, self.uuid) @@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel): return '{}/model/{}'.format(self.package.url, self.uuid) -class MLRelativeReasoning(EPModel): - rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages") - data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages") - eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages") +class PackageBasedModel(EPModel): + rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", + related_name="%(app_label)s_%(class)s_rule_packages") + data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", + related_name="%(app_label)s_%(class)s_data_packages") + eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", + related_name="%(app_label)s_%(class)s_eval_packages") threshold = models.FloatField(null=False, blank=False, default=0.5) + eval_results = JSONField(null=True, blank=True, default=dict) + app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True, + default=None) INITIAL = "INITIAL" INITIALIZING = "INITIALIZING" @@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel): } model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL) - eval_results = JSONField(null=True, blank=True, default=dict) - - app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True, - default=None) - def status(self): return self.PROGRESS_STATUS_CHOICES[self.model_status] def ready_for_prediction(self) -> bool: return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED] - @staticmethod - @transaction.atomic - def create(package: 'Package', rule_packages: List['Package'], - data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5, - name: 'str' = None, description: str = None, build_app_domain: bool = False, - app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, - app_domain_local_compatibility_threshold: float = None): - - mlrr = MLRelativeReasoning() - mlrr.package = package - - if name is None or name.strip() == '': - name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}" - - mlrr.name = name - - if description is not None and description.strip() != '': - mlrr.description = description - - if threshold is None or (threshold <= 0 or 1 <= threshold): - raise ValueError("Threshold must be a float between 0 and 1.") - - mlrr.threshold = threshold - - if len(rule_packages) == 0: - raise ValueError("At least one rule package must be provided.") - - mlrr.save() - - for p in rule_packages: - mlrr.rule_packages.add(p) - - if data_packages: - for p in data_packages: - mlrr.data_packages.add(p) - else: - for p in rule_packages: - mlrr.data_packages.add(p) - - if eval_packages: - for p in eval_packages: - mlrr.eval_packages.add(p) - - if build_app_domain: - ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold, - app_domain_local_compatibility_threshold) - mlrr.app_domain = ad - - mlrr.save() - - return mlrr - @cached_property def applicable_rules(self) -> List['Rule']: """ @@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel): ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") return Dataset.load(ds_path) + def retrain(self): + self.build_dataset() + self.build_model() + + def rebuild(self): + self.build_model() + + @abstractmethod + def build_model(self): + pass + + @staticmethod + def combine_products_and_probs(rules: List['Rule'], probabilities, products): + res = [] + for rule, p, smis in zip(rules, probabilities, products): + res.append(PredictionResult(smis, p, rule)) + return res + + class Meta: + abstract = True + + +class RuleBasedRelativeReasoning(PackageBasedModel): + min_count = models.IntegerField(null=False, blank=False, default=10) + max_count = models.IntegerField(null=False, blank=False, default=0) + + @staticmethod + @transaction.atomic + def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'], + eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0, + name: 'str' = None, description: str = None): + + rbrr = RuleBasedRelativeReasoning() + rbrr.package = package + + if name is None or name.strip() == '': + name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}" + + rbrr.name = name + + if description is not None and description.strip() != '': + rbrr.description = description + + if threshold is None or (threshold <= 0 or 1 <= threshold): + raise ValueError("Threshold must be a float between 0 and 1.") + + rbrr.threshold = threshold + + if min_count is None or min_count < 1: + raise ValueError("Minimum count must be an int greater than equal 1.") + + rbrr.min_count = min_count + + if max_count is None or max_count > min_count: + raise ValueError("Maximum count must be an int and must not be less than min_count.") + + if max_count is None: + raise ValueError("Maximum count must be at least 0.") + + if len(rule_packages) == 0: + raise ValueError("At least one rule package must be provided.") + + rbrr.save() + + for p in rule_packages: + rbrr.rule_packages.add(p) + + if data_packages: + for p in data_packages: + rbrr.data_packages.add(p) + else: + for p in rule_packages: + rbrr.data_packages.add(p) + + if eval_packages: + for p in eval_packages: + rbrr.eval_packages.add(p) + + rbrr.save() + + return rbrr + + def build_model(self): + self.model_status = self.BUILDING + self.save() + + ds = self.load_dataset() + labels = ds.y(na_replacement=None) + + mod = RelativeReasoning(*ds.triggered()) + mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)) + + f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl") + joblib.dump(mod, f) + + self.model_status = self.BUILT_NOT_EVALUATED + self.save() + + @cached_property + def model(self) -> 'RelativeReasoning': + mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl')) + return mod + + def predict(self, smiles) -> List['PredictionResult']: + start = datetime.now() + ds = self.load_dataset() + classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) + + mod = self.model + + pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None)) + + res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0]) + + end = datetime.now() + logger.info(f"Full predict took {(end - start).total_seconds()}s") + return res + + +class MLRelativeReasoning(PackageBasedModel): + + @staticmethod + @transaction.atomic + def create(package: 'Package', rule_packages: List['Package'], + data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5, + name: 'str' = None, description: str = None, build_app_domain: bool = False, + app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, + app_domain_local_compatibility_threshold: float = None): + + mlrr = MLRelativeReasoning() + mlrr.package = package + + if name is None or name.strip() == '': + name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}" + + mlrr.name = name + + if description is not None and description.strip() != '': + mlrr.description = description + + if threshold is None or (threshold <= 0 or 1 <= threshold): + raise ValueError("Threshold must be a float between 0 and 1.") + + mlrr.threshold = threshold + + if len(rule_packages) == 0: + raise ValueError("At least one rule package must be provided.") + + mlrr.save() + + for p in rule_packages: + mlrr.rule_packages.add(p) + + if data_packages: + for p in data_packages: + mlrr.data_packages.add(p) + else: + for p in rule_packages: + mlrr.data_packages.add(p) + + if eval_packages: + for p in eval_packages: + mlrr.eval_packages.add(p) + + if build_app_domain: + ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold, + app_domain_local_compatibility_threshold) + mlrr.app_domain = ad + + mlrr.save() + + return mlrr + def build_model(self): self.model_status = self.BUILDING self.save() @@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel): self.model_status = self.BUILT_NOT_EVALUATED self.save() - def retrain(self): - self.build_dataset() - self.build_model() - - def rebuild(self): - self.build_model() - def evaluate_model(self): if self.model_status != self.BUILT_NOT_EVALUATED: @@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel): logger.info(f"Full predict took {(end - start).total_seconds()}s") return res - @staticmethod - def combine_products_and_probs(rules: List['Rule'], probabilities, products): - res = [] - for rule, p, smis in zip(rules, probabilities, products): - res.append(PredictionResult(smis, p, rule)) - return res - @property def pr_curve(self): if self.model_status != self.FINISHED: @@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel): return accuracy -class RuleBaseRelativeReasoning(EPModel): - pass - class EnviFormer(EPModel): threshold = models.FloatField(null=False, blank=False, default=0.5) @@ -2406,6 +2518,12 @@ class EnviFormer(EPModel): def applicable_rules(self): return [] + def status(self): + return "Model is built and can be used for predictions, Model is not evaluated yet." + + def ready_for_prediction(self) -> bool: + return True + class PluginModel(EPModel): pass diff --git a/epdb/tasks.py b/epdb/tasks.py index 4ca4d183..3cbd9386 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -41,6 +41,12 @@ def evaluate_model(model_pk: int): mod.evaluate_model() +@shared_task(queue='model') +def retrain(model_pk: int): + mod = EPModel.objects.get(id=model_pk) + mod.retrain() + + @shared_task(queue='predict') def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway: pw = Pathway.objects.get(id=pw_pk) diff --git a/epdb/views.py b/epdb/views.py index da98412f..2ce006d2 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required from utilities.misc import HTMLGenerator from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \ - EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \ + EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \ UserPackagePermission, Permission, License, User, Edge logger = logging.getLogger(__name__) @@ -651,46 +651,50 @@ def package_models(request, package_uuid): mod = EnviFormer.create(current_package, name, description, threshold) - elif model_type == 'ml-relative-reasoning': - threshold = float(request.POST.get(f'{model_type}-threshold', 0.5)) - fingerprinter = request.POST.get(f'{model_type}-fingerprinter') - rule_packages = request.POST.getlist(f'{model_type}-rule-packages') - data_packages = request.POST.getlist(f'{model_type}-data-packages') - eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', []) + elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning': + # Generic fields for ML and Rule Based + rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages') + data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages') + eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', []) - # get Package objects from urls - rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages] - data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages] - eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages] + # Generic params + params = { + 'package' : current_package, + 'name' : name, + 'description' : description, + 'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages], + 'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages], + 'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages], + } - # App Domain related parameters - build_ad = request.POST.get('build-app-domain', False) == 'on' - num_neighbors = request.POST.get('num-neighbors', 5) - reliability_threshold = request.POST.get('reliability-threshold', 0.5) - local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5) + if model_type == 'ml-relative-reasoning': + # ML Specific + threshold = float(request.POST.get(f'{model_type}-threshold', 0.5)) + fingerprinter = request.POST.get(f'{model_type}-fingerprinter') - mod = MLRelativeReasoning.create( - package=current_package, - name=name, - description=description, - rule_packages=rule_package_objs, - data_packages=data_package_objs, - eval_packages=eval_packages_objs, - threshold=threshold, - # fingerprinter=fingerprinter, - build_app_domain=build_ad, - app_domain_num_neighbours=num_neighbors, - app_domain_reliability_threshold=reliability_threshold, - app_domain_local_compatibility_threshold=local_compatibility_threshold, - ) + # App Domain related parameters + build_ad = request.POST.get('build-app-domain', False) == 'on' + num_neighbors = request.POST.get('num-neighbors', 5) + reliability_threshold = request.POST.get('reliability-threshold', 0.5) + local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5) + + params['threshold'] = threshold + # params['fingerprinter'] = fingerprinter + params['build_app_domain'] = build_ad + params['app_domain_num_neighbours'] = num_neighbors + params['app_domain_reliability_threshold'] = reliability_threshold + params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold + + mod = MLRelativeReasoning.create( + **params + ) + else: + mod = RuleBasedRelativeReasoning.create( + **params + ) from .tasks import build_model build_model.delay(mod.pk) - - elif model_type == 'rule-base-relative-reasoning': - mod = RuleBaseRelativeReasoning() - - mod.save() else: return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."') return redirect(mod.url) @@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid): else: return HttpResponseBadRequest() else: + + name = request.POST.get('model-name', '').strip() + description = request.POST.get('model-description', '').strip() + + if any([name, description]): + if name: + current_model.name = name + + if description: + current_model.description = description + + current_model.save() + return redirect(current_model.url) + return HttpResponseBadRequest() else: diff --git a/templates/actions/objects/model.html b/templates/actions/objects/model.html index 466bcc66..acff9f23 100644 --- a/templates/actions/objects/model.html +++ b/templates/actions/objects/model.html @@ -1,4 +1,16 @@ {% if meta.can_edit %} +
{{ model.description }}