from __future__ import annotations import copy import logging from collections import defaultdict from datetime import datetime from pathlib import Path from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable from abc import ABC, abstractmethod import networkx as nx import numpy as np from envipy_plugins import Descriptor from numpy.random import default_rng import polars as pl from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.dummy import DummyClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score from sklearn.multioutput import ClassifierChain from sklearn.preprocessing import StandardScaler from utilities.chem import FormatConverter, PredictionResult logger = logging.getLogger(__name__) if TYPE_CHECKING: from epdb.models import Rule, CompoundStructure, Reaction class Dataset(ABC): def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None): if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__ self.df = data else: # Build either an empty dataframe with columns or fill it with list of list data if data is not None and len(columns) != len(data[0]): raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns") if columns is None: raise ValueError("Columns can't be None if data is not already a DataFrame") self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None) def add_rows(self, rows: List[List[str | int | float]]): """Add rows to the dataset. Extends the polars dataframe stored in self""" if len(self.columns) != len(rows[0]): raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns") new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None) self.df.extend(new_rows) def add_row(self, row: List[str | int | float]): """See add_rows""" self.add_rows([row]) def block_indices(self, prefix) -> List[int]: """Find the indexes in column labels that has the prefix""" indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) return indices @property def columns(self) -> List[str]: """Use the polars dataframe columns""" return self.df.columns @property def shape(self): return self.df.shape @abstractmethod def X(self, **kwargs): pass @abstractmethod def y(self, **kwargs): pass @staticmethod @abstractmethod def generate_dataset(reactions, *args, **kwargs): pass def __iter__(self): """Use polars iter_rows for iterating over the dataset""" return self.df.iter_rows() def __getitem__(self, item): """Item is passed to polars allowing for advanced indexing. See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__""" res = self.df[item] if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe return self.__class__(data=res) else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item return res def save(self, path: "Path | str"): import pickle with open(path, "wb") as fh: pickle.dump(self, fh) @staticmethod def load(path: "str | Path") -> "Dataset": import pickle return pickle.load(open(path, "rb")) def to_numpy(self): return self.df.to_numpy() def __repr__(self): return ( f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>" ) def __len__(self): return len(self.df) def iter_rows(self, named=False): return self.df.iter_rows(named=named) def filter(self, *predicates, **constraints): return self.__class__(data=self.df.filter(*predicates, **constraints)) def select(self, *exprs, **named_exprs): return self.__class__(data=self.df.select(*exprs, **named_exprs)) def with_columns(self, *exprs, **name_exprs): return self.__class__(data=self.df.with_columns(*exprs, **name_exprs)) def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False): return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last, multithreaded=multithreaded, maintain_order=maintain_order)) def item(self, row=None, column=None): return self.df.item(row, column) def fill_nan(self, value): return self.__class__(data=self.df.fill_nan(value)) @property def height(self): return self.df.height class RuleBasedDataset(Dataset): def __init__(self, num_labels=None, columns=None, data=None): super().__init__(columns, data) # Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init. self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c]) # Pre-calculate the ids of columns for features/labels, useful later in X and y self._struct_features: List[int] = self.block_indices("feature_") self._triggered: List[int] = self.block_indices("trig_") self._observed: List[int] = self.block_indices("obs_") self.feature_cols: List[int] = self._struct_features + self._triggered self.num_features: int = len(self.feature_cols) self.has_probs = False def times_triggered(self, rule_uuid) -> int: """Count how many times a rule is triggered by the number of rows with one in the rules trig column""" return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height def struct_features(self) -> List[int]: return self._struct_features def triggered(self) -> List[int]: return self._triggered def observed(self) -> List[int]: return self._observed def structure_id(self, index: int): """Get the UUID of a compound""" return self.item(index, "structure_id") def X(self, exclude_id_col=True, na_replacement=0): """Get all the feature and trig columns""" _col_ids = self.feature_cols if not exclude_id_col: _col_ids = [0] + _col_ids res = self[:, _col_ids] if na_replacement is not None: res.df = res.df.fill_null(na_replacement) return res def trig(self, na_replacement=0): """Get all the trig columns""" res = self[:, self._triggered] if na_replacement is not None: res.df = res.df.fill_null(na_replacement) return res def y(self, na_replacement=0): """Get all the obs columns""" res = self[:, self._observed] if na_replacement is not None: res.df = res.df.fill_null(na_replacement) return res @staticmethod def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None): if feat_funcs is None: feat_funcs = [FormatConverter.maccs] _structures = set() # Get all the structures for r in reactions: _structures.update(r.educts.all()) if not educts_only: _structures.update(r.products.all()) compounds = sorted(_structures, key=lambda x: x.url) triggered: Dict[str, Set[str]] = defaultdict(set) observed: Set[str] = set() # Apply rules on collected compounds and store tps for i, comp in enumerate(compounds): logger.debug(f"{i + 1}/{len(compounds)}...") for rule in applicable_rules: product_sets = rule.apply(comp.smiles) if len(product_sets) == 0: continue key = f"{rule.uuid} + {comp.uuid}" if key in triggered: logger.info(f"{key} already present. Duplicate reaction?") for prod_set in product_sets: for smi in prod_set: try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception: logger.debug(f"Standardizing SMILES failed for {smi}") triggered[key].add(smi) for i, r in enumerate(reactions): logger.debug(f"{i + 1}/{len(reactions)}...") if len(r.educts.all()) != 1: logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") continue for comp in r.educts.all(): for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" if key not in triggered: continue # standardize products from reactions for comparison standardized_products = [] for cs in r.products.all(): smi = cs.smiles try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception as e: logger.debug(f"Standardizing SMILES failed for {smi}") standardized_products.append(smi) if len(set(standardized_products).difference(triggered[key])) == 0: observed.add(key) feat_columns = [] for feat_func in feat_funcs: if isinstance(feat_func, Descriptor): feats = feat_func.get_molecule_descriptors(compounds[0].smiles) else: feats = feat_func(compounds[0].smiles) start_i = len(feat_columns) feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)]) ds_columns = (["structure_id"] + feat_columns + [f"trig_{r.uuid}" for r in applicable_rules] + [f"obs_{r.uuid}" for r in applicable_rules]) rows = [] for i, comp in enumerate(compounds): # Features feats = [] for feat_func in feat_funcs: if isinstance(feat_func, Descriptor): feat = feat_func.get_molecule_descriptors(comp.smiles) else: feat = feat_func(comp.smiles) feats.extend(feat) trig = [] obs = [] for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" # Check triggered if key in triggered: trig.append(1) else: trig.append(0) # Check obs if key in observed: obs.append(1) elif key not in triggered: obs.append(None) else: obs.append(0) rows.append([str(comp.uuid)] + feats + trig + obs) ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows) return ds def classification_dataset( self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: if isinstance(struct, str): struct_id = None struct_smiles = struct else: struct_id = str(struct.uuid) struct_smiles = struct.smiles features = FormatConverter.maccs(struct_smiles) trig = [] prods = [] for rule in applicable_rules: products = rule.apply(struct_smiles) if len(products): trig.append(1) prods.append(products) else: trig.append(0) prods.append([]) new_row = [struct_id] + features + trig + ([-1] * len(trig)) if self.has_probs: new_row += [-1] * len(trig) classify_data.append(new_row) classify_products.append(prods) ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data) return ds, classify_products def add_probs(self, probs): col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed] self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)]) self.has_probs = True def to_arff(self, path: "Path"): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff += "\n" for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]: if c == "structure_id": arff += f"@attribute {c} string\n" else: arff += f"@attribute {c} {{0,1}}\n" arff += "\n@data\n" for d in self: ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]]) xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]]) arff += f"{ys},{xs}\n" with open(path, "w") as fh: fh.write(arff) fh.flush() class EnviFormerDataset(Dataset): def __init__(self, columns=None, data=None): super().__init__(columns, data) def X(self): """Return the educts""" return self["educts"] def y(self): """Return the products""" return self["products"] @staticmethod def generate_dataset(reactions, *args, **kwargs): # Standardise reactions for the training data stereo = kwargs.get("stereo", False) rows = [] for reaction in reactions: e = ".".join( [ FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) for smile in reaction.educts.all() ] ) p = ".".join( [ FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) for smile in reaction.products.all() ] ) rows.append([e, p]) ds = EnviFormerDataset(["educts", "products"], rows) return ds class SparseLabelECC(BaseEstimator, ClassifierMixin): """ Ensemble of Classifier Chains with sparse label removal. Removes labels that are constant across all samples in training. """ def __init__( self, base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42), num_chains: int = 10, ): self.base_clf = base_clf self.num_chains = num_chains def fit(self, X, Y): y = np.array(Y) self.n_labels_ = y.shape[1] self.removed_labels_ = {} self.keep_columns_ = [] for col in range(self.n_labels_): unique_values = np.unique(y[:, col]) if len(unique_values) == 1: self.removed_labels_[col] = unique_values[0] else: self.keep_columns_.append(col) y_reduced = y[:, self.keep_columns_] self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)] for i, chain in enumerate(self.chains_): logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}") chain.fit(X, y_reduced) return self def predict(self, X, threshold=0.5): avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold full_y = np.zeros((avg_preds.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_preds[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = bool(value) return full_y def predict_proba(self, X): avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0) full_y = np.zeros((avg_proba.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_proba[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = float(value) return full_y def score(self, X, Y, sample_weight=None): """ Default scoring using subset accuracy (exact match). """ y_true = np.array(Y) y_pred = self.predict(X) return accuracy_score(y_true, y_pred, sample_weight=sample_weight) class BinaryRelevance: def __init__(self, baseline_clf): self.clf = baseline_clf self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] for label in range(len(Y[0])): X_l = X[~np.isnan(Y[:, label])] Y_l = Y[~np.isnan(Y[:, label]), label] if len(X_l) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy="constant", constant=0) clf.fit([X[0]], [0]) self.classifiers.append(clf) continue elif len(np.unique(Y_l)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy="most_frequent") else: clf = copy.deepcopy(self.clf) clf.fit(X_l, Y_l) self.classifiers.append(clf) def predict(self, X): labels = [] for clf in self.classifiers: labels.append(clf.predict(X)) return np.column_stack(labels) def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(X) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict([X[0]])[0] labels = np.column_stack((labels, pred)) return labels class MissingValuesClassifierChain: def __init__(self, base_clf): self.base_clf = base_clf self.permutation = None self.classifiers = None def fit(self, X, Y): X = np.array(X) Y = np.array(Y) if self.permutation is None: rng = default_rng(42) self.permutation = rng.permutation(len(Y[0])) Y = Y[:, self.permutation] if self.classifiers is None: self.classifiers = [] for p in range(len(self.permutation)): X_p = X[~np.isnan(Y[:, p])] Y_p = Y[~np.isnan(Y[:, p]), p] if len(X_p) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy="constant", constant=0) self.classifiers.append(clf.fit([X[0]], [0])) elif len(np.unique(Y_p)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy="most_frequent") self.classifiers.append(clf.fit(X_p, Y_p)) else: clf = copy.deepcopy(self.base_clf) self.classifiers.append(clf.fit(X_p, Y_p)) newcol = Y[:, p] pred = clf.predict(X) newcol[np.isnan(newcol)] = pred[ np.isnan(newcol) ] # fill in missing values with clf predictions X = np.column_stack((X, newcol)) def predict(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict(np.column_stack((X, labels))) labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(np.column_stack((X, np.round(labels)))) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0] labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] class EnsembleClassifierChain: def __init__(self, base_clf, num_chains=10): self.base_clf = base_clf self.num_chains = num_chains self.num_labels = None self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] if self.num_labels is None: self.num_labels = Y.shape[1] for p in range(self.num_chains): logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") clf = MissingValuesClassifierChain(self.base_clf) clf.fit(X, Y) self.classifiers.append(clf) def predict(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict(X) return np.round(labels / self.num_chains) def predict_proba(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict_proba(X) return labels / self.num_chains class RelativeReasoning: def __init__(self, start_index: int, end_index: int): self.start_index: int = start_index self.end_index: int = end_index self.winmap: Dict[int, List[int]] = defaultdict(list) self.min_count: int = 5 self.max_count: int = 0 def fit(self, X, Y): n_instances = len(Y) n_attributes = Y.shape[1] for i in range(n_attributes): for j in range(n_attributes): if i == j: continue countwin = 0 countloose = 0 countboth = 0 for k in range(n_instances): vi = Y[k, i] vj = Y[k, j] if vi is None or vj is None: continue if vi < vj: countwin += 1 elif vi > vj: countloose += 1 elif vi == vj and vi == 1: # tie countboth += 1 # We've seen more than self.min_count wins, more wins than loosing, no looses and no ties if ( countwin >= self.min_count and countwin > countloose and (countloose <= self.max_count or self.max_count < 0) and countboth == 0 ): self.winmap[i].append(j) def predict(self, X): res = np.zeros((len(X), (self.end_index + 1 - self.start_index))) # Loop through all instances for inst_idx, inst in enumerate(X): # Loop through all "triggered" features for i, t in enumerate(inst[self.start_index : self.end_index + 1]): # Set label res[inst_idx][i] = t # If we predict a 1, check if the rule gets dominated by another if t: # Second loop to check other triggered rules for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]): if i != i2: # Check if rule idx is in "dominated by" list if i2 in self.winmap.get(i, []): # if thatat rule also triggered, it dominated the current # set label to 0 if X[inst_idx][i2]: res[inst_idx][i] = 0 return res def predict_proba(self, X): return self.predict(X) class ApplicabilityDomainPCA(PCA): def __init__(self, num_neighbours: int = 5): super().__init__(n_components=num_neighbours) self.scaler = StandardScaler() self.num_neighbours = num_neighbours self.min_vals = None self.max_vals = None def build(self, train_dataset: "RuleBasedDataset"): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca X_pca = self.fit_transform(X_scaled) self.max_vals = np.max(X_pca, axis=0) self.min_vals = np.min(X_pca, axis=0) def __transform(self, instances): instances_scaled = self.scaler.transform(instances) instances_pca = self.transform(instances_scaled) return instances_pca def is_applicable(self, classify_instances: "RuleBasedDataset"): instances_pca = self.__transform(classify_instances.X()) is_applicable = [] for i, instance in enumerate(instances_pca): is_applicable.append(True) for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance): if not min_v <= new_v <= max_v: is_applicable[i] = False return is_applicable def tanimoto_distance(a: List[int], b: List[int]): if len(a) != len(b): raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}") sum_a = sum(a) sum_b = sum(b) sum_c = sum(v1 and v2 for v1, v2 in zip(a, b)) if sum_a + sum_b - sum_c == 0: return 0.0 return 1 - (sum_c / (sum_a + sum_b - sum_c)) def graph_from_pathway(data): """Convert Pathway or SPathway to networkx""" from epdb.models import Pathway from epdb.logic import SPathway graph = nx.DiGraph() co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation def get_edges(): if isinstance(data, Pathway): return data.edges.all() elif isinstance(data, SPathway): return data.edges else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_sources_targets(): if isinstance(data, Pathway): return [n.node for n in edge.start_nodes.constrained_target.all()], [ n.node for n in edge.end_nodes.constrained_target.all() ] elif isinstance(data, SPathway): return edge.educts, edge.products else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_smiles_depth(node): if isinstance(data, Pathway): return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth elif isinstance(data, SPathway): return FormatConverter.standardize(node.smiles, True), node.depth else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_probability(): try: if isinstance(data, Pathway): return edge.kv.get("probability") elif isinstance(data, SPathway): return edge.probability else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") except AttributeError: return 1 root_smiles = {get_smiles_depth(n) for n in data.root_nodes} for root, depth in root_smiles: graph.add_node(root, depth=depth, smiles=root, root=True) for edge in get_edges(): sources, targets = get_sources_targets() probability = get_probability() for source in sources: source_smiles, source_depth = get_smiles_depth(source) if source_smiles not in graph: graph.add_node( source_smiles, depth=source_depth, smiles=source_smiles, root=source_smiles in root_smiles, ) else: graph.nodes[source_smiles]["depth"] = min( source_depth, graph.nodes[source_smiles]["depth"] ) for target in targets: target_smiles, target_depth = get_smiles_depth(target) if target_smiles not in graph and target_smiles not in co2: graph.add_node( target_smiles, depth=target_depth, smiles=target_smiles, root=target_smiles in root_smiles, ) elif target_smiles not in co2: graph.nodes[target_smiles]["depth"] = min( target_depth, graph.nodes[target_smiles]["depth"] ) if target_smiles not in co2 and target_smiles != source_smiles: graph.add_edge(source_smiles, target_smiles, probability=probability) return graph def get_shortest_path(pathway, in_start_node, in_end_node): try: pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node) except nx.NetworkXNoPath: return [] pred.remove(in_start_node) pred.remove(in_end_node) return pred def set_pathway_eval_weight(pathway): node_eval_weights = {} for node in pathway.nodes: # Scale score according to depth level node_eval_weights[node] = ( 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0 ) return node_eval_weights def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates): if len(intermediates) < 1: return pred_pathway root_nodes = pred_pathway.graph["root_nodes"] for node in pred_pathway.nodes: if node in root_nodes: continue if node in intermediates and node not in data_pathway: pred_pathway.nodes[node]["depth"] = -99 else: shortest_path_list = [] for root_node in root_nodes: shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node) if shortest_path_nodes: shortest_path_list.append(shortest_path_nodes) if shortest_path_list: shortest_path_nodes = min(shortest_path_list, key=len) num_ints = sum( 1 for shortest_path_node in shortest_path_nodes if shortest_path_node in intermediates ) pred_pathway.nodes[node]["depth"] -= num_ints return pred_pathway def initialise_pathway(pathway): """Convert pathway to networkx graph for evaluation""" pathway = graph_from_pathway(pathway) pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0} pathway = get_pathway_with_depth(pathway) return pathway def get_pathway_with_depth(pathway): """Recalculates depths in the pathway. Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at different depths that got merged.""" current_depth = 0 for node in pathway.nodes: if node in pathway.graph["root_nodes"]: pathway.nodes[node]["depth"] = current_depth else: pathway.nodes[node]["depth"] = -99 while assign_next_depth(pathway, current_depth): current_depth += 1 return pathway def assign_next_depth(pathway, current_depth): new_assigned_nodes = False current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth} for node in current_depth_nodes: successors = pathway.successors(node) for s in successors: if pathway.nodes[s]["depth"] < 0: pathway.nodes[s]["depth"] = current_depth + 1 new_assigned_nodes = True return new_assigned_nodes def find_intermediates(data_pathway, pred_pathway): """Find any intermediate nodes in the predicted pathway""" common_nodes = get_common_nodes(pred_pathway, data_pathway) intermediates = set() for node in common_nodes: down_stream_nodes = data_pathway.successors(node) for down_stream_node in down_stream_nodes: if down_stream_node in pred_pathway: all_ints = get_shortest_path(pred_pathway, node, down_stream_node) intermediates.update(all_ints) return intermediates def get_common_nodes(pred_pathway, data_pathway): """A node is a common node if it is in both pathways and is either a root in both or not a root in both.""" common_nodes = set() for node in data_pathway.nodes: is_pathway_root_node = node in data_pathway.graph["root_nodes"] is_this_root_node = node in pred_pathway.graph["root_nodes"] if node in pred_pathway.nodes: if is_pathway_root_node is False and is_this_root_node is False: common_nodes.add(node) elif is_pathway_root_node and is_this_root_node: common_nodes.add(node) return common_nodes def prune_graph(graph, threshold): """ Removes edges with probability below the threshold, then keep the subgraph reachable from the root node. """ while True: try: cycle = nx.find_cycle(graph) graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle except nx.NetworkXNoCycle: break for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold if data["probability"] < threshold: graph.remove_edge(u, v) root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0] reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root reachable.add(root_node) for node in list(graph.nodes): # Remove nodes not reachable from root if node not in reachable: graph.remove_node(node) def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False): """Compare two pathways for multi-gen evaluation. It is assumed the smiles in both pathways have been standardised in the same manner.""" data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) if threshold is not None: prune_graph(pred_pathway, threshold) intermediates = find_intermediates(data_pathway, pred_pathway) if intermediates: pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates) test_pathway_eval_weights = set_pathway_eval_weight(data_pathway) pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway) common_nodes = get_common_nodes(pred_pathway, data_pathway) data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes) pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes) score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 for node in common_nodes: if pred_pathway.nodes[node]["depth"] > 0: score_TP += test_pathway_eval_weights[node] for node in data_only_nodes: if data_pathway.nodes[node]["depth"] > 0: score_FN += test_pathway_eval_weights[node] for node in pred_only_nodes: if pred_pathway.nodes[node]["depth"] > 0: score_FP += pred_pathway_eval_weights[node] final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0 precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0 recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0 if return_intermediates: return final_score, precision, recall, intermediates return final_score, precision, recall def node_subst_cost(node1, node2): if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]: return 0 return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max def node_ins_del_cost(node): return 1 / (2 ** node["depth"]) def pathway_edit_eval(data_pathway, pred_pathway): """Compute the graph edit distance for two pathways, a potential alternative to multigen_eval""" data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0]) return nx.graph_edit_distance( data_pathway, pred_pathway, node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost, node_ins_cost=node_ins_del_cost, roots=roots, )