forked from enviPath/enviPy
[Feature] MultiGen Eval (Backend) (#117)
Fixes #16 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#117
This commit is contained in:
@ -864,7 +864,7 @@ class PackageManager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def import_pacakge(data: Dict[str, Any], owner: User, preserve_uuids=False, add_import_timestamp=True,
|
def import_package(data: Dict[str, Any], owner: User, preserve_uuids=False, add_import_timestamp=True,
|
||||||
trust_reviewed=False) -> Package:
|
trust_reviewed=False) -> Package:
|
||||||
|
|
||||||
importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed)
|
importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed)
|
||||||
|
|||||||
476
epdb/models.py
476
epdb/models.py
@ -7,7 +7,7 @@ import secrets
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
from typing import Union, List, Optional, Dict, Tuple, Set, Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import joblib
|
import joblib
|
||||||
@ -588,33 +588,33 @@ class Package(EnviPathModel):
|
|||||||
return f"{self.name} (pk={self.pk})"
|
return f"{self.name} (pk={self.pk})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def compounds(self):
|
def compounds(self) -> QuerySet:
|
||||||
return self.compound_set.all()
|
return self.compound_set.all()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rules(self):
|
def rules(self) -> QuerySet:
|
||||||
return self.rule_set.all()
|
return self.rule_set.all()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reactions(self):
|
def reactions(self) -> QuerySet:
|
||||||
return self.reaction_set.all()
|
return self.reaction_set.all()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pathways(self) -> 'Pathway':
|
def pathways(self) -> QuerySet:
|
||||||
return self.pathway_set.all()
|
return self.pathway_set.all()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scenarios(self):
|
def scenarios(self) -> QuerySet:
|
||||||
return self.scenario_set.all()
|
return self.scenario_set.all()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models(self):
|
def models(self) -> QuerySet:
|
||||||
return self.epmodel_set.all()
|
return self.epmodel_set.all()
|
||||||
|
|
||||||
def _url(self):
|
def _url(self):
|
||||||
return '{}/package/{}'.format(s.SERVER_URL, self.uuid)
|
return '{}/package/{}'.format(s.SERVER_URL, self.uuid)
|
||||||
|
|
||||||
def get_applicable_rules(self):
|
def get_applicable_rules(self) -> List['Rule']:
|
||||||
"""
|
"""
|
||||||
Returns a ordered set of rules where the following applies:
|
Returns a ordered set of rules where the following applies:
|
||||||
1. All Composite will be added to result
|
1. All Composite will be added to result
|
||||||
@ -650,11 +650,11 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
|||||||
external_identifiers = GenericRelation('ExternalIdentifier')
|
external_identifiers = GenericRelation('ExternalIdentifier')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def structures(self):
|
def structures(self) -> QuerySet:
|
||||||
return CompoundStructure.objects.filter(compound=self)
|
return CompoundStructure.objects.filter(compound=self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def normalized_structure(self):
|
def normalized_structure(self) -> 'CompoundStructure' :
|
||||||
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
|
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
|
||||||
|
|
||||||
def _url(self):
|
def _url(self):
|
||||||
@ -1635,8 +1635,8 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|||||||
return new_pathway
|
return new_pathway
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None):
|
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None, depth: Optional[int] = 0):
|
||||||
return Node.create(self, smiles, 0)
|
return Node.create(self, smiles, depth, name=name, description=description)
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def add_edge(self, start_nodes: List['Node'], end_nodes: List['Node'], rule: Optional['Rule'] = None,
|
def add_edge(self, start_nodes: List['Node'], end_nodes: List['Node'], rule: Optional['Rule'] = None,
|
||||||
@ -1836,6 +1836,7 @@ class PackageBasedModel(EPModel):
|
|||||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||||
default=None)
|
default=None)
|
||||||
|
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
|
||||||
|
|
||||||
INITIAL = "INITIAL"
|
INITIAL = "INITIAL"
|
||||||
INITIALIZING = "INITIALIZING"
|
INITIALIZING = "INITIALIZING"
|
||||||
@ -1861,6 +1862,24 @@ class PackageBasedModel(EPModel):
|
|||||||
def ready_for_prediction(self) -> bool:
|
def ready_for_prediction(self) -> bool:
|
||||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pr_curve(self):
|
||||||
|
if self.model_status != self.FINISHED:
|
||||||
|
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
thresholds = self.eval_results['average_precision_per_threshold'].keys()
|
||||||
|
|
||||||
|
for t in thresholds:
|
||||||
|
res.append({
|
||||||
|
'precision': self.eval_results['average_precision_per_threshold'][t],
|
||||||
|
'recall': self.eval_results['average_recall_per_threshold'][t],
|
||||||
|
'threshold': float(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def applicable_rules(self) -> List['Rule']:
|
def applicable_rules(self) -> List['Rule']:
|
||||||
"""
|
"""
|
||||||
@ -1897,14 +1916,6 @@ class PackageBasedModel(EPModel):
|
|||||||
# TODO
|
# TODO
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_pathways(self):
|
|
||||||
pathway_qs = Pathway.objects.none()
|
|
||||||
for p in self.data_packages.all():
|
|
||||||
pathway_qs |= p.pathways
|
|
||||||
|
|
||||||
pathway_qs = pathway_qs.distinct()
|
|
||||||
return pathway_qs
|
|
||||||
|
|
||||||
def _get_reactions(self) -> QuerySet:
|
def _get_reactions(self) -> QuerySet:
|
||||||
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
||||||
|
|
||||||
@ -1937,9 +1948,277 @@ class PackageBasedModel(EPModel):
|
|||||||
self.build_model()
|
self.build_model()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build_model(self):
|
def _fit_model(self, ds: Dataset):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _model_args(self) -> Dict[str, Any]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
self.model_status = self.BUILDING
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
ds = self.load_dataset()
|
||||||
|
|
||||||
|
mod = self._fit_model(ds)
|
||||||
|
|
||||||
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||||
|
joblib.dump(mod, f)
|
||||||
|
|
||||||
|
if self.app_domain is not None:
|
||||||
|
logger.debug("Building applicability domain...")
|
||||||
|
self.app_domain.build()
|
||||||
|
logger.debug("Done building applicability domain.")
|
||||||
|
|
||||||
|
self.model_status = self.BUILT_NOT_EVALUATED
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def evaluate_model(self):
|
||||||
|
|
||||||
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||||
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||||
|
|
||||||
|
self.model_status = self.EVALUATING
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def train_func(X, y, train_index, model_kwargs):
|
||||||
|
clz = model_kwargs.pop('clz')
|
||||||
|
if clz == 'RuleBaseRelativeReasoning':
|
||||||
|
mod = RelativeReasoning(
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mod = EnsembleClassifierChain(
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_index is not None:
|
||||||
|
X, y = X[train_index], y[train_index]
|
||||||
|
|
||||||
|
mod.fit(X, y)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
def evaluate_sg(model, X, y, test_index, threshold):
|
||||||
|
X_test = X[test_index]
|
||||||
|
y_test = y[test_index]
|
||||||
|
|
||||||
|
y_pred = model.predict_proba(X_test)
|
||||||
|
y_thresholded = (y_pred >= threshold).astype(int)
|
||||||
|
|
||||||
|
# Flatten them to get rid of np.nan
|
||||||
|
y_test = np.asarray(y_test).flatten()
|
||||||
|
y_pred = np.asarray(y_pred).flatten()
|
||||||
|
y_thresholded = np.asarray(y_thresholded).flatten()
|
||||||
|
|
||||||
|
mask = ~np.isnan(y_test)
|
||||||
|
y_test_filtered = y_test[mask]
|
||||||
|
y_pred_filtered = y_pred[mask]
|
||||||
|
y_thresholded_filtered = y_thresholded[mask]
|
||||||
|
|
||||||
|
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
||||||
|
|
||||||
|
prec, rec = dict(), dict()
|
||||||
|
|
||||||
|
for t in np.arange(0, 1.05, 0.05):
|
||||||
|
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
||||||
|
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||||
|
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||||
|
|
||||||
|
return acc, prec, rec
|
||||||
|
|
||||||
|
def evaluate_mg(model, pathways: Union[QuerySet['Pathway']| List['Pathway']], threshold):
|
||||||
|
thresholds = np.arange(0.1, 1.1, 0.1)
|
||||||
|
|
||||||
|
precision = {f"{t:.2f}": [] for t in thresholds}
|
||||||
|
recall = {f"{t:.2f}": [] for t in thresholds}
|
||||||
|
|
||||||
|
# Note: only one root compound supported at this time
|
||||||
|
root_compounds = [[p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways]
|
||||||
|
|
||||||
|
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
||||||
|
# pass it to the setting used in prediction
|
||||||
|
if isinstance(self, MLRelativeReasoning):
|
||||||
|
mod = MLRelativeReasoning.objects.get(pk=self.pk)
|
||||||
|
elif isinstance(self, RuleBasedRelativeReasoning):
|
||||||
|
mod = RuleBasedRelativeReasoning.objects.get(pk=self.pk)
|
||||||
|
|
||||||
|
mod.model = model
|
||||||
|
|
||||||
|
s = Setting()
|
||||||
|
s.model = mod
|
||||||
|
s.model_threshold = thresholds.min()
|
||||||
|
s.max_depth = 10
|
||||||
|
s.max_nodes = 50
|
||||||
|
|
||||||
|
from epdb.logic import SPathway
|
||||||
|
from utilities.ml import multigen_eval
|
||||||
|
|
||||||
|
pred_pathways = []
|
||||||
|
for i, root in enumerate(root_compounds):
|
||||||
|
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
||||||
|
|
||||||
|
spw = SPathway(root_nodes=root, prediction_setting=s)
|
||||||
|
level = 0
|
||||||
|
|
||||||
|
while not spw.done:
|
||||||
|
spw.predict_step(from_depth=level)
|
||||||
|
level += 1
|
||||||
|
|
||||||
|
pred_pathways.append(spw)
|
||||||
|
|
||||||
|
mg_acc = 0.0
|
||||||
|
for t in thresholds:
|
||||||
|
for true, pred in zip(test_pathways, pred_pathways):
|
||||||
|
acc, pre, rec = multigen_eval(true, pred, t)
|
||||||
|
if abs(t - threshold) < 0.01:
|
||||||
|
mg_acc = acc
|
||||||
|
precision[f"{t:.2f}"].append(pre)
|
||||||
|
recall[f"{t:.2f}"].append(rec)
|
||||||
|
|
||||||
|
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
||||||
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||||
|
return mg_acc, precision, recall
|
||||||
|
|
||||||
|
ds = self.load_dataset()
|
||||||
|
|
||||||
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
else:
|
||||||
|
X = np.array(ds.X(na_replacement=np.nan))
|
||||||
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
|
||||||
|
n_splits = 20
|
||||||
|
|
||||||
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||||
|
splits = list(shuff.split(X))
|
||||||
|
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
|
||||||
|
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||||
|
for model, (_, test_index) in zip(models, splits))
|
||||||
|
|
||||||
|
def compute_averages(data):
|
||||||
|
num_items = len(data)
|
||||||
|
avg_first_item = sum(item[0] for item in data) / num_items
|
||||||
|
|
||||||
|
sum_dict2 = defaultdict(float)
|
||||||
|
sum_dict3 = defaultdict(float)
|
||||||
|
|
||||||
|
for _, dict2, dict3 in data:
|
||||||
|
for key in dict2:
|
||||||
|
sum_dict2[key] += dict2[key]
|
||||||
|
for key in dict3:
|
||||||
|
sum_dict3[key] += dict3[key]
|
||||||
|
|
||||||
|
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
||||||
|
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"average_accuracy": float(avg_first_item),
|
||||||
|
"average_precision_per_threshold": avg_dict2,
|
||||||
|
"average_recall_per_threshold": avg_dict3
|
||||||
|
}
|
||||||
|
|
||||||
|
self.eval_results = compute_averages(evaluations)
|
||||||
|
|
||||||
|
if self.multigen_eval:
|
||||||
|
|
||||||
|
# We have to consider 2 cases here:
|
||||||
|
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
|
||||||
|
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
|
||||||
|
|
||||||
|
if self.eval_packages.count() > 0:
|
||||||
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
|
evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||||
|
return
|
||||||
|
|
||||||
|
pathway_qs = Pathway.objects.prefetch_related(
|
||||||
|
'node_set',
|
||||||
|
'node_set__out_edges',
|
||||||
|
'node_set__default_node_label',
|
||||||
|
'node_set__scenarios',
|
||||||
|
'edge_set',
|
||||||
|
'edge_set__start_nodes',
|
||||||
|
'edge_set__end_nodes',
|
||||||
|
'edge_set__edge_label',
|
||||||
|
'edge_set__scenarios'
|
||||||
|
).filter(package__in=self.data_packages.all()).distinct()
|
||||||
|
|
||||||
|
pathways = []
|
||||||
|
for pathway in pathway_qs:
|
||||||
|
# There is one pathway with no root compounds, so this check is required
|
||||||
|
if len(pathway.root_nodes) > 0:
|
||||||
|
pathways.append(pathway)
|
||||||
|
else:
|
||||||
|
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||||
|
|
||||||
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||||
|
reaction_to_educts = defaultdict(set)
|
||||||
|
for pathway in pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
for e in reaction.edge_label.educts.all():
|
||||||
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||||
|
|
||||||
|
# build lookup to avoid recalculation of features, labels
|
||||||
|
id_to_index = {uuid: i for i, uuid in enumerate(ds[:, 0])}
|
||||||
|
|
||||||
|
# Compute splits of the collected pathway
|
||||||
|
splits = []
|
||||||
|
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
|
||||||
|
train_pathways = [pathways[i] for i in train]
|
||||||
|
test_pathways = [pathways[i] for i in test]
|
||||||
|
|
||||||
|
# Collect structures from test pathways
|
||||||
|
test_educts = set()
|
||||||
|
for pathway in test_pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||||
|
|
||||||
|
split_ids = []
|
||||||
|
overlap = 0
|
||||||
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||||
|
# the test pathways
|
||||||
|
for pathway in train_pathways:
|
||||||
|
for reaction in pathway.edges:
|
||||||
|
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
||||||
|
# Ensure compounds in the training set do not appear in the test set
|
||||||
|
if educt not in test_educts:
|
||||||
|
if educt in id_to_index:
|
||||||
|
split_ids.append(id_to_index[str(educt)])
|
||||||
|
try:
|
||||||
|
split_ids.append(id_to_index[str(educt)])
|
||||||
|
except KeyError:
|
||||||
|
logger.debug(f"Couldn't find features in X for compound {educt}")
|
||||||
|
else:
|
||||||
|
overlap += 1
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||||
|
|
||||||
|
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
||||||
|
split_x, split_y = X[split_ids], y[split_ids]
|
||||||
|
splits.append([(split_x, split_y), test_pathways])
|
||||||
|
|
||||||
|
|
||||||
|
# Build model on subsets obtained by pathway split
|
||||||
|
trained_models = Parallel(n_jobs=10)(
|
||||||
|
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
||||||
|
multi_ret_vals = Parallel(n_jobs=1)(
|
||||||
|
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
|
||||||
|
zip(trained_models, splits)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.eval_results.update({f"multigen_{k}": v for k, v in compute_averages(multi_ret_vals).items()})
|
||||||
|
|
||||||
|
self.model_status = self.FINISHED
|
||||||
|
self.save()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||||
res = []
|
res = []
|
||||||
@ -2011,21 +2290,22 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
|
|
||||||
return rbrr
|
return rbrr
|
||||||
|
|
||||||
def build_model(self):
|
def _fit_model(self, ds: Dataset):
|
||||||
self.model_status = self.BUILDING
|
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||||
self.save()
|
model = RelativeReasoning(
|
||||||
|
start_index= ds.triggered()[0],
|
||||||
|
end_index= ds.triggered()[1],
|
||||||
|
)
|
||||||
|
model.fit(X, y)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _model_args(self):
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
labels = ds.y(na_replacement=None)
|
return {
|
||||||
|
'clz': 'RuleBaseRelativeReasoning',
|
||||||
mod = RelativeReasoning(*ds.triggered())
|
'start_index': ds.triggered()[0],
|
||||||
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
|
'end_index': ds.triggered()[1],
|
||||||
|
}
|
||||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
||||||
joblib.dump(mod, f)
|
|
||||||
|
|
||||||
self.model_status = self.BUILT_NOT_EVALUATED
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self) -> 'RelativeReasoning':
|
def model(self) -> 'RelativeReasoning':
|
||||||
@ -2038,7 +2318,6 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||||
|
|
||||||
mod = self.model
|
mod = self.model
|
||||||
|
|
||||||
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
|
|
||||||
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||||
@ -2102,118 +2381,23 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
|
|
||||||
return mlrr
|
return mlrr
|
||||||
|
|
||||||
def build_model(self):
|
def _fit_model(self, ds: Dataset):
|
||||||
self.model_status = self.BUILDING
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
start = datetime.now()
|
|
||||||
|
|
||||||
ds = self.load_dataset()
|
|
||||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||||
|
|
||||||
mod = EnsembleClassifierChain(
|
|
||||||
**s.DEFAULT_MODEL_PARAMS
|
|
||||||
)
|
|
||||||
mod.fit(X, y)
|
|
||||||
|
|
||||||
end = datetime.now()
|
|
||||||
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
|
|
||||||
|
|
||||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
||||||
joblib.dump(mod, f)
|
|
||||||
|
|
||||||
if self.app_domain is not None:
|
|
||||||
logger.debug("Building applicability domain...")
|
|
||||||
self.app_domain.build()
|
|
||||||
logger.debug("Done building applicability domain.")
|
|
||||||
|
|
||||||
self.model_status = self.BUILT_NOT_EVALUATED
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def evaluate_model(self):
|
|
||||||
|
|
||||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
|
||||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
|
||||||
|
|
||||||
self.model_status = self.EVALUATING
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
ds = self.load_dataset()
|
|
||||||
|
|
||||||
X = np.array(ds.X(na_replacement=np.nan))
|
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
|
||||||
|
|
||||||
n_splits = 20
|
|
||||||
|
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
|
||||||
|
|
||||||
def train_and_evaluate(X, y, train_index, test_index, threshold):
|
|
||||||
X_train, X_test = X[train_index], X[test_index]
|
|
||||||
y_train, y_test = y[train_index], y[test_index]
|
|
||||||
|
|
||||||
model = EnsembleClassifierChain(
|
model = EnsembleClassifierChain(
|
||||||
**s.DEFAULT_MODEL_PARAMS
|
**s.DEFAULT_MODEL_PARAMS
|
||||||
)
|
)
|
||||||
model.fit(X_train, y_train)
|
model.fit(X, y)
|
||||||
|
return model
|
||||||
y_pred = model.predict_proba(X_test)
|
|
||||||
y_thresholded = (y_pred >= threshold).astype(int)
|
|
||||||
|
|
||||||
# Flatten them to get rid of np.nan
|
|
||||||
y_test = np.asarray(y_test).flatten()
|
|
||||||
y_pred = np.asarray(y_pred).flatten()
|
|
||||||
y_thresholded = np.asarray(y_thresholded).flatten()
|
|
||||||
|
|
||||||
mask = ~np.isnan(y_test)
|
|
||||||
y_test_filtered = y_test[mask]
|
|
||||||
y_pred_filtered = y_pred[mask]
|
|
||||||
y_thresholded_filtered = y_thresholded[mask]
|
|
||||||
|
|
||||||
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
|
||||||
|
|
||||||
prec, rec = dict(), dict()
|
|
||||||
|
|
||||||
for t in np.arange(0, 1.05, 0.05):
|
|
||||||
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
|
||||||
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
|
||||||
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
|
||||||
|
|
||||||
return acc, prec, rec
|
|
||||||
|
|
||||||
from joblib import Parallel, delayed
|
|
||||||
ret_vals = Parallel(n_jobs=10)(
|
|
||||||
delayed(train_and_evaluate)(X, y, train_index, test_index, self.threshold)
|
|
||||||
for train_index, test_index in shuff.split(X)
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_averages(data):
|
|
||||||
num_items = len(data)
|
|
||||||
avg_first_item = sum(item[0] for item in data) / num_items
|
|
||||||
|
|
||||||
sum_dict2 = defaultdict(float)
|
|
||||||
sum_dict3 = defaultdict(float)
|
|
||||||
|
|
||||||
for _, dict2, dict3 in data:
|
|
||||||
for key in dict2:
|
|
||||||
sum_dict2[key] += dict2[key]
|
|
||||||
for key in dict3:
|
|
||||||
sum_dict3[key] += dict3[key]
|
|
||||||
|
|
||||||
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
|
||||||
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
|
||||||
|
|
||||||
|
def _model_args(self):
|
||||||
return {
|
return {
|
||||||
"average_accuracy": float(avg_first_item),
|
'clz': 'MLRelativeReasoning',
|
||||||
"average_precision_per_threshold": avg_dict2,
|
**s.DEFAULT_MODEL_PARAMS,
|
||||||
"average_recall_per_threshold": avg_dict3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.eval_results = compute_averages(ret_vals)
|
|
||||||
self.model_status = self.FINISHED
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self) -> 'EnsembleClassifierChain':
|
||||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||||
mod.base_clf.n_jobs = -1
|
mod.base_clf.n_jobs = -1
|
||||||
return mod
|
return mod
|
||||||
@ -2230,24 +2414,6 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@property
|
|
||||||
def pr_curve(self):
|
|
||||||
if self.model_status != self.FINISHED:
|
|
||||||
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
|
||||||
|
|
||||||
res = []
|
|
||||||
|
|
||||||
thresholds = self.eval_results['average_precision_per_threshold'].keys()
|
|
||||||
|
|
||||||
for t in thresholds:
|
|
||||||
res.append({
|
|
||||||
'precision': self.eval_results['average_precision_per_threshold'][t],
|
|
||||||
'recall': self.eval_results['average_recall_per_threshold'][t],
|
|
||||||
'threshold': float(t)
|
|
||||||
})
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicabilityDomain(EnviPathModel):
|
class ApplicabilityDomain(EnviPathModel):
|
||||||
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
|
|
||||||
from django.conf import settings as s
|
from django.conf import settings as s
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
@ -35,4 +34,4 @@ def delete_epmodel_files(sender, instance, **kwargs):
|
|||||||
for f in os.listdir(s.MODEL_DIR):
|
for f in os.listdir(s.MODEL_DIR):
|
||||||
if f.startswith(mod_uuid):
|
if f.startswith(mod_uuid):
|
||||||
logger.info(f"Deleting {os.path.join(s.MODEL_DIR, f)}")
|
logger.info(f"Deleting {os.path.join(s.MODEL_DIR, f)}")
|
||||||
shutil.rmtree(os.path.join(s.MODEL_DIR, f))
|
os.remove(os.path.join(s.MODEL_DIR, f))
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from celery.signals import worker_process_init
|
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from epdb.models import Pathway, Node, Edge, EPModel, Setting
|
from epdb.models import Pathway, Node, Edge, EPModel, Setting
|
||||||
from epdb.logic import SPathway
|
from epdb.logic import SPathway
|
||||||
|
|
||||||
from utilities.chem import FormatConverter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -294,7 +294,7 @@ def packages(request):
|
|||||||
if hidden == 'import-legacy-package-json':
|
if hidden == 'import-legacy-package-json':
|
||||||
pack = PackageManager.import_legacy_package(data, current_user)
|
pack = PackageManager.import_legacy_package(data, current_user)
|
||||||
else:
|
else:
|
||||||
pack = PackageManager.import_pacakge(data, current_user)
|
pack = PackageManager.import_package(data, current_user)
|
||||||
|
|
||||||
return redirect(pack.url)
|
return redirect(pack.url)
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
@ -772,10 +772,13 @@ def package_model(request, package_uuid, model_uuid):
|
|||||||
if hidden == 'delete':
|
if hidden == 'delete':
|
||||||
current_model.delete()
|
current_model.delete()
|
||||||
return redirect(current_package.url + '/model')
|
return redirect(current_package.url + '/model')
|
||||||
|
elif hidden == 'evaluate':
|
||||||
|
from .tasks import evaluate_model
|
||||||
|
evaluate_model.delay(current_model.pk)
|
||||||
|
return redirect(current_model.url)
|
||||||
else:
|
else:
|
||||||
return HttpResponseBadRequest()
|
return HttpResponseBadRequest()
|
||||||
else:
|
else:
|
||||||
|
|
||||||
name = request.POST.get('model-name', '').strip()
|
name = request.POST.get('model-name', '').strip()
|
||||||
description = request.POST.get('model-description', '').strip()
|
description = request.POST.get('model-description', '').strip()
|
||||||
|
|
||||||
|
|||||||
439315
fixtures/EAWAG-BBD.json
439315
fixtures/EAWAG-BBD.json
File diff suppressed because it is too large
Load Diff
125432
fixtures/EAWAG-SLUDGE.json
125432
fixtures/EAWAG-SLUDGE.json
File diff suppressed because it is too large
Load Diff
1572257
fixtures/EAWAG-SOIL.json
1572257
fixtures/EAWAG-SOIL.json
File diff suppressed because one or more lines are too long
Binary file not shown.
BIN
fixtures/test_fixtures.jsonl.gz
Normal file
BIN
fixtures/test_fixtures.jsonl.gz
Normal file
Binary file not shown.
@ -18,6 +18,7 @@ dependencies = [
|
|||||||
"envipy-plugins",
|
"envipy-plugins",
|
||||||
"epam-indigo>=1.30.1",
|
"epam-indigo>=1.30.1",
|
||||||
"gunicorn>=23.0.0",
|
"gunicorn>=23.0.0",
|
||||||
|
"networkx>=3.4.2",
|
||||||
"psycopg2-binary>=2.9.10",
|
"psycopg2-binary>=2.9.10",
|
||||||
"python-dotenv>=1.1.0",
|
"python-dotenv>=1.1.0",
|
||||||
"rdkit>=2025.3.2",
|
"rdkit>=2025.3.2",
|
||||||
|
|||||||
@ -10,7 +10,7 @@
|
|||||||
<h4 class="modal-title">Evaluate Model</h4>
|
<h4 class="modal-title">Evaluate Model</h4>
|
||||||
</div>
|
</div>
|
||||||
<div class="modal-body">
|
<div class="modal-body">
|
||||||
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
|
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ current_object.url }}"
|
||||||
data-remote="true" method="post">
|
data-remote="true" method="post">
|
||||||
{% csrf_token %}
|
{% csrf_token %}
|
||||||
<div class="jumbotron">
|
<div class="jumbotron">
|
||||||
@ -35,6 +35,7 @@
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</select>
|
</select>
|
||||||
|
<input type="hidden" name="hidden" value="evaluate">
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
<div class="modal-footer">
|
<div class="modal-footer">
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from epdb.models import Compound, User, CompoundStructure
|
|||||||
|
|
||||||
|
|
||||||
class CompoundTest(TestCase):
|
class CompoundTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from epdb.models import Compound, User, Reaction
|
|||||||
|
|
||||||
|
|
||||||
class CopyTest(TestCase):
|
class CopyTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from utilities.ml import Dataset
|
|||||||
|
|
||||||
|
|
||||||
class DatasetTest(TestCase):
|
class DatasetTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.cs1 = Compound.create(
|
self.cs1 = Compound.create(
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
from django.test import TestCase
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import User, MLRelativeReasoning, Package
|
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
|
||||||
|
|
||||||
|
|
||||||
class ModelTest(TestCase):
|
class ModelTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -17,9 +18,10 @@ class ModelTest(TestCase):
|
|||||||
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
|
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
|
||||||
|
|
||||||
def test_smoke(self):
|
def test_smoke(self):
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
|
|
||||||
# get Package objects from urls
|
|
||||||
rule_package_objs = [self.BBD_SUBSET]
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = []
|
eval_packages_objs = []
|
||||||
@ -29,16 +31,42 @@ class ModelTest(TestCase):
|
|||||||
rule_package_objs,
|
rule_package_objs,
|
||||||
data_package_objs,
|
data_package_objs,
|
||||||
eval_packages_objs,
|
eval_packages_objs,
|
||||||
threshold,
|
threshold=threshold,
|
||||||
'ECC - BBD - 0.5',
|
name='ECC - BBD - 0.5',
|
||||||
'Created MLRelativeReasoning in Testcase',
|
description='Created MLRelativeReasoning in Testcase',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# mod = RuleBasedRelativeReasoning.create(
|
||||||
|
# self.package,
|
||||||
|
# rule_package_objs,
|
||||||
|
# data_package_objs,
|
||||||
|
# eval_packages_objs,
|
||||||
|
# threshold=threshold,
|
||||||
|
# min_count=5,
|
||||||
|
# max_count=0,
|
||||||
|
# name='ECC - BBD - 0.5',
|
||||||
|
# description='Created MLRelativeReasoning in Testcase',
|
||||||
|
# )
|
||||||
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
mod.build_model()
|
mod.build_model()
|
||||||
print("Model built!")
|
mod.multigen_eval = True
|
||||||
mod.evaluate_model()
|
mod.save()
|
||||||
print("Model Evaluated")
|
# mod.evaluate_model()
|
||||||
|
|
||||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||||
print(results)
|
|
||||||
|
products = dict()
|
||||||
|
for r in results:
|
||||||
|
for ps in r.product_sets:
|
||||||
|
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
|
||||||
|
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(products, expected)
|
||||||
|
|
||||||
|
# from pprint import pprint
|
||||||
|
# pprint(mod.eval_results)
|
||||||
|
|||||||
137
tests/test_multigen_eval.py
Normal file
137
tests/test_multigen_eval.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import json
|
||||||
|
from django.test import TestCase
|
||||||
|
from networkx.utils.misc import graphs_equal
|
||||||
|
from epdb.logic import PackageManager, SPathway
|
||||||
|
from epdb.models import Pathway, User, Package
|
||||||
|
from utilities.ml import multigen_eval, pathway_edit_eval, graph_from_pathway
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGenTest(TestCase):
|
||||||
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(MultiGenTest, cls).setUpClass()
|
||||||
|
cls.user: 'User' = User.objects.get(username='anonymous')
|
||||||
|
cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
|
||||||
|
cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures')
|
||||||
|
|
||||||
|
def test_equal_pathways(self):
|
||||||
|
"""Test that two identical pathways return a precision and recall of 1.0"""
|
||||||
|
pathways = self.BBD_SUBSET.pathways.all()
|
||||||
|
for pathway in pathways:
|
||||||
|
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||||
|
continue
|
||||||
|
score, precision, recall = multigen_eval(pathway, pathway)
|
||||||
|
self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. "
|
||||||
|
f"Failed on pathway: {pathway.name}")
|
||||||
|
self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. "
|
||||||
|
f"Failed on pathway: {pathway.name}")
|
||||||
|
|
||||||
|
def test_intermediates(self):
|
||||||
|
"""Test that an intermediate can be correctly identified and the metrics are correctly adjusted"""
|
||||||
|
score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True)
|
||||||
|
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
|
||||||
|
self.assertEqual(precision, 1, "Precision should be 1")
|
||||||
|
self.assertEqual(recall, 1, "Recall should be 1")
|
||||||
|
|
||||||
|
def test_fp(self):
|
||||||
|
"""Test that a false-positive (extra compound) is correctly penalised"""
|
||||||
|
score, precision, recall = multigen_eval(*self.fp_case())
|
||||||
|
self.assertAlmostEqual(precision, 0.75, 3, "Precision should be 0.75")
|
||||||
|
self.assertEqual(recall, 1, "Recall should be 1")
|
||||||
|
|
||||||
|
def test_fn(self):
|
||||||
|
"""Test that a false-negative (missed compound) is correctly penalised"""
|
||||||
|
score, precision, recall = multigen_eval(*self.fn_case())
|
||||||
|
self.assertEqual(precision, 1, "Precision should be 1.0")
|
||||||
|
self.assertAlmostEqual(recall, 0.667, 3, "Recall should be 0.667")
|
||||||
|
|
||||||
|
def test_all(self):
|
||||||
|
"""Test an intermediate, false-positive and false-negative together"""
|
||||||
|
score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True)
|
||||||
|
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
|
||||||
|
self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6")
|
||||||
|
self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75")
|
||||||
|
|
||||||
|
def test_shallow_pathway(self):
|
||||||
|
pathways = self.BBD_SUBSET.pathways.all()
|
||||||
|
for pathway in pathways:
|
||||||
|
pathway_name = pathway.name
|
||||||
|
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||||
|
continue
|
||||||
|
shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway))
|
||||||
|
pathway = graph_from_pathway(pathway)
|
||||||
|
if not graphs_equal(shallow_pathway, pathway):
|
||||||
|
print('\n\nS', shallow_pathway.adj)
|
||||||
|
print('\n\nPW', pathway.adj)
|
||||||
|
# print(shallow_pathway.nodes, pathway.nodes)
|
||||||
|
# print(shallow_pathway.graph, pathway.graph)
|
||||||
|
|
||||||
|
self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not "
|
||||||
|
f"equal to pathway for pathway {pathway.name}")
|
||||||
|
|
||||||
|
def test_graph_edit_eval(self):
|
||||||
|
"""Performs all the previous tests but with graph_edit_eval
|
||||||
|
Unlike multigen_eval, these test cases have not been hand verified"""
|
||||||
|
pathways = self.BBD_SUBSET.pathways.all()
|
||||||
|
for pathway in pathways:
|
||||||
|
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||||
|
continue
|
||||||
|
score = pathway_edit_eval(pathway, pathway)
|
||||||
|
self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. "
|
||||||
|
f"Failed on pathway: {pathway.name}")
|
||||||
|
inter_score = pathway_edit_eval(*self.intermediate_case())
|
||||||
|
self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case")
|
||||||
|
fp_score = pathway_edit_eval(*self.fp_case())
|
||||||
|
self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case")
|
||||||
|
fn_score = pathway_edit_eval(*self.fn_case())
|
||||||
|
self.assertAlmostEqual(fn_score, 1.25, 3, "Pathway edit distance failed on fn case")
|
||||||
|
all_score = pathway_edit_eval(*self.all_case())
|
||||||
|
self.assertAlmostEqual(all_score, 1.0, 3, "Pathway edit distance failed on all case")
|
||||||
|
|
||||||
|
def intermediate_case(self):
|
||||||
|
"""Create an example with an intermediate in the predicted pathway"""
|
||||||
|
true_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)])
|
||||||
|
pred_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
|
||||||
|
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
|
||||||
|
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
|
||||||
|
return true_pathway, pred_pathway
|
||||||
|
|
||||||
|
def fp_case(self):
|
||||||
|
"""Create an example with an extra compound in the predicted pathway"""
|
||||||
|
true_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||||
|
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||||
|
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||||
|
pred_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
|
||||||
|
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
|
||||||
|
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
|
||||||
|
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)])
|
||||||
|
return true_pathway, pred_pathway
|
||||||
|
|
||||||
|
def fn_case(self):
|
||||||
|
"""Create an example with a missing compound in the predicted pathway"""
|
||||||
|
true_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||||
|
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||||
|
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||||
|
pred_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)])
|
||||||
|
return true_pathway, pred_pathway
|
||||||
|
|
||||||
|
def all_case(self):
|
||||||
|
"""Create an example with an intermediate, extra compound and missing compound"""
|
||||||
|
true_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||||
|
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||||
|
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)])
|
||||||
|
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||||
|
pred_pathway = Pathway.create(self.package, "CCO")
|
||||||
|
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)])
|
||||||
|
pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)])
|
||||||
|
pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)])
|
||||||
|
return true_pathway, pred_pathway
|
||||||
@ -5,7 +5,7 @@ from epdb.models import Compound, User, Reaction, Rule
|
|||||||
|
|
||||||
|
|
||||||
class ReactionTest(TestCase):
|
class ReactionTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
@ -5,10 +5,7 @@ from epdb.models import Rule, User
|
|||||||
|
|
||||||
|
|
||||||
class RuleTest(TestCase):
|
class RuleTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from epdb.models import User, SimpleAmbitRule
|
|||||||
|
|
||||||
|
|
||||||
class SimpleAmbitRuleTest(TestCase):
|
class SimpleAmbitRuleTest(TestCase):
|
||||||
fixtures = ["test_fixtures.json.gz"]
|
fixtures = ["test_fixtures.jsonl.gz"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
@ -183,7 +183,7 @@ class FormatConverter(object):
|
|||||||
return smiles
|
return smiles
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def standardize(smiles):
|
def standardize(smiles, remove_stereo=False):
|
||||||
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
|
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
|
||||||
# follows the steps in
|
# follows the steps in
|
||||||
# https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/MolStandardize%20pieces.ipynb
|
# https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/MolStandardize%20pieces.ipynb
|
||||||
@ -208,6 +208,9 @@ class FormatConverter(object):
|
|||||||
# te = rdMolStandardize.TautomerEnumerator() # idem
|
# te = rdMolStandardize.TautomerEnumerator() # idem
|
||||||
# taut_uncharged_parent_clean_mol = te.Canonicalize(uncharged_parent_clean_mol)
|
# taut_uncharged_parent_clean_mol = te.Canonicalize(uncharged_parent_clean_mol)
|
||||||
|
|
||||||
|
if remove_stereo:
|
||||||
|
Chem.RemoveStereochemistry(uncharged_parent_clean_mol)
|
||||||
|
|
||||||
return Chem.MolToSmiles(uncharged_parent_clean_mol, kekuleSmiles=True)
|
return Chem.MolToSmiles(uncharged_parent_clean_mol, kekuleSmiles=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -919,7 +919,7 @@ class PackageImporter:
|
|||||||
name=edge_data['name'],
|
name=edge_data['name'],
|
||||||
description=edge_data['description'],
|
description=edge_data['description'],
|
||||||
kv=edge_data.get('kv', {}),
|
kv=edge_data.get('kv', {}),
|
||||||
edge_label=None # Will be set later
|
edge_label=self._get_cached_object('Reaction', edge_data['edge_label']['uuid'])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set aliases if present
|
# Set aliases if present
|
||||||
|
|||||||
347
utilities/ml.py
347
utilities/ml.py
@ -1,12 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.random import default_rng
|
||||||
|
from sklearn.dummy import DummyClassifier
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Set, Tuple
|
from typing import List, Dict, Set, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import networkx as nx
|
||||||
|
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
@ -22,61 +29,6 @@ from dataclasses import dataclass, field
|
|||||||
from utilities.chem import FormatConverter, PredictionResult
|
from utilities.chem import FormatConverter, PredictionResult
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SCompound:
|
|
||||||
smiles: str
|
|
||||||
uuid: str = field(default=None, compare=False, hash=False)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
if not hasattr(self, '_hash'):
|
|
||||||
self._hash = hash((
|
|
||||||
self.smiles
|
|
||||||
))
|
|
||||||
return self._hash
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SReaction:
|
|
||||||
educts: List[SCompound]
|
|
||||||
products: List[SCompound]
|
|
||||||
rule_uuid: SRule = field(default=None, compare=False, hash=False)
|
|
||||||
reaction_uuid: str = field(default=None, compare=False, hash=False)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
if not hasattr(self, '_hash'):
|
|
||||||
self._hash = hash((
|
|
||||||
tuple(sorted(self.educts, key=lambda x: x.smiles)),
|
|
||||||
tuple(sorted(self.products, key=lambda x: x.smiles)),
|
|
||||||
))
|
|
||||||
return self._hash
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, SReaction):
|
|
||||||
return NotImplemented
|
|
||||||
return (
|
|
||||||
sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and
|
|
||||||
sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SRule(ABC):
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SSimpleRule:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SParallelRule:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
|
|
||||||
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
|
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
|
||||||
@ -385,7 +337,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
|||||||
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
|
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
|
||||||
|
|
||||||
for i, chain in enumerate(self.chains_):
|
for i, chain in enumerate(self.chains_):
|
||||||
print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
||||||
chain.fit(X, y_reduced)
|
chain.fit(X, y_reduced)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
@ -423,14 +375,6 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
|||||||
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
|
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.dummy import DummyClassifier
|
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryRelevance:
|
class BinaryRelevance:
|
||||||
def __init__(self, baseline_clf):
|
def __init__(self, baseline_clf):
|
||||||
self.clf = baseline_clf
|
self.clf = baseline_clf
|
||||||
@ -483,7 +427,8 @@ class MissingValuesClassifierChain:
|
|||||||
X = np.array(X)
|
X = np.array(X)
|
||||||
Y = np.array(Y)
|
Y = np.array(Y)
|
||||||
if self.permutation is None:
|
if self.permutation is None:
|
||||||
self.permutation = np.random.permutation(len(Y[0]))
|
rng = default_rng(42)
|
||||||
|
self.permutation = rng.permutation(len(Y[0]))
|
||||||
|
|
||||||
Y = Y[:, self.permutation]
|
Y = Y[:, self.permutation]
|
||||||
|
|
||||||
@ -541,7 +486,7 @@ class EnsembleClassifierChain:
|
|||||||
self.num_labels = len(Y[0])
|
self.num_labels = len(Y[0])
|
||||||
|
|
||||||
for p in range(self.num_chains):
|
for p in range(self.num_chains):
|
||||||
print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||||
clf = MissingValuesClassifierChain(self.base_clf)
|
clf = MissingValuesClassifierChain(self.base_clf)
|
||||||
clf.fit(X, Y)
|
clf.fit(X, Y)
|
||||||
self.classifiers.append(clf)
|
self.classifiers.append(clf)
|
||||||
@ -609,12 +554,22 @@ class RelativeReasoning:
|
|||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
||||||
|
|
||||||
|
# Loop through all instances
|
||||||
for inst_idx, inst in enumerate(X):
|
for inst_idx, inst in enumerate(X):
|
||||||
|
# Loop through all "triggered" features
|
||||||
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||||
|
# Set label
|
||||||
res[inst_idx][i] = t
|
res[inst_idx][i] = t
|
||||||
|
# If we predict a 1, check if the rule gets dominated by another
|
||||||
if t:
|
if t:
|
||||||
|
# Second loop to check other triggered rules
|
||||||
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||||
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
|
if i != i2:
|
||||||
|
# Check if rule idx is in "dominated by" list
|
||||||
|
if i2 in self.winmap.get(i, []):
|
||||||
|
# if thatat rule also triggered, it dominated the current
|
||||||
|
# set label to 0
|
||||||
|
if X[inst_idx][i2]:
|
||||||
res[inst_idx][i] = 0
|
res[inst_idx][i] = 0
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -671,3 +626,259 @@ def tanimoto_distance(a: List[int], b: List[int]):
|
|||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
return 1 - (sum_c / (sum_a + sum_b - sum_c))
|
return 1 - (sum_c / (sum_a + sum_b - sum_c))
|
||||||
|
|
||||||
|
|
||||||
|
def graph_from_pathway(data):
|
||||||
|
"""Convert Pathway or SPathway to networkx"""
|
||||||
|
from epdb.models import Pathway
|
||||||
|
from epdb.logic import SPathway
|
||||||
|
graph = nx.DiGraph()
|
||||||
|
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
|
||||||
|
|
||||||
|
def get_edges():
|
||||||
|
if isinstance(data, Pathway):
|
||||||
|
return data.edges.all()
|
||||||
|
elif isinstance(data, SPathway):
|
||||||
|
return data.edges
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||||
|
|
||||||
|
def get_sources_targets():
|
||||||
|
if isinstance(data, Pathway):
|
||||||
|
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
|
||||||
|
elif isinstance(data, SPathway):
|
||||||
|
return edge.educts, edge.products
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||||
|
|
||||||
|
def get_smiles_depth(node):
|
||||||
|
if isinstance(data, Pathway):
|
||||||
|
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
|
||||||
|
elif isinstance(data, SPathway):
|
||||||
|
return FormatConverter.standardize(node.smiles, True), node.depth
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||||
|
|
||||||
|
def get_probability():
|
||||||
|
try:
|
||||||
|
if isinstance(data, Pathway):
|
||||||
|
return edge.kv.get('probability')
|
||||||
|
elif isinstance(data, SPathway):
|
||||||
|
return edge.probability
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||||
|
except AttributeError:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
|
||||||
|
for root, depth in root_smiles:
|
||||||
|
graph.add_node(root, depth=depth, smiles=root, root=True)
|
||||||
|
|
||||||
|
for edge in get_edges():
|
||||||
|
sources, targets = get_sources_targets()
|
||||||
|
probability = get_probability()
|
||||||
|
for source in sources:
|
||||||
|
source_smiles, source_depth = get_smiles_depth(source)
|
||||||
|
if source_smiles not in graph:
|
||||||
|
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
|
||||||
|
root=source_smiles in root_smiles)
|
||||||
|
else:
|
||||||
|
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
|
||||||
|
for target in targets:
|
||||||
|
target_smiles, target_depth = get_smiles_depth(target)
|
||||||
|
if target_smiles not in graph and target_smiles not in co2:
|
||||||
|
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
|
||||||
|
root=target_smiles in root_smiles)
|
||||||
|
elif target_smiles not in co2:
|
||||||
|
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
|
||||||
|
if target_smiles not in co2 and target_smiles != source_smiles:
|
||||||
|
graph.add_edge(source_smiles, target_smiles, probability=probability)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def get_shortest_path(pathway, in_start_node, in_end_node):
|
||||||
|
try:
|
||||||
|
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
|
||||||
|
except nx.NetworkXNoPath:
|
||||||
|
return []
|
||||||
|
pred.remove(in_start_node)
|
||||||
|
pred.remove(in_end_node)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
def set_pathway_eval_weight(pathway):
|
||||||
|
node_eval_weights = {}
|
||||||
|
for node in pathway.nodes:
|
||||||
|
# Scale score according to depth level
|
||||||
|
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
|
||||||
|
return node_eval_weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
|
||||||
|
if len(intermediates) < 1:
|
||||||
|
return pred_pathway
|
||||||
|
root_nodes = pred_pathway.graph["root_nodes"]
|
||||||
|
for node in pred_pathway.nodes:
|
||||||
|
if node in root_nodes:
|
||||||
|
continue
|
||||||
|
if node in intermediates and node not in data_pathway:
|
||||||
|
pred_pathway.nodes[node]["depth"] = -99
|
||||||
|
else:
|
||||||
|
shortest_path_list = []
|
||||||
|
for root_node in root_nodes:
|
||||||
|
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
|
||||||
|
if shortest_path_nodes:
|
||||||
|
shortest_path_list.append(shortest_path_nodes)
|
||||||
|
if shortest_path_list:
|
||||||
|
shortest_path_nodes = min(shortest_path_list, key=len)
|
||||||
|
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
|
||||||
|
shortest_path_node in intermediates)
|
||||||
|
pred_pathway.nodes[node]["depth"] -= num_ints
|
||||||
|
return pred_pathway
|
||||||
|
|
||||||
|
|
||||||
|
def initialise_pathway(pathway):
|
||||||
|
"""Convert pathway to networkx graph for evaluation"""
|
||||||
|
pathway = graph_from_pathway(pathway)
|
||||||
|
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
|
||||||
|
pathway = get_pathway_with_depth(pathway)
|
||||||
|
return pathway
|
||||||
|
|
||||||
|
|
||||||
|
def get_pathway_with_depth(pathway):
|
||||||
|
"""Recalculates depths in the pathway.
|
||||||
|
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
|
||||||
|
different depths that got merged."""
|
||||||
|
current_depth = 0
|
||||||
|
for node in pathway.nodes:
|
||||||
|
if node in pathway.graph["root_nodes"]:
|
||||||
|
pathway.nodes[node]["depth"] = current_depth
|
||||||
|
else:
|
||||||
|
pathway.nodes[node]["depth"] = -99
|
||||||
|
while assign_next_depth(pathway, current_depth):
|
||||||
|
current_depth += 1
|
||||||
|
return pathway
|
||||||
|
|
||||||
|
|
||||||
|
def assign_next_depth(pathway, current_depth):
|
||||||
|
new_assigned_nodes = False
|
||||||
|
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
|
||||||
|
for node in current_depth_nodes:
|
||||||
|
successors = pathway.successors(node)
|
||||||
|
for s in successors:
|
||||||
|
if pathway.nodes[s]["depth"] < 0:
|
||||||
|
pathway.nodes[s]["depth"] = current_depth + 1
|
||||||
|
new_assigned_nodes = True
|
||||||
|
return new_assigned_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def find_intermediates(data_pathway, pred_pathway):
|
||||||
|
"""Find any intermediate nodes in the predicted pathway"""
|
||||||
|
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
||||||
|
intermediates = set()
|
||||||
|
for node in common_nodes:
|
||||||
|
down_stream_nodes = data_pathway.successors(node)
|
||||||
|
for down_stream_node in down_stream_nodes:
|
||||||
|
if down_stream_node in pred_pathway:
|
||||||
|
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
|
||||||
|
intermediates.update(all_ints)
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
|
||||||
|
def get_common_nodes(pred_pathway, data_pathway):
|
||||||
|
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
|
||||||
|
common_nodes = set()
|
||||||
|
for node in data_pathway.nodes:
|
||||||
|
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
|
||||||
|
is_this_root_node = node in pred_pathway.graph["root_nodes"]
|
||||||
|
if node in pred_pathway.nodes:
|
||||||
|
if is_pathway_root_node is False and is_this_root_node is False:
|
||||||
|
common_nodes.add(node)
|
||||||
|
elif is_pathway_root_node and is_this_root_node:
|
||||||
|
common_nodes.add(node)
|
||||||
|
return common_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def prune_graph(graph, threshold):
|
||||||
|
"""
|
||||||
|
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
cycle = nx.find_cycle(graph)
|
||||||
|
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
|
||||||
|
except nx.NetworkXNoCycle:
|
||||||
|
break
|
||||||
|
|
||||||
|
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
|
||||||
|
if data["probability"] < threshold:
|
||||||
|
graph.remove_edge(u, v)
|
||||||
|
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
|
||||||
|
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
|
||||||
|
reachable.add(root_node)
|
||||||
|
|
||||||
|
for node in list(graph.nodes): # Remove nodes not reachable from root
|
||||||
|
if node not in reachable:
|
||||||
|
graph.remove_node(node)
|
||||||
|
|
||||||
|
|
||||||
|
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
|
||||||
|
"""Compare two pathways for multi-gen evaluation.
|
||||||
|
It is assumed the smiles in both pathways have been standardised in the same manner."""
|
||||||
|
data_pathway = initialise_pathway(data_pathway)
|
||||||
|
pred_pathway = initialise_pathway(pred_pathway)
|
||||||
|
if threshold is not None:
|
||||||
|
prune_graph(pred_pathway, threshold)
|
||||||
|
intermediates = find_intermediates(data_pathway, pred_pathway)
|
||||||
|
|
||||||
|
if intermediates:
|
||||||
|
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
|
||||||
|
|
||||||
|
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
|
||||||
|
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
|
||||||
|
|
||||||
|
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
||||||
|
|
||||||
|
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
|
||||||
|
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
|
||||||
|
|
||||||
|
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
for node in common_nodes:
|
||||||
|
if pred_pathway.nodes[node]["depth"] > 0:
|
||||||
|
score_TP += test_pathway_eval_weights[node]
|
||||||
|
|
||||||
|
for node in data_only_nodes:
|
||||||
|
if data_pathway.nodes[node]["depth"] > 0:
|
||||||
|
score_FN += test_pathway_eval_weights[node]
|
||||||
|
|
||||||
|
for node in pred_only_nodes:
|
||||||
|
if pred_pathway.nodes[node]["depth"] > 0:
|
||||||
|
score_FP += pred_pathway_eval_weights[node]
|
||||||
|
|
||||||
|
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
|
||||||
|
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
|
||||||
|
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
|
||||||
|
if return_intermediates:
|
||||||
|
return final_score, precision, recall, intermediates
|
||||||
|
return final_score, precision, recall
|
||||||
|
|
||||||
|
|
||||||
|
def node_subst_cost(node1, node2):
|
||||||
|
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
|
||||||
|
return 0
|
||||||
|
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
|
||||||
|
|
||||||
|
|
||||||
|
def node_ins_del_cost(node):
|
||||||
|
return 1 / (2 ** node["depth"])
|
||||||
|
|
||||||
|
|
||||||
|
def pathway_edit_eval(data_pathway, pred_pathway):
|
||||||
|
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
|
||||||
|
data_pathway = initialise_pathway(data_pathway)
|
||||||
|
pred_pathway = initialise_pathway(pred_pathway)
|
||||||
|
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
|
||||||
|
return nx.graph_edit_distance(data_pathway, pred_pathway,
|
||||||
|
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
|
||||||
|
node_ins_cost=node_ins_del_cost, roots=roots)
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@ -557,6 +557,7 @@ dependencies = [
|
|||||||
{ name = "envipy-plugins" },
|
{ name = "envipy-plugins" },
|
||||||
{ name = "epam-indigo" },
|
{ name = "epam-indigo" },
|
||||||
{ name = "gunicorn" },
|
{ name = "gunicorn" },
|
||||||
|
{ name = "networkx" },
|
||||||
{ name = "psycopg2-binary" },
|
{ name = "psycopg2-binary" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "rdkit" },
|
{ name = "rdkit" },
|
||||||
@ -588,6 +589,7 @@ requires-dist = [
|
|||||||
{ name = "epam-indigo", specifier = ">=1.30.1" },
|
{ name = "epam-indigo", specifier = ">=1.30.1" },
|
||||||
{ name = "gunicorn", specifier = ">=23.0.0" },
|
{ name = "gunicorn", specifier = ">=23.0.0" },
|
||||||
{ name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" },
|
{ name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" },
|
||||||
|
{ name = "networkx", specifier = ">=3.4.2" },
|
||||||
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
|
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
||||||
{ name = "rdkit", specifier = ">=2025.3.2" },
|
{ name = "rdkit", specifier = ">=2025.3.2" },
|
||||||
|
|||||||
Reference in New Issue
Block a user