forked from enviPath/enviPy
[Feature] Eval package evaluation
`evaluate_model` in `PackageBasedModel` and `EnviFormer` now use evaluation packages if any are present instead of the random splits. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#148 Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
@ -2130,6 +2130,19 @@ class PackageBasedModel(EPModel):
|
|||||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||||
return mg_acc, precision, recall
|
return mg_acc, precision, recall
|
||||||
|
|
||||||
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||||
|
if self.eval_packages.count() > 0:
|
||||||
|
eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct())
|
||||||
|
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||||
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
else:
|
||||||
|
X = np.array(ds.X(na_replacement=np.nan))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||||
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
|
else:
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
|
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
@ -2152,16 +2165,12 @@ class PackageBasedModel(EPModel):
|
|||||||
self.eval_results = self.compute_averages(evaluations)
|
self.eval_results = self.compute_averages(evaluations)
|
||||||
|
|
||||||
if self.multigen_eval:
|
if self.multigen_eval:
|
||||||
|
# If there are eval packages perform multi generation evaluation on them instead of random splits
|
||||||
# We have to consider 2 cases here:
|
|
||||||
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
|
|
||||||
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
|
|
||||||
|
|
||||||
if self.eval_packages.count() > 0:
|
if self.eval_packages.count() > 0:
|
||||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
evaluate_mg(self.model, pathway_qs, self.threshold)
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||||
return
|
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||||
|
else:
|
||||||
pathway_qs = Pathway.objects.prefetch_related(
|
pathway_qs = Pathway.objects.prefetch_related(
|
||||||
'node_set',
|
'node_set',
|
||||||
'node_set__out_edges',
|
'node_set__out_edges',
|
||||||
@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
def predict_batch(self, smiles_list):
|
def predict_batch(self, smiles_list):
|
||||||
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
||||||
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
||||||
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}")
|
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
||||||
products_list = self.model.predict_batch(canon_smiles)
|
products_list = self.model.predict_batch(canon_smiles)
|
||||||
logger.info(f"Got results {products_list}")
|
logger.info(f"Got results {products_list}")
|
||||||
|
|
||||||
@ -2943,6 +2952,19 @@ class EnviFormer(PackageBasedModel):
|
|||||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||||
return mg_acc, precision, recall
|
return mg_acc, precision, recall
|
||||||
|
|
||||||
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||||
|
if self.eval_packages.count() > 0:
|
||||||
|
ds = []
|
||||||
|
for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct():
|
||||||
|
educts = ".".join(
|
||||||
|
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||||
|
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in
|
||||||
|
reaction.products.all()])
|
||||||
|
ds.append(f"{educts}>>{products}")
|
||||||
|
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||||
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||||
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
|
else:
|
||||||
from enviformer.finetune import fine_tune
|
from enviformer.finetune import fine_tune
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
n_splits = 20
|
n_splits = 20
|
||||||
@ -2965,6 +2987,11 @@ class EnviFormer(PackageBasedModel):
|
|||||||
self.eval_results = self.compute_averages(single_gen_results)
|
self.eval_results = self.compute_averages(single_gen_results)
|
||||||
|
|
||||||
if self.multigen_eval:
|
if self.multigen_eval:
|
||||||
|
if self.eval_packages.count() > 0:
|
||||||
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||||
|
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||||
|
else:
|
||||||
pathway_qs = Pathway.objects.prefetch_related(
|
pathway_qs = Pathway.objects.prefetch_related(
|
||||||
'node_set',
|
'node_set',
|
||||||
'node_set__out_edges',
|
'node_set__out_edges',
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class EnviFormerTest(TestCase):
|
|||||||
with self.settings(MODEL_DIR=tmpdir):
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = []
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
|
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
|
||||||
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class ModelTest(TestCase):
|
|||||||
|
|
||||||
rule_package_objs = [self.BBD_SUBSET]
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = []
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
|
||||||
mod = MLRelativeReasoning.create(
|
mod = MLRelativeReasoning.create(
|
||||||
self.package,
|
self.package,
|
||||||
@ -52,7 +52,7 @@ class ModelTest(TestCase):
|
|||||||
mod.build_model()
|
mod.build_model()
|
||||||
mod.multigen_eval = True
|
mod.multigen_eval = True
|
||||||
mod.save()
|
mod.save()
|
||||||
# mod.evaluate_model()
|
mod.evaluate_model()
|
||||||
|
|
||||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user