From 8166df6f39d4cafaf70b0db8d85481e38f746fe7 Mon Sep 17 00:00:00 2001 From: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:40:26 +1300 Subject: [PATCH] work towards #120 --- epdb/models.py | 50 ++---------- utilities/ml.py | 212 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 203 insertions(+), 59 deletions(-) diff --git a/epdb/models.py b/epdb/models.py index c74540a1..4611d187 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning +from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset logger = logging.getLogger(__name__) @@ -3088,35 +3088,17 @@ class EnviFormer(PackageBasedModel): self.save() start = datetime.now() - # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently - ds = [] - for reaction in self._get_reactions(): - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - ds.append(f"{educts}>>{products}") + ds = EnviFormerDataset.generate_dataset(self._get_reactions()) end = datetime.now() logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") - with open(f, "w") as d_file: - json.dump(ds, d_file) + ds.save(f) return ds def load_dataset(self) -> "RuleBasedDataset": ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") - with open(ds_path) as d_file: - ds = json.load(d_file) - return ds + return EnviFormerDataset.load(ds_path) def _fit_model(self, ds): # Call to enviFormer's fine_tune function and return the model @@ -3148,13 +3130,12 @@ class EnviFormer(PackageBasedModel): def evaluate_sg(test_reactions, predictions, model_thresh): # Group the true products of reactions with the same reactant together + assert len(test_reactions) == len(predictions) true_dict = {} for r in test_reactions: reactant, true_product_set = r.split(">>") true_product_set = {p for p in true_product_set.split(".")} true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] - assert len(test_reactions) == len(predictions) - assert sum(len(v) for v in true_dict.values()) == len(test_reactions) # Group the predicted products of reactions with the same reactant together pred_dict = {} @@ -3274,24 +3255,9 @@ class EnviFormer(PackageBasedModel): # If there are eval packages perform single generation evaluation on them instead of random splits if self.eval_packages.count() > 0: - ds = [] - for reaction in Reaction.objects.filter( - package__in=self.eval_packages.all() - ).distinct(): - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - ds.append(f"{educts}>>{products}") - test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds]) + ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter( + package__in=self.eval_packages.all()).distinct()) + test_result = self.model.predict_batch(ds) single_gen_result = evaluate_sg(ds, test_result, self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: diff --git a/utilities/ml.py b/utilities/ml.py index 8d66f393..f75656b5 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -11,6 +11,7 @@ from abc import ABC, abstractmethod import networkx as nx import numpy as np from numpy.random import default_rng +import polars as pl from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.dummy import DummyClassifier @@ -28,6 +29,37 @@ if TYPE_CHECKING: class Dataset(ABC): + def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None): + if isinstance(data, pl.DataFrame): + self.df = data + else: + if data is not None and len(columns) != len(data[0]): + raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}") + if columns is None: + raise ValueError("Columns can't be None if data is not already a DataFrame") + self.df = pl.DataFrame(data=data, schema=columns) + + def add_rows(self, rows: List[List[str | int | float]]): + if len(self.columns) != len(rows[0]): + raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}") + new_rows = pl.DataFrame(data=rows, schema=self.columns) + self.df.extend(new_rows) + + def add_row(self, row: List[str | int | float]): + self.add_rows([row]) + + def _block_indices(self, prefix) -> Tuple[int, int]: + indices: List[int] = [] + for i, feature in enumerate(self.columns): + if feature.startswith(prefix): + indices.append(i) + + return min(indices), max(indices) + + @property + def columns(self) -> List[str]: + return self.df.columns + @abstractmethod def X(self): pass @@ -36,7 +68,143 @@ class Dataset(ABC): def y(self): pass + @staticmethod @abstractmethod + def generate_dataset(reactions, *args, **kwargs): + pass + + def at(self, position: int) -> RuleBasedDataset: + return RuleBasedDataset(self.columns, self.num_labels, self.df[position]) + + def __iter__(self): + return (self.at(i) for i, _ in enumerate(self.data)) + + def save(self, path: "Path"): + import pickle + + with open(path, "wb") as fh: + pickle.dump(self, fh) + + @staticmethod + def load(path: "str | Path") -> "RuleBasedDataset": + import pickle + + return pickle.load(open(path, "rb")) + + +class NewRuleBasedDataset(Dataset): + def __init__(self, num_labels, columns=None, data=None): + super().__init__(columns, data) + self.num_labels: int = num_labels + self.num_features: int = len(self.columns) - self.num_labels + + def times_triggered(self, rule_uuid) -> int: + return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height + + def struct_features(self) -> Tuple[int, int]: + return self._block_indices("feature_") + + def triggered(self) -> Tuple[int, int]: + return self._block_indices("trig_") + + def observed(self) -> Tuple[int, int]: + return self._block_indices("obs_") + + def X(self): + pass + + def y(self): + pass + + @staticmethod + def generate_dataset(reactions, applicable_rules, educts_only=True): + _structures = set() + for r in reactions: + _structures.update(r.educts.all()) + if not educts_only: + _structures.update(r.products.all()) + + compounds = sorted(_structures, key=lambda x: x.url) + triggered: Dict[str, Set[str]] = defaultdict(set) + observed: Set[str] = set() + + # Apply rules on collected compounds and store tps + for i, comp in enumerate(compounds): + logger.debug(f"{i + 1}/{len(compounds)}...") + + for rule in applicable_rules: + product_sets = rule.apply(comp.smiles) + if len(product_sets) == 0: + continue + + key = f"{rule.uuid} + {comp.uuid}" + if key in triggered: + logger.info(f"{key} already present. Duplicate reaction?") + + for prod_set in product_sets: + for smi in prod_set: + try: + smi = FormatConverter.standardize(smi, remove_stereo=True) + except Exception: + logger.debug(f"Standardizing SMILES failed for {smi}") + triggered[key].add(smi) + + for i, r in enumerate(reactions): + logger.debug(f"{i + 1}/{len(reactions)}...") + + if len(r.educts.all()) != 1: + logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") + continue + + for comp in r.educts.all(): + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + if key not in triggered: + continue + + # standardize products from reactions for comparison + standardized_products = [] + for cs in r.products.all(): + smi = cs.smiles + try: + smi = FormatConverter.standardize(smi, remove_stereo=True) + except Exception as e: + logger.debug(f"Standardizing SMILES failed for {smi}") + standardized_products.append(smi) + if len(set(standardized_products).difference(triggered[key])) == 0: + observed.add(key) + + ds_columns = (["structure_id"] + + [f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] + + [f"trig_{r.uuid}" for r in applicable_rules] + + [f"obs_{r.uuid}" for r in applicable_rules]) + ds = NewRuleBasedDataset(len(applicable_rules), ds_columns) + rows = [] + + for i, comp in enumerate(compounds): + # Features + feat = FormatConverter.maccs(comp.smiles) + trig = [] + obs = [] + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + # Check triggered + if key in triggered: + trig.append(1) + else: + trig.append(0) + # Check obs + if key in observed: + obs.append(1) + elif key not in triggered: + obs.append(None) + else: + obs.append(0) + rows.append([str(comp.uuid)] + feat + trig + obs) + ds.add_rows(rows) + return ds + + def __getitem__(self, item): pass @@ -99,9 +267,6 @@ class RuleBasedDataset(Dataset): def limit(self, limit: int) -> RuleBasedDataset: return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit]) - def __iter__(self): - return (self.at(i) for i, _ in enumerate(self.data)) - def classification_dataset( self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: @@ -297,18 +462,6 @@ class RuleBasedDataset(Dataset): return res - def save(self, path: "Path"): - import pickle - - with open(path, "wb") as fh: - pickle.dump(self, fh) - - @staticmethod - def load(path: "Path") -> "RuleBasedDataset": - import pickle - - return pickle.load(open(path, "rb")) - def to_arff(self, path: "Path"): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff += "\n" @@ -335,8 +488,30 @@ class RuleBasedDataset(Dataset): class EnviFormerDataset(Dataset): - def __init__(self): - pass + def __init__(self, educts, products): + assert len(educts) == len(products), "Can't have unequal length educts and products" + + @staticmethod + def generate_dataset(reactions, *args, **kwargs): + # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently + educts = [] + products = [] + for reaction in reactions: + e = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.educts.all() + ] + ) + p = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.products.all() + ] + ) + educts.append(e) + products.append(p) + return EnviFormerDataset(educts, products) def X(self): pass @@ -347,6 +522,9 @@ class EnviFormerDataset(Dataset): def __getitem__(self, item): pass + def __len__(self): + pass + class SparseLabelECC(BaseEstimator, ClassifierMixin): """