forked from enviPath/enviPy
[Feature] Eval package evaluation
`evaluate_model` in `PackageBasedModel` and `EnviFormer` now use evaluation packages if any are present instead of the random splits. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#148 Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
357
epdb/models.py
357
epdb/models.py
@ -2130,116 +2130,125 @@ class PackageBasedModel(EPModel):
|
|||||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||||
return mg_acc, precision, recall
|
return mg_acc, precision, recall
|
||||||
|
|
||||||
ds = self.load_dataset()
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||||
|
if self.eval_packages.count() > 0:
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct())
|
||||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
else:
|
||||||
|
X = np.array(ds.X(na_replacement=np.nan))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||||
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
else:
|
else:
|
||||||
X = np.array(ds.X(na_replacement=np.nan))
|
ds = self.load_dataset()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
|
||||||
|
|
||||||
n_splits = 20
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
else:
|
||||||
|
X = np.array(ds.X(na_replacement=np.nan))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
n_splits = 20
|
||||||
splits = list(shuff.split(X))
|
|
||||||
|
|
||||||
from joblib import Parallel, delayed
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||||
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
|
splits = list(shuff.split(X))
|
||||||
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
|
||||||
for model, (_, test_index) in zip(models, splits))
|
|
||||||
|
|
||||||
self.eval_results = self.compute_averages(evaluations)
|
from joblib import Parallel, delayed
|
||||||
|
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
|
||||||
|
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||||
|
for model, (_, test_index) in zip(models, splits))
|
||||||
|
|
||||||
|
self.eval_results = self.compute_averages(evaluations)
|
||||||
|
|
||||||
if self.multigen_eval:
|
if self.multigen_eval:
|
||||||
|
# If there are eval packages perform multi generation evaluation on them instead of random splits
|
||||||
# We have to consider 2 cases here:
|
|
||||||
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
|
|
||||||
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
|
|
||||||
|
|
||||||
if self.eval_packages.count() > 0:
|
if self.eval_packages.count() > 0:
|
||||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
evaluate_mg(self.model, pathway_qs, self.threshold)
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||||
return
|
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||||
|
else:
|
||||||
|
pathway_qs = Pathway.objects.prefetch_related(
|
||||||
|
'node_set',
|
||||||
|
'node_set__out_edges',
|
||||||
|
'node_set__default_node_label',
|
||||||
|
'node_set__scenarios',
|
||||||
|
'edge_set',
|
||||||
|
'edge_set__start_nodes',
|
||||||
|
'edge_set__end_nodes',
|
||||||
|
'edge_set__edge_label',
|
||||||
|
'edge_set__scenarios'
|
||||||
|
).filter(package__in=self.data_packages.all()).distinct()
|
||||||
|
|
||||||
pathway_qs = Pathway.objects.prefetch_related(
|
pathways = []
|
||||||
'node_set',
|
for pathway in pathway_qs:
|
||||||
'node_set__out_edges',
|
# There is one pathway with no root compounds, so this check is required
|
||||||
'node_set__default_node_label',
|
if len(pathway.root_nodes) > 0:
|
||||||
'node_set__scenarios',
|
pathways.append(pathway)
|
||||||
'edge_set',
|
else:
|
||||||
'edge_set__start_nodes',
|
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||||
'edge_set__end_nodes',
|
|
||||||
'edge_set__edge_label',
|
|
||||||
'edge_set__scenarios'
|
|
||||||
).filter(package__in=self.data_packages.all()).distinct()
|
|
||||||
|
|
||||||
pathways = []
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||||
for pathway in pathway_qs:
|
reaction_to_educts = defaultdict(set)
|
||||||
# There is one pathway with no root compounds, so this check is required
|
for pathway in pathways:
|
||||||
if len(pathway.root_nodes) > 0:
|
|
||||||
pathways.append(pathway)
|
|
||||||
else:
|
|
||||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
|
||||||
|
|
||||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
|
||||||
reaction_to_educts = defaultdict(set)
|
|
||||||
for pathway in pathways:
|
|
||||||
for reaction in pathway.edges:
|
|
||||||
for e in reaction.edge_label.educts.all():
|
|
||||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
|
||||||
|
|
||||||
# build lookup to avoid recalculation of features, labels
|
|
||||||
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
|
||||||
|
|
||||||
# Compute splits of the collected pathway
|
|
||||||
splits = []
|
|
||||||
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
|
|
||||||
train_pathways = [pathways[i] for i in train]
|
|
||||||
test_pathways = [pathways[i] for i in test]
|
|
||||||
|
|
||||||
# Collect structures from test pathways
|
|
||||||
test_educts = set()
|
|
||||||
for pathway in test_pathways:
|
|
||||||
for reaction in pathway.edges:
|
for reaction in pathway.edges:
|
||||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
for e in reaction.edge_label.educts.all():
|
||||||
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||||
|
|
||||||
split_ids = []
|
# build lookup to avoid recalculation of features, labels
|
||||||
overlap = 0
|
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
||||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
|
||||||
# the test pathways
|
# Compute splits of the collected pathway
|
||||||
for pathway in train_pathways:
|
splits = []
|
||||||
for reaction in pathway.edges:
|
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
|
||||||
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
train_pathways = [pathways[i] for i in train]
|
||||||
# Ensure compounds in the training set do not appear in the test set
|
test_pathways = [pathways[i] for i in test]
|
||||||
if educt not in test_educts:
|
|
||||||
if educt in id_to_index:
|
# Collect structures from test pathways
|
||||||
split_ids.append(id_to_index[educt])
|
test_educts = set()
|
||||||
|
for pathway in test_pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||||
|
|
||||||
|
split_ids = []
|
||||||
|
overlap = 0
|
||||||
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||||
|
# the test pathways
|
||||||
|
for pathway in train_pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
||||||
|
# Ensure compounds in the training set do not appear in the test set
|
||||||
|
if educt not in test_educts:
|
||||||
|
if educt in id_to_index:
|
||||||
|
split_ids.append(id_to_index[educt])
|
||||||
|
else:
|
||||||
|
logger.debug(f"Couldn't find features in X for compound {educt}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Couldn't find features in X for compound {educt}")
|
overlap += 1
|
||||||
else:
|
|
||||||
overlap += 1
|
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||||
|
|
||||||
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
||||||
split_x, split_y = X[split_ids], y[split_ids]
|
split_x, split_y = X[split_ids], y[split_ids]
|
||||||
splits.append([(split_x, split_y), test_pathways])
|
splits.append([(split_x, split_y), test_pathways])
|
||||||
|
|
||||||
|
|
||||||
# Build model on subsets obtained by pathway split
|
# Build model on subsets obtained by pathway split
|
||||||
trained_models = Parallel(n_jobs=10)(
|
trained_models = Parallel(n_jobs=10)(
|
||||||
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
|
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
||||||
multi_ret_vals = Parallel(n_jobs=1)(
|
multi_ret_vals = Parallel(n_jobs=1)(
|
||||||
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
|
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
|
||||||
zip(trained_models, splits)
|
zip(trained_models, splits)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
|
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
|
||||||
|
|
||||||
self.model_status = self.FINISHED
|
self.model_status = self.FINISHED
|
||||||
self.save()
|
self.save()
|
||||||
@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
def predict_batch(self, smiles_list):
|
def predict_batch(self, smiles_list):
|
||||||
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
||||||
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
||||||
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}")
|
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
||||||
products_list = self.model.predict_batch(canon_smiles)
|
products_list = self.model.predict_batch(canon_smiles)
|
||||||
logger.info(f"Got results {products_list}")
|
logger.info(f"Got results {products_list}")
|
||||||
|
|
||||||
@ -2943,89 +2952,107 @@ class EnviFormer(PackageBasedModel):
|
|||||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||||
return mg_acc, precision, recall
|
return mg_acc, precision, recall
|
||||||
|
|
||||||
from enviformer.finetune import fine_tune
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||||
ds = self.load_dataset()
|
if self.eval_packages.count() > 0:
|
||||||
n_splits = 20
|
ds = []
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct():
|
||||||
|
educts = ".".join(
|
||||||
|
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||||
|
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in
|
||||||
|
reaction.products.all()])
|
||||||
|
ds.append(f"{educts}>>{products}")
|
||||||
|
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||||
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||||
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
|
else:
|
||||||
|
from enviformer.finetune import fine_tune
|
||||||
|
ds = self.load_dataset()
|
||||||
|
n_splits = 20
|
||||||
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||||
|
|
||||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||||
# this helps reduce the memory footprint.
|
# this helps reduce the memory footprint.
|
||||||
single_gen_results = []
|
single_gen_results = []
|
||||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||||
train = [ds[i] for i in train_index]
|
train = [ds[i] for i in train_index]
|
||||||
test = [ds[i] for i in test_index]
|
test = [ds[i] for i in test_index]
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||||
model.to(s.ENVIFORMER_DEVICE)
|
model.to(s.ENVIFORMER_DEVICE)
|
||||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||||
|
|
||||||
self.eval_results = self.compute_averages(single_gen_results)
|
self.eval_results = self.compute_averages(single_gen_results)
|
||||||
|
|
||||||
if self.multigen_eval:
|
if self.multigen_eval:
|
||||||
pathway_qs = Pathway.objects.prefetch_related(
|
if self.eval_packages.count() > 0:
|
||||||
'node_set',
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
'node_set__out_edges',
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||||
'node_set__default_node_label',
|
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()})
|
||||||
'node_set__scenarios',
|
else:
|
||||||
'edge_set',
|
pathway_qs = Pathway.objects.prefetch_related(
|
||||||
'edge_set__start_nodes',
|
'node_set',
|
||||||
'edge_set__end_nodes',
|
'node_set__out_edges',
|
||||||
'edge_set__edge_label',
|
'node_set__default_node_label',
|
||||||
'edge_set__scenarios'
|
'node_set__scenarios',
|
||||||
).filter(package__in=self.data_packages.all()).distinct()
|
'edge_set',
|
||||||
|
'edge_set__start_nodes',
|
||||||
|
'edge_set__end_nodes',
|
||||||
|
'edge_set__edge_label',
|
||||||
|
'edge_set__scenarios'
|
||||||
|
).filter(package__in=self.data_packages.all()).distinct()
|
||||||
|
|
||||||
pathways = []
|
pathways = []
|
||||||
for pathway in pathway_qs:
|
for pathway in pathway_qs:
|
||||||
# There is one pathway with no root compounds, so this check is required
|
# There is one pathway with no root compounds, so this check is required
|
||||||
if len(pathway.root_nodes) > 0:
|
if len(pathway.root_nodes) > 0:
|
||||||
pathways.append(pathway)
|
pathways.append(pathway)
|
||||||
else:
|
else:
|
||||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||||
|
|
||||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||||
reaction_to_educts = defaultdict(set)
|
reaction_to_educts = defaultdict(set)
|
||||||
for pathway in pathways:
|
for pathway in pathways:
|
||||||
for reaction in pathway.edges:
|
|
||||||
for e in reaction.edge_label.educts.all():
|
|
||||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
|
||||||
|
|
||||||
multi_gen_results = []
|
|
||||||
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
|
||||||
# iteration instead of storing all trained models.
|
|
||||||
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
|
|
||||||
train_pathways = [pathways[i] for i in train]
|
|
||||||
test_pathways = [pathways[i] for i in test]
|
|
||||||
|
|
||||||
# Collect structures from test pathways
|
|
||||||
test_educts = set()
|
|
||||||
for pathway in test_pathways:
|
|
||||||
for reaction in pathway.edges:
|
for reaction in pathway.edges:
|
||||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
for e in reaction.edge_label.educts.all():
|
||||||
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||||
|
|
||||||
train_reactions = []
|
multi_gen_results = []
|
||||||
overlap = 0
|
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
||||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
# iteration instead of storing all trained models.
|
||||||
# the test pathways
|
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
|
||||||
for pathway in train_pathways:
|
train_pathways = [pathways[i] for i in train]
|
||||||
for reaction in pathway.edges:
|
test_pathways = [pathways[i] for i in test]
|
||||||
reaction = reaction.edge_label
|
|
||||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
|
||||||
overlap += 1
|
|
||||||
continue
|
|
||||||
educts = ".".join(
|
|
||||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
|
||||||
products = ".".join(
|
|
||||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
|
|
||||||
train_reactions.append(f"{educts}>>{products}")
|
|
||||||
logging.debug(
|
|
||||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
|
||||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
|
||||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
|
||||||
|
|
||||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
|
# Collect structures from test pathways
|
||||||
|
test_educts = set()
|
||||||
|
for pathway in test_pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||||
|
|
||||||
|
train_reactions = []
|
||||||
|
overlap = 0
|
||||||
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||||
|
# the test pathways
|
||||||
|
for pathway in train_pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
reaction = reaction.edge_label
|
||||||
|
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||||
|
overlap += 1
|
||||||
|
continue
|
||||||
|
educts = ".".join(
|
||||||
|
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||||
|
products = ".".join(
|
||||||
|
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
|
||||||
|
train_reactions.append(f"{educts}>>{products}")
|
||||||
|
logging.debug(
|
||||||
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||||
|
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||||
|
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||||
|
|
||||||
|
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
|
||||||
|
|
||||||
self.model_status = self.FINISHED
|
self.model_status = self.FINISHED
|
||||||
self.save()
|
self.save()
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class EnviFormerTest(TestCase):
|
|||||||
with self.settings(MODEL_DIR=tmpdir):
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = []
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
|
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
|
||||||
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class ModelTest(TestCase):
|
|||||||
|
|
||||||
rule_package_objs = [self.BBD_SUBSET]
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = []
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
|
||||||
mod = MLRelativeReasoning.create(
|
mod = MLRelativeReasoning.create(
|
||||||
self.package,
|
self.package,
|
||||||
@ -52,7 +52,7 @@ class ModelTest(TestCase):
|
|||||||
mod.build_model()
|
mod.build_model()
|
||||||
mod.multigen_eval = True
|
mod.multigen_eval = True
|
||||||
mod.save()
|
mod.save()
|
||||||
# mod.evaluate_model()
|
mod.evaluate_model()
|
||||||
|
|
||||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user