forked from enviPath/enviPy
[Feature] MultiGen Eval (Backend) (#117)
Fixes #16 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#117
This commit is contained in:
@ -183,7 +183,7 @@ class FormatConverter(object):
|
||||
return smiles
|
||||
|
||||
@staticmethod
|
||||
def standardize(smiles):
|
||||
def standardize(smiles, remove_stereo=False):
|
||||
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
|
||||
# follows the steps in
|
||||
# https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/MolStandardize%20pieces.ipynb
|
||||
@ -208,6 +208,9 @@ class FormatConverter(object):
|
||||
# te = rdMolStandardize.TautomerEnumerator() # idem
|
||||
# taut_uncharged_parent_clean_mol = te.Canonicalize(uncharged_parent_clean_mol)
|
||||
|
||||
if remove_stereo:
|
||||
Chem.RemoveStereochemistry(uncharged_parent_clean_mol)
|
||||
|
||||
return Chem.MolToSmiles(uncharged_parent_clean_mol, kekuleSmiles=True)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -919,7 +919,7 @@ class PackageImporter:
|
||||
name=edge_data['name'],
|
||||
description=edge_data['description'],
|
||||
kv=edge_data.get('kv', {}),
|
||||
edge_label=None # Will be set later
|
||||
edge_label=self._get_cached_object('Reaction', edge_data['edge_label']['uuid'])
|
||||
)
|
||||
|
||||
# Set aliases if present
|
||||
|
||||
349
utilities/ml.py
349
utilities/ml.py
@ -1,12 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import default_rng
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
@ -22,61 +29,6 @@ from dataclasses import dataclass, field
|
||||
from utilities.chem import FormatConverter, PredictionResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class SCompound:
|
||||
smiles: str
|
||||
uuid: str = field(default=None, compare=False, hash=False)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((
|
||||
self.smiles
|
||||
))
|
||||
return self._hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class SReaction:
|
||||
educts: List[SCompound]
|
||||
products: List[SCompound]
|
||||
rule_uuid: SRule = field(default=None, compare=False, hash=False)
|
||||
reaction_uuid: str = field(default=None, compare=False, hash=False)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((
|
||||
tuple(sorted(self.educts, key=lambda x: x.smiles)),
|
||||
tuple(sorted(self.products, key=lambda x: x.smiles)),
|
||||
))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, SReaction):
|
||||
return NotImplemented
|
||||
return (
|
||||
sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and
|
||||
sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SRule(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSimpleRule:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SParallelRule:
|
||||
pass
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
||||
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
|
||||
@ -385,7 +337,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
|
||||
|
||||
for i, chain in enumerate(self.chains_):
|
||||
print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
||||
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
||||
chain.fit(X, y_reduced)
|
||||
|
||||
return self
|
||||
@ -423,14 +375,6 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
|
||||
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
|
||||
class BinaryRelevance:
|
||||
def __init__(self, baseline_clf):
|
||||
self.clf = baseline_clf
|
||||
@ -483,7 +427,8 @@ class MissingValuesClassifierChain:
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if self.permutation is None:
|
||||
self.permutation = np.random.permutation(len(Y[0]))
|
||||
rng = default_rng(42)
|
||||
self.permutation = rng.permutation(len(Y[0]))
|
||||
|
||||
Y = Y[:, self.permutation]
|
||||
|
||||
@ -541,7 +486,7 @@ class EnsembleClassifierChain:
|
||||
self.num_labels = len(Y[0])
|
||||
|
||||
for p in range(self.num_chains):
|
||||
print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
clf = MissingValuesClassifierChain(self.base_clf)
|
||||
clf.fit(X, Y)
|
||||
self.classifiers.append(clf)
|
||||
@ -609,13 +554,23 @@ class RelativeReasoning:
|
||||
def predict(self, X):
|
||||
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
||||
|
||||
# Loop through all instances
|
||||
for inst_idx, inst in enumerate(X):
|
||||
# Loop through all "triggered" features
|
||||
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
# Set label
|
||||
res[inst_idx][i] = t
|
||||
# If we predict a 1, check if the rule gets dominated by another
|
||||
if t:
|
||||
# Second loop to check other triggered rules
|
||||
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
|
||||
res[inst_idx][i] = 0
|
||||
if i != i2:
|
||||
# Check if rule idx is in "dominated by" list
|
||||
if i2 in self.winmap.get(i, []):
|
||||
# if thatat rule also triggered, it dominated the current
|
||||
# set label to 0
|
||||
if X[inst_idx][i2]:
|
||||
res[inst_idx][i] = 0
|
||||
|
||||
return res
|
||||
|
||||
@ -671,3 +626,259 @@ def tanimoto_distance(a: List[int], b: List[int]):
|
||||
return 0.0
|
||||
|
||||
return 1 - (sum_c / (sum_a + sum_b - sum_c))
|
||||
|
||||
|
||||
def graph_from_pathway(data):
|
||||
"""Convert Pathway or SPathway to networkx"""
|
||||
from epdb.models import Pathway
|
||||
from epdb.logic import SPathway
|
||||
graph = nx.DiGraph()
|
||||
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
|
||||
|
||||
def get_edges():
|
||||
if isinstance(data, Pathway):
|
||||
return data.edges.all()
|
||||
elif isinstance(data, SPathway):
|
||||
return data.edges
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
|
||||
def get_sources_targets():
|
||||
if isinstance(data, Pathway):
|
||||
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
|
||||
elif isinstance(data, SPathway):
|
||||
return edge.educts, edge.products
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
|
||||
def get_smiles_depth(node):
|
||||
if isinstance(data, Pathway):
|
||||
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
|
||||
elif isinstance(data, SPathway):
|
||||
return FormatConverter.standardize(node.smiles, True), node.depth
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
|
||||
def get_probability():
|
||||
try:
|
||||
if isinstance(data, Pathway):
|
||||
return edge.kv.get('probability')
|
||||
elif isinstance(data, SPathway):
|
||||
return edge.probability
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
except AttributeError:
|
||||
return 1
|
||||
|
||||
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
|
||||
for root, depth in root_smiles:
|
||||
graph.add_node(root, depth=depth, smiles=root, root=True)
|
||||
|
||||
for edge in get_edges():
|
||||
sources, targets = get_sources_targets()
|
||||
probability = get_probability()
|
||||
for source in sources:
|
||||
source_smiles, source_depth = get_smiles_depth(source)
|
||||
if source_smiles not in graph:
|
||||
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
|
||||
root=source_smiles in root_smiles)
|
||||
else:
|
||||
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
|
||||
for target in targets:
|
||||
target_smiles, target_depth = get_smiles_depth(target)
|
||||
if target_smiles not in graph and target_smiles not in co2:
|
||||
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
|
||||
root=target_smiles in root_smiles)
|
||||
elif target_smiles not in co2:
|
||||
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
|
||||
if target_smiles not in co2 and target_smiles != source_smiles:
|
||||
graph.add_edge(source_smiles, target_smiles, probability=probability)
|
||||
return graph
|
||||
|
||||
|
||||
def get_shortest_path(pathway, in_start_node, in_end_node):
|
||||
try:
|
||||
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
|
||||
except nx.NetworkXNoPath:
|
||||
return []
|
||||
pred.remove(in_start_node)
|
||||
pred.remove(in_end_node)
|
||||
return pred
|
||||
|
||||
|
||||
def set_pathway_eval_weight(pathway):
|
||||
node_eval_weights = {}
|
||||
for node in pathway.nodes:
|
||||
# Scale score according to depth level
|
||||
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
|
||||
return node_eval_weights
|
||||
|
||||
|
||||
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
|
||||
if len(intermediates) < 1:
|
||||
return pred_pathway
|
||||
root_nodes = pred_pathway.graph["root_nodes"]
|
||||
for node in pred_pathway.nodes:
|
||||
if node in root_nodes:
|
||||
continue
|
||||
if node in intermediates and node not in data_pathway:
|
||||
pred_pathway.nodes[node]["depth"] = -99
|
||||
else:
|
||||
shortest_path_list = []
|
||||
for root_node in root_nodes:
|
||||
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
|
||||
if shortest_path_nodes:
|
||||
shortest_path_list.append(shortest_path_nodes)
|
||||
if shortest_path_list:
|
||||
shortest_path_nodes = min(shortest_path_list, key=len)
|
||||
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
|
||||
shortest_path_node in intermediates)
|
||||
pred_pathway.nodes[node]["depth"] -= num_ints
|
||||
return pred_pathway
|
||||
|
||||
|
||||
def initialise_pathway(pathway):
|
||||
"""Convert pathway to networkx graph for evaluation"""
|
||||
pathway = graph_from_pathway(pathway)
|
||||
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
|
||||
pathway = get_pathway_with_depth(pathway)
|
||||
return pathway
|
||||
|
||||
|
||||
def get_pathway_with_depth(pathway):
|
||||
"""Recalculates depths in the pathway.
|
||||
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
|
||||
different depths that got merged."""
|
||||
current_depth = 0
|
||||
for node in pathway.nodes:
|
||||
if node in pathway.graph["root_nodes"]:
|
||||
pathway.nodes[node]["depth"] = current_depth
|
||||
else:
|
||||
pathway.nodes[node]["depth"] = -99
|
||||
while assign_next_depth(pathway, current_depth):
|
||||
current_depth += 1
|
||||
return pathway
|
||||
|
||||
|
||||
def assign_next_depth(pathway, current_depth):
|
||||
new_assigned_nodes = False
|
||||
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
|
||||
for node in current_depth_nodes:
|
||||
successors = pathway.successors(node)
|
||||
for s in successors:
|
||||
if pathway.nodes[s]["depth"] < 0:
|
||||
pathway.nodes[s]["depth"] = current_depth + 1
|
||||
new_assigned_nodes = True
|
||||
return new_assigned_nodes
|
||||
|
||||
|
||||
def find_intermediates(data_pathway, pred_pathway):
|
||||
"""Find any intermediate nodes in the predicted pathway"""
|
||||
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
||||
intermediates = set()
|
||||
for node in common_nodes:
|
||||
down_stream_nodes = data_pathway.successors(node)
|
||||
for down_stream_node in down_stream_nodes:
|
||||
if down_stream_node in pred_pathway:
|
||||
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
|
||||
intermediates.update(all_ints)
|
||||
return intermediates
|
||||
|
||||
|
||||
def get_common_nodes(pred_pathway, data_pathway):
|
||||
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
|
||||
common_nodes = set()
|
||||
for node in data_pathway.nodes:
|
||||
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
|
||||
is_this_root_node = node in pred_pathway.graph["root_nodes"]
|
||||
if node in pred_pathway.nodes:
|
||||
if is_pathway_root_node is False and is_this_root_node is False:
|
||||
common_nodes.add(node)
|
||||
elif is_pathway_root_node and is_this_root_node:
|
||||
common_nodes.add(node)
|
||||
return common_nodes
|
||||
|
||||
|
||||
def prune_graph(graph, threshold):
|
||||
"""
|
||||
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
cycle = nx.find_cycle(graph)
|
||||
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
|
||||
except nx.NetworkXNoCycle:
|
||||
break
|
||||
|
||||
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
|
||||
if data["probability"] < threshold:
|
||||
graph.remove_edge(u, v)
|
||||
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
|
||||
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
|
||||
reachable.add(root_node)
|
||||
|
||||
for node in list(graph.nodes): # Remove nodes not reachable from root
|
||||
if node not in reachable:
|
||||
graph.remove_node(node)
|
||||
|
||||
|
||||
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
|
||||
"""Compare two pathways for multi-gen evaluation.
|
||||
It is assumed the smiles in both pathways have been standardised in the same manner."""
|
||||
data_pathway = initialise_pathway(data_pathway)
|
||||
pred_pathway = initialise_pathway(pred_pathway)
|
||||
if threshold is not None:
|
||||
prune_graph(pred_pathway, threshold)
|
||||
intermediates = find_intermediates(data_pathway, pred_pathway)
|
||||
|
||||
if intermediates:
|
||||
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
|
||||
|
||||
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
|
||||
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
|
||||
|
||||
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
||||
|
||||
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
|
||||
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
|
||||
|
||||
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
for node in common_nodes:
|
||||
if pred_pathway.nodes[node]["depth"] > 0:
|
||||
score_TP += test_pathway_eval_weights[node]
|
||||
|
||||
for node in data_only_nodes:
|
||||
if data_pathway.nodes[node]["depth"] > 0:
|
||||
score_FN += test_pathway_eval_weights[node]
|
||||
|
||||
for node in pred_only_nodes:
|
||||
if pred_pathway.nodes[node]["depth"] > 0:
|
||||
score_FP += pred_pathway_eval_weights[node]
|
||||
|
||||
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
|
||||
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
|
||||
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
|
||||
if return_intermediates:
|
||||
return final_score, precision, recall, intermediates
|
||||
return final_score, precision, recall
|
||||
|
||||
|
||||
def node_subst_cost(node1, node2):
|
||||
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
|
||||
return 0
|
||||
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
|
||||
|
||||
|
||||
def node_ins_del_cost(node):
|
||||
return 1 / (2 ** node["depth"])
|
||||
|
||||
|
||||
def pathway_edit_eval(data_pathway, pred_pathway):
|
||||
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
|
||||
data_pathway = initialise_pathway(data_pathway)
|
||||
pred_pathway = initialise_pathway(pred_pathway)
|
||||
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
|
||||
return nx.graph_edit_distance(data_pathway, pred_pathway,
|
||||
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
|
||||
node_ins_cost=node_ins_del_cost, roots=roots)
|
||||
|
||||
Reference in New Issue
Block a user