[Feature] MultiGen Eval (Backend) (#117)

Fixes #16

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#117
This commit is contained in:
2025-09-18 18:40:45 +12:00
parent 762a6b7baf
commit 50db2fb372
24 changed files with 816 additions and 2137274 deletions

View File

@ -7,7 +7,7 @@ import secrets
from abc import abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import Union, List, Optional, Dict, Tuple, Set
from typing import Union, List, Optional, Dict, Tuple, Set, Any
from uuid import uuid4
import joblib
@ -588,33 +588,33 @@ class Package(EnviPathModel):
return f"{self.name} (pk={self.pk})"
@property
def compounds(self):
def compounds(self) -> QuerySet:
return self.compound_set.all()
@property
def rules(self):
def rules(self) -> QuerySet:
return self.rule_set.all()
@property
def reactions(self):
def reactions(self) -> QuerySet:
return self.reaction_set.all()
@property
def pathways(self) -> 'Pathway':
def pathways(self) -> QuerySet:
return self.pathway_set.all()
@property
def scenarios(self):
def scenarios(self) -> QuerySet:
return self.scenario_set.all()
@property
def models(self):
def models(self) -> QuerySet:
return self.epmodel_set.all()
def _url(self):
return '{}/package/{}'.format(s.SERVER_URL, self.uuid)
def get_applicable_rules(self):
def get_applicable_rules(self) -> List['Rule']:
"""
Returns a ordered set of rules where the following applies:
1. All Composite will be added to result
@ -650,11 +650,11 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
external_identifiers = GenericRelation('ExternalIdentifier')
@property
def structures(self):
def structures(self) -> QuerySet:
return CompoundStructure.objects.filter(compound=self)
@property
def normalized_structure(self):
def normalized_structure(self) -> 'CompoundStructure' :
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
def _url(self):
@ -1635,8 +1635,8 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
return new_pathway
@transaction.atomic
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None):
return Node.create(self, smiles, 0)
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None, depth: Optional[int] = 0):
return Node.create(self, smiles, depth, name=name, description=description)
@transaction.atomic
def add_edge(self, start_nodes: List['Node'], end_nodes: List['Node'], rule: Optional['Rule'] = None,
@ -1836,6 +1836,7 @@ class PackageBasedModel(EPModel):
eval_results = JSONField(null=True, blank=True, default=dict)
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
default=None)
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
INITIAL = "INITIAL"
INITIALIZING = "INITIALIZING"
@ -1861,6 +1862,24 @@ class PackageBasedModel(EPModel):
def ready_for_prediction(self) -> bool:
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
@property
def pr_curve(self):
if self.model_status != self.FINISHED:
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
res = []
thresholds = self.eval_results['average_precision_per_threshold'].keys()
for t in thresholds:
res.append({
'precision': self.eval_results['average_precision_per_threshold'][t],
'recall': self.eval_results['average_recall_per_threshold'][t],
'threshold': float(t)
})
return res
@cached_property
def applicable_rules(self) -> List['Rule']:
"""
@ -1897,14 +1916,6 @@ class PackageBasedModel(EPModel):
# TODO
return []
def _get_pathways(self):
pathway_qs = Pathway.objects.none()
for p in self.data_packages.all():
pathway_qs |= p.pathways
pathway_qs = pathway_qs.distinct()
return pathway_qs
def _get_reactions(self) -> QuerySet:
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
@ -1937,9 +1948,277 @@ class PackageBasedModel(EPModel):
self.build_model()
@abstractmethod
def build_model(self):
def _fit_model(self, ds: Dataset):
pass
@abstractmethod
def _model_args(self) -> Dict[str, Any]:
pass
def build_model(self):
self.model_status = self.BUILDING
self.save()
ds = self.load_dataset()
mod = self._fit_model(ds)
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f)
if self.app_domain is not None:
logger.debug("Building applicability domain...")
self.app_domain.build()
logger.debug("Done building applicability domain.")
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING
self.save()
def train_func(X, y, train_index, model_kwargs):
clz = model_kwargs.pop('clz')
if clz == 'RuleBaseRelativeReasoning':
mod = RelativeReasoning(
**model_kwargs
)
else:
mod = EnsembleClassifierChain(
**model_kwargs
)
if train_index is not None:
X, y = X[train_index], y[train_index]
mod.fit(X, y)
return mod
def evaluate_sg(model, X, y, test_index, threshold):
X_test = X[test_index]
y_test = y[test_index]
y_pred = model.predict_proba(X_test)
y_thresholded = (y_pred >= threshold).astype(int)
# Flatten them to get rid of np.nan
y_test = np.asarray(y_test).flatten()
y_pred = np.asarray(y_pred).flatten()
y_thresholded = np.asarray(y_thresholded).flatten()
mask = ~np.isnan(y_test)
y_test_filtered = y_test[mask]
y_pred_filtered = y_pred[mask]
y_thresholded_filtered = y_thresholded[mask]
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
prec, rec = dict(), dict()
for t in np.arange(0, 1.05, 0.05):
temp_thresholded = (y_pred_filtered >= t).astype(int)
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
return acc, prec, rec
def evaluate_mg(model, pathways: Union[QuerySet['Pathway']| List['Pathway']], threshold):
thresholds = np.arange(0.1, 1.1, 0.1)
precision = {f"{t:.2f}": [] for t in thresholds}
recall = {f"{t:.2f}": [] for t in thresholds}
# Note: only one root compound supported at this time
root_compounds = [[p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways]
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
# pass it to the setting used in prediction
if isinstance(self, MLRelativeReasoning):
mod = MLRelativeReasoning.objects.get(pk=self.pk)
elif isinstance(self, RuleBasedRelativeReasoning):
mod = RuleBasedRelativeReasoning.objects.get(pk=self.pk)
mod.model = model
s = Setting()
s.model = mod
s.model_threshold = thresholds.min()
s.max_depth = 10
s.max_nodes = 50
from epdb.logic import SPathway
from utilities.ml import multigen_eval
pred_pathways = []
for i, root in enumerate(root_compounds):
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
spw = SPathway(root_nodes=root, prediction_setting=s)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
pred_pathways.append(spw)
mg_acc = 0.0
for t in thresholds:
for true, pred in zip(test_pathways, pred_pathways):
acc, pre, rec = multigen_eval(true, pred, t)
if abs(t - threshold) < 0.01:
mg_acc = acc
precision[f"{t:.2f}"].append(pre)
recall[f"{t:.2f}"].append(rec)
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
return mg_acc, precision, recall
ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X))
from joblib import Parallel, delayed
models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits)
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits))
def compute_averages(data):
num_items = len(data)
avg_first_item = sum(item[0] for item in data) / num_items
sum_dict2 = defaultdict(float)
sum_dict3 = defaultdict(float)
for _, dict2, dict3 in data:
for key in dict2:
sum_dict2[key] += dict2[key]
for key in dict3:
sum_dict3[key] += dict3[key]
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
return {
"average_accuracy": float(avg_first_item),
"average_precision_per_threshold": avg_dict2,
"average_recall_per_threshold": avg_dict3
}
self.eval_results = compute_averages(evaluations)
if self.multigen_eval:
# We have to consider 2 cases here:
# 1. No eval packages provided -> Split Train data X times and train and evaluate model
# 2. eval packages provided -> Use the already built model and do evaluation on the set provided.
if self.eval_packages.count() > 0:
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
evaluate_mg(self.model, pathway_qs, self.threshold)
return
pathway_qs = Pathway.objects.prefetch_related(
'node_set',
'node_set__out_edges',
'node_set__default_node_label',
'node_set__scenarios',
'edge_set',
'edge_set__start_nodes',
'edge_set__end_nodes',
'edge_set__edge_label',
'edge_set__scenarios'
).filter(package__in=self.data_packages.all()).distinct()
pathways = []
for pathway in pathway_qs:
# There is one pathway with no root compounds, so this check is required
if len(pathway.root_nodes) > 0:
pathways.append(pathway)
else:
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
# build lookup reaction -> {uuid1, uuid2} for overlap check
reaction_to_educts = defaultdict(set)
for pathway in pathways:
for reaction in pathway.edges:
for e in reaction.edge_label.educts.all():
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
# build lookup to avoid recalculation of features, labels
id_to_index = {uuid: i for i, uuid in enumerate(ds[:, 0])}
# Compute splits of the collected pathway
splits = []
for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways):
train_pathways = [pathways[i] for i in train]
test_pathways = [pathways[i] for i in test]
# Collect structures from test pathways
test_educts = set()
for pathway in test_pathways:
for reaction in pathway.edges:
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
split_ids = []
overlap = 0
# Collect indices of the structures contained in train pathways iff they're not present in any of
# the test pathways
for pathway in train_pathways:
for reaction in pathway.edges:
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
# Ensure compounds in the training set do not appear in the test set
if educt not in test_educts:
if educt in id_to_index:
split_ids.append(id_to_index[str(educt)])
try:
split_ids.append(id_to_index[str(educt)])
except KeyError:
logger.debug(f"Couldn't find features in X for compound {educt}")
else:
overlap += 1
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
# Get the rows from the dataset corresponding to compounds in the training set pathways
split_x, split_y = X[split_ids], y[split_ids]
splits.append([(split_x, split_y), test_pathways])
# Build model on subsets obtained by pathway split
trained_models = Parallel(n_jobs=10)(
delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits
)
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
multi_ret_vals = Parallel(n_jobs=1)(
delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in
zip(trained_models, splits)
)
self.eval_results.update({f"multigen_{k}": v for k, v in compute_averages(multi_ret_vals).items()})
self.model_status = self.FINISHED
self.save()
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
@ -2011,21 +2290,22 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return rbrr
def build_model(self):
self.model_status = self.BUILDING
self.save()
def _fit_model(self, ds: Dataset):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning(
start_index= ds.triggered()[0],
end_index= ds.triggered()[1],
)
model.fit(X, y)
return model
def _model_args(self):
ds = self.load_dataset()
labels = ds.y(na_replacement=None)
mod = RelativeReasoning(*ds.triggered())
mod.fit(ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None))
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f)
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
return {
'clz': 'RuleBaseRelativeReasoning',
'start_index': ds.triggered()[0],
'end_index': ds.triggered()[1],
}
@cached_property
def model(self) -> 'RelativeReasoning':
@ -2038,7 +2318,6 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
mod = self.model
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
@ -2102,118 +2381,23 @@ class MLRelativeReasoning(PackageBasedModel):
return mlrr
def build_model(self):
self.model_status = self.BUILDING
self.save()
start = datetime.now()
ds = self.load_dataset()
def _fit_model(self, ds: Dataset):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
mod = EnsembleClassifierChain(
model = EnsembleClassifierChain(
**s.DEFAULT_MODEL_PARAMS
)
mod.fit(X, y)
model.fit(X, y)
return model
end = datetime.now()
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f)
if self.app_domain is not None:
logger.debug("Building applicability domain...")
self.app_domain.build()
logger.debug("Done building applicability domain.")
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING
self.save()
ds = self.load_dataset()
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
def train_and_evaluate(X, y, train_index, test_index, threshold):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model = EnsembleClassifierChain(
**s.DEFAULT_MODEL_PARAMS
)
model.fit(X_train, y_train)
y_pred = model.predict_proba(X_test)
y_thresholded = (y_pred >= threshold).astype(int)
# Flatten them to get rid of np.nan
y_test = np.asarray(y_test).flatten()
y_pred = np.asarray(y_pred).flatten()
y_thresholded = np.asarray(y_thresholded).flatten()
mask = ~np.isnan(y_test)
y_test_filtered = y_test[mask]
y_pred_filtered = y_pred[mask]
y_thresholded_filtered = y_thresholded[mask]
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
prec, rec = dict(), dict()
for t in np.arange(0, 1.05, 0.05):
temp_thresholded = (y_pred_filtered >= t).astype(int)
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
return acc, prec, rec
from joblib import Parallel, delayed
ret_vals = Parallel(n_jobs=10)(
delayed(train_and_evaluate)(X, y, train_index, test_index, self.threshold)
for train_index, test_index in shuff.split(X)
)
def compute_averages(data):
num_items = len(data)
avg_first_item = sum(item[0] for item in data) / num_items
sum_dict2 = defaultdict(float)
sum_dict3 = defaultdict(float)
for _, dict2, dict3 in data:
for key in dict2:
sum_dict2[key] += dict2[key]
for key in dict3:
sum_dict3[key] += dict3[key]
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
return {
"average_accuracy": float(avg_first_item),
"average_precision_per_threshold": avg_dict2,
"average_recall_per_threshold": avg_dict3
}
self.eval_results = compute_averages(ret_vals)
self.model_status = self.FINISHED
self.save()
def _model_args(self):
return {
'clz': 'MLRelativeReasoning',
**s.DEFAULT_MODEL_PARAMS,
}
@cached_property
def model(self):
def model(self) -> 'EnsembleClassifierChain':
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
mod.base_clf.n_jobs = -1
return mod
@ -2230,24 +2414,6 @@ class MLRelativeReasoning(PackageBasedModel):
logger.info(f"Full predict took {(end - start).total_seconds()}s")
return res
@property
def pr_curve(self):
if self.model_status != self.FINISHED:
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
res = []
thresholds = self.eval_results['average_precision_per_threshold'].keys()
for t in thresholds:
res.append({
'precision': self.eval_results['average_precision_per_threshold'][t],
'recall': self.eval_results['average_recall_per_threshold'][t],
'threshold': float(t)
})
return res
class ApplicabilityDomain(EnviPathModel):
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)