[Feature] MultiGen Eval (Backend) (#117)

Fixes #16

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#117
This commit is contained in:
2025-09-18 18:40:45 +12:00
parent 762a6b7baf
commit 50db2fb372
24 changed files with 816 additions and 2137274 deletions

View File

@ -183,7 +183,7 @@ class FormatConverter(object):
return smiles
@staticmethod
def standardize(smiles):
def standardize(smiles, remove_stereo=False):
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
# follows the steps in
# https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/MolStandardize%20pieces.ipynb
@ -208,6 +208,9 @@ class FormatConverter(object):
# te = rdMolStandardize.TautomerEnumerator() # idem
# taut_uncharged_parent_clean_mol = te.Canonicalize(uncharged_parent_clean_mol)
if remove_stereo:
Chem.RemoveStereochemistry(uncharged_parent_clean_mol)
return Chem.MolToSmiles(uncharged_parent_clean_mol, kekuleSmiles=True)
@staticmethod

View File

@ -919,7 +919,7 @@ class PackageImporter:
name=edge_data['name'],
description=edge_data['description'],
kv=edge_data.get('kv', {}),
edge_label=None # Will be set later
edge_label=self._get_cached_object('Reaction', edge_data['edge_label']['uuid'])
)
# Set aliases if present

View File

