from __future__ import annotations import copy import logging from collections import defaultdict from datetime import datetime from pathlib import Path from typing import List, Dict, Set, Tuple, TYPE_CHECKING from abc import ABC, abstractmethod import networkx as nx import numpy as np from numpy.random import default_rng import polars as pl from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.dummy import DummyClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score from sklearn.multioutput import ClassifierChain from sklearn.preprocessing import StandardScaler from utilities.chem import FormatConverter, PredictionResult logger = logging.getLogger(__name__) if TYPE_CHECKING: from epdb.models import Rule, CompoundStructure, Reaction class Dataset(ABC): def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None): if isinstance(data, pl.DataFrame): self.df = data else: if data is not None and len(columns) != len(data[0]): raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}") if columns is None: raise ValueError("Columns can't be None if data is not already a DataFrame") self.df = pl.DataFrame(data=data, schema=columns) def add_rows(self, rows: List[List[str | int | float]]): if len(self.columns) != len(rows[0]): raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}") new_rows = pl.DataFrame(data=rows, schema=self.columns) self.df.extend(new_rows) def add_row(self, row: List[str | int | float]): self.add_rows([row]) def _block_indices(self, prefix) -> Tuple[int, int]: indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) return min(indices), max(indices) @property def columns(self) -> List[str]: return self.df.columns @abstractmethod def X(self): pass @abstractmethod def y(self): pass @staticmethod @abstractmethod def generate_dataset(reactions, *args, **kwargs): pass def at(self, position: int) -> RuleBasedDataset: return RuleBasedDataset(self.columns, self.num_labels, self.df[position]) def __iter__(self): return (self.at(i) for i, _ in enumerate(self.data)) def save(self, path: "Path"): import pickle with open(path, "wb") as fh: pickle.dump(self, fh) @staticmethod def load(path: "str | Path") -> "RuleBasedDataset": import pickle return pickle.load(open(path, "rb")) class NewRuleBasedDataset(Dataset): def __init__(self, num_labels, columns=None, data=None): super().__init__(columns, data) self.num_labels: int = num_labels self.num_features: int = len(self.columns) - self.num_labels def times_triggered(self, rule_uuid) -> int: return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height def struct_features(self) -> Tuple[int, int]: return self._block_indices("feature_") def triggered(self) -> Tuple[int, int]: return self._block_indices("trig_") def observed(self) -> Tuple[int, int]: return self._block_indices("obs_") def X(self): pass def y(self): pass @staticmethod def generate_dataset(reactions, applicable_rules, educts_only=True): _structures = set() for r in reactions: _structures.update(r.educts.all()) if not educts_only: _structures.update(r.products.all()) compounds = sorted(_structures, key=lambda x: x.url) triggered: Dict[str, Set[str]] = defaultdict(set) observed: Set[str] = set() # Apply rules on collected compounds and store tps for i, comp in enumerate(compounds): logger.debug(f"{i + 1}/{len(compounds)}...") for rule in applicable_rules: product_sets = rule.apply(comp.smiles) if len(product_sets) == 0: continue key = f"{rule.uuid} + {comp.uuid}" if key in triggered: logger.info(f"{key} already present. Duplicate reaction?") for prod_set in product_sets: for smi in prod_set: try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception: logger.debug(f"Standardizing SMILES failed for {smi}") triggered[key].add(smi) for i, r in enumerate(reactions): logger.debug(f"{i + 1}/{len(reactions)}...") if len(r.educts.all()) != 1: logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") continue for comp in r.educts.all(): for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" if key not in triggered: continue # standardize products from reactions for comparison standardized_products = [] for cs in r.products.all(): smi = cs.smiles try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception as e: logger.debug(f"Standardizing SMILES failed for {smi}") standardized_products.append(smi) if len(set(standardized_products).difference(triggered[key])) == 0: observed.add(key) ds_columns = (["structure_id"] + [f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] + [f"trig_{r.uuid}" for r in applicable_rules] + [f"obs_{r.uuid}" for r in applicable_rules]) ds = NewRuleBasedDataset(len(applicable_rules), ds_columns) rows = [] for i, comp in enumerate(compounds): # Features feat = FormatConverter.maccs(comp.smiles) trig = [] obs = [] for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" # Check triggered if key in triggered: trig.append(1) else: trig.append(0) # Check obs if key in observed: obs.append(1) elif key not in triggered: obs.append(None) else: obs.append(0) rows.append([str(comp.uuid)] + feat + trig + obs) ds.add_rows(rows) return ds def __getitem__(self, item): pass class RuleBasedDataset(Dataset): def __init__( self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None ): self.columns: List[str] = columns self.num_labels: int = num_labels if data is None: self.data: List[List[str | int | float]] = list() else: self.data = data self.num_features: int = len(columns) - self.num_labels self._struct_features: Tuple[int, int] = self._block_indices("feature_") self._triggered: Tuple[int, int] = self._block_indices("trig_") self._observed: Tuple[int, int] = self._block_indices("obs_") def _block_indices(self, prefix) -> Tuple[int, int]: indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) return min(indices), max(indices) def structure_id(self): return self.data[0][0] def add_row(self, row: List[str | int | float]): if len(self.columns) != len(row): raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") self.data.append(row) def times_triggered(self, rule_uuid) -> int: idx = self.columns.index(f"trig_{rule_uuid}") times_triggered = 0 for row in self.data: if row[idx] == 1: times_triggered += 1 return times_triggered def struct_features(self) -> Tuple[int, int]: return self._struct_features def triggered(self) -> Tuple[int, int]: return self._triggered def observed(self) -> Tuple[int, int]: return self._observed def at(self, position: int) -> RuleBasedDataset: return RuleBasedDataset(self.columns, self.num_labels, [self.data[position]]) def limit(self, limit: int) -> RuleBasedDataset: return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit]) def classification_dataset( self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: if isinstance(struct, str): struct_id = None struct_smiles = struct else: struct_id = str(struct.uuid) struct_smiles = struct.smiles features = FormatConverter.maccs(struct_smiles) trig = [] prods = [] for rule in applicable_rules: products = rule.apply(struct_smiles) if len(products): trig.append(1) prods.append(products) else: trig.append(0) prods.append([]) classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) classify_products.append(prods) return RuleBasedDataset( columns=self.columns, num_labels=self.num_labels, data=classify_data ), classify_products @staticmethod def generate_dataset( reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True ) -> RuleBasedDataset: _structures = set() for r in reactions: for e in r.educts.all(): _structures.add(e) if not educts_only: for e in r.products: _structures.add(e) compounds = sorted(_structures, key=lambda x: x.url) triggered: Dict[str, Set[str]] = defaultdict(set) observed: Set[str] = set() # Apply rules on collected compounds and store tps for i, comp in enumerate(compounds): logger.debug(f"{i + 1}/{len(compounds)}...") for rule in applicable_rules: product_sets = rule.apply(comp.smiles) if len(product_sets) == 0: continue key = f"{rule.uuid} + {comp.uuid}" if key in triggered: logger.info(f"{key} already present. Duplicate reaction?") for prod_set in product_sets: for smi in prod_set: try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception: # :shrug: logger.debug(f"Standardizing SMILES failed for {smi}") pass triggered[key].add(smi) for i, r in enumerate(reactions): logger.debug(f"{i + 1}/{len(reactions)}...") if len(r.educts.all()) != 1: logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") continue for comp in r.educts.all(): for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" if key not in triggered: continue # standardize products from reactions for comparison standardized_products = [] for cs in r.products.all(): smi = cs.smiles try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception as e: # :shrug: logger.debug(f"Standardizing SMILES failed for {smi}") pass standardized_products.append(smi) if len(set(standardized_products).difference(triggered[key])) == 0: observed.add(key) else: pass ds = None for i, comp in enumerate(compounds): # Features feat = FormatConverter.maccs(comp.smiles) trig = [] obs = [] for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" # Check triggered if key in triggered: trig.append(1) else: trig.append(0) # Check obs if key in observed: obs.append(1) elif key not in triggered: obs.append(None) else: obs.append(0) if ds is None: header = ( ["structure_id"] + [f"feature_{i}" for i, _ in enumerate(feat)] + [f"trig_{r.uuid}" for r in applicable_rules] + [f"obs_{r.uuid}" for r in applicable_rules] ) ds = RuleBasedDataset(header, len(applicable_rules)) ds.add_row([str(comp.uuid)] + feat + trig + obs) return ds def X(self, exclude_id_col=True, na_replacement=0): res = self.__getitem__( (slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)) ) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def trig(self, na_replacement=0): res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1]))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def y(self, na_replacement=0): res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def __getitem__(self, key): if not isinstance(key, tuple): raise TypeError("Dataset must be indexed with dataset[rows, columns]") row_key, col_key = key # Normalize rows if isinstance(row_key, int): rows = [self.data[row_key]] else: rows = self.data[row_key] # Normalize columns if isinstance(col_key, int): res = [row[col_key] for row in rows] else: res = [ [row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice) else [row[i] for i in col_key] for row in rows ] return res def to_arff(self, path: "Path"): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff += "\n" for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]: if c == "structure_id": arff += f"@attribute {c} string\n" else: arff += f"@attribute {c} {{0,1}}\n" arff += "\n@data\n" for d in self.data: ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]]) xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]]) arff += f"{ys},{xs}\n" with open(path, "w") as fh: fh.write(arff) fh.flush() def __repr__(self): return ( f"" ) class EnviFormerDataset(Dataset): def __init__(self, educts, products): assert len(educts) == len(products), "Can't have unequal length educts and products" @staticmethod def generate_dataset(reactions, *args, **kwargs): # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently educts = [] products = [] for reaction in reactions: e = ".".join( [ FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all() ] ) p = ".".join( [ FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all() ] ) educts.append(e) products.append(p) return EnviFormerDataset(educts, products) def X(self): pass def y(self): pass def __getitem__(self, item): pass def __len__(self): pass class SparseLabelECC(BaseEstimator, ClassifierMixin): """ Ensemble of Classifier Chains with sparse label removal. Removes labels that are constant across all samples in training. """ def __init__( self, base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42), num_chains: int = 10, ): self.base_clf = base_clf self.num_chains = num_chains def fit(self, X, Y): y = np.array(Y) self.n_labels_ = y.shape[1] self.removed_labels_ = {} self.keep_columns_ = [] for col in range(self.n_labels_): unique_values = np.unique(y[:, col]) if len(unique_values) == 1: self.removed_labels_[col] = unique_values[0] else: self.keep_columns_.append(col) y_reduced = y[:, self.keep_columns_] self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)] for i, chain in enumerate(self.chains_): logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}") chain.fit(X, y_reduced) return self def predict(self, X, threshold=0.5): avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold full_y = np.zeros((avg_preds.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_preds[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = bool(value) return full_y def predict_proba(self, X): avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0) full_y = np.zeros((avg_proba.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_proba[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = float(value) return full_y def score(self, X, Y, sample_weight=None): """ Default scoring using subset accuracy (exact match). """ y_true = np.array(Y) y_pred = self.predict(X) return accuracy_score(y_true, y_pred, sample_weight=sample_weight) class BinaryRelevance: def __init__(self, baseline_clf): self.clf = baseline_clf self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] for label in range(len(Y[0])): X_l = X[~np.isnan(Y[:, label])] Y_l = Y[~np.isnan(Y[:, label]), label] if len(X_l) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy="constant", constant=0) clf.fit([X[0]], [0]) self.classifiers.append(clf) continue elif len(np.unique(Y_l)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy="most_frequent") else: clf = copy.deepcopy(self.clf) clf.fit(X_l, Y_l) self.classifiers.append(clf) def predict(self, X): labels = [] for clf in self.classifiers: labels.append(clf.predict(X)) return np.column_stack(labels) def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(X) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict([X[0]])[0] labels = np.column_stack((labels, pred)) return labels class MissingValuesClassifierChain: def __init__(self, base_clf): self.base_clf = base_clf self.permutation = None self.classifiers = None def fit(self, X, Y): X = np.array(X) Y = np.array(Y) if self.permutation is None: rng = default_rng(42) self.permutation = rng.permutation(len(Y[0])) Y = Y[:, self.permutation] if self.classifiers is None: self.classifiers = [] for p in range(len(self.permutation)): X_p = X[~np.isnan(Y[:, p])] Y_p = Y[~np.isnan(Y[:, p]), p] if len(X_p) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy="constant", constant=0) self.classifiers.append(clf.fit([X[0]], [0])) elif len(np.unique(Y_p)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy="most_frequent") self.classifiers.append(clf.fit(X_p, Y_p)) else: clf = copy.deepcopy(self.base_clf) self.classifiers.append(clf.fit(X_p, Y_p)) newcol = Y[:, p] pred = clf.predict(X) newcol[np.isnan(newcol)] = pred[ np.isnan(newcol) ] # fill in missing values with clf predictions X = np.column_stack((X, newcol)) def predict(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict(np.column_stack((X, labels))) labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(np.column_stack((X, np.round(labels)))) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0] labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] class EnsembleClassifierChain: def __init__(self, base_clf, num_chains=10): self.base_clf = base_clf self.num_chains = num_chains self.num_labels = None self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] if self.num_labels is None: self.num_labels = len(Y[0]) for p in range(self.num_chains): logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") clf = MissingValuesClassifierChain(self.base_clf) clf.fit(X, Y) self.classifiers.append(clf) def predict(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict(X) return np.round(labels / self.num_chains) def predict_proba(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict_proba(X) return labels / self.num_chains class RelativeReasoning: def __init__(self, start_index: int, end_index: int): self.start_index: int = start_index self.end_index: int = end_index self.winmap: Dict[int, List[int]] = defaultdict(list) self.min_count: int = 5 self.max_count: int = 0 def fit(self, X, Y): n_instances = len(Y) n_attributes = len(Y[0]) for i in range(n_attributes): for j in range(n_attributes): if i == j: continue countwin = 0 countloose = 0 countboth = 0 for k in range(n_instances): vi = Y[k][i] vj = Y[k][j] if vi is None or vj is None: continue if vi < vj: countwin += 1 elif vi > vj: countloose += 1 elif vi == vj and vi == 1: # tie countboth += 1 # We've seen more than self.min_count wins, more wins than loosing, no looses and no ties if ( countwin >= self.min_count and countwin > countloose and (countloose <= self.max_count or self.max_count < 0) and countboth == 0 ): self.winmap[i].append(j) def predict(self, X): res = np.zeros((len(X), (self.end_index + 1 - self.start_index))) # Loop through all instances for inst_idx, inst in enumerate(X): # Loop through all "triggered" features for i, t in enumerate(inst[self.start_index : self.end_index + 1]): # Set label res[inst_idx][i] = t # If we predict a 1, check if the rule gets dominated by another if t: # Second loop to check other triggered rules for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]): if i != i2: # Check if rule idx is in "dominated by" list if i2 in self.winmap.get(i, []): # if thatat rule also triggered, it dominated the current # set label to 0 if X[inst_idx][i2]: res[inst_idx][i] = 0 return res def predict_proba(self, X): return self.predict(X) class ApplicabilityDomainPCA(PCA): def __init__(self, num_neighbours: int = 5): super().__init__(n_components=num_neighbours) self.scaler = StandardScaler() self.num_neighbours = num_neighbours self.min_vals = None self.max_vals = None def build(self, train_dataset: "RuleBasedDataset"): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca X_pca = self.fit_transform(X_scaled) self.max_vals = np.max(X_pca, axis=0) self.min_vals = np.min(X_pca, axis=0) def __transform(self, instances): instances_scaled = self.scaler.transform(instances) instances_pca = self.transform(instances_scaled) return instances_pca def is_applicable(self, classify_instances: "RuleBasedDataset"): instances_pca = self.__transform(classify_instances.X()) is_applicable = [] for i, instance in enumerate(instances_pca): is_applicable.append(True) for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance): if not min_v <= new_v <= max_v: is_applicable[i] = False return is_applicable def tanimoto_distance(a: List[int], b: List[int]): if len(a) != len(b): raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}") sum_a = sum(a) sum_b = sum(b) sum_c = sum(v1 and v2 for v1, v2 in zip(a, b)) if sum_a + sum_b - sum_c == 0: return 0.0 return 1 - (sum_c / (sum_a + sum_b - sum_c)) def graph_from_pathway(data): """Convert Pathway or SPathway to networkx""" from epdb.models import Pathway from epdb.logic import SPathway graph = nx.DiGraph() co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation def get_edges(): if isinstance(data, Pathway): return data.edges.all() elif isinstance(data, SPathway): return data.edges else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_sources_targets(): if isinstance(data, Pathway): return [n.node for n in edge.start_nodes.constrained_target.all()], [ n.node for n in edge.end_nodes.constrained_target.all() ] elif isinstance(data, SPathway): return edge.educts, edge.products else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_smiles_depth(node): if isinstance(data, Pathway): return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth elif isinstance(data, SPathway): return FormatConverter.standardize(node.smiles, True), node.depth else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_probability(): try: if isinstance(data, Pathway): return edge.kv.get("probability") elif isinstance(data, SPathway): return edge.probability else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") except AttributeError: return 1 root_smiles = {get_smiles_depth(n) for n in data.root_nodes} for root, depth in root_smiles: graph.add_node(root, depth=depth, smiles=root, root=True) for edge in get_edges(): sources, targets = get_sources_targets() probability = get_probability() for source in sources: source_smiles, source_depth = get_smiles_depth(source) if source_smiles not in graph: graph.add_node( source_smiles, depth=source_depth, smiles=source_smiles, root=source_smiles in root_smiles, ) else: graph.nodes[source_smiles]["depth"] = min( source_depth, graph.nodes[source_smiles]["depth"] ) for target in targets: target_smiles, target_depth = get_smiles_depth(target) if target_smiles not in graph and target_smiles not in co2: graph.add_node( target_smiles, depth=target_depth, smiles=target_smiles, root=target_smiles in root_smiles, ) elif target_smiles not in co2: graph.nodes[target_smiles]["depth"] = min( target_depth, graph.nodes[target_smiles]["depth"] ) if target_smiles not in co2 and target_smiles != source_smiles: graph.add_edge(source_smiles, target_smiles, probability=probability) return graph def get_shortest_path(pathway, in_start_node, in_end_node): try: pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node) except nx.NetworkXNoPath: return [] pred.remove(in_start_node) pred.remove(in_end_node) return pred def set_pathway_eval_weight(pathway): node_eval_weights = {} for node in pathway.nodes: # Scale score according to depth level node_eval_weights[node] = ( 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0 ) return node_eval_weights def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates): if len(intermediates) < 1: return pred_pathway root_nodes = pred_pathway.graph["root_nodes"] for node in pred_pathway.nodes: if node in root_nodes: continue if node in intermediates and node not in data_pathway: pred_pathway.nodes[node]["depth"] = -99 else: shortest_path_list = [] for root_node in root_nodes: shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node) if shortest_path_nodes: shortest_path_list.append(shortest_path_nodes) if shortest_path_list: shortest_path_nodes = min(shortest_path_list, key=len) num_ints = sum( 1 for shortest_path_node in shortest_path_nodes if shortest_path_node in intermediates ) pred_pathway.nodes[node]["depth"] -= num_ints return pred_pathway def initialise_pathway(pathway): """Convert pathway to networkx graph for evaluation""" pathway = graph_from_pathway(pathway) pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0} pathway = get_pathway_with_depth(pathway) return pathway def get_pathway_with_depth(pathway): """Recalculates depths in the pathway. Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at different depths that got merged.""" current_depth = 0 for node in pathway.nodes: if node in pathway.graph["root_nodes"]: pathway.nodes[node]["depth"] = current_depth else: pathway.nodes[node]["depth"] = -99 while assign_next_depth(pathway, current_depth): current_depth += 1 return pathway def assign_next_depth(pathway, current_depth): new_assigned_nodes = False current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth} for node in current_depth_nodes: successors = pathway.successors(node) for s in successors: if pathway.nodes[s]["depth"] < 0: pathway.nodes[s]["depth"] = current_depth + 1 new_assigned_nodes = True return new_assigned_nodes def find_intermediates(data_pathway, pred_pathway): """Find any intermediate nodes in the predicted pathway""" common_nodes = get_common_nodes(pred_pathway, data_pathway) intermediates = set() for node in common_nodes: down_stream_nodes = data_pathway.successors(node) for down_stream_node in down_stream_nodes: if down_stream_node in pred_pathway: all_ints = get_shortest_path(pred_pathway, node, down_stream_node) intermediates.update(all_ints) return intermediates def get_common_nodes(pred_pathway, data_pathway): """A node is a common node if it is in both pathways and is either a root in both or not a root in both.""" common_nodes = set() for node in data_pathway.nodes: is_pathway_root_node = node in data_pathway.graph["root_nodes"] is_this_root_node = node in pred_pathway.graph["root_nodes"] if node in pred_pathway.nodes: if is_pathway_root_node is False and is_this_root_node is False: common_nodes.add(node) elif is_pathway_root_node and is_this_root_node: common_nodes.add(node) return common_nodes def prune_graph(graph, threshold): """ Removes edges with probability below the threshold, then keep the subgraph reachable from the root node. """ while True: try: cycle = nx.find_cycle(graph) graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle except nx.NetworkXNoCycle: break for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold if data["probability"] < threshold: graph.remove_edge(u, v) root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0] reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root reachable.add(root_node) for node in list(graph.nodes): # Remove nodes not reachable from root if node not in reachable: graph.remove_node(node) def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False): """Compare two pathways for multi-gen evaluation. It is assumed the smiles in both pathways have been standardised in the same manner.""" data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) if threshold is not None: prune_graph(pred_pathway, threshold) intermediates = find_intermediates(data_pathway, pred_pathway) if intermediates: pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates) test_pathway_eval_weights = set_pathway_eval_weight(data_pathway) pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway) common_nodes = get_common_nodes(pred_pathway, data_pathway) data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes) pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes) score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 for node in common_nodes: if pred_pathway.nodes[node]["depth"] > 0: score_TP += test_pathway_eval_weights[node] for node in data_only_nodes: if data_pathway.nodes[node]["depth"] > 0: score_FN += test_pathway_eval_weights[node] for node in pred_only_nodes: if pred_pathway.nodes[node]["depth"] > 0: score_FP += pred_pathway_eval_weights[node] final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0 precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0 recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0 if return_intermediates: return final_score, precision, recall, intermediates return final_score, precision, recall def node_subst_cost(node1, node2): if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]: return 0 return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max def node_ins_del_cost(node): return 1 / (2 ** node["depth"]) def pathway_edit_eval(data_pathway, pred_pathway): """Compute the graph edit distance for two pathways, a potential alternative to multigen_eval""" data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0]) return nx.graph_edit_distance( data_pathway, pred_pathway, node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost, node_ins_cost=node_ins_del_cost, roots=roots, )