forked from enviPath/enviPy
[Feature] MultiGen Eval (Backend) (#117)
Fixes #16 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#117
This commit is contained in:
@ -864,7 +864,7 @@ class PackageManager(object):
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def import_pacakge(data: Dict[str, Any], owner: User, preserve_uuids=False, add_import_timestamp=True,
|
||||
def import_package(data: Dict[str, Any], owner: User, preserve_uuids=False, add_import_timestamp=True,
|
||||
trust_reviewed=False) -> Package:
|
||||
|
||||
importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed)
|
||||
|
||||
482
epdb/models.py
482
epdb/models.py
@ -7,7 +7,7 @@ import secrets
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set, Any
|
||||
from uuid import uuid4
|
||||
|
||||
import joblib
|
||||
@ -588,33 +588,33 @@ class Package(EnviPathModel):
|
||||
return f"{self.name} (pk={self.pk})"
|
||||
|
||||
@property
|
||||
def compounds(self):
|
||||
def compounds(self) -> QuerySet:
|
||||
return self.compound_set.all()
|
||||
|
||||
@property
|
||||
def rules(self):
|
||||
def rules(self) -> QuerySet:
|
||||
return self.rule_set.all()
|
||||
|
||||
@property
|
||||
def reactions(self):
|
||||
def reactions(self) -> QuerySet:
|
||||
return self.reaction_set.all()
|
||||
|
||||
@property
|
||||
def pathways(self) -> 'Pathway':
|
||||
def pathways(self) -> QuerySet:
|
||||
return self.pathway_set.all()
|
||||
|
||||
@property
|
||||
def scenarios(self):
|
||||
def scenarios(self) -> QuerySet:
|
||||
return self.scenario_set.all()
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
def models(self) -> QuerySet:
|
||||
return self.epmodel_set.all()
|
||||
|
||||
def _url(self):
|
||||
return '{}/package/{}'.format(s.SERVER_URL, self.uuid)
|
||||
|
||||
def get_applicable_rules(self):
|
||||
def get_applicable_rules(self) -> List['Rule']:
|
||||
"""
|
||||
Returns a ordered set of rules where the following applies:
|
||||
1. All Composite will be added to result
|
||||
@ -650,11 +650,11 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
external_identifiers = GenericRelation('ExternalIdentifier')
|
||||
|
||||
@property
|
||||
def structures(self):
|
||||
def structures(self) -> QuerySet:
|
||||
return CompoundStructure.objects.filter(compound=self)
|
||||
|
||||
@property
|
||||
def normalized_structure(self):
|
||||
def normalized_structure(self) -> 'CompoundStructure' :
|
||||
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
|
||||
|
||||
def _url(self):
|
||||
@ -1635,8 +1635,8 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
return new_pathway
|
||||
|
||||
@transaction.atomic
|
||||
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None):
|
||||
return Node.create(self, smiles, 0)
|
||||
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None, depth: Optional[int] = 0):
|
||||
return Node.create(self, smiles, depth, name=name, description=description)
|
||||
|
||||
@transaction.atomic
|
||||
def add_edge(self, start_nodes: List['Node'], end_nodes: List['Node'], rule: Optional['Rule'] = None,
|
||||
@ -1836,6 +1836,7 @@ class PackageBasedModel(EPModel):
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
||||
default=None)
|
||||
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
|
||||
|
||||
INITIAL = "INITIAL"
|
||||
INITIALIZING = "INITIALIZING"
|
||||
@ -1861,6 +1862,24 @@ class PackageBasedModel(EPModel):
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
||||
|
||||
res = []
|
||||
|
||||
thresholds = self.eval_results['average_precision_per_threshold'].keys()
|
||||
|
||||
for t in thresholds:
|
||||
res.append({
|
||||
'precision': self.eval_results['average_precision_per_threshold'][t],
|
||||
'recall': self.eval_results['average_recall_per_threshold'][t],
|
||||
'threshold': float(t)
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self) -> List['Rule']:
|
||||
"""
|
||||
@ -1897,14 +1916,6 @@ class PackageBasedModel(EPModel):
|
||||
# TODO
|
||||
return []
|
||||
|
||||
def _get_pathways(self):
|
||||
pathway_qs = Pathway.objects.none()
|
||||
for p in self.data_packages.all():
|
||||
pathway_qs |= p.pathways
|
||||
|
||||
pathway_qs = pathway_qs.distinct()
|
||||
return pathway_qs
|
||||
|
||||
def _get_reactions(self) -> QuerySet:
|
||||
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
@ -1937,9 +1948,277 @@ class PackageBasedModel(EPModel):
|
||||
self.build_model()
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self):
|
||||
def _fit_model(self, ds: Dataset):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _model_args(self) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
ds = self.load_dataset()
|
||||
|
||||
mod = self._fit_model(ds)
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
if self.app_domain is not None:
|
||||
logger.debug("Building applicability domain...")
|
||||
self.app_domain.build()
|
||||
logger.debug("Done building applicability domain.")
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def evaluate_model(self):
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
def train_func(X, y, train_index, model_kwargs):
|
||||
clz = model_kwargs.pop('clz')
|
||||
if clz == 'RuleBaseRelativeReasoning':
|
||||
mod = RelativeReasoning(
|
||||
**model_kwargs
|
||||
)
|
||||
else:
|
||||
mod = EnsembleClassifierChain(
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
if train_index is not None:
|
||||
X, y = X[train_index], y[train_index]
|
||||
|
||||
mod.fit(X, y)
|
||||
return mod
|
||||
|
||||
def evaluate_sg(model, X, y, test_index, threshold):
|
||||
X_test = X[test_index]
|
||||
y_test = y[test_index]
|
||||
|
||||
y_pred = model.predict_proba(X_test)
|
||||
y_thresholded = (y_pred >= threshold).astype(int)
|
||||
|
||||
# Flatten them to get rid of np.nan
|
||||
y_test = np.asarray(y_test).flatten()
|
||||
y_pred = np.asarray(y_pred).flatten()
|
||||
y_thresholded = np.asarray(y_thresholded).flatten()
|
||||
|
||||
mask = ~np.isnan(y_test)
|
||||
y_test_filtered = y_test[mask]
|
||||
y_pred_filtered = y_pred[mask]
|
||||
y_thresholded_filtered = y_thresholded[mask]
|
||||
|
||||
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
||||
|
||||
prec, rec = dict(), dict()
|
||||
|
||||
for t in np.arange(0, 1.05, 0.05):
|
||||
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
||||
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
|
||||
return acc, prec, rec
|
||||
|
||||
def evaluate_mg(model, pathways: Union[QuerySet['Pathway']| List['Pathway']], threshold):
|
||||
thresholds = np.arange(0.1, 1.1, 0.1)
|
||||
|
||||
precision = {f"{t:.2f}": [] for t in thresholds}
|
||||
recall = {f"{t:.2f}": [] for t in thresholds}
|
||||
|
||||
# Note: only one root compound supported at this time
|
||||
root_compounds = [[p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways]
|
||||
|
||||
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
||||
# pass it to the setting used in prediction
|
||||
if isinstance(self, MLRelativeReasoning):
|
||||
mod = MLRelativeReasoning.objects.get(pk=self.pk)
|
||||
elif isinstance(self, RuleBasedRelativeReasoning):
|
||||
mod = RuleBasedRelativeReasoning.objects.get(pk=self.pk)
|
||||
|
||||
mod.model = model
|
||||
|
||||
s = Setting()
|
||||
s.model = mod
|
||||
s.model_threshold = thresholds.min()
|
||||
s.max_depth = 10
|
||||
s.max_nodes = 50
|
||||
|
||||
from epdb.logic import SPathway
|
||||
from utilities.ml import multigen_eval
|
||||
|
||||
pred_pathways = []
|
||||
for i, root in enumerate(root_compounds):
|
||||
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
||||
|
||||
spw = SPathway(root_nodes=root, prediction_setting=s)
|
||||
level = 0
|
||||
|
||||
while not spw.done:
|
||||
spw.predict_step(from_depth=level)
|
||||
level += 1
|
||||
|
||||
pred_pathways.append(spw)
|
||||
|
||||
mg_acc = 0.0
|
||||
for t in thresholds:
|
||||
for true, pred in zip(test_pathways, pred_pathways):
|
||||
acc, pre, rec = multigen_eval(true, pred, t)
|
||||
if abs(t - threshold) < 0.01:
|
||||
mg_acc = acc
|
||||
precision[f"{t:.2f}"].append(pre)
|
||||
recall[f"{t:.2f}"].append(rec)
|
||||
|
||||
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||
return mg_acc, precision, recall
|
||||
|
||||
ds = self.load_dataset()
|
||||
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
|
||||
n_splits = 20
|
||||
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
splits = list(shuff.split(X))
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
|
||||
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits))
|
||||
|
||||
def compute_averages(data):
|
||||
num_items = len(data)
|
||||
avg_first_item = sum(item[0] for item in data) / num_items
|
||||
|
||||
sum_dict2 = defaultdict(float)
|
||||
sum_dict3 = defaultdict(float)
|
||||
|
||||
for _, dict2, dict3 in data:
|
||||
for key in dict2:
|
||||
sum_dict2[key] += dict2[key]
|
||||
for key in dict3:
|
||||
sum_dict3[key] += dict3[key]
|
||||
|
||||
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
||||
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
||||
|
||||
return {
|
||||
"average_accuracy": float(avg_first_item),
|
||||
"average_precision_per_threshold": avg_dict2,
|
||||
"average_recall_per_threshold": avg_dict3
|
||||
}
|
||||
|
||||
self.eval_results = compute_averages(evaluations)
|
||||
|
||||
if self.multigen_eval:
|
||||
|
||||
# We have to consider 2 cases here:
|
||||
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
|
||||
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
|
||||
|
||||
if self.eval_packages.count() > 0:
|
||||
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
evaluate_mg(self.model, pathway_qs, self.threshold)
|
||||
return
|
||||
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
'node_set__default_node_label',
|
||||
'node_set__scenarios',
|
||||
'edge_set',
|
||||
'edge_set__start_nodes',
|
||||
'edge_set__end_nodes',
|
||||
'edge_set__edge_label',
|
||||
'edge_set__scenarios'
|
||||
).filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
pathways = []
|
||||
for pathway in pathway_qs:
|
||||
# There is one pathway with no root compounds, so this check is required
|
||||
if len(pathway.root_nodes) > 0:
|
||||
pathways.append(pathway)
|
||||
else:
|
||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||
|
||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||
reaction_to_educts = defaultdict(set)
|
||||
for pathway in pathways:
|
||||
for reaction in pathway.edges:
|
||||
for e in reaction.edge_label.educts.all():
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
# build lookup to avoid recalculation of features, labels
|
||||
id_to_index = {uuid: i for i, uuid in enumerate(ds[:, 0])}
|
||||
|
||||
# Compute splits of the collected pathway
|
||||
splits = []
|
||||
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
# Collect structures from test pathways
|
||||
test_educts = set()
|
||||
for pathway in test_pathways:
|
||||
for reaction in pathway.edges:
|
||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||
|
||||
split_ids = []
|
||||
overlap = 0
|
||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||
# the test pathways
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
||||
# Ensure compounds in the training set do not appear in the test set
|
||||
if educt not in test_educts:
|
||||
if educt in id_to_index:
|
||||
split_ids.append(id_to_index[str(educt)])
|
||||
try:
|
||||
split_ids.append(id_to_index[str(educt)])
|
||||
except KeyError:
|
||||
logger.debug(f"Couldn't find features in X for compound {educt}")
|
||||
else:
|
||||
overlap += 1
|
||||
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||
|
||||
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
||||
split_x, split_y = X[split_ids], y[split_ids]
|
||||
splits.append([(split_x, split_y), test_pathways])
|
||||
|
||||
|
||||
# Build model on subsets obtained by pathway split
|
||||
trained_models = Parallel(n_jobs=10)(
|
||||
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
|
||||
)
|
||||
|
||||
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
||||
multi_ret_vals = Parallel(n_jobs=1)(
|
||||
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
|
||||
zip(trained_models, splits)
|
||||
)
|
||||
|
||||
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in compute_averages(multi_ret_vals).items()})
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
@ -2011,21 +2290,22 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return rbrr
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
def _fit_model(self, ds: Dataset):
|
||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||
model = RelativeReasoning(
|
||||
start_index= ds.triggered()[0],
|
||||
end_index= ds.triggered()[1],
|
||||
)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
def _model_args(self):
|
||||
ds = self.load_dataset()
|
||||
labels = ds.y(na_replacement=None)
|
||||
|
||||
mod = RelativeReasoning(*ds.triggered())
|
||||
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
return {
|
||||
'clz': 'RuleBaseRelativeReasoning',
|
||||
'start_index': ds.triggered()[0],
|
||||
'end_index': ds.triggered()[1],
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def model(self) -> 'RelativeReasoning':
|
||||
@ -2038,7 +2318,6 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
|
||||
mod = self.model
|
||||
|
||||
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
||||
|
||||
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
||||
@ -2102,118 +2381,23 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return mlrr
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
ds = self.load_dataset()
|
||||
def _fit_model(self, ds: Dataset):
|
||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||
|
||||
mod = EnsembleClassifierChain(
|
||||
model = EnsembleClassifierChain(
|
||||
**s.DEFAULT_MODEL_PARAMS
|
||||
)
|
||||
mod.fit(X, y)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
|
||||
if self.app_domain is not None:
|
||||
logger.debug("Building applicability domain...")
|
||||
self.app_domain.build()
|
||||
logger.debug("Done building applicability domain.")
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def evaluate_model(self):
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
ds = self.load_dataset()
|
||||
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
|
||||
n_splits = 20
|
||||
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
|
||||
def train_and_evaluate(X, y, train_index, test_index, threshold):
|
||||
X_train, X_test = X[train_index], X[test_index]
|
||||
y_train, y_test = y[train_index], y[test_index]
|
||||
|
||||
model = EnsembleClassifierChain(
|
||||
**s.DEFAULT_MODEL_PARAMS
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict_proba(X_test)
|
||||
y_thresholded = (y_pred >= threshold).astype(int)
|
||||
|
||||
# Flatten them to get rid of np.nan
|
||||
y_test = np.asarray(y_test).flatten()
|
||||
y_pred = np.asarray(y_pred).flatten()
|
||||
y_thresholded = np.asarray(y_thresholded).flatten()
|
||||
|
||||
mask = ~np.isnan(y_test)
|
||||
y_test_filtered = y_test[mask]
|
||||
y_pred_filtered = y_pred[mask]
|
||||
y_thresholded_filtered = y_thresholded[mask]
|
||||
|
||||
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
||||
|
||||
prec, rec = dict(), dict()
|
||||
|
||||
for t in np.arange(0, 1.05, 0.05):
|
||||
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
||||
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
||||
|
||||
return acc, prec, rec
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
ret_vals = Parallel(n_jobs=10)(
|
||||
delayed(train_and_evaluate)(X, y, train_index, test_index, self.threshold)
|
||||
for train_index, test_index in shuff.split(X)
|
||||
)
|
||||
|
||||
def compute_averages(data):
|
||||
num_items = len(data)
|
||||
avg_first_item = sum(item[0] for item in data) / num_items
|
||||
|
||||
sum_dict2 = defaultdict(float)
|
||||
sum_dict3 = defaultdict(float)
|
||||
|
||||
for _, dict2, dict3 in data:
|
||||
for key in dict2:
|
||||
sum_dict2[key] += dict2[key]
|
||||
for key in dict3:
|
||||
sum_dict3[key] += dict3[key]
|
||||
|
||||
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
||||
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
||||
|
||||
return {
|
||||
"average_accuracy": float(avg_first_item),
|
||||
"average_precision_per_threshold": avg_dict2,
|
||||
"average_recall_per_threshold": avg_dict3
|
||||
}
|
||||
|
||||
self.eval_results = compute_averages(ret_vals)
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
def _model_args(self):
|
||||
return {
|
||||
'clz': 'MLRelativeReasoning',
|
||||
**s.DEFAULT_MODEL_PARAMS,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
def model(self) -> 'EnsembleClassifierChain':
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
mod.base_clf.n_jobs = -1
|
||||
return mod
|
||||
@ -2230,24 +2414,6 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
||||
return res
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
||||
|
||||
res = []
|
||||
|
||||
thresholds = self.eval_results['average_precision_per_threshold'].keys()
|
||||
|
||||
for t in thresholds:
|
||||
res.append({
|
||||
'precision': self.eval_results['average_precision_per_threshold'][t],
|
||||
'recall': self.eval_results['average_recall_per_threshold'][t],
|
||||
'threshold': float(t)
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class ApplicabilityDomain(EnviPathModel):
|
||||
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from django.conf import settings as s
|
||||
from django.db import transaction
|
||||
@ -35,4 +34,4 @@ def delete_epmodel_files(sender, instance, **kwargs):
|
||||
for f in os.listdir(s.MODEL_DIR):
|
||||
if f.startswith(mod_uuid):
|
||||
logger.info(f"Deleting {os.path.join(s.MODEL_DIR, f)}")
|
||||
shutil.rmtree(os.path.join(s.MODEL_DIR, f))
|
||||
os.remove(os.path.join(s.MODEL_DIR, f))
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from celery.signals import worker_process_init
|
||||
from celery import shared_task
|
||||
from epdb.models import Pathway, Node, Edge, EPModel, Setting
|
||||
from epdb.logic import SPathway
|
||||
|
||||
from utilities.chem import FormatConverter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -294,7 +294,7 @@ def packages(request):
|
||||
if hidden == 'import-legacy-package-json':
|
||||
pack = PackageManager.import_legacy_package(data, current_user)
|
||||
else:
|
||||
pack = PackageManager.import_pacakge(data, current_user)
|
||||
pack = PackageManager.import_package(data, current_user)
|
||||
|
||||
return redirect(pack.url)
|
||||
except UnicodeDecodeError:
|
||||
@ -772,10 +772,13 @@ def package_model(request, package_uuid, model_uuid):
|
||||
if hidden == 'delete':
|
||||
current_model.delete()
|
||||
return redirect(current_package.url + '/model')
|
||||
elif hidden == 'evaluate':
|
||||
from .tasks import evaluate_model
|
||||
evaluate_model.delay(current_model.pk)
|
||||
return redirect(current_model.url)
|
||||
else:
|
||||
return HttpResponseBadRequest()
|
||||
else:
|
||||
|
||||
name = request.POST.get('model-name', '').strip()
|
||||
description = request.POST.get('model-description', '').strip()
|
||||
|
||||
|
||||
439315
fixtures/EAWAG-BBD.json
439315
fixtures/EAWAG-BBD.json
File diff suppressed because it is too large
Load Diff
125432
fixtures/EAWAG-SLUDGE.json
125432
fixtures/EAWAG-SLUDGE.json
File diff suppressed because it is too large
Load Diff
1572257
fixtures/EAWAG-SOIL.json
1572257
fixtures/EAWAG-SOIL.json
File diff suppressed because one or more lines are too long
Binary file not shown.
BIN
fixtures/test_fixtures.jsonl.gz
Normal file
BIN
fixtures/test_fixtures.jsonl.gz
Normal file
Binary file not shown.
@ -18,6 +18,7 @@ dependencies = [
|
||||
"envipy-plugins",
|
||||
"epam-indigo>=1.30.1",
|
||||
"gunicorn>=23.0.0",
|
||||
"networkx>=3.4.2",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"python-dotenv>=1.1.0",
|
||||
"rdkit>=2025.3.2",
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
<h4 class="modal-title">Evaluate Model</h4>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ meta.current_package.url }}/model"
|
||||
<form id="evaluate_model_form" accept-charset="UTF-8" action="{{ current_object.url }}"
|
||||
data-remote="true" method="post">
|
||||
{% csrf_token %}
|
||||
<div class="jumbotron">
|
||||
@ -35,6 +35,7 @@
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</select>
|
||||
<input type="hidden" name="hidden" value="evaluate">
|
||||
</form>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
|
||||
@ -5,7 +5,7 @@ from epdb.models import Compound, User, CompoundStructure
|
||||
|
||||
|
||||
class CompoundTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
@ -7,7 +7,7 @@ from epdb.models import Compound, User, Reaction
|
||||
|
||||
|
||||
class CopyTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -6,7 +6,7 @@ from utilities.ml import Dataset
|
||||
|
||||
|
||||
class DatasetTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
def setUp(self):
|
||||
self.cs1 = Compound.create(
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
from django.test import TestCase
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import User, MLRelativeReasoning, Package
|
||||
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
|
||||
|
||||
|
||||
class ModelTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -17,28 +18,55 @@ class ModelTest(TestCase):
|
||||
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
|
||||
|
||||
def test_smoke(self):
|
||||
threshold = float(0.5)
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = []
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold,
|
||||
'ECC - BBD - 0.5',
|
||||
'Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
eval_packages_objs,
|
||||
threshold=threshold,
|
||||
name='ECC - BBD - 0.5',
|
||||
description='Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
print("Model built!")
|
||||
mod.evaluate_model()
|
||||
print("Model Evaluated")
|
||||
# mod = RuleBasedRelativeReasoning.create(
|
||||
# self.package,
|
||||
# rule_package_objs,
|
||||
# data_package_objs,
|
||||
# eval_packages_objs,
|
||||
# threshold=threshold,
|
||||
# min_count=5,
|
||||
# max_count=0,
|
||||
# name='ECC - BBD - 0.5',
|
||||
# description='Created MLRelativeReasoning in Testcase',
|
||||
# )
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
print(results)
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.multigen_eval = True
|
||||
mod.save()
|
||||
# mod.evaluate_model()
|
||||
|
||||
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|
||||
|
||||
products = dict()
|
||||
for r in results:
|
||||
for ps in r.product_sets:
|
||||
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
|
||||
|
||||
expected = {
|
||||
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
|
||||
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
|
||||
}
|
||||
|
||||
self.assertEqual(products, expected)
|
||||
|
||||
# from pprint import pprint
|
||||
# pprint(mod.eval_results)
|
||||
|
||||
137
tests/test_multigen_eval.py
Normal file
137
tests/test_multigen_eval.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
from django.test import TestCase
|
||||
from networkx.utils.misc import graphs_equal
|
||||
from epdb.logic import PackageManager, SPathway
|
||||
from epdb.models import Pathway, User, Package
|
||||
from utilities.ml import multigen_eval, pathway_edit_eval, graph_from_pathway
|
||||
|
||||
|
||||
class MultiGenTest(TestCase):
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(MultiGenTest, cls).setUpClass()
|
||||
cls.user: 'User' = User.objects.get(username='anonymous')
|
||||
cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
|
||||
cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures')
|
||||
|
||||
def test_equal_pathways(self):
|
||||
"""Test that two identical pathways return a precision and recall of 1.0"""
|
||||
pathways = self.BBD_SUBSET.pathways.all()
|
||||
for pathway in pathways:
|
||||
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||
continue
|
||||
score, precision, recall = multigen_eval(pathway, pathway)
|
||||
self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. "
|
||||
f"Failed on pathway: {pathway.name}")
|
||||
self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. "
|
||||
f"Failed on pathway: {pathway.name}")
|
||||
|
||||
def test_intermediates(self):
|
||||
"""Test that an intermediate can be correctly identified and the metrics are correctly adjusted"""
|
||||
score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True)
|
||||
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
|
||||
self.assertEqual(precision, 1, "Precision should be 1")
|
||||
self.assertEqual(recall, 1, "Recall should be 1")
|
||||
|
||||
def test_fp(self):
|
||||
"""Test that a false-positive (extra compound) is correctly penalised"""
|
||||
score, precision, recall = multigen_eval(*self.fp_case())
|
||||
self.assertAlmostEqual(precision, 0.75, 3, "Precision should be 0.75")
|
||||
self.assertEqual(recall, 1, "Recall should be 1")
|
||||
|
||||
def test_fn(self):
|
||||
"""Test that a false-negative (missed compound) is correctly penalised"""
|
||||
score, precision, recall = multigen_eval(*self.fn_case())
|
||||
self.assertEqual(precision, 1, "Precision should be 1.0")
|
||||
self.assertAlmostEqual(recall, 0.667, 3, "Recall should be 0.667")
|
||||
|
||||
def test_all(self):
|
||||
"""Test an intermediate, false-positive and false-negative together"""
|
||||
score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True)
|
||||
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
|
||||
self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6")
|
||||
self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75")
|
||||
|
||||
def test_shallow_pathway(self):
|
||||
pathways = self.BBD_SUBSET.pathways.all()
|
||||
for pathway in pathways:
|
||||
pathway_name = pathway.name
|
||||
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||
continue
|
||||
shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway))
|
||||
pathway = graph_from_pathway(pathway)
|
||||
if not graphs_equal(shallow_pathway, pathway):
|
||||
print('\n\nS', shallow_pathway.adj)
|
||||
print('\n\nPW', pathway.adj)
|
||||
# print(shallow_pathway.nodes, pathway.nodes)
|
||||
# print(shallow_pathway.graph, pathway.graph)
|
||||
|
||||
self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not "
|
||||
f"equal to pathway for pathway {pathway.name}")
|
||||
|
||||
def test_graph_edit_eval(self):
|
||||
"""Performs all the previous tests but with graph_edit_eval
|
||||
Unlike multigen_eval, these test cases have not been hand verified"""
|
||||
pathways = self.BBD_SUBSET.pathways.all()
|
||||
for pathway in pathways:
|
||||
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
|
||||
continue
|
||||
score = pathway_edit_eval(pathway, pathway)
|
||||
self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. "
|
||||
f"Failed on pathway: {pathway.name}")
|
||||
inter_score = pathway_edit_eval(*self.intermediate_case())
|
||||
self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case")
|
||||
fp_score = pathway_edit_eval(*self.fp_case())
|
||||
self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case")
|
||||
fn_score = pathway_edit_eval(*self.fn_case())
|
||||
self.assertAlmostEqual(fn_score, 1.25, 3, "Pathway edit distance failed on fn case")
|
||||
all_score = pathway_edit_eval(*self.all_case())
|
||||
self.assertAlmostEqual(all_score, 1.0, 3, "Pathway edit distance failed on all case")
|
||||
|
||||
def intermediate_case(self):
|
||||
"""Create an example with an intermediate in the predicted pathway"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
|
||||
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
|
||||
return true_pathway, pred_pathway
|
||||
|
||||
def fp_case(self):
|
||||
"""Create an example with an extra compound in the predicted pathway"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
|
||||
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)])
|
||||
return true_pathway, pred_pathway
|
||||
|
||||
def fn_case(self):
|
||||
"""Create an example with a missing compound in the predicted pathway"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)])
|
||||
return true_pathway, pred_pathway
|
||||
|
||||
def all_case(self):
|
||||
"""Create an example with an intermediate, extra compound and missing compound"""
|
||||
true_pathway = Pathway.create(self.package, "CCO")
|
||||
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
|
||||
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)])
|
||||
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
|
||||
pred_pathway = Pathway.create(self.package, "CCO")
|
||||
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)])
|
||||
pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)])
|
||||
pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)])
|
||||
return true_pathway, pred_pathway
|
||||
@ -5,7 +5,7 @@ from epdb.models import Compound, User, Reaction, Rule
|
||||
|
||||
|
||||
class ReactionTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -5,10 +5,7 @@ from epdb.models import Rule, User
|
||||
|
||||
|
||||
class RuleTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -7,7 +7,7 @@ from epdb.models import User, SimpleAmbitRule
|
||||
|
||||
|
||||
class SimpleAmbitRuleTest(TestCase):
|
||||
fixtures = ["test_fixtures.json.gz"]
|
||||
fixtures = ["test_fixtures.jsonl.gz"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@ -183,7 +183,7 @@ class FormatConverter(object):
|
||||
return smiles
|
||||
|
||||
@staticmethod
|
||||
def standardize(smiles):
|
||||
def standardize(smiles, remove_stereo=False):
|
||||
# Taken from https://bitsilla.com/blog/2021/06/standardizing-a-molecule-using-rdkit/
|
||||
# follows the steps in
|
||||
# https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/MolStandardize%20pieces.ipynb
|
||||
@ -208,6 +208,9 @@ class FormatConverter(object):
|
||||
# te = rdMolStandardize.TautomerEnumerator() # idem
|
||||
# taut_uncharged_parent_clean_mol = te.Canonicalize(uncharged_parent_clean_mol)
|
||||
|
||||
if remove_stereo:
|
||||
Chem.RemoveStereochemistry(uncharged_parent_clean_mol)
|
||||
|
||||
return Chem.MolToSmiles(uncharged_parent_clean_mol, kekuleSmiles=True)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -919,7 +919,7 @@ class PackageImporter:
|
||||
name=edge_data['name'],
|
||||
description=edge_data['description'],
|
||||
kv=edge_data.get('kv', {}),
|
||||
edge_label=None # Will be set later
|
||||
edge_label=self._get_cached_object('Reaction', edge_data['edge_label']['uuid'])
|
||||
)
|
||||
|
||||
# Set aliases if present
|
||||
|
||||
349
utilities/ml.py
349
utilities/ml.py
@ -1,12 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import default_rng
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
@ -22,61 +29,6 @@ from dataclasses import dataclass, field
|
||||
from utilities.chem import FormatConverter, PredictionResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class SCompound:
|
||||
smiles: str
|
||||
uuid: str = field(default=None, compare=False, hash=False)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((
|
||||
self.smiles
|
||||
))
|
||||
return self._hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class SReaction:
|
||||
educts: List[SCompound]
|
||||
products: List[SCompound]
|
||||
rule_uuid: SRule = field(default=None, compare=False, hash=False)
|
||||
reaction_uuid: str = field(default=None, compare=False, hash=False)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((
|
||||
tuple(sorted(self.educts, key=lambda x: x.smiles)),
|
||||
tuple(sorted(self.products, key=lambda x: x.smiles)),
|
||||
))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, SReaction):
|
||||
return NotImplemented
|
||||
return (
|
||||
sorted(self.educts, key=lambda x: x.smiles) == sorted(other.educts, key=lambda x: x.smiles) and
|
||||
sorted(self.products, key=lambda x: x.smiles) == sorted(other.products, key=lambda x: x.smiles)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SRule(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSimpleRule:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SParallelRule:
|
||||
pass
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
||||
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
|
||||
@ -385,7 +337,7 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
self.chains_ = [ClassifierChain(self.base_clf) for i in range(self.num_chains)]
|
||||
|
||||
for i, chain in enumerate(self.chains_):
|
||||
print(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
||||
logger.debug(f"{datetime.now()} fitting {i + 1}/{self.num_chains}")
|
||||
chain.fit(X, y_reduced)
|
||||
|
||||
return self
|
||||
@ -423,14 +375,6 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
return accuracy_score(y_true, y_pred, sample_weight=sample_weight)
|
||||
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
|
||||
class BinaryRelevance:
|
||||
def __init__(self, baseline_clf):
|
||||
self.clf = baseline_clf
|
||||
@ -483,7 +427,8 @@ class MissingValuesClassifierChain:
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if self.permutation is None:
|
||||
self.permutation = np.random.permutation(len(Y[0]))
|
||||
rng = default_rng(42)
|
||||
self.permutation = rng.permutation(len(Y[0]))
|
||||
|
||||
Y = Y[:, self.permutation]
|
||||
|
||||
@ -541,7 +486,7 @@ class EnsembleClassifierChain:
|
||||
self.num_labels = len(Y[0])
|
||||
|
||||
for p in range(self.num_chains):
|
||||
print(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
clf = MissingValuesClassifierChain(self.base_clf)
|
||||
clf.fit(X, Y)
|
||||
self.classifiers.append(clf)
|
||||
@ -609,13 +554,23 @@ class RelativeReasoning:
|
||||
def predict(self, X):
|
||||
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
||||
|
||||
# Loop through all instances
|
||||
for inst_idx, inst in enumerate(X):
|
||||
# Loop through all "triggered" features
|
||||
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
# Set label
|
||||
res[inst_idx][i] = t
|
||||
# If we predict a 1, check if the rule gets dominated by another
|
||||
if t:
|
||||
# Second loop to check other triggered rules
|
||||
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
|
||||
res[inst_idx][i] = 0
|
||||
if i != i2:
|
||||
# Check if rule idx is in "dominated by" list
|
||||
if i2 in self.winmap.get(i, []):
|
||||
# if thatat rule also triggered, it dominated the current
|
||||
# set label to 0
|
||||
if X[inst_idx][i2]:
|
||||
res[inst_idx][i] = 0
|
||||
|
||||
return res
|
||||
|
||||
@ -671,3 +626,259 @@ def tanimoto_distance(a: List[int], b: List[int]):
|
||||
return 0.0
|
||||
|
||||
return 1 - (sum_c / (sum_a + sum_b - sum_c))
|
||||
|
||||
|
||||
def graph_from_pathway(data):
|
||||
"""Convert Pathway or SPathway to networkx"""
|
||||
from epdb.models import Pathway
|
||||
from epdb.logic import SPathway
|
||||
graph = nx.DiGraph()
|
||||
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
|
||||
|
||||
def get_edges():
|
||||
if isinstance(data, Pathway):
|
||||
return data.edges.all()
|
||||
elif isinstance(data, SPathway):
|
||||
return data.edges
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
|
||||
def get_sources_targets():
|
||||
if isinstance(data, Pathway):
|
||||
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
|
||||
elif isinstance(data, SPathway):
|
||||
return edge.educts, edge.products
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
|
||||
def get_smiles_depth(node):
|
||||
if isinstance(data, Pathway):
|
||||
return FormatConverter.standardize(node.default_node_label.smiles, True), node.depth
|
||||
elif isinstance(data, SPathway):
|
||||
return FormatConverter.standardize(node.smiles, True), node.depth
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
|
||||
def get_probability():
|
||||
try:
|
||||
if isinstance(data, Pathway):
|
||||
return edge.kv.get('probability')
|
||||
elif isinstance(data, SPathway):
|
||||
return edge.probability
|
||||
else:
|
||||
raise TypeError(f"Can't convert type {type(data)} to networkx for multigen eval")
|
||||
except AttributeError:
|
||||
return 1
|
||||
|
||||
root_smiles = {get_smiles_depth(n) for n in data.root_nodes}
|
||||
for root, depth in root_smiles:
|
||||
graph.add_node(root, depth=depth, smiles=root, root=True)
|
||||
|
||||
for edge in get_edges():
|
||||
sources, targets = get_sources_targets()
|
||||
probability = get_probability()
|
||||
for source in sources:
|
||||
source_smiles, source_depth = get_smiles_depth(source)
|
||||
if source_smiles not in graph:
|
||||
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
|
||||
root=source_smiles in root_smiles)
|
||||
else:
|
||||
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
|
||||
for target in targets:
|
||||
target_smiles, target_depth = get_smiles_depth(target)
|
||||
if target_smiles not in graph and target_smiles not in co2:
|
||||
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
|
||||
root=target_smiles in root_smiles)
|
||||
elif target_smiles not in co2:
|
||||
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
|
||||
if target_smiles not in co2 and target_smiles != source_smiles:
|
||||
graph.add_edge(source_smiles, target_smiles, probability=probability)
|
||||
return graph
|
||||
|
||||
|
||||
def get_shortest_path(pathway, in_start_node, in_end_node):
|
||||
try:
|
||||
pred = nx.shortest_path(pathway, source=in_start_node, target=in_end_node)
|
||||
except nx.NetworkXNoPath:
|
||||
return []
|
||||
pred.remove(in_start_node)
|
||||
pred.remove(in_end_node)
|
||||
return pred
|
||||
|
||||
|
||||
def set_pathway_eval_weight(pathway):
|
||||
node_eval_weights = {}
|
||||
for node in pathway.nodes:
|
||||
# Scale score according to depth level
|
||||
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
|
||||
return node_eval_weights
|
||||
|
||||
|
||||
def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
|
||||
if len(intermediates) < 1:
|
||||
return pred_pathway
|
||||
root_nodes = pred_pathway.graph["root_nodes"]
|
||||
for node in pred_pathway.nodes:
|
||||
if node in root_nodes:
|
||||
continue
|
||||
if node in intermediates and node not in data_pathway:
|
||||
pred_pathway.nodes[node]["depth"] = -99
|
||||
else:
|
||||
shortest_path_list = []
|
||||
for root_node in root_nodes:
|
||||
shortest_path_nodes = get_shortest_path(pred_pathway, root_node, node)
|
||||
if shortest_path_nodes:
|
||||
shortest_path_list.append(shortest_path_nodes)
|
||||
if shortest_path_list:
|
||||
shortest_path_nodes = min(shortest_path_list, key=len)
|
||||
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
|
||||
shortest_path_node in intermediates)
|
||||
pred_pathway.nodes[node]["depth"] -= num_ints
|
||||
return pred_pathway
|
||||
|
||||
|
||||
def initialise_pathway(pathway):
|
||||
"""Convert pathway to networkx graph for evaluation"""
|
||||
pathway = graph_from_pathway(pathway)
|
||||
pathway.graph["root_nodes"] = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == 0}
|
||||
pathway = get_pathway_with_depth(pathway)
|
||||
return pathway
|
||||
|
||||
|
||||
def get_pathway_with_depth(pathway):
|
||||
"""Recalculates depths in the pathway.
|
||||
Can fix incorrect depths from json parse if there were multiple nodes with the same SMILES at
|
||||
different depths that got merged."""
|
||||
current_depth = 0
|
||||
for node in pathway.nodes:
|
||||
if node in pathway.graph["root_nodes"]:
|
||||
pathway.nodes[node]["depth"] = current_depth
|
||||
else:
|
||||
pathway.nodes[node]["depth"] = -99
|
||||
while assign_next_depth(pathway, current_depth):
|
||||
current_depth += 1
|
||||
return pathway
|
||||
|
||||
|
||||
def assign_next_depth(pathway, current_depth):
|
||||
new_assigned_nodes = False
|
||||
current_depth_nodes = {n for n in pathway.nodes if pathway.nodes[n]["depth"] == current_depth}
|
||||
for node in current_depth_nodes:
|
||||
successors = pathway.successors(node)
|
||||
for s in successors:
|
||||
if pathway.nodes[s]["depth"] < 0:
|
||||
pathway.nodes[s]["depth"] = current_depth + 1
|
||||
new_assigned_nodes = True
|
||||
return new_assigned_nodes
|
||||
|
||||
|
||||
def find_intermediates(data_pathway, pred_pathway):
|
||||
"""Find any intermediate nodes in the predicted pathway"""
|
||||
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
||||
intermediates = set()
|
||||
for node in common_nodes:
|
||||
down_stream_nodes = data_pathway.successors(node)
|
||||
for down_stream_node in down_stream_nodes:
|
||||
if down_stream_node in pred_pathway:
|
||||
all_ints = get_shortest_path(pred_pathway, node, down_stream_node)
|
||||
intermediates.update(all_ints)
|
||||
return intermediates
|
||||
|
||||
|
||||
def get_common_nodes(pred_pathway, data_pathway):
|
||||
"""A node is a common node if it is in both pathways and is either a root in both or not a root in both."""
|
||||
common_nodes = set()
|
||||
for node in data_pathway.nodes:
|
||||
is_pathway_root_node = node in data_pathway.graph["root_nodes"]
|
||||
is_this_root_node = node in pred_pathway.graph["root_nodes"]
|
||||
if node in pred_pathway.nodes:
|
||||
if is_pathway_root_node is False and is_this_root_node is False:
|
||||
common_nodes.add(node)
|
||||
elif is_pathway_root_node and is_this_root_node:
|
||||
common_nodes.add(node)
|
||||
return common_nodes
|
||||
|
||||
|
||||
def prune_graph(graph, threshold):
|
||||
"""
|
||||
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
cycle = nx.find_cycle(graph)
|
||||
graph.remove_edge(*cycle[-1]) # Remove the last edge in the cycle
|
||||
except nx.NetworkXNoCycle:
|
||||
break
|
||||
|
||||
for u, v, data in list(graph.edges(data=True)): # Remove edges below threshold
|
||||
if data["probability"] < threshold:
|
||||
graph.remove_edge(u, v)
|
||||
root_node = [n for n in graph.nodes if graph.nodes[n]["root"]][0]
|
||||
reachable = nx.descendants(graph, root_node) # Get all reachable nodes from root
|
||||
reachable.add(root_node)
|
||||
|
||||
for node in list(graph.nodes): # Remove nodes not reachable from root
|
||||
if node not in reachable:
|
||||
graph.remove_node(node)
|
||||
|
||||
|
||||
def multigen_eval(data_pathway, pred_pathway, threshold=None, return_intermediates=False):
|
||||
"""Compare two pathways for multi-gen evaluation.
|
||||
It is assumed the smiles in both pathways have been standardised in the same manner."""
|
||||
data_pathway = initialise_pathway(data_pathway)
|
||||
pred_pathway = initialise_pathway(pred_pathway)
|
||||
if threshold is not None:
|
||||
prune_graph(pred_pathway, threshold)
|
||||
intermediates = find_intermediates(data_pathway, pred_pathway)
|
||||
|
||||
if intermediates:
|
||||
pred_pathway = get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates)
|
||||
|
||||
test_pathway_eval_weights = set_pathway_eval_weight(data_pathway)
|
||||
pred_pathway_eval_weights = set_pathway_eval_weight(pred_pathway)
|
||||
|
||||
common_nodes = get_common_nodes(pred_pathway, data_pathway)
|
||||
|
||||
data_only_nodes = set(n for n in data_pathway.nodes if n not in common_nodes)
|
||||
pred_only_nodes = set(n for n in pred_pathway.nodes if n not in common_nodes)
|
||||
|
||||
score_TP, score_FP, score_FN, final_score, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
for node in common_nodes:
|
||||
if pred_pathway.nodes[node]["depth"] > 0:
|
||||
score_TP += test_pathway_eval_weights[node]
|
||||
|
||||
for node in data_only_nodes:
|
||||
if data_pathway.nodes[node]["depth"] > 0:
|
||||
score_FN += test_pathway_eval_weights[node]
|
||||
|
||||
for node in pred_only_nodes:
|
||||
if pred_pathway.nodes[node]["depth"] > 0:
|
||||
score_FP += pred_pathway_eval_weights[node]
|
||||
|
||||
final_score = score_TP / denom if (denom := score_TP + score_FP + score_FN) > 0 else 0.0
|
||||
precision = score_TP / denom if (denom := score_TP + score_FP) > 0 else 0.0
|
||||
recall = score_TP / denom if (denom := score_TP + score_FN) > 0 else 0.0
|
||||
if return_intermediates:
|
||||
return final_score, precision, recall, intermediates
|
||||
return final_score, precision, recall
|
||||
|
||||
|
||||
def node_subst_cost(node1, node2):
|
||||
if node1["smiles"] == node2["smiles"] and node1["depth"] == node2["depth"]:
|
||||
return 0
|
||||
return 1 / (2 ** max(node1["depth"], node2["depth"])) # Maybe could be min instead of max
|
||||
|
||||
|
||||
def node_ins_del_cost(node):
|
||||
return 1 / (2 ** node["depth"])
|
||||
|
||||
|
||||
def pathway_edit_eval(data_pathway, pred_pathway):
|
||||
"""Compute the graph edit distance for two pathways, a potential alternative to multigen_eval"""
|
||||
data_pathway = initialise_pathway(data_pathway)
|
||||
pred_pathway = initialise_pathway(pred_pathway)
|
||||
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
|
||||
return nx.graph_edit_distance(data_pathway, pred_pathway,
|
||||
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
|
||||
node_ins_cost=node_ins_del_cost, roots=roots)
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@ -557,6 +557,7 @@ dependencies = [
|
||||
{ name = "envipy-plugins" },
|
||||
{ name = "epam-indigo" },
|
||||
{ name = "gunicorn" },
|
||||
{ name = "networkx" },
|
||||
{ name = "psycopg2-binary" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "rdkit" },
|
||||
@ -588,6 +589,7 @@ requires-dist = [
|
||||
{ name = "epam-indigo", specifier = ">=1.30.1" },
|
||||
{ name = "gunicorn", specifier = ">=23.0.0" },
|
||||
{ name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" },
|
||||
{ name = "networkx", specifier = ">=3.4.2" },
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
|
||||
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
||||
{ name = "rdkit", specifier = ">=2025.3.2" },
|
||||
|
||||
Reference in New Issue
Block a user