[Feature] Rule Based Model (#92)

Fixes #89

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#92
This commit is contained in:
2025-09-09 19:32:12 +12:00
parent 1a6608287d
commit 5477b5b3d4
10 changed files with 560 additions and 185 deletions

View File

@ -4,6 +4,7 @@ import json
import logging
import os
import secrets
from abc import abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import Union, List, Optional, Dict, Tuple, Set
@ -27,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
logger = logging.getLogger(__name__)
@ -1321,11 +1322,15 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
@property
def root_nodes(self):
return Node.objects.filter(pathway=self, depth=0)
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
# potentially prefetched node_set
return self.node_set.all().filter(pathway=self, depth=0)
@property
def nodes(self):
return Node.objects.filter(pathway=self)
# same as Node.objects.filter(pathway=self) but will utilize
# potentially prefetched node_set
return self.node_set.all()
def get_node(self, node_url):
for n in self.nodes:
@ -1335,7 +1340,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
@property
def edges(self):
return Edge.objects.filter(pathway=self)
# same as Edge.objects.filter(pathway=self) but will utilize
# potentially prefetched edge_set
return self.edge_set.all()
def _url(self):
return '{}/pathway/{}'.format(self.package.url, self.uuid)
@ -1808,11 +1815,17 @@ class EPModel(PolymorphicModel, EnviPathModel):
return '{}/model/{}'.format(self.package.url, self.uuid)
class MLRelativeReasoning(EPModel):
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages")
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages")
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages")
class PackageBasedModel(EPModel):
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages",
related_name="%(app_label)s_%(class)s_rule_packages")
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages",
related_name="%(app_label)s_%(class)s_data_packages")
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages",
related_name="%(app_label)s_%(class)s_eval_packages")
threshold = models.FloatField(null=False, blank=False, default=0.5)
eval_results = JSONField(null=True, blank=True, default=dict)
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
default=None)
INITIAL = "INITIAL"
INITIALIZING = "INITIALIZING"
@ -1832,69 +1845,12 @@ class MLRelativeReasoning(EPModel):
}
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
eval_results = JSONField(null=True, blank=True, default=dict)
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
default=None)
def status(self):
return self.PROGRESS_STATUS_CHOICES[self.model_status]
def ready_for_prediction(self) -> bool:
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
@staticmethod
@transaction.atomic
def create(package: 'Package', rule_packages: List['Package'],
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
name: 'str' = None, description: str = None, build_app_domain: bool = False,
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
app_domain_local_compatibility_threshold: float = None):
mlrr = MLRelativeReasoning()
mlrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
mlrr.name = name
if description is not None and description.strip() != '':
mlrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
mlrr.threshold = threshold
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
mlrr.save()
for p in rule_packages:
mlrr.rule_packages.add(p)
if data_packages:
for p in data_packages:
mlrr.data_packages.add(p)
else:
for p in rule_packages:
mlrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mlrr.eval_packages.add(p)
if build_app_domain:
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
app_domain_local_compatibility_threshold)
mlrr.app_domain = ad
mlrr.save()
return mlrr
@cached_property
def applicable_rules(self) -> List['Rule']:
"""
@ -1963,6 +1919,179 @@ class MLRelativeReasoning(EPModel):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return Dataset.load(ds_path)
def retrain(self):
self.build_dataset()
self.build_model()
def rebuild(self):
self.build_model()
@abstractmethod
def build_model(self):
pass
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
for rule, p, smis in zip(rules, probabilities, products):
res.append(PredictionResult(smis, p, rule))
return res
class Meta:
abstract = True
class RuleBasedRelativeReasoning(PackageBasedModel):
min_count = models.IntegerField(null=False, blank=False, default=10)
max_count = models.IntegerField(null=False, blank=False, default=0)
@staticmethod
@transaction.atomic
def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'],
eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0,
name: 'str' = None, description: str = None):
rbrr = RuleBasedRelativeReasoning()
rbrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
rbrr.name = name
if description is not None and description.strip() != '':
rbrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
rbrr.threshold = threshold
if min_count is None or min_count < 1:
raise ValueError("Minimum count must be an int greater than equal 1.")
rbrr.min_count = min_count
if max_count is None or max_count > min_count:
raise ValueError("Maximum count must be an int and must not be less than min_count.")
if max_count is None:
raise ValueError("Maximum count must be at least 0.")
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
rbrr.save()
for p in rule_packages:
rbrr.rule_packages.add(p)
if data_packages:
for p in data_packages:
rbrr.data_packages.add(p)
else:
for p in rule_packages:
rbrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
rbrr.eval_packages.add(p)
rbrr.save()
return rbrr
def build_model(self):
self.model_status = self.BUILDING
self.save()
ds = self.load_dataset()
labels = ds.y(na_replacement=None)
mod = RelativeReasoning(*ds.triggered())
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f)
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
@cached_property
def model(self) -> 'RelativeReasoning':
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
return mod
def predict(self, smiles) -> List['PredictionResult']:
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
mod = self.model
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
end = datetime.now()
logger.info(f"Full predict took {(end - start).total_seconds()}s")
return res
class MLRelativeReasoning(PackageBasedModel):
@staticmethod
@transaction.atomic
def create(package: 'Package', rule_packages: List['Package'],
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
name: 'str' = None, description: str = None, build_app_domain: bool = False,
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
app_domain_local_compatibility_threshold: float = None):
mlrr = MLRelativeReasoning()
mlrr.package = package
if name is None or name.strip() == '':
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
mlrr.name = name
if description is not None and description.strip() != '':
mlrr.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
mlrr.threshold = threshold
if len(rule_packages) == 0:
raise ValueError("At least one rule package must be provided.")
mlrr.save()
for p in rule_packages:
mlrr.rule_packages.add(p)
if data_packages:
for p in data_packages:
mlrr.data_packages.add(p)
else:
for p in rule_packages:
mlrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mlrr.eval_packages.add(p)
if build_app_domain:
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
app_domain_local_compatibility_threshold)
mlrr.app_domain = ad
mlrr.save()
return mlrr
def build_model(self):
self.model_status = self.BUILDING
self.save()
@ -1991,13 +2120,6 @@ class MLRelativeReasoning(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def retrain(self):
self.build_dataset()
self.build_model()
def rebuild(self):
self.build_model()
def evaluate_model(self):
if self.model_status != self.BUILT_NOT_EVALUATED:
@ -2098,13 +2220,6 @@ class MLRelativeReasoning(EPModel):
logger.info(f"Full predict took {(end - start).total_seconds()}s")
return res
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
for rule, p, smis in zip(rules, probabilities, products):
res.append(PredictionResult(smis, p, rule))
return res
@property
def pr_curve(self):
if self.model_status != self.FINISHED:
@ -2358,9 +2473,6 @@ class ApplicabilityDomain(EnviPathModel):
return accuracy
class RuleBaseRelativeReasoning(EPModel):
pass
class EnviFormer(EPModel):
threshold = models.FloatField(null=False, blank=False, default=0.5)
@ -2406,6 +2518,12 @@ class EnviFormer(EPModel):
def applicable_rules(self):
return []
def status(self):
return "Model is built and can be used for predictions, Model is not evaluated yet."
def ready_for_prediction(self) -> bool:
return True
class PluginModel(EPModel):
pass