@ -1,12 +1,19 @@
from __future__ import annotations
import copy
import numpy as np
from numpy.random import default_rng
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Set, Tuple
import numpy as np
import networkx as nx
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
@ -22,61 +29,6 @@ from dataclasses import dataclass, field
from utilities.chem import FormatConverter, PredictionResult
@dataclass
class SCompound:
smiles: str
uuid: str = field(default=None, compare=False, hash=False)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((
self.smiles
))
return self._hash
@dataclass
class SReaction:
educts: List[SCompound]
products: List[SCompound]
rule_uuid: SRule = field(default=None, compare=False, hash=False)
reaction_uuid: str = field(default=None, compare=False, hash=False)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((
tuple(sorted(self.educts, key=lambda x: x.smiles)),
tuple(sorted(self.products, key=lambda x: x.smiles)),
))
return self._hash
def __eq__(self, other):
if not isinstance(other, SReaction):
return NotImplemented
return (
sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and
sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles)
)
@dataclass
class SRule(ABC):
@abstractmethod
def apply(self):
pass
@dataclass
class SSimpleRule:
pass
@dataclass
class SParallelRule:
pass
class Dataset:
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
@ -385,7 +337,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
for i, chain in enumerate(self.chains_):
print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
chain.fit(X, y_reduced)
return self
@ -423,14 +375,6 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
import copy
import numpy as np
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
class BinaryRelevance:
def __init__(self, baseline_clf):
self.clf = baseline_clf
@ -483,7 +427,8 @@ class MissingValuesClassifierChain:
X = np.array(X)
Y = np.array(Y)
if self.permutation is None:
self.permutation = np.random.permutation(len(Y[0]))
rng = default_rng(42)
self.permutation = rng.permutation(len(Y[0]))
Y = Y[:, self.permutation]
@ -541,7 +486,7 @@ class EnsembleClassifierChain:
self.num_labels = len(Y[0])
for p in range(self.num_chains):
print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
clf = MissingValuesClassifierChain(self.base_clf)
clf.fit(X, Y)
self.classifiers.append(clf)
@ -609,13 +554,23 @@ class RelativeReasoning:
def predict(self, X):
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
# Loop through all instances
for inst_idx, inst in enumerate(X):
# Loop through all "triggered" features
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
# Set label
res[inst_idx][i] = t
# If we predict a 1, check if the rule gets dominated by another
if t:
# Second loop to check other triggered rules
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
res[inst_idx][i] = 0
if i != i2:
# Check if rule idx is in "dominated by" list
if i2 in self.winmap.get(i, []):
# if thatat rule also triggered, it dominated the current
# set label to 0
if X[inst_idx][i2]:
res[inst_idx][i] = 0
return res
@ -671,3 +626,259 @@ def tanimoto_distance(a: List[int], b: List[int]):
return 0.0
return 1 - (sum_c / (sum_a + sum_b - sum_c))
def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway
from epdb.logic import SPathway
graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
def get_edges():
if isinstance(data, Pathway):
return data.edges.all()
elif isinstance(data, SPathway):
return data.edges
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_sources_targets():
if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
elif isinstance(data, SPathway):
return edge.educts, edge.products
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_smiles_depth(node):
if isinstance(data, Pathway):
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
elif isinstance(data, SPathway):
return FormatConverter.standardize(node.smiles, True), node.depth
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
def get_probability():
try:
if isinstance(data, Pathway):
return edge.kv.get('probability')
elif isinstance(data, SPathway):
return edge.probability
else:
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
except AttributeError:
return 1
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
for root, depth in root_smiles:
graph.add_node(root, depth=depth, smiles=root, root=True)
for edge in get_edges():
sources, targets = get_sources_targets()
probability = get_probability()
for source in sources:
source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph:
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
root=source_smiles in root_smiles)
else:
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
for target in targets:
target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2:
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
root=target_smiles in root_smiles)
elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph
def get_shortest_path(pathway, in_start_node, in_end_node):
try:
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
except nx.NetworkXNoPath:
return []
pred.remove(in_start_node)
pred.remove(in_end_node)
return pred
def set_pathway_eval_weight(pathway):
node_eval_weights = {}
for node in pathway.nodes:
# Scale score according to depth level
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
return node_eval_weights
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
if len(intermediates) < 1:
return pred_pathway
root_nodes = pred_pathway.graph["root_nodes"]
for node in pred_pathway.nodes:
if node in root_nodes:
continue
if node in intermediates and node not in data_pathway:
pred_pathway.nodes[node]["depth"] = -99
else:
shortest_path_list = []
for root_node in root_nodes:
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
if shortest_path_nodes:
shortest_path_list.append(shortest_path_nodes)
if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
shortest_path_node in intermediates)
pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway
def initialise_pathway(pathway):
"""Convert pathway to networkx graph for evaluation"""
pathway = graph_from_pathway(pathway)
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
pathway = get_pathway_with_depth(pathway)
return pathway
def get_pathway_with_depth(pathway):
"""Recalculates depths in the pathway.
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
different depths that got merged."""
current_depth = 0
for node in pathway.nodes:
if node in pathway.graph["root_nodes"]:
pathway.nodes[node]["depth"] = current_depth
else:
pathway.nodes[node]["depth"] = -99
while assign_next_depth(pathway, current_depth):
current_depth += 1
return pathway
def assign_next_depth(pathway, current_depth):
new_assigned_nodes = False
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
for node in current_depth_nodes:
successors = pathway.successors(node)
for s in successors:
if pathway.nodes[s]["depth"] < 0:
pathway.nodes[s]["depth"] = current_depth + 1
new_assigned_nodes = True
return new_assigned_nodes
def find_intermediates(data_pathway, pred_pathway):
"""Find any intermediate nodes in the predicted pathway"""
common_nodes = get_common_nodes(pred_pathway, data_pathway)
intermediates = set()
for node in common_nodes:
down_stream_nodes = data_pathway.successors(node)
for down_stream_node in down_stream_nodes:
if down_stream_node in pred_pathway:
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
intermediates.update(all_ints)
return intermediates
def get_common_nodes(pred_pathway, data_pathway):
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
common_nodes = set()
for node in data_pathway.nodes:
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
is_this_root_node = node in pred_pathway.graph["root_nodes"]
if node in pred_pathway.nodes:
if is_pathway_root_node is False and is_this_root_node is False:
common_nodes.add(node)
elif is_pathway_root_node and is_this_root_node:
common_nodes.add(node)
return common_nodes
def prune_graph(graph, threshold):
"""
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
"""
while True:
try:
cycle = nx.find_cycle(graph)
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
except nx.NetworkXNoCycle:
break
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
if data["probability"] < threshold:
graph.remove_edge(u, v)
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
reachable.add(root_node)
for node in list(graph.nodes): # Remove nodes not reachable from root
if node not in reachable:
graph.remove_node(node)
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
"""Compare two pathways for multi-gen evaluation.
It is assumed the smiles in both pathways have been standardised in the same manner."""
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
if threshold is not None:
prune_graph(pred_pathway, threshold)
intermediates = find_intermediates(data_pathway, pred_pathway)
if intermediates:
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
common_nodes = get_common_nodes(pred_pathway, data_pathway)
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
for node in common_nodes:
if pred_pathway.nodes[node]["depth"] > 0:
score_TP += test_pathway_eval_weights[node]
for node in data_only_nodes:
if data_pathway.nodes[node]["depth"] > 0:
score_FN += test_pathway_eval_weights[node]
for node in pred_only_nodes:
if pred_pathway.nodes[node]["depth"] > 0:
score_FP += pred_pathway_eval_weights[node]
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
if return_intermediates:
return final_score, precision, recall, intermediates
return final_score, precision, recall
def node_subst_cost(node1, node2):
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
return 0
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
def node_ins_del_cost(node):
return 1 / (2 ** node["depth"])
def pathway_edit_eval(data_pathway, pred_pathway):
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(data_pathway, pred_pathway,
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost, roots=roots)