Files
enviPy-bayer/utilities/ml.py
2025-10-02 00:40:00 +13:00

885 lines
32 KiB
Python

from __future__ import annotations
import copy
import numpy as np
from numpy.random import default_rng
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Set, Tuple
import networkx as nx
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler
logger = logging.getLogger(__name__)
from dataclasses import dataclass, field
from utilities.chem import FormatConverter, PredictionResult
class Dataset:
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
self.columns: List[str] = columns
self.num_labels: int = num_labels
if data is None:
self.data: List[List[str | int | float]] = list()
else:
self.data = data
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
self._triggered: Tuple[int, int] = self._block_indices('trig_')
self._observed: Tuple[int, int] = self._block_indices('obs_')
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
def structure_id(self):
return self.data[0][0]
def add_row(self, row: List[str | int | float]):
if len(self.columns) != len(row):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f'trig_{rule_uuid}')
times_triggered = 0
for row in self.data:
if row[idx] == 1:
times_triggered += 1
return times_triggered
def struct_features(self) -> Tuple[int, int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
return self._observed
def at(self, position: int) -> Dataset:
return Dataset(self.columns, self.num_labels, [self.data[position]])
def limit(self, limit: int) -> Dataset:
return Dataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
if isinstance(struct, str):
struct_id = None
struct_smiles = struct
else:
struct_id = str(struct.uuid)
struct_smiles = struct.smiles
features = FormatConverter.maccs(struct_smiles)
trig = []
prods = []
for rule in applicable_rules:
products = rule.apply(struct_smiles)
if len(products):
trig.append(1)
prods.append(products)
else:
trig.append(0)
prods.append([])
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
@staticmethod
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
_structures = set()
for r in reactions:
for e in r.educts.all():
_structures.add(e)
if not educts_only:
for e in r.products:
_structures.add(e)
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
pass
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
pass
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
else:
pass
ds = None
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
if ds is None:
header = ['structure_id'] + \
[f'feature_{i}' for i, _ in enumerate(feat)] \
+ [f'trig_{r.uuid}' for r in applicable_rules] \
+ [f'obs_{r.uuid}' for r in applicable_rules]
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
row_key, col_key = key
# Normalize rows
if isinstance(row_key, int):
rows = [self.data[row_key]]
else:
rows = self.data[row_key]
# Normalize columns
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
else [row[i] for i in col_key] for row in rows]
return res
def save(self, path: 'Path'):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path') -> 'Dataset':
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: 'Path'):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]:
if c == 'structure_id':
arff += f"@attribute {c} string\n"
else:
arff += f"@attribute {c} {{0,1}}\n"
arff += f"\n@data\n"
for d in self.data:
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
arff += f'{ys},{xs}\n'
with open(path, "w") as fh:
fh.write(arff)
fh.flush()
def __repr__(self):
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
class SparseLabelECC(BaseEstimator, ClassifierMixin):
"""
Ensemble of Classifier Chains with sparse label removal.
Removes labels that are constant across all samples in training.
"""
def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42),
num_chains: int = 10):
self.base_clf = base_clf
self.num_chains = num_chains
def fit(self, X, Y):
y = np.array(Y)
self.n_labels_ = y.shape[1]
self.removed_labels_ = {}
self.keep_columns_ = []
for col in range(self.n_labels_):
unique_values = np.unique(y[:, col])
if len(unique_values) == 1:
self.removed_labels_[col] = unique_values[0]
else:
self.keep_columns_.append(col)
y_reduced = y[:, self.keep_columns_]
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
for i, chain in enumerate(self.chains_):
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
chain.fit(X, y_reduced)
return self
def predict(self, X, threshold=0.5):
avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold
full_y = np.zeros((avg_preds.shape[0], self.n_labels_))
for idx, col in enumerate(self.keep_columns_):
full_y[:, col] = avg_preds[:, idx]
for col, value in self.removed_labels_.items():
full_y[:, col] = bool(value)
return full_y
def predict_proba(self, X):
avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0)
full_y = np.zeros((avg_proba.shape[0], self.n_labels_))
for idx, col in enumerate(self.keep_columns_):
full_y[:, col] = avg_proba[:, idx]
for col, value in self.removed_labels_.items():
full_y[:, col] = float(value)
return full_y
def score(self, X, Y, sample_weight=None):
"""
Default scoring using subset accuracy (exact match).
"""
y_true = np.array(Y)
y_pred = self.predict(X)
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
class BinaryRelevance:
def __init__(self, baseline_clf):
self.clf = baseline_clf
self.classifiers = None
def fit(self, X, Y):
if self.classifiers is None:
self.classifiers = []
for l in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, l])]
Y_l = (Y[~np.isnan(Y[:, l]), l])
if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf.fit([X[0]], [0])
self.classifiers.append(clf)
continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
else:
clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l)
self.classifiers.append(clf)
def predict(self, X):
labels = []
for clf in self.classifiers:
labels.append(clf.predict(X))
return np.column_stack(labels)
def predict_proba(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict_proba(X)
if pred.shape[1] > 1:
pred = pred[:, 1]
else:
pred = pred * clf.predict([X[0]])[0]
labels = np.column_stack((labels, pred))
return labels
class MissingValuesClassifierChain:
def __init__(self, base_clf):
self.base_clf = base_clf
self.permutation = None
self.classifiers = None
def fit(self, X, Y):
X = np.array(X)
Y = np.array(Y)
if self.permutation is None:
rng = default_rng(42)
self.permutation = rng.permutation(len(Y[0]))
Y = Y[:, self.permutation]
if self.classifiers is None:
self.classifiers = []
for p in range(len(self.permutation)):
X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
self.classifiers.append(clf.fit(X_p, Y_p))
else:
clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p]
pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
X = np.column_stack((X, newcol))
def predict(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict(np.column_stack((X, labels)))
labels = np.column_stack((labels, pred))
return labels[:, np.argsort(self.permutation)]
def predict_proba(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict_proba(np.column_stack((X, np.round(labels))))
if pred.shape[1] > 1:
pred = pred[:, 1]
else:
pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0]
labels = np.column_stack((labels, pred))
return labels[:, np.argsort(self.permutation)]
class EnsembleClassifierChain:
def __init__(self, base_clf, num_chains=10):
self.base_clf = base_clf
self.num_chains = num_chains
self.num_labels = None
self.classifiers = None
def fit(self, X, Y):
if self.classifiers is None:
self.classifiers = []
if self.num_labels is None:
self.num_labels = len(Y[0])
for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
clf = MissingValuesClassifierChain(self.base_clf)
clf.fit(X, Y)
self.classifiers.append(clf)
def predict(self, X):
labels = np.zeros((len(X), self.num_labels))
for clf in self.classifiers:
labels += clf.predict(X)
return np.round(labels / self.num_chains)
def predict_proba(self, X):
labels = np.zeros((len(X), self.num_labels))
for clf in self.classifiers:
labels += clf.predict_proba(X)
return labels / self.num_chains
class RelativeReasoning:
def __init__(self, start_index: int, end_index: int):
self.start_index: int = start_index
self.end_index: int = end_index
self.winmap: Dict[int, List[int]] = defaultdict(list)
self.min_count: int = 5
self.max_count: int = 0
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = len(Y[0])
for i in range(n_attributes):
for j in range(n_attributes):
if i == j:
continue
countwin = 0
countloose = 0
countboth = 0
for k in range(n_instances):
vi = Y[k][i]
vj = Y[k][j]
if vi is None or vj is None:
continue
if vi < vj:
countwin += 1
elif vi > vj:
countloose += 1
elif vi == vj and vi == 1: # tie
countboth += 1
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count and
countwin > countloose and
(
countloose <= self.max_count or
self.max_count < 0
) and
countboth == 0
):
self.winmap[i].append(j)
def predict(self, X):
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
# Loop through all instances
for inst_idx, inst in enumerate(X):
# Loop through all "triggered" features
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
# Set label
res[inst_idx][i] = t
# If we predict a 1, check if the rule gets dominated by another
if t:
# Second loop to check other triggered rules
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
if i != i2:
# Check if rule idx is in "dominated by" list
if i2 in self.winmap.get(i, []):
# if thatat rule also triggered, it dominated the current
# set label to 0
if X[inst_idx][i2]:
res[inst_idx][i] = 0
return res
def predict_proba(self, X):
return self.predict(X)
class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler()
self.num_neighbours = num_neighbours
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: 'Dataset'):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
X_pca = self.fit_transform(X_scaled)
self.max_vals = np.max(X_pca, axis=0)
self.min_vals = np.min(X_pca, axis=0)
def __transform(self, instances):
instances_scaled = self.scaler.transform(instances)
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: 'Dataset'):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []
for i, instance in enumerate(instances_pca):
is_applicable.append(True)
for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance):
if not min_v <= new_v <= max_v:
is_applicable[i] = False
return is_applicable
def tanimoto_distance(a: List[int], b: List[int]):
if len(a) != len(b):
raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}")
sum_a = sum(a)
sum_b = sum(b)
sum_c = sum(v1 and v2 for v1, v2 in zip(a, b))
if sum_a + sum_b - sum_c == 0:
return 0.0
return 1 - (sum_c / (sum_a + sum_b - sum_c))
def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway
from epdb.logic import SPathway
graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
def get_edges():
if isinstance(data, Pathway):
return data.edges.all()
elif isinstance(data, SPathway):
return data.edges
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_sources_targets():
if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
elif isinstance(data, SPathway):
return edge.educts, edge.products
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_smiles_depth(node):
if isinstance(data, Pathway):
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
elif isinstance(data, SPathway):
return FormatConverter.standardize(node.smiles, True), node.depth
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_probability():
try:
if isinstance(data, Pathway):
return edge.kv.get('probability')
elif isinstance(data, SPathway):
return edge.probability
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
except AttributeError:
return 1
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
for root, depth in root_smiles:
graph.add_node(root, depth=depth, smiles=root, root=True)
for edge in get_edges():
sources, targets = get_sources_targets()
probability = get_probability()
for source in sources:
source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph:
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
root=source_smiles in root_smiles)
else:
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
for target in targets:
target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2:
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
root=target_smiles in root_smiles)
elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph
def get_shortest_path(pathway, in_start_node, in_end_node):
try:
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
except nx.NetworkXNoPath:
return []
pred.remove(in_start_node)
pred.remove(in_end_node)
return pred
def set_pathway_eval_weight(pathway):
node_eval_weights = {}
for node in pathway.nodes:
# Scale score according to depth level
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
return node_eval_weights
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
if len(intermediates) < 1:
return pred_pathway
root_nodes = pred_pathway.graph["root_nodes"]
for node in pred_pathway.nodes:
if node in root_nodes:
continue
if node in intermediates and node not in data_pathway:
pred_pathway.nodes[node]["depth"] = -99
else:
shortest_path_list = []
for root_node in root_nodes:
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
if shortest_path_nodes:
shortest_path_list.append(shortest_path_nodes)
if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
shortest_path_node in intermediates)
pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway
def initialise_pathway(pathway):
"""Convert pathway to networkx graph for evaluation"""
pathway = graph_from_pathway(pathway)
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
pathway = get_pathway_with_depth(pathway)
return pathway
def get_pathway_with_depth(pathway):
"""Recalculates depths in the pathway.
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
different depths that got merged."""
current_depth = 0
for node in pathway.nodes:
if node in pathway.graph["root_nodes"]:
pathway.nodes[node]["depth"] = current_depth
else:
pathway.nodes[node]["depth"] = -99
while assign_next_depth(pathway, current_depth):
current_depth += 1
return pathway
def assign_next_depth(pathway, current_depth):
new_assigned_nodes = False
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
for node in current_depth_nodes:
successors = pathway.successors(node)
for s in successors:
if pathway.nodes[s]["depth"] < 0:
pathway.nodes[s]["depth"] = current_depth + 1
new_assigned_nodes = True
return new_assigned_nodes
def find_intermediates(data_pathway, pred_pathway):
"""Find any intermediate nodes in the predicted pathway"""
common_nodes = get_common_nodes(pred_pathway, data_pathway)
intermediates = set()
for node in common_nodes:
down_stream_nodes = data_pathway.successors(node)
for down_stream_node in down_stream_nodes:
if down_stream_node in pred_pathway:
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
intermediates.update(all_ints)
return intermediates
def get_common_nodes(pred_pathway, data_pathway):
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
common_nodes = set()
for node in data_pathway.nodes:
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
is_this_root_node = node in pred_pathway.graph["root_nodes"]
if node in pred_pathway.nodes:
if is_pathway_root_node is False and is_this_root_node is False:
common_nodes.add(node)
elif is_pathway_root_node and is_this_root_node:
common_nodes.add(node)
return common_nodes
def prune_graph(graph, threshold):
"""
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
"""
while True:
try:
cycle = nx.find_cycle(graph)
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
except nx.NetworkXNoCycle:
break
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
if data["probability"] < threshold:
graph.remove_edge(u, v)
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
reachable.add(root_node)
for node in list(graph.nodes): # Remove nodes not reachable from root
if node not in reachable:
graph.remove_node(node)
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
"""Compare two pathways for multi-gen evaluation.
It is assumed the smiles in both pathways have been standardised in the same manner."""
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
if threshold is not None:
prune_graph(pred_pathway, threshold)
intermediates = find_intermediates(data_pathway, pred_pathway)
if intermediates:
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
common_nodes = get_common_nodes(pred_pathway, data_pathway)
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
for node in common_nodes:
if pred_pathway.nodes[node]["depth"] > 0:
score_TP += test_pathway_eval_weights[node]
for node in data_only_nodes:
if data_pathway.nodes[node]["depth"] > 0:
score_FN += test_pathway_eval_weights[node]
for node in pred_only_nodes:
if pred_pathway.nodes[node]["depth"] > 0:
score_FP += pred_pathway_eval_weights[node]
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
if return_intermediates:
return final_score, precision, recall, intermediates
return final_score, precision, recall
def node_subst_cost(node1, node2):
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
return 0
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
def node_ins_del_cost(node):
return 1 / (2 ** node["depth"])
def pathway_edit_eval(data_pathway, pred_pathway):
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(data_pathway, pred_pathway,
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost, roots=roots)