Files
enviPy-bayer/utilities/ml.py
2025-11-06 10:42:32 +13:00

997 lines
36 KiB
Python

from __future__ import annotations
import copy
import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
from abc import ABC, abstractmethod
import networkx as nx
import numpy as np
from envipy_plugins import Descriptor
from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler
from utilities.chem import FormatConverter, PredictionResult
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
self.df = data
else:
# Build either an empty dataframe with columns or fill it with list of list data
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
def add_rows(self, rows: List[List[str | int | float]]):
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
self.df.extend(new_rows)
def add_row(self, row: List[str | int | float]):
"""See add_rows"""
self.add_rows([row])
def block_indices(self, prefix) -> List[int]:
"""Find the indexes in column labels that has the prefix"""
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return indices
@property
def columns(self) -> List[str]:
"""Use the polars dataframe columns"""
return self.df.columns
@property
def shape(self):
return self.df.shape
@abstractmethod
def X(self, **kwargs):
pass
@abstractmethod
def y(self, **kwargs):
pass
@staticmethod
@abstractmethod
def generate_dataset(reactions, *args, **kwargs):
pass
def __iter__(self):
"""Use polars iter_rows for iterating over the dataset"""
return self.df.iter_rows()
def __getitem__(self, item):
"""Item is passed to polars allowing for advanced indexing.
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path | str"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
)
def __len__(self):
return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
def filter(self, *predicates, **constraints):
return self.__class__(data=self.df.filter(*predicates, **constraints))
def select(self, *exprs, **named_exprs):
return self.__class__(data=self.df.select(*exprs, **named_exprs))
def with_columns(self, *exprs, **name_exprs):
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
multithreaded=multithreaded, maintain_order=maintain_order))
def item(self, row=None, column=None):
return self.df.item(row, column)
def fill_nan(self, value):
return self.__class__(data=self.df.fill_nan(value))
@property
def height(self):
return self.df.height
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> List[int]:
return self._struct_features
def triggered(self) -> List[int]:
return self._triggered
def observed(self) -> List[int]:
return self._observed
def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.item(index, "structure_id")
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
_col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, self._observed]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
feat_columns = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
else:
feats = feat_func(compounds[0].smiles)
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
ds_columns = (["structure_id"] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
rows = []
for i, comp in enumerate(compounds):
# Features
feats = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feat = feat_func.get_molecule_descriptors(comp.smiles)
else:
feat = feat_func(comp.smiles)
feats.extend(feat)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feats + trig + obs)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
return ds
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
if isinstance(struct, str):
struct_id = None
struct_smiles = struct
else:
struct_id = str(struct.uuid)
struct_smiles = struct.smiles
features = FormatConverter.maccs(struct_smiles)
trig = []
prods = []
for rule in applicable_rules:
products = rule.apply(struct_smiles)
if len(products):
trig.append(1)
prods.append(products)
else:
trig.append(0)
prods.append([])
new_row = [struct_id] + features + trig + ([-1] * len(trig))
if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
if c == "structure_id":
arff += f"@attribute {c} string\n"
else:
arff += f"@attribute {c} {{0,1}}\n"
arff += "\n@data\n"
for d in self:
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
with open(path, "w") as fh:
fh.write(arff)
fh.flush()
class EnviFormerDataset(Dataset):
def __init__(self, columns=None, data=None):
super().__init__(columns, data)
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data
stereo = kwargs.get("stereo", False)
rows = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all()
]
)
rows.append([e, p])
ds = EnviFormerDataset(["educts", "products"], rows)
return ds
class SparseLabelECC(BaseEstimator, ClassifierMixin):
"""
Ensemble of Classifier Chains with sparse label removal.
Removes labels that are constant across all samples in training.
"""
def __init__(
self,
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
num_chains: int = 10,
):
self.base_clf = base_clf
self.num_chains = num_chains
def fit(self, X, Y):
y = np.array(Y)
self.n_labels_ = y.shape[1]
self.removed_labels_ = {}
self.keep_columns_ = []
for col in range(self.n_labels_):
unique_values = np.unique(y[:, col])
if len(unique_values) == 1:
self.removed_labels_[col] = unique_values[0]
else:
self.keep_columns_.append(col)
y_reduced = y[:, self.keep_columns_]
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
for i, chain in enumerate(self.chains_):
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
chain.fit(X, y_reduced)
return self
def predict(self, X, threshold=0.5):
avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold
full_y = np.zeros((avg_preds.shape[0], self.n_labels_))
for idx, col in enumerate(self.keep_columns_):
full_y[:, col] = avg_preds[:, idx]
for col, value in self.removed_labels_.items():
full_y[:, col] = bool(value)
return full_y
def predict_proba(self, X):
avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0)
full_y = np.zeros((avg_proba.shape[0], self.n_labels_))
for idx, col in enumerate(self.keep_columns_):
full_y[:, col] = avg_proba[:, idx]
for col, value in self.removed_labels_.items():
full_y[:, col] = float(value)
return full_y
def score(self, X, Y, sample_weight=None):
"""
Default scoring using subset accuracy (exact match).
"""
y_true = np.array(Y)
y_pred = self.predict(X)
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
class BinaryRelevance:
def __init__(self, baseline_clf):
self.clf = baseline_clf
self.classifiers = None
def fit(self, X, Y):
if self.classifiers is None:
self.classifiers = []
for label in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, label])]
Y_l = Y[~np.isnan(Y[:, label]), label]
if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy="constant", constant=0)
clf.fit([X[0]], [0])
self.classifiers.append(clf)
continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy="most_frequent")
else:
clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l)
self.classifiers.append(clf)
def predict(self, X):
labels = []
for clf in self.classifiers:
labels.append(clf.predict(X))
return np.column_stack(labels)
def predict_proba(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict_proba(X)
if pred.shape[1] > 1:
pred = pred[:, 1]
else:
pred = pred * clf.predict([X[0]])[0]
labels = np.column_stack((labels, pred))
return labels
class MissingValuesClassifierChain:
def __init__(self, base_clf):
self.base_clf = base_clf
self.permutation = None
self.classifiers = None
def fit(self, X, Y):
X = np.array(X)
Y = np.array(Y)
if self.permutation is None:
rng = default_rng(42)
self.permutation = rng.permutation(len(Y[0]))
Y = Y[:, self.permutation]
if self.classifiers is None:
self.classifiers = []
for p in range(len(self.permutation)):
X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy="constant", constant=0)
self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy="most_frequent")
self.classifiers.append(clf.fit(X_p, Y_p))
else:
clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p]
pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[
np.isnan(newcol)
] # fill in missing values with clf predictions
X = np.column_stack((X, newcol))
def predict(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict(np.column_stack((X, labels)))
labels = np.column_stack((labels, pred))
return labels[:, np.argsort(self.permutation)]
def predict_proba(self, X):
labels = np.empty((len(X), 0))
for clf in self.classifiers:
pred = clf.predict_proba(np.column_stack((X, np.round(labels))))
if pred.shape[1] > 1:
pred = pred[:, 1]
else:
pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0]
labels = np.column_stack((labels, pred))
return labels[:, np.argsort(self.permutation)]
class EnsembleClassifierChain:
def __init__(self, base_clf, num_chains=10):
self.base_clf = base_clf
self.num_chains = num_chains
self.num_labels = None
self.classifiers = None
def fit(self, X, Y):
if self.classifiers is None:
self.classifiers = []
if self.num_labels is None:
self.num_labels = Y.shape[1]
for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
clf = MissingValuesClassifierChain(self.base_clf)
clf.fit(X, Y)
self.classifiers.append(clf)
def predict(self, X):
labels = np.zeros((len(X), self.num_labels))
for clf in self.classifiers:
labels += clf.predict(X)
return np.round(labels / self.num_chains)
def predict_proba(self, X):
labels = np.zeros((len(X), self.num_labels))
for clf in self.classifiers:
labels += clf.predict_proba(X)
return labels / self.num_chains
class RelativeReasoning:
def __init__(self, start_index: int, end_index: int):
self.start_index: int = start_index
self.end_index: int = end_index
self.winmap: Dict[int, List[int]] = defaultdict(list)
self.min_count: int = 5
self.max_count: int = 0
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = Y.shape[1]
for i in range(n_attributes):
for j in range(n_attributes):
if i == j:
continue
countwin = 0
countloose = 0
countboth = 0
for k in range(n_instances):
vi = Y[k, i]
vj = Y[k, j]
if vi is None or vj is None:
continue
if vi < vj:
countwin += 1
elif vi > vj:
countloose += 1
elif vi == vj and vi == 1: # tie
countboth += 1
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count
and countwin > countloose
and (countloose <= self.max_count or self.max_count < 0)
and countboth == 0
):
self.winmap[i].append(j)
def predict(self, X):
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
# Loop through all instances
for inst_idx, inst in enumerate(X):
# Loop through all "triggered" features
for i, t in enumerate(inst[self.start_index : self.end_index + 1]):
# Set label
res[inst_idx][i] = t
# If we predict a 1, check if the rule gets dominated by another
if t:
# Second loop to check other triggered rules
for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]):
if i != i2:
# Check if rule idx is in "dominated by" list
if i2 in self.winmap.get(i, []):
# if thatat rule also triggered, it dominated the current
# set label to 0
if X[inst_idx][i2]:
res[inst_idx][i] = 0
return res
def predict_proba(self, X):
return self.predict(X)
class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler()
self.num_neighbours = num_neighbours
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: "RuleBasedDataset"):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
X_pca = self.fit_transform(X_scaled)
self.max_vals = np.max(X_pca, axis=0)
self.min_vals = np.min(X_pca, axis=0)
def __transform(self, instances):
instances_scaled = self.scaler.transform(instances)
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: "RuleBasedDataset"):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []
for i, instance in enumerate(instances_pca):
is_applicable.append(True)
for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance):
if not min_v <= new_v <= max_v:
is_applicable[i] = False
return is_applicable
def tanimoto_distance(a: List[int], b: List[int]):
if len(a) != len(b):
raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}")
sum_a = sum(a)
sum_b = sum(b)
sum_c = sum(v1 and v2 for v1, v2 in zip(a, b))
if sum_a + sum_b - sum_c == 0:
return 0.0
return 1 - (sum_c / (sum_a + sum_b - sum_c))
def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway
from epdb.logic import SPathway
graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
def get_edges():
if isinstance(data, Pathway):
return data.edges.all()
elif isinstance(data, SPathway):
return data.edges
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_sources_targets():
if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [
n.node for n in edge.end_nodes.constrained_target.all()
]
elif isinstance(data, SPathway):
return edge.educts, edge.products
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_smiles_depth(node):
if isinstance(data, Pathway):
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
elif isinstance(data, SPathway):
return FormatConverter.standardize(node.smiles, True), node.depth
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_probability():
try:
if isinstance(data, Pathway):
return edge.kv.get("probability")
elif isinstance(data, SPathway):
return edge.probability
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
except AttributeError:
return 1
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
for root, depth in root_smiles:
graph.add_node(root, depth=depth, smiles=root, root=True)
for edge in get_edges():
sources, targets = get_sources_targets()
probability = get_probability()
for source in sources:
source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph:
graph.add_node(
source_smiles,
depth=source_depth,
smiles=source_smiles,
root=source_smiles in root_smiles,
)
else:
graph.nodes[source_smiles]["depth"] = min(
source_depth, graph.nodes[source_smiles]["depth"]
)
for target in targets:
target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2:
graph.add_node(
target_smiles,
depth=target_depth,
smiles=target_smiles,
root=target_smiles in root_smiles,
)
elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(
target_depth, graph.nodes[target_smiles]["depth"]
)
if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph
def get_shortest_path(pathway, in_start_node, in_end_node):
try:
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
except nx.NetworkXNoPath:
return []
pred.remove(in_start_node)
pred.remove(in_end_node)
return pred
def set_pathway_eval_weight(pathway):
node_eval_weights = {}
for node in pathway.nodes:
# Scale score according to depth level
node_eval_weights[node] = (
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
)
return node_eval_weights
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
if len(intermediates) < 1:
return pred_pathway
root_nodes = pred_pathway.graph["root_nodes"]
for node in pred_pathway.nodes:
if node in root_nodes:
continue
if node in intermediates and node not in data_pathway:
pred_pathway.nodes[node]["depth"] = -99
else:
shortest_path_list = []
for root_node in root_nodes:
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
if shortest_path_nodes:
shortest_path_list.append(shortest_path_nodes)
if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(
1
for shortest_path_node in shortest_path_nodes
if shortest_path_node in intermediates
)
pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway
def initialise_pathway(pathway):
"""Convert pathway to networkx graph for evaluation"""
pathway = graph_from_pathway(pathway)
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
pathway = get_pathway_with_depth(pathway)
return pathway
def get_pathway_with_depth(pathway):
"""Recalculates depths in the pathway.
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
different depths that got merged."""
current_depth = 0
for node in pathway.nodes:
if node in pathway.graph["root_nodes"]:
pathway.nodes[node]["depth"] = current_depth
else:
pathway.nodes[node]["depth"] = -99
while assign_next_depth(pathway, current_depth):
current_depth += 1
return pathway
def assign_next_depth(pathway, current_depth):
new_assigned_nodes = False
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
for node in current_depth_nodes:
successors = pathway.successors(node)
for s in successors:
if pathway.nodes[s]["depth"] < 0:
pathway.nodes[s]["depth"] = current_depth + 1
new_assigned_nodes = True
return new_assigned_nodes
def find_intermediates(data_pathway, pred_pathway):
"""Find any intermediate nodes in the predicted pathway"""
common_nodes = get_common_nodes(pred_pathway, data_pathway)
intermediates = set()
for node in common_nodes:
down_stream_nodes = data_pathway.successors(node)
for down_stream_node in down_stream_nodes:
if down_stream_node in pred_pathway:
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
intermediates.update(all_ints)
return intermediates
def get_common_nodes(pred_pathway, data_pathway):
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
common_nodes = set()
for node in data_pathway.nodes:
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
is_this_root_node = node in pred_pathway.graph["root_nodes"]
if node in pred_pathway.nodes:
if is_pathway_root_node is False and is_this_root_node is False:
common_nodes.add(node)
elif is_pathway_root_node and is_this_root_node:
common_nodes.add(node)
return common_nodes
def prune_graph(graph, threshold):
"""
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
"""
while True:
try:
cycle = nx.find_cycle(graph)
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
except nx.NetworkXNoCycle:
break
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
if data["probability"] < threshold:
graph.remove_edge(u, v)
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
reachable.add(root_node)
for node in list(graph.nodes): # Remove nodes not reachable from root
if node not in reachable:
graph.remove_node(node)
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
"""Compare two pathways for multi-gen evaluation.
It is assumed the smiles in both pathways have been standardised in the same manner."""
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
if threshold is not None:
prune_graph(pred_pathway, threshold)
intermediates = find_intermediates(data_pathway, pred_pathway)
if intermediates:
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
common_nodes = get_common_nodes(pred_pathway, data_pathway)
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
for node in common_nodes:
if pred_pathway.nodes[node]["depth"] > 0:
score_TP += test_pathway_eval_weights[node]
for node in data_only_nodes:
if data_pathway.nodes[node]["depth"] > 0:
score_FN += test_pathway_eval_weights[node]
for node in pred_only_nodes:
if pred_pathway.nodes[node]["depth"] > 0:
score_FP += pred_pathway_eval_weights[node]
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
if return_intermediates:
return final_score, precision, recall, intermediates
return final_score, precision, recall
def node_subst_cost(node1, node2):
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
return 0
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
def node_ins_del_cost(node):
return 1 / (2 ** node["depth"])
def pathway_edit_eval(data_pathway, pred_pathway):
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(
data_pathway,
pred_pathway,
node_subst_cost=node_subst_cost,
node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost,
roots=roots,
)