[Feature] Eval package evaluation

`evaluate_model` in `PackageBasedModel` and `EnviFormer` now use evaluation packages if any are present instead of the random splits.

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#148
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-10-08 19:03:21 +13:00
committed by jebus
parent 36879c266b
commit 22f0bbe10b
3 changed files with 195 additions and 168 deletions

View File

@ -2130,6 +2130,19 @@ class PackageBasedModel(EPModel):
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()} recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
return mg_acc, precision, recall return mg_acc, precision, recall
# If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0:
eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct())
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
ds = self.load_dataset() ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning): if isinstance(self, RuleBasedRelativeReasoning):
@ -2152,16 +2165,12 @@ class PackageBasedModel(EPModel):
self.eval_results = self.compute_averages(evaluations) self.eval_results = self.compute_averages(evaluations)
if self.multigen_eval: if self.multigen_eval:
# If there are eval packages perform multi generation evaluation on them instead of random splits
# We have to consider 2 cases here:
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
if self.eval_packages.count() > 0: if self.eval_packages.count() > 0:
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct() pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
evaluate_mg(self.model, pathway_qs, self.threshold) multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
return self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
else:
pathway_qs = Pathway.objects.prefetch_related( pathway_qs = Pathway.objects.prefetch_related(
'node_set', 'node_set',
'node_set__out_edges', 'node_set__out_edges',
@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel):
def predict_batch(self, smiles_list): def predict_batch(self, smiles_list):
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately # Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list] canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}") logger.info(f"Submitting {canon_smiles} to {self.name}")
products_list = self.model.predict_batch(canon_smiles) products_list = self.model.predict_batch(canon_smiles)
logger.info(f"Got results {products_list}") logger.info(f"Got results {products_list}")
@ -2943,6 +2952,19 @@ class EnviFormer(PackageBasedModel):
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()} recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
return mg_acc, precision, recall return mg_acc, precision, recall
# If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0:
ds = []
for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct():
educts = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in
reaction.products.all()])
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
from enviformer.finetune import fine_tune from enviformer.finetune import fine_tune
ds = self.load_dataset() ds = self.load_dataset()
n_splits = 20 n_splits = 20
@ -2965,6 +2987,11 @@ class EnviFormer(PackageBasedModel):
self.eval_results = self.compute_averages(single_gen_results) self.eval_results = self.compute_averages(single_gen_results)
if self.multigen_eval: if self.multigen_eval:
if self.eval_packages.count() > 0:
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
else:
pathway_qs = Pathway.objects.prefetch_related( pathway_qs = Pathway.objects.prefetch_related(
'node_set', 'node_set',
'node_set__out_edges', 'node_set__out_edges',

View File

@ -20,7 +20,7 @@ class EnviFormerTest(TestCase):
with self.settings(MODEL_DIR=tmpdir): with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5) threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [] eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold) mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
mod.build_dataset() mod.build_dataset()

View File

@ -24,7 +24,7 @@ class ModelTest(TestCase):
rule_package_objs = [self.BBD_SUBSET] rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [] eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create( mod = MLRelativeReasoning.create(
self.package, self.package,
@ -52,7 +52,7 @@ class ModelTest(TestCase):
mod.build_model() mod.build_model()
mod.multigen_eval = True mod.multigen_eval = True
mod.save() mod.save()
# mod.evaluate_model() mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C') results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')