Merge remote-tracking branch 'origin/develop' into enhancement/dataset

# Conflicts:
#	epdb/models.py
#	tests/test_enviformer.py
#	tests/test_model.py
This commit is contained in:
Liam Brydon
2025-11-07 08:28:03 +13:00
25 changed files with 1024 additions and 280 deletions

View File

@ -2226,10 +2226,18 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self, **kwargs):
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
if multigen:
self.multigen_eval = multigen
self.save()
if eval_packages is not None:
for p in eval_packages:
self.eval_packages.add(p)
self.model_status = self.EVALUATING
self.save()
@ -2526,7 +2534,6 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
package: "Package",
rule_packages: List["Package"],
data_packages: List["Package"],
eval_packages: List["Package"],
threshold: float = 0.5,
min_count: int = 10,
max_count: int = 0,
@ -2575,10 +2582,6 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
for p in rule_packages:
rbrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
rbrr.eval_packages.add(p)
rbrr.save()
return rbrr
@ -2633,7 +2636,6 @@ class MLRelativeReasoning(PackageBasedModel):
package: "Package",
rule_packages: List["Package"],
data_packages: List["Package"],
eval_packages: List["Package"],
threshold: float = 0.5,
name: "str" = None,
description: str = None,
@ -2673,10 +2675,6 @@ class MLRelativeReasoning(PackageBasedModel):
for p in rule_packages:
mlrr.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mlrr.eval_packages.add(p)
if build_app_domain:
ad = ApplicabilityDomain.create(
mlrr,
@ -2953,7 +2951,6 @@ class EnviFormer(PackageBasedModel):
def create(
package: "Package",
data_packages: List["Package"],
eval_packages: List["Package"],
threshold: float = 0.5,
name: "str" = None,
description: str = None,
@ -2986,10 +2983,6 @@ class EnviFormer(PackageBasedModel):
for p in data_packages:
mod.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mod.eval_packages.add(p)
# if build_app_domain:
# ad = ApplicabilityDomain.create(mod, app_domain_num_neighbours, app_domain_reliability_threshold,
# app_domain_local_compatibility_threshold)
@ -3082,10 +3075,18 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"}
return args
def evaluate_model(self, **kwargs):
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
if multigen:
self.multigen_eval = multigen
self.save()
if eval_packages is not None:
for p in eval_packages:
self.eval_packages.add(p)
self.model_status = self.EVALUATING
self.save()
@ -3226,7 +3227,7 @@ class EnviFormer(PackageBasedModel):
ds = self.load_dataset()
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint.
@ -3294,7 +3295,7 @@ class EnviFormer(PackageBasedModel):
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
# iteration instead of storing all trained models.
for split_id, (train, test) in enumerate(
ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)
ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42).split(pathways)
):
train_pathways = [pathways[i] for i in train]
test_pathways = [pathways[i] for i in test]
@ -3577,3 +3578,53 @@ class Setting(EnviPathModel):
self.public = True
self.global_default = True
self.save()
class JobLogStatus(models.TextChoices):
INITIAL = "INITIAL", "Initial"
SUCCESS = "SUCCESS", "Success"
FAILURE = "FAILURE", "Failure"
REVOKED = "REVOKED", "Revoked"
IGNORED = "IGNORED", "Ignored"
class JobLog(TimeStampedModel):
user = models.ForeignKey("epdb.User", models.CASCADE)
task_id = models.UUIDField(unique=True)
job_name = models.TextField(null=False, blank=False)
status = models.CharField(
max_length=20,
choices=JobLogStatus.choices,
default=JobLogStatus.INITIAL,
)
done_at = models.DateTimeField(null=True, blank=True, default=None)
task_result = models.TextField(null=True, blank=True, default=None)
def check_for_update(self):
async_res = self.get_result()
new_status = async_res.state
TERMINAL_STATES = [
"SUCCESS",
"FAILURE",
"REVOKED",
"IGNORED",
]
if new_status != self.status and new_status in TERMINAL_STATES:
self.status = new_status
self.done_at = async_res.date_done
if new_status == "SUCCESS":
self.task_result = async_res.result
self.save()
return True
return False
def get_result(self):
from celery.result import AsyncResult
return AsyncResult(str(self.task_id))