[Feature] Eval package evaluation

`evaluate_model` in `PackageBasedModel` and `EnviFormer` now use evaluation packages if any are present instead of the random splits.

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#148
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-10-08 19:03:21 +13:00
committed by jebus
parent 36879c266b
commit 22f0bbe10b
3 changed files with 195 additions and 168 deletions

View File

@ -2130,116 +2130,125 @@ class PackageBasedModel(EPModel):
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()} recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
return mg_acc, precision, recall return mg_acc, precision, recall
ds = self.load_dataset() # If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0:
if isinstance(self, RuleBasedRelativeReasoning): eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct())
X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
y = np.array(ds.y(na_replacement=np.nan)) if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else: else:
X = np.array(ds.X(na_replacement=np.nan)) ds = self.load_dataset()
y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20 if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) n_splits = 20
splits = list(shuff.split(X))
from joblib import Parallel, delayed shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits) splits = list(shuff.split(X))
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits))
self.eval_results = self.compute_averages(evaluations) from joblib import Parallel, delayed
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits))
self.eval_results = self.compute_averages(evaluations)
if self.multigen_eval: if self.multigen_eval:
# If there are eval packages perform multi generation evaluation on them instead of random splits
# We have to consider 2 cases here:
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
if self.eval_packages.count() > 0: if self.eval_packages.count() > 0:
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct() pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
evaluate_mg(self.model, pathway_qs, self.threshold) multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
return self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
else:
pathway_qs = Pathway.objects.prefetch_related(
'node_set',
'node_set__out_edges',
'node_set__default_node_label',
'node_set__scenarios',
'edge_set',
'edge_set__start_nodes',
'edge_set__end_nodes',
'edge_set__edge_label',
'edge_set__scenarios'
).filter(package__in=self.data_packages.all()).distinct()
pathway_qs = Pathway.objects.prefetch_related( pathways = []
'node_set', for pathway in pathway_qs:
'node_set__out_edges', # There is one pathway with no root compounds, so this check is required
'node_set__default_node_label', if len(pathway.root_nodes) > 0:
'node_set__scenarios', pathways.append(pathway)
'edge_set', else:
'edge_set__start_nodes', logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
'edge_set__end_nodes',
'edge_set__edge_label',
'edge_set__scenarios'
).filter(package__in=self.data_packages.all()).distinct()
pathways = [] # build lookup reaction -> {uuid1, uuid2} for overlap check
for pathway in pathway_qs: reaction_to_educts = defaultdict(set)
# There is one pathway with no root compounds, so this check is required for pathway in pathways:
if len(pathway.root_nodes) > 0:
pathways.append(pathway)
else:
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
# build lookup reaction -> {uuid1, uuid2} for overlap check
reaction_to_educts = defaultdict(set)
for pathway in pathways:
for reaction in pathway.edges:
for e in reaction.edge_label.educts.all():
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
# build lookup to avoid recalculation of features, labels
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
# Compute splits of the collected pathway
splits = []
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
train_pathways = [pathways[i] for i in train]
test_pathways = [pathways[i] for i in test]
# Collect structures from test pathways
test_educts = set()
for pathway in test_pathways:
for reaction in pathway.edges: for reaction in pathway.edges:
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)]) for e in reaction.edge_label.educts.all():
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
split_ids = [] # build lookup to avoid recalculation of features, labels
overlap = 0 id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
# Collect indices of the structures contained in train pathways iff they're not present in any of
# the test pathways # Compute splits of the collected pathway
for pathway in train_pathways: splits = []
for reaction in pathway.edges: for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]: train_pathways = [pathways[i] for i in train]
# Ensure compounds in the training set do not appear in the test set test_pathways = [pathways[i] for i in test]
if educt not in test_educts:
if educt in id_to_index: # Collect structures from test pathways
split_ids.append(id_to_index[educt]) test_educts = set()
for pathway in test_pathways:
for reaction in pathway.edges:
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
split_ids = []
overlap = 0
# Collect indices of the structures contained in train pathways iff they're not present in any of
# the test pathways
for pathway in train_pathways:
for reaction in pathway.edges:
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
# Ensure compounds in the training set do not appear in the test set
if educt not in test_educts:
if educt in id_to_index:
split_ids.append(id_to_index[educt])
else:
logger.debug(f"Couldn't find features in X for compound {educt}")
else: else:
logger.debug(f"Couldn't find features in X for compound {educt}") overlap += 1
else:
overlap += 1
logging.debug( logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
# Get the rows from the dataset corresponding to compounds in the training set pathways # Get the rows from the dataset corresponding to compounds in the training set pathways
split_x, split_y = X[split_ids], y[split_ids] split_x, split_y = X[split_ids], y[split_ids]
splits.append([(split_x, split_y), test_pathways]) splits.append([(split_x, split_y), test_pathways])
# Build model on subsets obtained by pathway split # Build model on subsets obtained by pathway split
trained_models = Parallel(n_jobs=10)( trained_models = Parallel(n_jobs=10)(
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
) )
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work # Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
multi_ret_vals = Parallel(n_jobs=1)( multi_ret_vals = Parallel(n_jobs=1)(
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
zip(trained_models, splits) zip(trained_models, splits)
) )
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()}) self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
self.model_status = self.FINISHED self.model_status = self.FINISHED
self.save() self.save()
@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel):
def predict_batch(self, smiles_list): def predict_batch(self, smiles_list):
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately # Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list] canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}") logger.info(f"Submitting {canon_smiles} to {self.name}")
products_list = self.model.predict_batch(canon_smiles) products_list = self.model.predict_batch(canon_smiles)
logger.info(f"Got results {products_list}") logger.info(f"Got results {products_list}")
@ -2943,89 +2952,107 @@ class EnviFormer(PackageBasedModel):
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()} recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
return mg_acc, precision, recall return mg_acc, precision, recall
from enviformer.finetune import fine_tune # If there are eval packages perform single generation evaluation on them instead of random splits
ds = self.load_dataset() if self.eval_packages.count() > 0:
n_splits = 20 ds = []
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct():
educts = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in
reaction.products.all()])
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
from enviformer.finetune import fine_tune
ds = self.load_dataset()
n_splits = 20
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint. # this helps reduce the memory footprint.
single_gen_results = [] single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index] train = [ds[i] for i in train_index]
test = [ds[i] for i in test_index] test = [ds[i] for i in test_index]
start = datetime.now() start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now() end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
model.to(s.ENVIFORMER_DEVICE) model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results) self.eval_results = self.compute_averages(single_gen_results)
if self.multigen_eval: if self.multigen_eval:
pathway_qs = Pathway.objects.prefetch_related( if self.eval_packages.count() > 0:
'node_set', pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
'node_set__out_edges', multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
'node_set__default_node_label', self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
'node_set__scenarios', else:
'edge_set', pathway_qs = Pathway.objects.prefetch_related(
'edge_set__start_nodes', 'node_set',
'edge_set__end_nodes', 'node_set__out_edges',
'edge_set__edge_label', 'node_set__default_node_label',
'edge_set__scenarios' 'node_set__scenarios',
).filter(package__in=self.data_packages.all()).distinct() 'edge_set',
'edge_set__start_nodes',
'edge_set__end_nodes',
'edge_set__edge_label',
'edge_set__scenarios'
).filter(package__in=self.data_packages.all()).distinct()
pathways = [] pathways = []
for pathway in pathway_qs: for pathway in pathway_qs:
# There is one pathway with no root compounds, so this check is required # There is one pathway with no root compounds, so this check is required
if len(pathway.root_nodes) > 0: if len(pathway.root_nodes) > 0:
pathways.append(pathway) pathways.append(pathway)
else: else:
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
# build lookup reaction -> {uuid1, uuid2} for overlap check # build lookup reaction -> {uuid1, uuid2} for overlap check
reaction_to_educts = defaultdict(set) reaction_to_educts = defaultdict(set)
for pathway in pathways: for pathway in pathways:
for reaction in pathway.edges:
for e in reaction.edge_label.educts.all():
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
multi_gen_results = []
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
# iteration instead of storing all trained models.
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
train_pathways = [pathways[i] for i in train]
test_pathways = [pathways[i] for i in test]
# Collect structures from test pathways
test_educts = set()
for pathway in test_pathways:
for reaction in pathway.edges: for reaction in pathway.edges:
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)]) for e in reaction.edge_label.educts.all():
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
train_reactions = [] multi_gen_results = []
overlap = 0 # Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
# Collect indices of the structures contained in train pathways iff they're not present in any of # iteration instead of storing all trained models.
# the test pathways for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
for pathway in train_pathways: train_pathways = [pathways[i] for i in train]
for reaction in pathway.edges: test_pathways = [pathways[i] for i in test]
reaction = reaction.edge_label
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
overlap += 1
continue
educts = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
products = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
train_reactions.append(f"{educts}>>{products}")
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()}) # Collect structures from test pathways
test_educts = set()
for pathway in test_pathways:
for reaction in pathway.edges:
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
train_reactions = []
overlap = 0
# Collect indices of the structures contained in train pathways iff they're not present in any of
# the test pathways
for pathway in train_pathways:
for reaction in pathway.edges:
reaction = reaction.edge_label
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
overlap += 1
continue
educts = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
products = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
train_reactions.append(f"{educts}>>{products}")
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
self.model_status = self.FINISHED self.model_status = self.FINISHED
self.save() self.save()

View File

@ -20,7 +20,7 @@ class EnviFormerTest(TestCase):
with self.settings(MODEL_DIR=tmpdir): with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5) threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [] eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold) mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
mod.build_dataset() mod.build_dataset()

View File

@ -24,7 +24,7 @@ class ModelTest(TestCase):
rule_package_objs = [self.BBD_SUBSET] rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [] eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create( mod = MLRelativeReasoning.create(
self.package, self.package,
@ -52,7 +52,7 @@ class ModelTest(TestCase):
mod.build_model() mod.build_model()
mod.multigen_eval = True mod.multigen_eval = True
mod.save() mod.save()
# mod.evaluate_model() mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C') results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')