From ac5d370b18201e4bfbb7b0c19fa33388c6e2d95b Mon Sep 17 00:00:00 2001 From: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:58:16 +1300 Subject: [PATCH] new RuleBasedDataset and EnviFormer dataset working for respective models #120 --- epdb/models.py | 64 ++++++++++++++----------------------- tests/test_dataset.py | 45 +++++++++++++++++--------- tests/test_enviformer.py | 2 +- tests/test_model.py | 47 ++++++++++++++++++--------- utilities/ml.py | 69 +++++++++++++++++++++++----------------- 5 files changed, 126 insertions(+), 101 deletions(-) diff --git a/epdb/models.py b/epdb/models.py index e3e66157..fec94028 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -2225,7 +2225,7 @@ class PackageBasedModel(EPModel): self.model_status = self.BUILT_NOT_EVALUATED self.save() - def evaluate_model(self): + def evaluate_model(self, **kwargs): if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") @@ -2354,18 +2354,18 @@ class PackageBasedModel(EPModel): X = np.array(ds.X(na_replacement=np.nan)) y = np.array(ds.y(na_replacement=np.nan)) - n_splits = 20 + n_splits = kwargs.get("n_splits", 20) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) splits = list(shuff.split(X)) from joblib import Parallel, delayed - models = Parallel(n_jobs=10)( + models = Parallel(n_jobs=min(10, len(splits)))( delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits ) - evaluations = Parallel(n_jobs=10)( + evaluations = Parallel(n_jobs=min(10, len(splits)))( delayed(evaluate_sg)(model, X, y, test_index, self.threshold) for model, (_, test_index) in zip(models, splits) ) @@ -2716,7 +2716,7 @@ class MLRelativeReasoning(PackageBasedModel): start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) - pred = self.model.predict_proba(classify_ds.X()) + pred = self.model.predict_proba(np.array(classify_ds.X())) res = MLRelativeReasoning.combine_products_and_probs( self.applicable_rules, pred[0], classify_prods[0] @@ -3096,7 +3096,7 @@ class EnviFormer(PackageBasedModel): ds.save(f) return ds - def load_dataset(self) -> "RuleBasedDataset": + def load_dataset(self): ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") return EnviFormerDataset.load(ds_path) @@ -3105,7 +3105,7 @@ class EnviFormer(PackageBasedModel): from enviformer.finetune import fine_tune start = datetime.now() - model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) + model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) end = datetime.now() logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") return model @@ -3121,19 +3121,19 @@ class EnviFormer(PackageBasedModel): args = {"clz": "EnviFormer"} return args - def evaluate_model(self): + def evaluate_model(self, **kwargs): if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") self.model_status = self.EVALUATING self.save() - def evaluate_sg(test_reactions, predictions, model_thresh): + def evaluate_sg(test_ds, predictions, model_thresh): # Group the true products of reactions with the same reactant together - assert len(test_reactions) == len(predictions) + assert len(test_ds) == len(predictions) true_dict = {} - for r in test_reactions: - reactant, true_product_set = r.split(">>") + for r in test_ds: + reactant, true_product_set = r true_product_set = {p for p in true_product_set.split(".")} true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] @@ -3141,7 +3141,7 @@ class EnviFormer(PackageBasedModel): pred_dict = {} for k, pred in enumerate(predictions): pred_smiles, pred_proba = zip(*pred.items()) - reactant, true_product = test_reactions[k].split(">>") + reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"] pred_dict.setdefault(reactant, {"predict": [], "scores": []}) for smiles, proba in zip(pred_smiles, pred_proba): smiles = set(smiles.split(".")) @@ -3176,7 +3176,7 @@ class EnviFormer(PackageBasedModel): break # Recall is TP (correct) / TP + FN (len(test_reactions)) - rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()} + rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()} # Precision is TP (correct) / TP + FP (predicted) prec = { f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items() @@ -3257,30 +3257,30 @@ class EnviFormer(PackageBasedModel): if self.eval_packages.count() > 0: ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter( package__in=self.eval_packages.all()).distinct()) - test_result = self.model.predict_batch(ds) + test_result = self.model.predict_batch(ds.X()) single_gen_result = evaluate_sg(ds, test_result, self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: from enviformer.finetune import fine_tune ds = self.load_dataset() - n_splits = 20 + n_splits = kwargs.get("n_splits", 20) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models # this helps reduce the memory footprint. single_gen_results = [] for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): - train = [ds[i] for i in train_index] - test = [ds[i] for i in test_index] + train = ds[train_index] + test = ds[test_index] start = datetime.now() - model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) + model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) end = datetime.now() logger.debug( f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" ) model.to(s.ENVIFORMER_DEVICE) - test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) + test_result = model.predict_batch(test.X()) single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) self.eval_results = self.compute_averages(single_gen_results) @@ -3351,31 +3351,15 @@ class EnviFormer(PackageBasedModel): for pathway in train_pathways: for reaction in pathway.edges: reaction = reaction.edge_label - if any( - [ - educt in test_educts - for educt in reaction_to_educts[str(reaction.uuid)] - ] - ): + if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): overlap += 1 continue - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - train_reactions.append(f"{educts}>>{products}") + train_reactions.append(reaction) + train_ds = EnviFormerDataset.generate_dataset(train_reactions) logging.debug( f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" ) - model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") + model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}") multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) self.eval_results.update( diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 962ce400..9bdb29ed 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,8 +1,9 @@ +import os.path +from tempfile import TemporaryDirectory from django.test import TestCase - from epdb.logic import PackageManager from epdb.models import Reaction, Compound, User, Rule, Package -from utilities.ml import RuleBasedDataset +from utilities.ml import RuleBasedDataset, EnviFormerDataset class DatasetTest(TestCase): @@ -45,11 +46,11 @@ class DatasetTest(TestCase): def test_generate_dataset(self): """Test generating dataset does not crash""" - self.generate_dataset() + self.generate_rule_dataset() def test_indexing(self): """Test indexing a few different ways to check for crashes""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds[5]) print(ds[2, 5]) print(ds[3:6, 2:8]) @@ -57,45 +58,45 @@ class DatasetTest(TestCase): def test_add_rows(self): """Test adding one row and adding multiple rows""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() ds.add_row(list(ds.df.row(1))) ds.add_rows([list(ds.df.row(i)) for i in range(5)]) def test_times_triggered(self): """Check getting times triggered for a rule id""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds.times_triggered(rules[0].uuid)) def test_block_indices(self): """Test the usages of _block_indices""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds.struct_features()) print(ds.triggered()) print(ds.observed()) def test_structure_id(self): """Check getting a structure id from row index""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds.structure_id(0)) def test_x(self): """Test getting X portion of the dataframe""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds.X().df.head()) def test_trig(self): """Test getting the triggered portion of the dataframe""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds.trig().df.head()) def test_y(self): """Test getting the Y portion of the dataframe""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() print(ds.y().df.head()) def test_classification_dataset(self): """Test making the classification dataset""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)] class_ds, products = ds.classification_dataset(compounds, rules) print(class_ds.df.head(5)) @@ -103,12 +104,16 @@ class DatasetTest(TestCase): def test_to_arff(self): """Test exporting the arff version of the dataset""" - ds, reactions, rules = self.generate_dataset() + ds, reactions, rules = self.generate_rule_dataset() ds.to_arff("dataset_arff_test.arff") def test_save_load(self): """Test saving and loading dataset""" - ds, reactions, rules = self.generate_dataset() + with TemporaryDirectory() as tmpdir: + ds, reactions, rules = self.generate_rule_dataset() + ds.save(os.path.join(tmpdir, "save_dataset.pkl")) + ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl")) + self.assertTrue(ds.df.equals(ds_loaded.df)) def test_dataset_example(self): """Test with a concrete example checking dataset size""" @@ -120,9 +125,19 @@ class DatasetTest(TestCase): self.assertEqual(len(ds.y()), 1) self.assertEqual(ds.y().df.item(), 1) - def generate_dataset(self): + def test_enviformer_dataset(self): + ds, reactions = self.generate_enviformer_dataset() + print(ds.X().head()) + print(ds.y().head()) + + def generate_rule_dataset(self): """Generate a RuleBasedDataset from test package data""" reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) return ds, reactions, applicable_rules + + def generate_enviformer_dataset(self): + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + ds = EnviFormerDataset.generate_dataset(reactions) + return ds, reactions diff --git a/tests/test_enviformer.py b/tests/test_enviformer.py index b81ca2ca..28f24126 100644 --- a/tests/test_enviformer.py +++ b/tests/test_enviformer.py @@ -50,7 +50,7 @@ class EnviFormerTest(TestCase): mod.build_model() mod.multigen_eval = True mod.save() - mod.evaluate_model() + mod.evaluate_model(n_splits=2) mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") diff --git a/tests/test_model.py b/tests/test_model.py index e46046ec..f3a5f3c7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,7 +4,7 @@ import numpy as np from django.test import TestCase from epdb.logic import PackageManager -from epdb.models import User, MLRelativeReasoning, Package +from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning class ModelTest(TestCase): @@ -17,7 +17,7 @@ class ModelTest(TestCase): cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.BBD_SUBSET = Package.objects.get(name="Fixtures") - def test_smoke(self): + def test_mlrr(self): with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) @@ -36,23 +36,11 @@ class ModelTest(TestCase): description="Created MLRelativeReasoning in Testcase", ) - # mod = RuleBasedRelativeReasoning.create( - # self.package, - # rule_package_objs, - # data_package_objs, - # eval_packages_objs, - # threshold=threshold, - # min_count=5, - # max_count=0, - # name='ECC - BBD - 0.5', - # description='Created MLRelativeReasoning in Testcase', - # ) - mod.build_dataset() mod.build_model() mod.multigen_eval = True mod.save() - mod.evaluate_model() + mod.evaluate_model(n_splits=2) results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") @@ -73,3 +61,32 @@ class ModelTest(TestCase): # from pprint import pprint # pprint(mod.eval_results) + + def test_rbrr(self): + with TemporaryDirectory() as tmpdir: + with self.settings(MODEL_DIR=tmpdir): + threshold = float(0.5) + + rule_package_objs = [self.BBD_SUBSET] + data_package_objs = [self.BBD_SUBSET] + eval_packages_objs = [self.BBD_SUBSET] + + mod = RuleBasedRelativeReasoning.create( + self.package, + rule_package_objs, + data_package_objs, + eval_packages_objs, + threshold=threshold, + min_count=5, + max_count=0, + name='ECC - BBD - 0.5', + description='Created MLRelativeReasoning in Testcase', + ) + + mod.build_dataset() + mod.build_model() + mod.multigen_eval = True + mod.save() + mod.evaluate_model(n_splits=2) + + results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") diff --git a/utilities/ml.py b/utilities/ml.py index 9521cc39..e680046b 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -65,6 +65,10 @@ class Dataset(ABC): """Use the polars dataframe columns""" return self.df.columns + @property + def shape(self): + return self.df.shape + @abstractmethod def X(self, **kwargs): pass @@ -123,13 +127,15 @@ class Dataset(ABC): class RuleBasedDataset(Dataset): def __init__(self, num_labels=None, columns=None, data=None): super().__init__(columns, data) - self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c]) + # Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init. + self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c]) self.num_features: int = len(self.columns) - self.num_labels self._struct_features: Tuple[int, int] = self._block_indices("feature_") self._triggered: Tuple[int, int] = self._block_indices("trig_") self._observed: Tuple[int, int] = self._block_indices("obs_") def times_triggered(self, rule_uuid) -> int: + """Count how many times a rule is triggered by the number of rows with one in the rules trig column""" return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height def struct_features(self) -> Tuple[int, int]: @@ -142,21 +148,25 @@ class RuleBasedDataset(Dataset): return self._observed def structure_id(self, index: int): + """Get the UUID of a compound""" return self.df.item(index, "structure_id") def X(self, exclude_id_col=True, na_replacement=0): + """Get all the feature and trig columns""" res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels] if na_replacement is not None: res.df = res.df.fill_null(na_replacement) return res def trig(self, na_replacement=0): + """Get all the trig columns""" res = self[:, self._triggered[0]: self._triggered[1]] if na_replacement is not None: res.df = res.df.fill_null(na_replacement) return res def y(self, na_replacement=0): + """Get all the obs columns""" res = self[:, len(self.columns) - self.num_labels:] if na_replacement is not None: res.df = res.df.fill_null(na_replacement) @@ -164,7 +174,7 @@ class RuleBasedDataset(Dataset): @staticmethod def generate_dataset(reactions, applicable_rules, educts_only=True): - _structures = set() + _structures = set() # Get all the structures for r in reactions: _structures.update(r.educts.all()) if not educts_only: @@ -254,8 +264,12 @@ class RuleBasedDataset(Dataset): ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] + if isinstance(structures[0], str): + struct_smiles = structures[0] + else: + struct_smiles = structures[0].smiles ds_columns = (["structure_id"] + - [f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] + + [f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] + [f"trig_{r.uuid}" for r in applicable_rules] + [f"obs_{r.uuid}" for r in applicable_rules]) for struct in structures: @@ -306,43 +320,38 @@ class RuleBasedDataset(Dataset): class EnviFormerDataset(Dataset): - def __init__(self, educts, products): - super().__init__() - assert len(educts) == len(products), "Can't have unequal length educts and products" + def __init__(self, columns=None, data=None): + super().__init__(columns, data) + + def X(self): + """Return the educts""" + return self["educts"] + + def y(self): + """Return the products""" + return self["products"] @staticmethod def generate_dataset(reactions, *args, **kwargs): - # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently - educts = [] - products = [] + # Standardise reactions for the training data + stereo = kwargs.get("stereo", False) + rows = [] for reaction in reactions: e = ".".join( [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) + FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) for smile in reaction.educts.all() ] ) p = ".".join( [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) + FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) for smile in reaction.products.all() ] ) - educts.append(e) - products.append(p) - return EnviFormerDataset(educts, products) - - def X(self): - pass - - def y(self): - pass - - def __getitem__(self, item): - pass - - def __len__(self): - pass + rows.append([e, p]) + ds = EnviFormerDataset(["educts", "products"], rows) + return ds class SparseLabelECC(BaseEstimator, ClassifierMixin): @@ -524,7 +533,7 @@ class EnsembleClassifierChain: self.classifiers = [] if self.num_labels is None: - self.num_labels = len(Y[0]) + self.num_labels = Y.shape[1] for p in range(self.num_chains): logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") @@ -555,7 +564,7 @@ class RelativeReasoning: def fit(self, X, Y): n_instances = len(Y) - n_attributes = len(Y[0]) + n_attributes = Y.shape[1] for i in range(n_attributes): for j in range(n_attributes): @@ -567,8 +576,8 @@ class RelativeReasoning: countboth = 0 for k in range(n_instances): - vi = Y[k][i] - vj = Y[k][j] + vi = Y[k, i] + vj = Y[k, j] if vi is None or vj is None: continue