from __future__ import annotations import copy import numpy as np from numpy.random import default_rng from sklearn.dummy import DummyClassifier from sklearn.tree import DecisionTreeClassifier import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime from typing import List, Dict, Set, Tuple import networkx as nx from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score from sklearn.multioutput import ClassifierChain from sklearn.preprocessing import StandardScaler logger = logging.getLogger(__name__) from dataclasses import dataclass, field from utilities.chem import FormatConverter, PredictionResult class Dataset: def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None): self.columns: List[str] = columns self.num_labels: int = num_labels if data is None: self.data: List[List[str | int | float]] = list() else: self.data = data self.num_features: int = len(columns) - self.num_labels self._struct_features: Tuple[int, int] = self._block_indices('feature_') self._triggered: Tuple[int, int] = self._block_indices('trig_') self._observed: Tuple[int, int] = self._block_indices('obs_') def _block_indices(self, prefix) -> Tuple[int, int]: indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) return min(indices), max(indices) def structure_id(self): return self.data[0][0] def add_row(self, row: List[str | int | float]): if len(self.columns) != len(row): raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") self.data.append(row) def times_triggered(self, rule_uuid) -> int: idx = self.columns.index(f'trig_{rule_uuid}') times_triggered = 0 for row in self.data: if row[idx] == 1: times_triggered += 1 return times_triggered def struct_features(self) -> Tuple[int, int]: return self._struct_features def triggered(self) -> Tuple[int, int]: return self._triggered def observed(self) -> Tuple[int, int]: return self._observed def at(self, position: int) -> Dataset: return Dataset(self.columns, self.num_labels, [self.data[position]]) def limit(self, limit: int) -> Dataset: return Dataset(self.columns, self.num_labels, self.data[:limit]) def __iter__(self): return (self.at(i) for i, _ in enumerate(self.data)) def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: if isinstance(struct, str): struct_id = None struct_smiles = struct else: struct_id = str(struct.uuid) struct_smiles = struct.smiles features = FormatConverter.maccs(struct_smiles) trig = [] prods = [] for rule in applicable_rules: products = rule.apply(struct_smiles) if len(products): trig.append(1) prods.append(products) else: trig.append(0) prods.append([]) classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) classify_products.append(prods) return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products @staticmethod def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset: _structures = set() for r in reactions: for e in r.educts.all(): _structures.add(e) if not educts_only: for e in r.products: _structures.add(e) compounds = sorted(_structures, key=lambda x: x.url) triggered: Dict[str, Set[str]] = defaultdict(set) observed: Set[str] = set() # Apply rules on collected compounds and store tps for i, comp in enumerate(compounds): logger.debug(f"{i + 1}/{len(compounds)}...") for rule in applicable_rules: product_sets = rule.apply(comp.smiles) if len(product_sets) == 0: continue key = f"{rule.uuid} + {comp.uuid}" if key in triggered: logger.info(f"{key} already present. Duplicate reaction?") for prod_set in product_sets: for smi in prod_set: try: smi = FormatConverter.standardize(smi) except Exception: # :shrug: logger.debug(f'Standardizing SMILES failed for {smi}') pass triggered[key].add(smi) for i, r in enumerate(reactions): logger.debug(f"{i + 1}/{len(reactions)}...") if len(r.educts.all()) != 1: logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") continue for comp in r.educts.all(): for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" if key not in triggered: continue # standardize products from reactions for comparison standardized_products = [] for cs in r.products.all(): smi = cs.smiles try: smi = FormatConverter.standardize(smi) except Exception as e: # :shrug: logger.debug(f'Standardizing SMILES failed for {smi}') pass standardized_products.append(smi) if len(set(standardized_products).difference(triggered[key])) == 0: observed.add(key) else: pass ds = None for i, comp in enumerate(compounds): # Features feat = FormatConverter.maccs(comp.smiles) trig = [] obs = [] for rule in applicable_rules: key = f"{rule.uuid} + {comp.uuid}" # Check triggered if key in triggered: trig.append(1) else: trig.append(0) # Check obs if key in observed: obs.append(1) elif key not in triggered: obs.append(None) else: obs.append(0) if ds is None: header = ['structure_id'] + \ [f'feature_{i}' for i, _ in enumerate(feat)] \ + [f'trig_{r.uuid}' for r in applicable_rules] \ + [f'obs_{r.uuid}' for r in applicable_rules] ds = Dataset(header, len(applicable_rules)) ds.add_row([str(comp.uuid)] + feat + trig + obs) return ds def X(self, exclude_id_col=True, na_replacement=0): res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def trig(self, na_replacement=0): res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1]))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def y(self, na_replacement=0): res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res def __getitem__(self, key): if not isinstance(key, tuple): raise TypeError("Dataset must be indexed with dataset[rows, columns]") row_key, col_key = key # Normalize rows if isinstance(row_key, int): rows = [self.data[row_key]] else: rows = self.data[row_key] # Normalize columns if isinstance(col_key, int): res = [row[col_key] for row in rows] else: res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice) else [row[i] for i in col_key] for row in rows] return res def save(self, path: 'Path'): import pickle with open(path, "wb") as fh: pickle.dump(self, fh) @staticmethod def load(path: 'Path') -> 'Dataset': import pickle return pickle.load(open(path, "rb")) def to_arff(self, path: 'Path'): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff += "\n" for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]: if c == 'structure_id': arff += f"@attribute {c} string\n" else: arff += f"@attribute {c} {{0,1}}\n" arff += f"\n@data\n" for d in self.data: ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]]) xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]]) arff += f'{ys},{xs}\n' with open(path, "w") as fh: fh.write(arff) fh.flush() def __repr__(self): return f"" class SparseLabelECC(BaseEstimator, ClassifierMixin): """ Ensemble of Classifier Chains with sparse label removal. Removes labels that are constant across all samples in training. """ def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42), num_chains: int = 10): self.base_clf = base_clf self.num_chains = num_chains def fit(self, X, Y): y = np.array(Y) self.n_labels_ = y.shape[1] self.removed_labels_ = {} self.keep_columns_ = [] for col in range(self.n_labels_): unique_values = np.unique(y[:, col]) if len(unique_values) == 1: self.removed_labels_[col] = unique_values[0] else: self.keep_columns_.append(col) y_reduced = y[:, self.keep_columns_] self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)] for i, chain in enumerate(self.chains_): logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}") chain.fit(X, y_reduced) return self def predict(self, X, threshold=0.5): avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold full_y = np.zeros((avg_preds.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_preds[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = bool(value) return full_y def predict_proba(self, X): avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0) full_y = np.zeros((avg_proba.shape[0], self.n_labels_)) for idx, col in enumerate(self.keep_columns_): full_y[:, col] = avg_proba[:, idx] for col, value in self.removed_labels_.items(): full_y[:, col] = float(value) return full_y def score(self, X, Y, sample_weight=None): """ Default scoring using subset accuracy (exact match). """ y_true = np.array(Y) y_pred = self.predict(X) return accuracy_score(y_true, y_pred, sample_weight=sample_weight) class BinaryRelevance: def __init__(self, baseline_clf): self.clf = baseline_clf self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] for l in range(len(Y[0])): X_l = X[~np.isnan(Y[:, l])] Y_l = (Y[~np.isnan(Y[:, l]), l]) if len(X_l) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy='constant', constant=0) clf.fit([X[0]], [0]) self.classifiers.append(clf) continue elif len(np.unique(Y_l)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy='most_frequent') else: clf = copy.deepcopy(self.clf) clf.fit(X_l, Y_l) self.classifiers.append(clf) def predict(self, X): labels = [] for clf in self.classifiers: labels.append(clf.predict(X)) return np.column_stack(labels) def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(X) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict([X[0]])[0] labels = np.column_stack((labels, pred)) return labels class MissingValuesClassifierChain: def __init__(self, base_clf): self.base_clf = base_clf self.permutation = None self.classifiers = None def fit(self, X, Y): X = np.array(X) Y = np.array(Y) if self.permutation is None: rng = default_rng(42) self.permutation = rng.permutation(len(Y[0])) Y = Y[:, self.permutation] if self.classifiers is None: self.classifiers = [] for p in range(len(self.permutation)): X_p = X[~np.isnan(Y[:, p])] Y_p = Y[~np.isnan(Y[:, p]), p] if len(X_p) == 0: # all labels are nan -> predict 0 clf = DummyClassifier(strategy='constant', constant=0) self.classifiers.append(clf.fit([X[0]], [0])) elif len(np.unique(Y_p)) == 1: # only one class -> predict that class clf = DummyClassifier(strategy='most_frequent') self.classifiers.append(clf.fit(X_p, Y_p)) else: clf = copy.deepcopy(self.base_clf) self.classifiers.append(clf.fit(X_p, Y_p)) newcol = Y[:, p] pred = clf.predict(X) newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions X = np.column_stack((X, newcol)) def predict(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict(np.column_stack((X, labels))) labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] def predict_proba(self, X): labels = np.empty((len(X), 0)) for clf in self.classifiers: pred = clf.predict_proba(np.column_stack((X, np.round(labels)))) if pred.shape[1] > 1: pred = pred[:, 1] else: pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0] labels = np.column_stack((labels, pred)) return labels[:, np.argsort(self.permutation)] class EnsembleClassifierChain: def __init__(self, base_clf, num_chains=10): self.base_clf = base_clf self.num_chains = num_chains self.num_labels = None self.classifiers = None def fit(self, X, Y): if self.classifiers is None: self.classifiers = [] if self.num_labels is None: self.num_labels = len(Y[0]) for p in range(self.num_chains): logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") clf = MissingValuesClassifierChain(self.base_clf) clf.fit(X, Y) self.classifiers.append(clf) def predict(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict(X) return np.round(labels / self.num_chains) def predict_proba(self, X): labels = np.zeros((len(X), self.num_labels)) for clf in self.classifiers: labels += clf.predict_proba(X) return labels / self.num_chains class RelativeReasoning: def __init__(self, start_index: int, end_index: int): self.start_index: int = start_index self.end_index: int = end_index self.winmap: Dict[int, List[int]] = defaultdict(list) self.min_count: int = 5 self.max_count: int = 0 def fit(self, X, Y): n_instances = len(Y) n_attributes = len(Y[0]) for i in range(n_attributes): for j in range(n_attributes): if i == j: continue countwin = 0 countloose = 0 countboth = 0 for k in range(n_instances): vi = Y[k][i] vj = Y[k][j] if vi is None or vj is None: continue if vi < vj: countwin += 1 elif vi > vj: countloose += 1 elif vi == vj and vi == 1: # tie countboth += 1 # We've seen more than self.min_count wins, more wins than loosing, no looses and no ties if ( countwin >= self.min_count and countwin > countloose and ( countloose <= self.max_count or self.max_count < 0 ) and countboth == 0 ): self.winmap[i].append(j) def predict(self, X): res = np.zeros((len(X), (self.end_index + 1 - self.start_index))) # Loop through all instances for inst_idx, inst in enumerate(X): # Loop through all "triggered" features for i, t in enumerate(inst[self.start_index: self.end_index + 1]): # Set label res[inst_idx][i] = t # If we predict a 1, check if the rule gets dominated by another if t: # Second loop to check other triggered rules for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]): if i != i2: # Check if rule idx is in "dominated by" list if i2 in self.winmap.get(i, []): # if thatat rule also triggered, it dominated the current # set label to 0 if X[inst_idx][i2]: res[inst_idx][i] = 0 return res def predict_proba(self, X): return self.predict(X) class ApplicabilityDomainPCA(PCA): def __init__(self, num_neighbours: int = 5): super().__init__(n_components=num_neighbours) self.scaler = StandardScaler() self.num_neighbours = num_neighbours self.min_vals = None self.max_vals = None def build(self, train_dataset: 'Dataset'): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca X_pca = self.fit_transform(X_scaled) self.max_vals = np.max(X_pca, axis=0) self.min_vals = np.min(X_pca, axis=0) def __transform(self, instances): instances_scaled = self.scaler.transform(instances) instances_pca = self.transform(instances_scaled) return instances_pca def is_applicable(self, classify_instances: 'Dataset'): instances_pca = self.__transform(classify_instances.X()) is_applicable = [] for i, instance in enumerate(instances_pca): is_applicable.append(True) for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance): if not min_v <= new_v <= max_v: is_applicable[i] = False return is_applicable def tanimoto_distance(a: List[int], b: List[int]): if len(a) != len(b): raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}") sum_a = sum(a) sum_b = sum(b) sum_c = sum(v1 and v2 for v1, v2 in zip(a, b)) if sum_a + sum_b - sum_c == 0: return 0.0 return 1 - (sum_c / (sum_a + sum_b - sum_c)) def graph_from_pathway(data): """Convert Pathway or SPathway to networkx""" from epdb.models import Pathway from epdb.logic import SPathway graph = nx.DiGraph() co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation def get_edges(): if isinstance(data, Pathway): return data.edges.all() elif isinstance(data, SPathway): return data.edges else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_sources_targets(): if isinstance(data, Pathway): return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()] elif isinstance(data, SPathway): return edge.educts, edge.products else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_smiles_depth(node): if isinstance(data, Pathway): return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth elif isinstance(data, SPathway): return FormatConverter.standardize(node.smiles, True), node.depth else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") def get_probability(): try: if isinstance(data, Pathway): return edge.kv.get('probability') elif isinstance(data, SPathway): return edge.probability else: raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval") except AttributeError: return 1 root_smiles = {get_smiles_depth(n) for n in data.root_nodes} for root, depth in root_smiles: graph.add_node(root, depth=depth, smiles=root, root=True) for edge in get_edges(): sources, targets = get_sources_targets() probability = get_probability() for source in sources: source_smiles, source_depth = get_smiles_depth(source) if source_smiles not in graph: graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles, root=source_smiles in root_smiles) else: graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"]) for target in targets: target_smiles, target_depth = get_smiles_depth(target) if target_smiles not in graph and target_smiles not in co2: graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles, root=target_smiles in root_smiles) elif target_smiles not in co2: graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"]) if target_smiles not in co2 and target_smiles != source_smiles: graph.add_edge(source_smiles, target_smiles, probability=probability) return graph def get_shortest_path(pathway, in_start_node, in_end_node): try: pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node) except nx.NetworkXNoPath: return [] pred.remove(in_start_node) pred.remove(in_end_node) return pred def set_pathway_eval_weight(pathway): node_eval_weights = {} for node in pathway.nodes: # Scale score according to depth level node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0 return node_eval_weights def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates): if len(intermediates) < 1: return pred_pathway root_nodes = pred_pathway.graph["root_nodes"] for node in pred_pathway.nodes: if node in root_nodes: continue if node in intermediates and node not in data_pathway: pred_pathway.nodes[node]["depth"] = -99 else: shortest_path_list = [] for root_node in root_nodes: shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node) if shortest_path_nodes: shortest_path_list.append(shortest_path_nodes) if shortest_path_list: shortest_path_nodes = min(shortest_path_list, key=len) num_ints = sum(1 for shortest_path_node in shortest_path_nodes if shortest_path_node in intermediates) pred_pathway.nodes[node]["depth"] -= num_ints return pred_pathway def initialise_pathway(pathway): """Convert pathway to networkx graph for evaluation""" pathway = graph_from_pathway(pathway) pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0} pathway = get_pathway_with_depth(pathway) return pathway def get_pathway_with_depth(pathway): """Recalculates depths in the pathway. Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at different depths that got merged.""" current_depth = 0 for node in pathway.nodes: if node in pathway.graph["root_nodes"]: pathway.nodes[node]["depth"] = current_depth else: pathway.nodes[node]["depth"] = -99 while assign_next_depth(pathway, current_depth): current_depth += 1 return pathway def assign_next_depth(pathway, current_depth): new_assigned_nodes = False current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth} for node in current_depth_nodes: successors = pathway.successors(node) for s in successors: if pathway.nodes[s]["depth"] < 0: pathway.nodes[s]["depth"] = current_depth + 1 new_assigned_nodes = True return new_assigned_nodes def find_intermediates(data_pathway, pred_pathway): """Find any intermediate nodes in the predicted pathway""" common_nodes = get_common_nodes(pred_pathway, data_pathway) intermediates = set() for node in common_nodes: down_stream_nodes = data_pathway.successors(node) for down_stream_node in down_stream_nodes: if down_stream_node in pred_pathway: all_ints = get_shortest_path(pred_pathway, node, down_stream_node) intermediates.update(all_ints) return intermediates def get_common_nodes(pred_pathway, data_pathway): """A node is a common node if it is in both pathways and is either a root in both or not a root in both.""" common_nodes = set() for node in data_pathway.nodes: is_pathway_root_node = node in data_pathway.graph["root_nodes"] is_this_root_node = node in pred_pathway.graph["root_nodes"] if node in pred_pathway.nodes: if is_pathway_root_node is False and is_this_root_node is False: common_nodes.add(node) elif is_pathway_root_node and is_this_root_node: common_nodes.add(node) return common_nodes def prune_graph(graph, threshold): """ Removes edges with probability below the threshold, then keep the subgraph reachable from the root node. """ while True: try: cycle = nx.find_cycle(graph) graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle except nx.NetworkXNoCycle: break for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold if data["probability"] < threshold: graph.remove_edge(u, v) root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0] reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root reachable.add(root_node) for node in list(graph.nodes): # Remove nodes not reachable from root if node not in reachable: graph.remove_node(node) def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False): """Compare two pathways for multi-gen evaluation. It is assumed the smiles in both pathways have been standardised in the same manner.""" data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) if threshold is not None: prune_graph(pred_pathway, threshold) intermediates = find_intermediates(data_pathway, pred_pathway) if intermediates: pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates) test_pathway_eval_weights = set_pathway_eval_weight(data_pathway) pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway) common_nodes = get_common_nodes(pred_pathway, data_pathway) data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes) pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes) score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 for node in common_nodes: if pred_pathway.nodes[node]["depth"] > 0: score_TP += test_pathway_eval_weights[node] for node in data_only_nodes: if data_pathway.nodes[node]["depth"] > 0: score_FN += test_pathway_eval_weights[node] for node in pred_only_nodes: if pred_pathway.nodes[node]["depth"] > 0: score_FP += pred_pathway_eval_weights[node] final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0 precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0 recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0 if return_intermediates: return final_score, precision, recall, intermediates return final_score, precision, recall def node_subst_cost(node1, node2): if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]: return 0 return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max def node_ins_del_cost(node): return 1 / (2 ** node["depth"]) def pathway_edit_eval(data_pathway, pred_pathway): """Compute the graph edit distance for two pathways, a potential alternative to multigen_eval""" data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0]) return nx.graph_edit_distance(data_pathway, pred_pathway, node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost, node_ins_cost=node_ins_del_cost, roots=roots)