forked from enviPath/enviPy
[Feature] Eval package evaluation
`evaluate_model` in `PackageBasedModel` and `EnviFormer` now use evaluation packages if any are present instead of the random splits. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#148 Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
357
epdb/models.py
357
epdb/models.py
@ -2130,116 +2130,125 @@ class PackageBasedModel(EPModel):
|
||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||
return mg_acc, precision, recall
|
||||
|
||||
ds = self.load_dataset()
|
||||
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct())
|
||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
ds = self.load_dataset()
|
||||
|
||||
n_splits = 20
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
splits = list(shuff.split(X))
|
||||
n_splits = 20
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
|
||||
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits))
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
splits = list(shuff.split(X))
|
||||
|
||||
self.eval_results = self.compute_averages(evaluations)
|
||||
from joblib import Parallel, delayed
|
||||
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
|
||||
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits))
|
||||
|
||||
self.eval_results = self.compute_averages(evaluations)
|
||||
|
||||
if self.multigen_eval:
|
||||
|
||||
# We have to consider 2 cases here:
|
||||
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
|
||||
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
|
||||
|
||||
# If there are eval packages perform multi generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
return
|
||||
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||
else:
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
'node_set__default_node_label',
|
||||
'node_set__scenarios',
|
||||
'edge_set',
|
||||
'edge_set__start_nodes',
|
||||
'edge_set__end_nodes',
|
||||
'edge_set__edge_label',
|
||||
'edge_set__scenarios'
|
||||
).filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
'node_set__default_node_label',
|
||||
'node_set__scenarios',
|
||||
'edge_set',
|
||||
'edge_set__start_nodes',
|
||||
'edge_set__end_nodes',
|
||||
'edge_set__edge_label',
|
||||
'edge_set__scenarios'
|
||||
).filter(package__in=self.data_packages.all()).distinct()
|
||||
pathways = []
|
||||
for pathway in pathway_qs:
|
||||
# There is one pathway with no root compounds, so this check is required
|
||||
if len(pathway.root_nodes) > 0:
|
||||
pathways.append(pathway)
|
||||
else:
|
||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||
|
||||
pathways = []
|
||||
for pathway in pathway_qs:
|
||||
# There is one pathway with no root compounds, so this check is required
|
||||
if len(pathway.root_nodes) > 0:
|
||||
pathways.append(pathway)
|
||||
else:
|
||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||
|
||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||
reaction_to_educts = defaultdict(set)
|
||||
for pathway in pathways:
|
||||
for reaction in pathway.edges:
|
||||
for e in reaction.edge_label.educts.all():
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
# build lookup to avoid recalculation of features, labels
|
||||
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
||||
|
||||
# Compute splits of the collected pathway
|
||||
splits = []
|
||||
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
# Collect structures from test pathways
|
||||
test_educts = set()
|
||||
for pathway in test_pathways:
|
||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||
reaction_to_educts = defaultdict(set)
|
||||
for pathway in pathways:
|
||||
for reaction in pathway.edges:
|
||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||
for e in reaction.edge_label.educts.all():
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
split_ids = []
|
||||
overlap = 0
|
||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||
# the test pathways
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
||||
# Ensure compounds in the training set do not appear in the test set
|
||||
if educt not in test_educts:
|
||||
if educt in id_to_index:
|
||||
split_ids.append(id_to_index[educt])
|
||||
# build lookup to avoid recalculation of features, labels
|
||||
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
||||
|
||||
# Compute splits of the collected pathway
|
||||
splits = []
|
||||
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
# Collect structures from test pathways
|
||||
test_educts = set()
|
||||
for pathway in test_pathways:
|
||||
for reaction in pathway.edges:
|
||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||
|
||||
split_ids = []
|
||||
overlap = 0
|
||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||
# the test pathways
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
||||
# Ensure compounds in the training set do not appear in the test set
|
||||
if educt not in test_educts:
|
||||
if educt in id_to_index:
|
||||
split_ids.append(id_to_index[educt])
|
||||
else:
|
||||
logger.debug(f"Couldn't find features in X for compound {educt}")
|
||||
else:
|
||||
logger.debug(f"Couldn't find features in X for compound {educt}")
|
||||
else:
|
||||
overlap += 1
|
||||
overlap += 1
|
||||
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||
|
||||
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
||||
split_x, split_y = X[split_ids], y[split_ids]
|
||||
splits.append([(split_x, split_y), test_pathways])
|
||||
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
||||
split_x, split_y = X[split_ids], y[split_ids]
|
||||
splits.append([(split_x, split_y), test_pathways])
|
||||
|
||||
|
||||
# Build model on subsets obtained by pathway split
|
||||
trained_models = Parallel(n_jobs=10)(
|
||||
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
|
||||
)
|
||||
# Build model on subsets obtained by pathway split
|
||||
trained_models = Parallel(n_jobs=10)(
|
||||
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
|
||||
)
|
||||
|
||||
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
||||
multi_ret_vals = Parallel(n_jobs=1)(
|
||||
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
|
||||
zip(trained_models, splits)
|
||||
)
|
||||
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
||||
multi_ret_vals = Parallel(n_jobs=1)(
|
||||
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
|
||||
zip(trained_models, splits)
|
||||
)
|
||||
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel):
|
||||
def predict_batch(self, smiles_list):
|
||||
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
||||
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
||||
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}")
|
||||
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
||||
products_list = self.model.predict_batch(canon_smiles)
|
||||
logger.info(f"Got results {products_list}")
|
||||
|
||||
@ -2943,89 +2952,107 @@ class EnviFormer(PackageBasedModel):
|
||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||
return mg_acc, precision, recall
|
||||
|
||||
from enviformer.finetune import fine_tune
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
ds = []
|
||||
for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct():
|
||||
educts = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in
|
||||
reaction.products.all()])
|
||||
ds.append(f"{educts}>>{products}")
|
||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
from enviformer.finetune import fine_tune
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
single_gen_results = []
|
||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||
train = [ds[i] for i in train_index]
|
||||
test = [ds[i] for i in test_index]
|
||||
start = datetime.now()
|
||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
model.to(s.ENVIFORMER_DEVICE)
|
||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
single_gen_results = []
|
||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||
train = [ds[i] for i in train_index]
|
||||
test = [ds[i] for i in test_index]
|
||||
start = datetime.now()
|
||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
model.to(s.ENVIFORMER_DEVICE)
|
||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
|
||||
if self.multigen_eval:
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
'node_set__default_node_label',
|
||||
'node_set__scenarios',
|
||||
'edge_set',
|
||||
'edge_set__start_nodes',
|
||||
'edge_set__end_nodes',
|
||||
'edge_set__edge_label',
|
||||
'edge_set__scenarios'
|
||||
).filter(package__in=self.data_packages.all()).distinct()
|
||||
if self.eval_packages.count() > 0:
|
||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||
else:
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
'node_set__default_node_label',
|
||||
'node_set__scenarios',
|
||||
'edge_set',
|
||||
'edge_set__start_nodes',
|
||||
'edge_set__end_nodes',
|
||||
'edge_set__edge_label',
|
||||
'edge_set__scenarios'
|
||||
).filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
pathways = []
|
||||
for pathway in pathway_qs:
|
||||
# There is one pathway with no root compounds, so this check is required
|
||||
if len(pathway.root_nodes) > 0:
|
||||
pathways.append(pathway)
|
||||
else:
|
||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||
pathways = []
|
||||
for pathway in pathway_qs:
|
||||
# There is one pathway with no root compounds, so this check is required
|
||||
if len(pathway.root_nodes) > 0:
|
||||
pathways.append(pathway)
|
||||
else:
|
||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||
|
||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||
reaction_to_educts = defaultdict(set)
|
||||
for pathway in pathways:
|
||||
for reaction in pathway.edges:
|
||||
for e in reaction.edge_label.educts.all():
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
multi_gen_results = []
|
||||
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
||||
# iteration instead of storing all trained models.
|
||||
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
# Collect structures from test pathways
|
||||
test_educts = set()
|
||||
for pathway in test_pathways:
|
||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||
reaction_to_educts = defaultdict(set)
|
||||
for pathway in pathways:
|
||||
for reaction in pathway.edges:
|
||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||
for e in reaction.edge_label.educts.all():
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
train_reactions = []
|
||||
overlap = 0
|
||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||
# the test pathways
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
reaction = reaction.edge_label
|
||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||
overlap += 1
|
||||
continue
|
||||
educts = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||
products = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
|
||||
train_reactions.append(f"{educts}>>{products}")
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||
multi_gen_results = []
|
||||
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
||||
# iteration instead of storing all trained models.
|
||||
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
|
||||
# Collect structures from test pathways
|
||||
test_educts = set()
|
||||
for pathway in test_pathways:
|
||||
for reaction in pathway.edges:
|
||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||
|
||||
train_reactions = []
|
||||
overlap = 0
|
||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||
# the test pathways
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
reaction = reaction.edge_label
|
||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||
overlap += 1
|
||||
continue
|
||||
educts = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||
products = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
|
||||
train_reactions.append(f"{educts}>>{products}")
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
Reference in New Issue
Block a user