diff --git a/epdb/models.py b/epdb/models.py index c8322899..206402a3 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -2130,116 +2130,125 @@ class PackageBasedModel(EPModel): recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()} return mg_acc, precision, recall - ds = self.load_dataset() - - if isinstance(self, RuleBasedRelativeReasoning): - X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) - y = np.array(ds.y(na_replacement=np.nan)) + # If there are eval packages perform single generation evaluation on them instead of random splits + if self.eval_packages.count() > 0: + eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()) + ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) + if isinstance(self, RuleBasedRelativeReasoning): + X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) + y = np.array(ds.y(na_replacement=np.nan)) + else: + X = np.array(ds.X(na_replacement=np.nan)) + y = np.array(ds.y(na_replacement=np.nan)) + single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold) + self.eval_results = self.compute_averages([single_gen_result]) else: - X = np.array(ds.X(na_replacement=np.nan)) - y = np.array(ds.y(na_replacement=np.nan)) + ds = self.load_dataset() - n_splits = 20 + if isinstance(self, RuleBasedRelativeReasoning): + X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) + y = np.array(ds.y(na_replacement=np.nan)) + else: + X = np.array(ds.X(na_replacement=np.nan)) + y = np.array(ds.y(na_replacement=np.nan)) - shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) - splits = list(shuff.split(X)) + n_splits = 20 - from joblib import Parallel, delayed - models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits) - evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold) - for model, (_, test_index) in zip(models, splits)) + shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) + splits = list(shuff.split(X)) - self.eval_results = self.compute_averages(evaluations) + from joblib import Parallel, delayed + models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits) + evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold) + for model, (_, test_index) in zip(models, splits)) + + self.eval_results = self.compute_averages(evaluations) if self.multigen_eval: - - # We have to consider 2 cases here: - # 1. No eval packages provided -> Split Train data X times and train and evaluate model - # 2. eval packages provided -> Use the already built model and do evaluation on the set provided. - + # If there are eval packages perform multi generation evaluation on them instead of random splits if self.eval_packages.count() > 0: pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct() - evaluate_mg(self.model, pathway_qs, self.threshold) - return + multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold) + self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()}) + else: + pathway_qs = Pathway.objects.prefetch_related( + 'node_set', + 'node_set__out_edges', + 'node_set__default_node_label', + 'node_set__scenarios', + 'edge_set', + 'edge_set__start_nodes', + 'edge_set__end_nodes', + 'edge_set__edge_label', + 'edge_set__scenarios' + ).filter(package__in=self.data_packages.all()).distinct() - pathway_qs = Pathway.objects.prefetch_related( - 'node_set', - 'node_set__out_edges', - 'node_set__default_node_label', - 'node_set__scenarios', - 'edge_set', - 'edge_set__start_nodes', - 'edge_set__end_nodes', - 'edge_set__edge_label', - 'edge_set__scenarios' - ).filter(package__in=self.data_packages.all()).distinct() + pathways = [] + for pathway in pathway_qs: + # There is one pathway with no root compounds, so this check is required + if len(pathway.root_nodes) > 0: + pathways.append(pathway) + else: + logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") - pathways = [] - for pathway in pathway_qs: - # There is one pathway with no root compounds, so this check is required - if len(pathway.root_nodes) > 0: - pathways.append(pathway) - else: - logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") - - # build lookup reaction -> {uuid1, uuid2} for overlap check - reaction_to_educts = defaultdict(set) - for pathway in pathways: - for reaction in pathway.edges: - for e in reaction.edge_label.educts.all(): - reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid)) - - # build lookup to avoid recalculation of features, labels - id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])} - - # Compute splits of the collected pathway - splits = [] - for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways): - train_pathways = [pathways[i] for i in train] - test_pathways = [pathways[i] for i in test] - - # Collect structures from test pathways - test_educts = set() - for pathway in test_pathways: + # build lookup reaction -> {uuid1, uuid2} for overlap check + reaction_to_educts = defaultdict(set) + for pathway in pathways: for reaction in pathway.edges: - test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)]) + for e in reaction.edge_label.educts.all(): + reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid)) - split_ids = [] - overlap = 0 - # Collect indices of the structures contained in train pathways iff they're not present in any of - # the test pathways - for pathway in train_pathways: - for reaction in pathway.edges: - for educt in reaction_to_educts[str(reaction.edge_label.uuid)]: - # Ensure compounds in the training set do not appear in the test set - if educt not in test_educts: - if educt in id_to_index: - split_ids.append(id_to_index[educt]) + # build lookup to avoid recalculation of features, labels + id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])} + + # Compute splits of the collected pathway + splits = [] + for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways): + train_pathways = [pathways[i] for i in train] + test_pathways = [pathways[i] for i in test] + + # Collect structures from test pathways + test_educts = set() + for pathway in test_pathways: + for reaction in pathway.edges: + test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)]) + + split_ids = [] + overlap = 0 + # Collect indices of the structures contained in train pathways iff they're not present in any of + # the test pathways + for pathway in train_pathways: + for reaction in pathway.edges: + for educt in reaction_to_educts[str(reaction.edge_label.uuid)]: + # Ensure compounds in the training set do not appear in the test set + if educt not in test_educts: + if educt in id_to_index: + split_ids.append(id_to_index[educt]) + else: + logger.debug(f"Couldn't find features in X for compound {educt}") else: - logger.debug(f"Couldn't find features in X for compound {educt}") - else: - overlap += 1 + overlap += 1 - logging.debug( - f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") + logging.debug( + f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") - # Get the rows from the dataset corresponding to compounds in the training set pathways - split_x, split_y = X[split_ids], y[split_ids] - splits.append([(split_x, split_y), test_pathways]) + # Get the rows from the dataset corresponding to compounds in the training set pathways + split_x, split_y = X[split_ids], y[split_ids] + splits.append([(split_x, split_y), test_pathways]) - # Build model on subsets obtained by pathway split - trained_models = Parallel(n_jobs=10)( - delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits - ) + # Build model on subsets obtained by pathway split + trained_models = Parallel(n_jobs=10)( + delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits + ) - # Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work - multi_ret_vals = Parallel(n_jobs=1)( - delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in - zip(trained_models, splits) - ) + # Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work + multi_ret_vals = Parallel(n_jobs=1)( + delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in + zip(trained_models, splits) + ) - self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()}) + self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()}) self.model_status = self.FINISHED self.save() @@ -2761,7 +2770,7 @@ class EnviFormer(PackageBasedModel): def predict_batch(self, smiles_list): # Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list] - logger.info(f"Submitting {canon_smiles} to {hash(self.model)}") + logger.info(f"Submitting {canon_smiles} to {self.name}") products_list = self.model.predict_batch(canon_smiles) logger.info(f"Got results {products_list}") @@ -2943,89 +2952,107 @@ class EnviFormer(PackageBasedModel): recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()} return mg_acc, precision, recall - from enviformer.finetune import fine_tune - ds = self.load_dataset() - n_splits = 20 - shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) + # If there are eval packages perform single generation evaluation on them instead of random splits + if self.eval_packages.count() > 0: + ds = [] + for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct(): + educts = ".".join( + [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()]) + products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in + reaction.products.all()]) + ds.append(f"{educts}>>{products}") + test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds]) + single_gen_result = evaluate_sg(ds, test_result, self.threshold) + self.eval_results = self.compute_averages([single_gen_result]) + else: + from enviformer.finetune import fine_tune + ds = self.load_dataset() + n_splits = 20 + shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) - # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models - # this helps reduce the memory footprint. - single_gen_results = [] - for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): - train = [ds[i] for i in train_index] - test = [ds[i] for i in test_index] - start = datetime.now() - model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) - end = datetime.now() - logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") - model.to(s.ENVIFORMER_DEVICE) - test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) - single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) + # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models + # this helps reduce the memory footprint. + single_gen_results = [] + for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): + train = [ds[i] for i in train_index] + test = [ds[i] for i in test_index] + start = datetime.now() + model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) + end = datetime.now() + logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") + model.to(s.ENVIFORMER_DEVICE) + test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) + single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) - self.eval_results = self.compute_averages(single_gen_results) + self.eval_results = self.compute_averages(single_gen_results) if self.multigen_eval: - pathway_qs = Pathway.objects.prefetch_related( - 'node_set', - 'node_set__out_edges', - 'node_set__default_node_label', - 'node_set__scenarios', - 'edge_set', - 'edge_set__start_nodes', - 'edge_set__end_nodes', - 'edge_set__edge_label', - 'edge_set__scenarios' - ).filter(package__in=self.data_packages.all()).distinct() + if self.eval_packages.count() > 0: + pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct() + multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold) + self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()}) + else: + pathway_qs = Pathway.objects.prefetch_related( + 'node_set', + 'node_set__out_edges', + 'node_set__default_node_label', + 'node_set__scenarios', + 'edge_set', + 'edge_set__start_nodes', + 'edge_set__end_nodes', + 'edge_set__edge_label', + 'edge_set__scenarios' + ).filter(package__in=self.data_packages.all()).distinct() - pathways = [] - for pathway in pathway_qs: - # There is one pathway with no root compounds, so this check is required - if len(pathway.root_nodes) > 0: - pathways.append(pathway) - else: - logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") + pathways = [] + for pathway in pathway_qs: + # There is one pathway with no root compounds, so this check is required + if len(pathway.root_nodes) > 0: + pathways.append(pathway) + else: + logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") - # build lookup reaction -> {uuid1, uuid2} for overlap check - reaction_to_educts = defaultdict(set) - for pathway in pathways: - for reaction in pathway.edges: - for e in reaction.edge_label.educts.all(): - reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid)) - - multi_gen_results = [] - # Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each - # iteration instead of storing all trained models. - for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)): - train_pathways = [pathways[i] for i in train] - test_pathways = [pathways[i] for i in test] - - # Collect structures from test pathways - test_educts = set() - for pathway in test_pathways: + # build lookup reaction -> {uuid1, uuid2} for overlap check + reaction_to_educts = defaultdict(set) + for pathway in pathways: for reaction in pathway.edges: - test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)]) + for e in reaction.edge_label.educts.all(): + reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid)) - train_reactions = [] - overlap = 0 - # Collect indices of the structures contained in train pathways iff they're not present in any of - # the test pathways - for pathway in train_pathways: - for reaction in pathway.edges: - reaction = reaction.edge_label - if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): - overlap += 1 - continue - educts = ".".join( - [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()]) - products = ".".join( - [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()]) - train_reactions.append(f"{educts}>>{products}") - logging.debug( - f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") - model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") - multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) + multi_gen_results = [] + # Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each + # iteration instead of storing all trained models. + for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)): + train_pathways = [pathways[i] for i in train] + test_pathways = [pathways[i] for i in test] - self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()}) + # Collect structures from test pathways + test_educts = set() + for pathway in test_pathways: + for reaction in pathway.edges: + test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)]) + + train_reactions = [] + overlap = 0 + # Collect indices of the structures contained in train pathways iff they're not present in any of + # the test pathways + for pathway in train_pathways: + for reaction in pathway.edges: + reaction = reaction.edge_label + if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): + overlap += 1 + continue + educts = ".".join( + [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()]) + products = ".".join( + [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()]) + train_reactions.append(f"{educts}>>{products}") + logging.debug( + f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") + model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") + multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) + + self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()}) self.model_status = self.FINISHED self.save() diff --git a/tests/test_enviformer.py b/tests/test_enviformer.py index 8b6b368d..536046ad 100644 --- a/tests/test_enviformer.py +++ b/tests/test_enviformer.py @@ -20,7 +20,7 @@ class EnviFormerTest(TestCase): with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] - eval_packages_objs = [] + eval_packages_objs = [self.BBD_SUBSET] mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold) mod.build_dataset() diff --git a/tests/test_model.py b/tests/test_model.py index a4074cda..36a1fd39 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -24,7 +24,7 @@ class ModelTest(TestCase): rule_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET] - eval_packages_objs = [] + eval_packages_objs = [self.BBD_SUBSET] mod = MLRelativeReasoning.create( self.package, @@ -52,7 +52,7 @@ class ModelTest(TestCase): mod.build_model() mod.multigen_eval = True mod.save() - # mod.evaluate_model() + mod.evaluate_model() results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')