forked from enviPath/enviPy
966 lines
35 KiB
Python
966 lines
35 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
|
from abc import ABC, abstractmethod
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
from numpy.random import default_rng
|
|
import polars as pl
|
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
|
from sklearn.decomposition import PCA
|
|
from sklearn.dummy import DummyClassifier
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.metrics import accuracy_score
|
|
from sklearn.multioutput import ClassifierChain
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
from utilities.chem import FormatConverter, PredictionResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from epdb.models import Rule, CompoundStructure, Reaction
|
|
|
|
|
|
class Dataset(ABC):
|
|
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
|
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
|
|
self.df = data
|
|
else:
|
|
# Build either an empty dataframe with columns or fill it with list of list data
|
|
if data is not None and len(columns) != len(data[0]):
|
|
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
|
|
if columns is None:
|
|
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
|
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
|
|
|
|
def add_rows(self, rows: List[List[str | int | float]]):
|
|
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
|
|
if len(self.columns) != len(rows[0]):
|
|
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
|
|
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
|
|
self.df.extend(new_rows)
|
|
|
|
def add_row(self, row: List[str | int | float]):
|
|
"""See add_rows"""
|
|
self.add_rows([row])
|
|
|
|
def block_indices(self, prefix) -> List[int]:
|
|
"""Find the start and end indexes in column labels that has the prefix"""
|
|
indices: List[int] = []
|
|
for i, feature in enumerate(self.columns):
|
|
if feature.startswith(prefix):
|
|
indices.append(i)
|
|
return indices
|
|
|
|
@property
|
|
def columns(self) -> List[str]:
|
|
"""Use the polars dataframe columns"""
|
|
return self.df.columns
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.df.shape
|
|
|
|
@abstractmethod
|
|
def X(self, **kwargs):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def y(self, **kwargs):
|
|
pass
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def generate_dataset(reactions, *args, **kwargs):
|
|
pass
|
|
|
|
def at(self, position: int) -> Dataset:
|
|
"""See __getitem__"""
|
|
return self[position]
|
|
|
|
def limit(self, limit: int) -> Dataset:
|
|
"""See __getitem__"""
|
|
return self[:limit]
|
|
|
|
def __iter__(self):
|
|
"""Use polars iter_rows for iterating over the dataset"""
|
|
return self.df.iter_rows()
|
|
|
|
def __getitem__(self, item):
|
|
"""Item is passed to polars allowing for advanced indexing.
|
|
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
|
res = self.df[item]
|
|
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
|
return self.__class__(data=res)
|
|
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
|
return res
|
|
|
|
def save(self, path: "Path | str"):
|
|
import pickle
|
|
|
|
with open(path, "wb") as fh:
|
|
pickle.dump(self, fh)
|
|
|
|
@staticmethod
|
|
def load(path: "str | Path") -> "Dataset":
|
|
import pickle
|
|
|
|
return pickle.load(open(path, "rb"))
|
|
|
|
def to_numpy(self):
|
|
return self.df.to_numpy()
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.df)
|
|
|
|
def iter_rows(self, named=False):
|
|
return self.df.iter_rows(named=named)
|
|
|
|
|
|
class RuleBasedDataset(Dataset):
|
|
def __init__(self, num_labels=None, columns=None, data=None):
|
|
super().__init__(columns, data)
|
|
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
|
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
|
# Pre-calculate the ids of columns for features/labels, useful later in X and y
|
|
self._struct_features: List[int] = self.block_indices("feature_")
|
|
self._triggered: List[int] = self.block_indices("trig_")
|
|
self._observed: List[int] = self.block_indices("obs_")
|
|
self.feature_cols: List[int] = self._struct_features + self._triggered
|
|
self.num_features: int = len(self.feature_cols)
|
|
self.has_probs = False
|
|
|
|
def times_triggered(self, rule_uuid) -> int:
|
|
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
|
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
|
|
|
def struct_features(self) -> List[int]:
|
|
return self._struct_features
|
|
|
|
def triggered(self) -> List[int]:
|
|
return self._triggered
|
|
|
|
def observed(self) -> List[int]:
|
|
return self._observed
|
|
|
|
def structure_id(self, index: int):
|
|
"""Get the UUID of a compound"""
|
|
return self.df.item(index, "structure_id")
|
|
|
|
def X(self, exclude_id_col=True, na_replacement=0):
|
|
"""Get all the feature and trig columns"""
|
|
_col_ids = self.feature_cols
|
|
if not exclude_id_col:
|
|
_col_ids = [0] + _col_ids
|
|
res = self[:, _col_ids]
|
|
if na_replacement is not None:
|
|
res.df = res.df.fill_null(na_replacement)
|
|
return res
|
|
|
|
def trig(self, na_replacement=0):
|
|
"""Get all the trig columns"""
|
|
res = self[:, self._triggered]
|
|
if na_replacement is not None:
|
|
res.df = res.df.fill_null(na_replacement)
|
|
return res
|
|
|
|
def y(self, na_replacement=0):
|
|
"""Get all the obs columns"""
|
|
res = self[:, self._observed]
|
|
if na_replacement is not None:
|
|
res.df = res.df.fill_null(na_replacement)
|
|
return res
|
|
|
|
@staticmethod
|
|
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
|
_structures = set() # Get all the structures
|
|
for r in reactions:
|
|
_structures.update(r.educts.all())
|
|
if not educts_only:
|
|
_structures.update(r.products.all())
|
|
|
|
compounds = sorted(_structures, key=lambda x: x.url)
|
|
triggered: Dict[str, Set[str]] = defaultdict(set)
|
|
observed: Set[str] = set()
|
|
|
|
# Apply rules on collected compounds and store tps
|
|
for i, comp in enumerate(compounds):
|
|
logger.debug(f"{i + 1}/{len(compounds)}...")
|
|
|
|
for rule in applicable_rules:
|
|
product_sets = rule.apply(comp.smiles)
|
|
if len(product_sets) == 0:
|
|
continue
|
|
|
|
key = f"{rule.uuid} + {comp.uuid}"
|
|
if key in triggered:
|
|
logger.info(f"{key} already present. Duplicate reaction?")
|
|
|
|
for prod_set in product_sets:
|
|
for smi in prod_set:
|
|
try:
|
|
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
|
except Exception:
|
|
logger.debug(f"Standardizing SMILES failed for {smi}")
|
|
triggered[key].add(smi)
|
|
|
|
for i, r in enumerate(reactions):
|
|
logger.debug(f"{i + 1}/{len(reactions)}...")
|
|
|
|
if len(r.educts.all()) != 1:
|
|
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
|
continue
|
|
|
|
for comp in r.educts.all():
|
|
for rule in applicable_rules:
|
|
key = f"{rule.uuid} + {comp.uuid}"
|
|
if key not in triggered:
|
|
continue
|
|
|
|
# standardize products from reactions for comparison
|
|
standardized_products = []
|
|
for cs in r.products.all():
|
|
smi = cs.smiles
|
|
try:
|
|
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
|
except Exception as e:
|
|
logger.debug(f"Standardizing SMILES failed for {smi}")
|
|
standardized_products.append(smi)
|
|
if len(set(standardized_products).difference(triggered[key])) == 0:
|
|
observed.add(key)
|
|
|
|
ds_columns = (["structure_id"] +
|
|
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
|
|
[f"trig_{r.uuid}" for r in applicable_rules] +
|
|
[f"obs_{r.uuid}" for r in applicable_rules])
|
|
rows = []
|
|
|
|
for i, comp in enumerate(compounds):
|
|
# Features
|
|
feat = FormatConverter.maccs(comp.smiles)
|
|
trig = []
|
|
obs = []
|
|
for rule in applicable_rules:
|
|
key = f"{rule.uuid} + {comp.uuid}"
|
|
# Check triggered
|
|
if key in triggered:
|
|
trig.append(1)
|
|
else:
|
|
trig.append(0)
|
|
# Check obs
|
|
if key in observed:
|
|
obs.append(1)
|
|
elif key not in triggered:
|
|
obs.append(None)
|
|
else:
|
|
obs.append(0)
|
|
rows.append([str(comp.uuid)] + feat + trig + obs)
|
|
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
|
|
return ds
|
|
|
|
def classification_dataset(
|
|
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
|
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
|
classify_data = []
|
|
classify_products = []
|
|
for struct in structures:
|
|
if isinstance(struct, str):
|
|
struct_id = None
|
|
struct_smiles = struct
|
|
else:
|
|
struct_id = str(struct.uuid)
|
|
struct_smiles = struct.smiles
|
|
|
|
features = FormatConverter.maccs(struct_smiles)
|
|
|
|
trig = []
|
|
prods = []
|
|
for rule in applicable_rules:
|
|
products = rule.apply(struct_smiles)
|
|
|
|
if len(products):
|
|
trig.append(1)
|
|
prods.append(products)
|
|
else:
|
|
trig.append(0)
|
|
prods.append([])
|
|
new_row = [struct_id] + features + trig + ([-1] * len(trig))
|
|
if self.has_probs:
|
|
new_row += [-1] * len(trig)
|
|
classify_data.append(new_row)
|
|
classify_products.append(prods)
|
|
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
|
|
return ds, classify_products
|
|
|
|
def add_probs(self, probs):
|
|
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
|
|
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
|
|
self.has_probs = True
|
|
|
|
def to_arff(self, path: "Path"):
|
|
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
|
arff += "\n"
|
|
for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
|
|
if c == "structure_id":
|
|
arff += f"@attribute {c} string\n"
|
|
else:
|
|
arff += f"@attribute {c} {{0,1}}\n"
|
|
|
|
arff += "\n@data\n"
|
|
for d in self:
|
|
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
|
|
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
|
|
arff += f"{ys},{xs}\n"
|
|
|
|
with open(path, "w") as fh:
|
|
fh.write(arff)
|
|
fh.flush()
|
|
|
|
|
|
class EnviFormerDataset(Dataset):
|
|
def __init__(self, columns=None, data=None):
|
|
super().__init__(columns, data)
|
|
|
|
def X(self):
|
|
"""Return the educts"""
|
|
return self["educts"]
|
|
|
|
def y(self):
|
|
"""Return the products"""
|
|
return self["products"]
|
|
|
|
@staticmethod
|
|
def generate_dataset(reactions, *args, **kwargs):
|
|
# Standardise reactions for the training data
|
|
stereo = kwargs.get("stereo", False)
|
|
rows = []
|
|
for reaction in reactions:
|
|
e = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
|
for smile in reaction.educts.all()
|
|
]
|
|
)
|
|
p = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
|
for smile in reaction.products.all()
|
|
]
|
|
)
|
|
rows.append([e, p])
|
|
ds = EnviFormerDataset(["educts", "products"], rows)
|
|
return ds
|
|
|
|
|
|
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
|
"""
|
|
Ensemble of Classifier Chains with sparse label removal.
|
|
Removes labels that are constant across all samples in training.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
|
|
num_chains: int = 10,
|
|
):
|
|
self.base_clf = base_clf
|
|
self.num_chains = num_chains
|
|
|
|
def fit(self, X, Y):
|
|
y = np.array(Y)
|
|
self.n_labels_ = y.shape[1]
|
|
self.removed_labels_ = {}
|
|
self.keep_columns_ = []
|
|
|
|
for col in range(self.n_labels_):
|
|
unique_values = np.unique(y[:, col])
|
|
if len(unique_values) == 1:
|
|
self.removed_labels_[col] = unique_values[0]
|
|
else:
|
|
self.keep_columns_.append(col)
|
|
|
|
y_reduced = y[:, self.keep_columns_]
|
|
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
|
|
|
|
for i, chain in enumerate(self.chains_):
|
|
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
|
chain.fit(X, y_reduced)
|
|
|
|
return self
|
|
|
|
def predict(self, X, threshold=0.5):
|
|
avg_preds = np.mean([chain.predict(X) for chain in self.chains_], axis=0) > threshold
|
|
full_y = np.zeros((avg_preds.shape[0], self.n_labels_))
|
|
|
|
for idx, col in enumerate(self.keep_columns_):
|
|
full_y[:, col] = avg_preds[:, idx]
|
|
|
|
for col, value in self.removed_labels_.items():
|
|
full_y[:, col] = bool(value)
|
|
|
|
return full_y
|
|
|
|
def predict_proba(self, X):
|
|
avg_proba = np.mean([chain.predict_proba(X) for chain in self.chains_], axis=0)
|
|
full_y = np.zeros((avg_proba.shape[0], self.n_labels_))
|
|
|
|
for idx, col in enumerate(self.keep_columns_):
|
|
full_y[:, col] = avg_proba[:, idx]
|
|
|
|
for col, value in self.removed_labels_.items():
|
|
full_y[:, col] = float(value)
|
|
|
|
return full_y
|
|
|
|
def score(self, X, Y, sample_weight=None):
|
|
"""
|
|
Default scoring using subset accuracy (exact match).
|
|
"""
|
|
y_true = np.array(Y)
|
|
y_pred = self.predict(X)
|
|
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
|
|
|
|
|
|
class BinaryRelevance:
|
|
def __init__(self, baseline_clf):
|
|
self.clf = baseline_clf
|
|
self.classifiers = None
|
|
|
|
def fit(self, X, Y):
|
|
if self.classifiers is None:
|
|
self.classifiers = []
|
|
|
|
for label in range(len(Y[0])):
|
|
X_l = X[~np.isnan(Y[:, label])]
|
|
Y_l = Y[~np.isnan(Y[:, label]), label]
|
|
if len(X_l) == 0: # all labels are nan -> predict 0
|
|
clf = DummyClassifier(strategy="constant", constant=0)
|
|
clf.fit([X[0]], [0])
|
|
self.classifiers.append(clf)
|
|
continue
|
|
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
|
|
clf = DummyClassifier(strategy="most_frequent")
|
|
else:
|
|
clf = copy.deepcopy(self.clf)
|
|
clf.fit(X_l, Y_l)
|
|
self.classifiers.append(clf)
|
|
|
|
def predict(self, X):
|
|
labels = []
|
|
for clf in self.classifiers:
|
|
labels.append(clf.predict(X))
|
|
return np.column_stack(labels)
|
|
|
|
def predict_proba(self, X):
|
|
labels = np.empty((len(X), 0))
|
|
for clf in self.classifiers:
|
|
pred = clf.predict_proba(X)
|
|
if pred.shape[1] > 1:
|
|
pred = pred[:, 1]
|
|
else:
|
|
pred = pred * clf.predict([X[0]])[0]
|
|
labels = np.column_stack((labels, pred))
|
|
return labels
|
|
|
|
|
|
class MissingValuesClassifierChain:
|
|
def __init__(self, base_clf):
|
|
self.base_clf = base_clf
|
|
self.permutation = None
|
|
self.classifiers = None
|
|
|
|
def fit(self, X, Y):
|
|
X = np.array(X)
|
|
Y = np.array(Y)
|
|
if self.permutation is None:
|
|
rng = default_rng(42)
|
|
self.permutation = rng.permutation(len(Y[0]))
|
|
|
|
Y = Y[:, self.permutation]
|
|
|
|
if self.classifiers is None:
|
|
self.classifiers = []
|
|
|
|
for p in range(len(self.permutation)):
|
|
X_p = X[~np.isnan(Y[:, p])]
|
|
Y_p = Y[~np.isnan(Y[:, p]), p]
|
|
if len(X_p) == 0: # all labels are nan -> predict 0
|
|
clf = DummyClassifier(strategy="constant", constant=0)
|
|
self.classifiers.append(clf.fit([X[0]], [0]))
|
|
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
|
|
clf = DummyClassifier(strategy="most_frequent")
|
|
self.classifiers.append(clf.fit(X_p, Y_p))
|
|
else:
|
|
clf = copy.deepcopy(self.base_clf)
|
|
self.classifiers.append(clf.fit(X_p, Y_p))
|
|
newcol = Y[:, p]
|
|
pred = clf.predict(X)
|
|
newcol[np.isnan(newcol)] = pred[
|
|
np.isnan(newcol)
|
|
] # fill in missing values with clf predictions
|
|
X = np.column_stack((X, newcol))
|
|
|
|
def predict(self, X):
|
|
labels = np.empty((len(X), 0))
|
|
for clf in self.classifiers:
|
|
pred = clf.predict(np.column_stack((X, labels)))
|
|
labels = np.column_stack((labels, pred))
|
|
return labels[:, np.argsort(self.permutation)]
|
|
|
|
def predict_proba(self, X):
|
|
labels = np.empty((len(X), 0))
|
|
for clf in self.classifiers:
|
|
pred = clf.predict_proba(np.column_stack((X, np.round(labels))))
|
|
if pred.shape[1] > 1:
|
|
pred = pred[:, 1]
|
|
else:
|
|
pred = pred * clf.predict(np.column_stack(([X[0]], np.round([labels[0]]))))[0]
|
|
labels = np.column_stack((labels, pred))
|
|
return labels[:, np.argsort(self.permutation)]
|
|
|
|
|
|
class EnsembleClassifierChain:
|
|
def __init__(self, base_clf, num_chains=10):
|
|
self.base_clf = base_clf
|
|
self.num_chains = num_chains
|
|
self.num_labels = None
|
|
self.classifiers = None
|
|
|
|
def fit(self, X, Y):
|
|
if self.classifiers is None:
|
|
self.classifiers = []
|
|
|
|
if self.num_labels is None:
|
|
self.num_labels = Y.shape[1]
|
|
|
|
for p in range(self.num_chains):
|
|
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
|
clf = MissingValuesClassifierChain(self.base_clf)
|
|
clf.fit(X, Y)
|
|
self.classifiers.append(clf)
|
|
|
|
def predict(self, X):
|
|
labels = np.zeros((len(X), self.num_labels))
|
|
for clf in self.classifiers:
|
|
labels += clf.predict(X)
|
|
return np.round(labels / self.num_chains)
|
|
|
|
def predict_proba(self, X):
|
|
labels = np.zeros((len(X), self.num_labels))
|
|
for clf in self.classifiers:
|
|
labels += clf.predict_proba(X)
|
|
return labels / self.num_chains
|
|
|
|
|
|
class RelativeReasoning:
|
|
def __init__(self, start_index: int, end_index: int):
|
|
self.start_index: int = start_index
|
|
self.end_index: int = end_index
|
|
self.winmap: Dict[int, List[int]] = defaultdict(list)
|
|
self.min_count: int = 5
|
|
self.max_count: int = 0
|
|
|
|
def fit(self, X, Y):
|
|
n_instances = len(Y)
|
|
n_attributes = Y.shape[1]
|
|
|
|
for i in range(n_attributes):
|
|
for j in range(n_attributes):
|
|
if i == j:
|
|
continue
|
|
|
|
countwin = 0
|
|
countloose = 0
|
|
countboth = 0
|
|
|
|
for k in range(n_instances):
|
|
vi = Y[k, i]
|
|
vj = Y[k, j]
|
|
|
|
if vi is None or vj is None:
|
|
continue
|
|
|
|
if vi < vj:
|
|
countwin += 1
|
|
elif vi > vj:
|
|
countloose += 1
|
|
elif vi == vj and vi == 1: # tie
|
|
countboth += 1
|
|
|
|
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
|
|
if (
|
|
countwin >= self.min_count
|
|
and countwin > countloose
|
|
and (countloose <= self.max_count or self.max_count < 0)
|
|
and countboth == 0
|
|
):
|
|
self.winmap[i].append(j)
|
|
|
|
def predict(self, X):
|
|
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
|
|
|
# Loop through all instances
|
|
for inst_idx, inst in enumerate(X):
|
|
# Loop through all "triggered" features
|
|
for i, t in enumerate(inst[self.start_index : self.end_index + 1]):
|
|
# Set label
|
|
res[inst_idx][i] = t
|
|
# If we predict a 1, check if the rule gets dominated by another
|
|
if t:
|
|
# Second loop to check other triggered rules
|
|
for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]):
|
|
if i != i2:
|
|
# Check if rule idx is in "dominated by" list
|
|
if i2 in self.winmap.get(i, []):
|
|
# if thatat rule also triggered, it dominated the current
|
|
# set label to 0
|
|
if X[inst_idx][i2]:
|
|
res[inst_idx][i] = 0
|
|
|
|
return res
|
|
|
|
def predict_proba(self, X):
|
|
return self.predict(X)
|
|
|
|
|
|
class ApplicabilityDomainPCA(PCA):
|
|
def __init__(self, num_neighbours: int = 5):
|
|
super().__init__(n_components=num_neighbours)
|
|
self.scaler = StandardScaler()
|
|
self.num_neighbours = num_neighbours
|
|
self.min_vals = None
|
|
self.max_vals = None
|
|
|
|
def build(self, train_dataset: "RuleBasedDataset"):
|
|
# transform
|
|
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
|
# fit pca
|
|
X_pca = self.fit_transform(X_scaled)
|
|
|
|
self.max_vals = np.max(X_pca, axis=0)
|
|
self.min_vals = np.min(X_pca, axis=0)
|
|
|
|
def __transform(self, instances):
|
|
instances_scaled = self.scaler.transform(instances)
|
|
instances_pca = self.transform(instances_scaled)
|
|
return instances_pca
|
|
|
|
def is_applicable(self, classify_instances: "RuleBasedDataset"):
|
|
instances_pca = self.__transform(classify_instances.X())
|
|
|
|
is_applicable = []
|
|
for i, instance in enumerate(instances_pca):
|
|
is_applicable.append(True)
|
|
for min_v, max_v, new_v in zip(self.min_vals, self.max_vals, instance):
|
|
if not min_v <= new_v <= max_v:
|
|
is_applicable[i] = False
|
|
|
|
return is_applicable
|
|
|
|
|
|
def tanimoto_distance(a: List[int], b: List[int]):
|
|
if len(a) != len(b):
|
|
raise ValueError(f"Lists must be the same length {len(a)} != {len(b)}")
|
|
|
|
sum_a = sum(a)
|
|
sum_b = sum(b)
|
|
sum_c = sum(v1 and v2 for v1, v2 in zip(a, b))
|
|
|
|
if sum_a + sum_b - sum_c == 0:
|
|
return 0.0
|
|
|
|
return 1 - (sum_c / (sum_a + sum_b - sum_c))
|
|
|
|
|
|
def graph_from_pathway(data):
|
|
"""Convert Pathway or SPathway to networkx"""
|
|
from epdb.models import Pathway
|
|
from epdb.logic import SPathway
|
|
|
|
graph = nx.DiGraph()
|
|
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
|
|
|
|
def get_edges():
|
|
if isinstance(data, Pathway):
|
|
return data.edges.all()
|
|
elif isinstance(data, SPathway):
|
|
return data.edges
|
|
else:
|
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
|
|
|
def get_sources_targets():
|
|
if isinstance(data, Pathway):
|
|
return [n.node for n in edge.start_nodes.constrained_target.all()], [
|
|
n.node for n in edge.end_nodes.constrained_target.all()
|
|
]
|
|
elif isinstance(data, SPathway):
|
|
return edge.educts, edge.products
|
|
else:
|
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
|
|
|
def get_smiles_depth(node):
|
|
if isinstance(data, Pathway):
|
|
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
|
|
elif isinstance(data, SPathway):
|
|
return FormatConverter.standardize(node.smiles, True), node.depth
|
|
else:
|
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
|
|
|
def get_probability():
|
|
try:
|
|
if isinstance(data, Pathway):
|
|
return edge.kv.get("probability")
|
|
elif isinstance(data, SPathway):
|
|
return edge.probability
|
|
else:
|
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
|
except AttributeError:
|
|
return 1
|
|
|
|
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
|
|
for root, depth in root_smiles:
|
|
graph.add_node(root, depth=depth, smiles=root, root=True)
|
|
|
|
for edge in get_edges():
|
|
sources, targets = get_sources_targets()
|
|
probability = get_probability()
|
|
for source in sources:
|
|
source_smiles, source_depth = get_smiles_depth(source)
|
|
if source_smiles not in graph:
|
|
graph.add_node(
|
|
source_smiles,
|
|
depth=source_depth,
|
|
smiles=source_smiles,
|
|
root=source_smiles in root_smiles,
|
|
)
|
|
else:
|
|
graph.nodes[source_smiles]["depth"] = min(
|
|
source_depth, graph.nodes[source_smiles]["depth"]
|
|
)
|
|
for target in targets:
|
|
target_smiles, target_depth = get_smiles_depth(target)
|
|
if target_smiles not in graph and target_smiles not in co2:
|
|
graph.add_node(
|
|
target_smiles,
|
|
depth=target_depth,
|
|
smiles=target_smiles,
|
|
root=target_smiles in root_smiles,
|
|
)
|
|
elif target_smiles not in co2:
|
|
graph.nodes[target_smiles]["depth"] = min(
|
|
target_depth, graph.nodes[target_smiles]["depth"]
|
|
)
|
|
if target_smiles not in co2 and target_smiles != source_smiles:
|
|
graph.add_edge(source_smiles, target_smiles, probability=probability)
|
|
return graph
|
|
|
|
|
|
def get_shortest_path(pathway, in_start_node, in_end_node):
|
|
try:
|
|
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
|
|
except nx.NetworkXNoPath:
|
|
return []
|
|
pred.remove(in_start_node)
|
|
pred.remove(in_end_node)
|
|
return pred
|
|
|
|
|
|
def set_pathway_eval_weight(pathway):
|
|
node_eval_weights = {}
|
|
for node in pathway.nodes:
|
|
# Scale score according to depth level
|
|
node_eval_weights[node] = (
|
|
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
|
|
)
|
|
return node_eval_weights
|
|
|
|
|
|
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
|
|
if len(intermediates) < 1:
|
|
return pred_pathway
|
|
root_nodes = pred_pathway.graph["root_nodes"]
|
|
for node in pred_pathway.nodes:
|
|
if node in root_nodes:
|
|
continue
|
|
if node in intermediates and node not in data_pathway:
|
|
pred_pathway.nodes[node]["depth"] = -99
|
|
else:
|
|
shortest_path_list = []
|
|
for root_node in root_nodes:
|
|
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
|
|
if shortest_path_nodes:
|
|
shortest_path_list.append(shortest_path_nodes)
|
|
if shortest_path_list:
|
|
shortest_path_nodes = min(shortest_path_list, key=len)
|
|
num_ints = sum(
|
|
1
|
|
for shortest_path_node in shortest_path_nodes
|
|
if shortest_path_node in intermediates
|
|
)
|
|
pred_pathway.nodes[node]["depth"] -= num_ints
|
|
return pred_pathway
|
|
|
|
|
|
def initialise_pathway(pathway):
|
|
"""Convert pathway to networkx graph for evaluation"""
|
|
pathway = graph_from_pathway(pathway)
|
|
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
|
|
pathway = get_pathway_with_depth(pathway)
|
|
return pathway
|
|
|
|
|
|
def get_pathway_with_depth(pathway):
|
|
"""Recalculates depths in the pathway.
|
|
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
|
|
different depths that got merged."""
|
|
current_depth = 0
|
|
for node in pathway.nodes:
|
|
if node in pathway.graph["root_nodes"]:
|
|
pathway.nodes[node]["depth"] = current_depth
|
|
else:
|
|
pathway.nodes[node]["depth"] = -99
|
|
while assign_next_depth(pathway, current_depth):
|
|
current_depth += 1
|
|
return pathway
|
|
|
|
|
|
def assign_next_depth(pathway, current_depth):
|
|
new_assigned_nodes = False
|
|
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
|
|
for node in current_depth_nodes:
|
|
successors = pathway.successors(node)
|
|
for s in successors:
|
|
if pathway.nodes[s]["depth"] < 0:
|
|
pathway.nodes[s]["depth"] = current_depth + 1
|
|
new_assigned_nodes = True
|
|
return new_assigned_nodes
|
|
|
|
|
|
def find_intermediates(data_pathway, pred_pathway):
|
|
"""Find any intermediate nodes in the predicted pathway"""
|
|
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
|
intermediates = set()
|
|
for node in common_nodes:
|
|
down_stream_nodes = data_pathway.successors(node)
|
|
for down_stream_node in down_stream_nodes:
|
|
if down_stream_node in pred_pathway:
|
|
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
|
|
intermediates.update(all_ints)
|
|
return intermediates
|
|
|
|
|
|
def get_common_nodes(pred_pathway, data_pathway):
|
|
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
|
|
common_nodes = set()
|
|
for node in data_pathway.nodes:
|
|
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
|
|
is_this_root_node = node in pred_pathway.graph["root_nodes"]
|
|
if node in pred_pathway.nodes:
|
|
if is_pathway_root_node is False and is_this_root_node is False:
|
|
common_nodes.add(node)
|
|
elif is_pathway_root_node and is_this_root_node:
|
|
common_nodes.add(node)
|
|
return common_nodes
|
|
|
|
|
|
def prune_graph(graph, threshold):
|
|
"""
|
|
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
|
|
"""
|
|
while True:
|
|
try:
|
|
cycle = nx.find_cycle(graph)
|
|
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
|
|
except nx.NetworkXNoCycle:
|
|
break
|
|
|
|
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
|
|
if data["probability"] < threshold:
|
|
graph.remove_edge(u, v)
|
|
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
|
|
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
|
|
reachable.add(root_node)
|
|
|
|
for node in list(graph.nodes): # Remove nodes not reachable from root
|
|
if node not in reachable:
|
|
graph.remove_node(node)
|
|
|
|
|
|
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
|
|
"""Compare two pathways for multi-gen evaluation.
|
|
It is assumed the smiles in both pathways have been standardised in the same manner."""
|
|
data_pathway = initialise_pathway(data_pathway)
|
|
pred_pathway = initialise_pathway(pred_pathway)
|
|
if threshold is not None:
|
|
prune_graph(pred_pathway, threshold)
|
|
intermediates = find_intermediates(data_pathway, pred_pathway)
|
|
|
|
if intermediates:
|
|
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
|
|
|
|
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
|
|
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
|
|
|
|
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
|
|
|
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
|
|
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
|
|
|
|
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
|
|
|
for node in common_nodes:
|
|
if pred_pathway.nodes[node]["depth"] > 0:
|
|
score_TP += test_pathway_eval_weights[node]
|
|
|
|
for node in data_only_nodes:
|
|
if data_pathway.nodes[node]["depth"] > 0:
|
|
score_FN += test_pathway_eval_weights[node]
|
|
|
|
for node in pred_only_nodes:
|
|
if pred_pathway.nodes[node]["depth"] > 0:
|
|
score_FP += pred_pathway_eval_weights[node]
|
|
|
|
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
|
|
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
|
|
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
|
|
if return_intermediates:
|
|
return final_score, precision, recall, intermediates
|
|
return final_score, precision, recall
|
|
|
|
|
|
def node_subst_cost(node1, node2):
|
|
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
|
|
return 0
|
|
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
|
|
|
|
|
|
def node_ins_del_cost(node):
|
|
return 1 / (2 ** node["depth"])
|
|
|
|
|
|
def pathway_edit_eval(data_pathway, pred_pathway):
|
|
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
|
|
data_pathway = initialise_pathway(data_pathway)
|
|
pred_pathway = initialise_pathway(pred_pathway)
|
|
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
|
|
return nx.graph_edit_distance(
|
|
data_pathway,
|
|
pred_pathway,
|
|
node_subst_cost=node_subst_cost,
|
|
node_del_cost=node_ins_del_cost,
|
|
node_ins_cost=node_ins_del_cost,
|
|
roots=roots,
|
|
)
|