from __future__ import annotations import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime from typing import List, Dict, Set, Tuple import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score from sklearn.multioutput import ClassifierChain from sklearn.preprocessing import StandardScaler logger = logging.getLogger(__name__) from dataclasses import dataclass, field from utilities.chem import FormatConverter, PredictionResult @dataclass class SCompound: smiles: str uuid: str = field(default=None, compare=False, hash=False) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash(( self.smiles )) return self._hash @dataclass class SReaction: educts: List[SCompound] products: List[SCompound] rule_uuid: SRule = field(default=None, compare=False, hash=False) reaction_uuid: str = field(default=None, compare=False, hash=False) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash(( tuple(sorted(self.educts, key=lambda x: x.smiles)), tuple(sorted(self.products, key=lambda x: x.smiles)), )) return self._hash def __eq__(self, other): if not isinstance(other, SReaction): return NotImplemented return ( sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles) ) @dataclass class SRule(ABC): @abstractmethod def apply(self): pass @dataclass class SSimpleRule: pass @dataclass class SParallelRule: pass class Dataset: def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None): self.columns: List[str] = columns self.num_labels: int = num_labels if data is None: self.data: List[List[str | int | float]] = list() else: self.data = data self.num_features: int = len(columns) - self.num_labels self._struct_features: Tuple[int, int] = self._block_indices('feature_') self._triggered: Tuple[int, int] = self._block_indices('trig_') self._observed: Tuple[int, int] = self._block_indices('obs_') def _block_indices(self, prefix) -> Tuple[int, int]: indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) return min(indices), max(indices) def structure_id(self): return self.data[0][0] def add_row(self, row: List[str | int | float]): if len(self.columns) != len(row): raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") self.data.append(row) def times_triggered(self, rule_uuid) -> int: idx = self.columns.index(f'trig_{rule_uuid}') times_triggered = 0 for row in self.data: if row[idx] == 1: times_triggered += 1 return times_triggered def struct_features(self) -> Tuple[int, int]: return self._struct_features def triggered(self) -> Tuple[int, int]: return self._triggered def observed(self) -> Tuple[int, int]: return self._observed def at(self, position: int) -> Dataset: return Dataset(self.columns, self.num_labels, [self.data[position]]) def limit(self, limit: int) -> Dataset: return Dataset(self.columns, self.num_labels, self.data[:limit]) def __iter__(self): return (self.at(i) for i, _ in enumerate(self.data)) def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: if isinstance(struct, str): struct_id = None struct_smiles = struct else: struct_id = str(struct.uuid) struct_smiles = struct.smiles features = FormatConverter.maccs(struct_smiles) trig = [] prods = [] for rule in applicable_rules: products = rule.apply(struct_smiles) if len(products): trig.append(1) prods.append(products) else: trig.append(0) prods.append([]) classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) classify_products.append(prods) return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products @staticmethod def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset: _structures = set() for r in reactions: for e in r.educts.all(): _structures.add(e) if not educts_only: for e in r.products: _structures.add(e) compounds = sorted(_structures, key=lambda x: x.url) triggered: Dict[str, Set[str]] = defaultdict(set) observed: Set[str] = set() # Apply rules on collected compounds and store tps for i, comp in enumerate(compounds): logger.debug(f"{i + 1}/{len(compounds)}...") for rule in applicable_rules: product_sets = rule.apply(comp.smiles) if len(product_sets) == 0: continue key = f"{rule.uuid} + {comp.uuid}" if key in triggered: logger.info(f"{key} already present. Duplicate reaction?") for prod_set in product_sets: for smi in prod_set: try: smi = FormatConverter.standardize(smi) except Exception: # :shrug: logger.debug(f'Standardizing SMILES failed for {smi}') pass triggered[key].add(smi) for i, r in enumerate(reactions): logger.debug(f"{i + 1}/{len(reactions)}...") if len(r.educts.all()) != 1: logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") continue for comp in r.educts.all(): for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" if key not in triggered: continue # standardize products from reactions for comparison standardized_products = [] for cs in r.products.all(): smi = cs.smiles try: smi = FormatConverter.standardize(smi) except Exception as e: # :shrug: logger.debug(f'Standardizing SMILES failed for {smi}') pass standardized_products.append(smi) if len(set(standardized_products).difference(triggered[key])) == 0: observed.add(key) else: pass ds = None for i, comp in enumerate(compounds): # Features feat = FormatConverter.maccs(comp.smiles) trig = [] obs = [] for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" # Check triggered if key in triggered: trig.append(1) else: trig.append(0) # Check obs if key in observed: obs.append(1) elif key not in triggered: obs.append(None) else: obs.append(0) if ds is None: header = ['structure_id'] + \ [f'feature_{i}' for i, _ in enumerate(feat)] \ + [f'trig_{r.uuid}' for r in applicable_rules] \ + [f'obs_{r.uuid}' for r in applicable_rules] ds = Dataset(header, len(applicable_rules)) ds.add_row([str(comp.uuid)] + feat + trig + obs) return ds def X(self, exclude_id_col=True, na_replacement=0): res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def y(self, na_replacement=0): res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def __getitem__(self, key): if not isinstance(key, tuple): raise TypeError("Dataset must be indexed with dataset[rows, columns]") row_key, col_key = key # Normalize rows if isinstance(row_key, int): rows = [self.data[row_key]] else: rows = self.data[row_key] # Normalize columns if isinstance(col_key, int): res = [row[col_key] for row in rows] else: res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice) else [row[i] for i in col_key] for row in rows] return res def save(self, path: 'Path'): import pickle with open(path, "wb") as fh: pickle.dump(self, fh) @staticmethod def load(path: 'Path'): import pickle return pickle.load(open(path, "rb")) def to_arff(self, path: 'Path'): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff += "\n" for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]: if c == 'structure_id': arff += f"@attribute {c} string\n" else: arff += f"@attribute {c} {{0,1}}\n" arff += f"\n@data\n" for d in self.data: ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]]) xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]]) arff += f'{ys},{xs}\n' with open(path, "w") as fh: fh.write(arff) fh.flush() def __repr__(self): return f"" class SparseLabelECC(BaseEstimator, ClassifierMixin): """ Ensemble of Classifier Chains with sparse label removal. Removes labels that are constant across all samples in training. """ def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42), num_chains: int = 10): self.base_clf = base_clf self.num_chains = num_chains def fit(self, X, Y): y = np.array(Y) self.n_labels_ = y.shape[1] self.removed_labels_ = {} self.keep_columns_ = [] for col in range(self.n_labels_): unique_values = np.unique(y[:, col]) if len(unique_values) == 1: self.removed_labels_[col] = unique_values[0] else: self.keep_columns_.append(col) y_reduced = y[:, self.keep_columns_] self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)] for i, chain in enumerate(self.chains_): print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}") chain.fit(X, y_reduced) return self def predict(self, X, threshold=0.5): avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold full_y = np.zeros((avg_preds.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_preds[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = bool(value) return full_y def predict_proba(self, X): avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0) full_y = np.zeros((avg_proba.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_proba[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = float(value) return full_y def score(self, X, Y, sample_weight=None): """ Default scoring using subset accuracy (exact match). """ y_true = np.array(Y) y_pred = self.predict(X) return accuracy_score(y_true, y_pred, sample_weight=sample_weight) import copy import numpy as np from sklearn.dummy import DummyClassifier from sklearn.tree import DecisionTreeClassifier class BinaryRelevance: def __init__(self, baseline_clf): self.clf = baseline_clf self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] for l in range(len(Y[0])): X_l = X[~np.isnan(Y[:, l])] Y_l = (Y[~np.isnan(Y[:, l]), l]) if len(X_l) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy='constant', constant=0) clf.fit([X[0]], [0]) self.classifiers.append(clf) continue elif len(np.unique(Y_l)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy='most_frequent') else: clf = copy.deepcopy(self.clf) clf.fit(X_l, Y_l) self.classifiers.append(clf) def predict(self, X): labels = [] for clf in self.classifiers: labels.append(clf.predict(X)) return np.column_stack(labels) def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(X) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict([X[0]])[0] labels = np.column_stack((labels, pred)) return labels class MissingValuesClassifierChain: def __init__(self, base_clf): self.base_clf = base_clf self.permutation = None self.classifiers = None def fit(self, X, Y): X = np.array(X) Y = np.array(Y) if self.permutation is None: self.permutation = np.random.permutation(len(Y[0])) Y = Y[:, self.permutation] if self.classifiers is None: self.classifiers = [] for p in range(len(self.permutation)): X_p = X[~np.isnan(Y[:, p])] Y_p = Y[~np.isnan(Y[:, p]), p] if len(X_p) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy='constant', constant=0) self.classifiers.append(clf.fit([X[0]], [0])) elif len(np.unique(Y_p)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy='most_frequent') self.classifiers.append(clf.fit(X_p, Y_p)) else: clf = copy.deepcopy(self.base_clf) self.classifiers.append(clf.fit(X_p, Y_p)) newcol = Y[:, p] pred = clf.predict(X) newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions X = np.column_stack((X, newcol)) def predict(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict(np.column_stack((X, labels))) labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(np.column_stack((X, np.round(labels)))) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0] labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] class EnsembleClassifierChain: def __init__(self, base_clf, num_chains=10): self.base_clf = base_clf self.num_chains = num_chains self.num_labels = None self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] if self.num_labels is None: self.num_labels = len(Y[0]) for p in range(self.num_chains): print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") clf = MissingValuesClassifierChain(self.base_clf) clf.fit(X, Y) self.classifiers.append(clf) def predict(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict(X) return np.round(labels / self.num_chains) def predict_proba(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict_proba(X) return labels / self.num_chains class ApplicabilityDomainPCA(PCA): def __init__(self, num_neighbours: int = 5): super().__init__(n_components=num_neighbours) self.scaler = StandardScaler() self.num_neighbours = num_neighbours self.min_vals = None self.max_vals = None def build(self, train_dataset: 'Dataset'): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca X_pca = self.fit_transform(X_scaled) self.max_vals = np.max(X_pca, axis=0) self.min_vals = np.min(X_pca, axis=0) def __transform(self, instances): instances_scaled = self.scaler.transform(instances) instances_pca = self.transform(instances_scaled) return instances_pca def is_applicable(self, classify_instances: 'Dataset'): instances_pca = self.__transform(classify_instances.X()) is_applicable = [] for i, instance in enumerate(instances_pca): is_applicable.append(True) for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance): if not min_v <= new_v <= max_v: is_applicable[i] = False return is_applicable def tanimoto_distance(a: List[int], b: List[int]): if len(a) != len(b): raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}") sum_a = sum(a) sum_b = sum(b) sum_c = sum(v1 and v2 for v1, v2 in zip(a, b)) if sum_a + sum_b - sum_c == 0: return 0.0 return 1 - (sum_c / (sum_a + sum_b - sum_c))