forked from enviPath/enviPy
[Feature] Eval package evaluation
`evaluate_model` in `PackageBasedModel` and `EnviFormer` now use evaluation packages if any are present instead of the random splits. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#148 Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
@ -2130,6 +2130,19 @@ class PackageBasedModel(EPModel):
|
||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||
return mg_acc, precision, recall
|
||||
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct())
|
||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
ds = self.load_dataset()
|
||||
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
@ -2152,16 +2165,12 @@ class PackageBasedModel(EPModel):
|
||||
self.eval_results = self.compute_averages(evaluations)
|
||||
|
||||
if self.multigen_eval:
|
||||
|
||||
# We have to consider 2 cases here:
|
||||
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
|
||||
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
|
||||
|
||||
# If there are eval packages perform multi generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
return
|
||||
|
||||
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||
else:
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel):
|
||||
def predict_batch(self, smiles_list):
|
||||
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
||||
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
||||
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}")
|
||||
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
||||
products_list = self.model.predict_batch(canon_smiles)
|
||||
logger.info(f"Got results {products_list}")
|
||||
|
||||
@ -2943,6 +2952,19 @@ class EnviFormer(PackageBasedModel):
|
||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||
return mg_acc, precision, recall
|
||||
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
ds = []
|
||||
for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct():
|
||||
educts = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in
|
||||
reaction.products.all()])
|
||||
ds.append(f"{educts}>>{products}")
|
||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
from enviformer.finetune import fine_tune
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
@ -2965,6 +2987,11 @@ class EnviFormer(PackageBasedModel):
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
|
||||
if self.multigen_eval:
|
||||
if self.eval_packages.count() > 0:
|
||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||
else:
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
|
||||
@ -20,7 +20,7 @@ class EnviFormerTest(TestCase):
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
eval_packages_objs = [self.BBD_SUBSET]
|
||||
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
|
||||
|
||||
mod.build_dataset()
|
||||
|
||||
@ -24,7 +24,7 @@ class ModelTest(TestCase):
|
||||
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
eval_packages_objs = [self.BBD_SUBSET]
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
@ -52,7 +52,7 @@ class ModelTest(TestCase):
|
||||
mod.build_model()
|
||||
mod.multigen_eval = True
|
||||
mod.save()
|
||||
# mod.evaluate_model()
|
||||
mod.evaluate_model()
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user