new RuleBasedDataset and EnviFormer dataset working for respective models #120

This commit is contained in:
Liam Brydon
2025-11-04 10:58:16 +13:00
parent ff51e48f90
commit ac5d370b18
5 changed files with 126 additions and 101 deletions

View File

@ -2225,7 +2225,7 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self):
def evaluate_model(self, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -2354,18 +2354,18 @@ class PackageBasedModel(EPModel):
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X))
from joblib import Parallel, delayed
models = Parallel(n_jobs=10)(
models = Parallel(n_jobs=min(10, len(splits)))(
delayed(train_func)(X, y, train_index, self._model_args())
for train_index, _ in splits
)
evaluations = Parallel(n_jobs=10)(
evaluations = Parallel(n_jobs=min(10, len(splits)))(
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits)
)
@ -2716,7 +2716,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X())
pred = self.model.predict_proba(np.array(classify_ds.X()))
res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0]
@ -3096,7 +3096,7 @@ class EnviFormer(PackageBasedModel):
ds.save(f)
return ds
def load_dataset(self) -> "RuleBasedDataset":
def load_dataset(self):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
return EnviFormerDataset.load(ds_path)
@ -3105,7 +3105,7 @@ class EnviFormer(PackageBasedModel):
from enviformer.finetune import fine_tune
start = datetime.now()
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model
@ -3121,19 +3121,19 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"}
return args
def evaluate_model(self):
def evaluate_model(self, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING
self.save()
def evaluate_sg(test_reactions, predictions, model_thresh):
def evaluate_sg(test_ds, predictions, model_thresh):
# Group the true products of reactions with the same reactant together
assert len(test_reactions) == len(predictions)
assert len(test_ds) == len(predictions)
true_dict = {}
for r in test_reactions:
reactant, true_product_set = r.split(">>")
for r in test_ds:
reactant, true_product_set = r
true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
@ -3141,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
pred_dict = {}
for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_reactions[k].split(">>")
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split("."))
@ -3176,7 +3176,7 @@ class EnviFormer(PackageBasedModel):
break
# Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted)
prec = {
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
@ -3257,30 +3257,30 @@ class EnviFormer(PackageBasedModel):
if self.eval_packages.count() > 0:
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
package__in=self.eval_packages.all()).distinct())
test_result = self.model.predict_batch(ds)
test_result = self.model.predict_batch(ds.X())
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
from enviformer.finetune import fine_tune
ds = self.load_dataset()
n_splits = 20
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint.
single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index]
test = [ds[i] for i in test_index]
train = ds[train_index]
test = ds[test_index]
start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
)
model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
test_result = model.predict_batch(test.X())
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results)
@ -3351,31 +3351,15 @@ class EnviFormer(PackageBasedModel):
for pathway in train_pathways:
for reaction in pathway.edges:
reaction = reaction.edge_label
if any(
[
educt in test_educts
for educt in reaction_to_educts[str(reaction.uuid)]
]
):
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
overlap += 1
continue
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
train_reactions.append(f"{educts}>>{products}")
train_reactions.append(reaction)
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
)
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update(

View File

@ -1,8 +1,9 @@
import os.path
from tempfile import TemporaryDirectory
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.ml import RuleBasedDataset
from utilities.ml import RuleBasedDataset, EnviFormerDataset
class DatasetTest(TestCase):
@ -45,11 +46,11 @@ class DatasetTest(TestCase):
def test_generate_dataset(self):
"""Test generating dataset does not crash"""
self.generate_dataset()
self.generate_rule_dataset()
def test_indexing(self):
"""Test indexing a few different ways to check for crashes"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds[5])
print(ds[2, 5])
print(ds[3:6, 2:8])
@ -57,45 +58,45 @@ class DatasetTest(TestCase):
def test_add_rows(self):
"""Test adding one row and adding multiple rows"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
ds.add_row(list(ds.df.row(1)))
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
def test_times_triggered(self):
"""Check getting times triggered for a rule id"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds.times_triggered(rules[0].uuid))
def test_block_indices(self):
"""Test the usages of _block_indices"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds.struct_features())
print(ds.triggered())
print(ds.observed())
def test_structure_id(self):
"""Check getting a structure id from row index"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds.structure_id(0))
def test_x(self):
"""Test getting X portion of the dataframe"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds.X().df.head())
def test_trig(self):
"""Test getting the triggered portion of the dataframe"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds.trig().df.head())
def test_y(self):
"""Test getting the Y portion of the dataframe"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
print(ds.y().df.head())
def test_classification_dataset(self):
"""Test making the classification dataset"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
class_ds, products = ds.classification_dataset(compounds, rules)
print(class_ds.df.head(5))
@ -103,12 +104,16 @@ class DatasetTest(TestCase):
def test_to_arff(self):
"""Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_dataset()
ds, reactions, rules = self.generate_rule_dataset()
ds.to_arff("dataset_arff_test.arff")
def test_save_load(self):
"""Test saving and loading dataset"""
ds, reactions, rules = self.generate_dataset()
with TemporaryDirectory() as tmpdir:
ds, reactions, rules = self.generate_rule_dataset()
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
self.assertTrue(ds.df.equals(ds_loaded.df))
def test_dataset_example(self):
"""Test with a concrete example checking dataset size"""
@ -120,9 +125,19 @@ class DatasetTest(TestCase):
self.assertEqual(len(ds.y()), 1)
self.assertEqual(ds.y().df.item(), 1)
def generate_dataset(self):
def test_enviformer_dataset(self):
ds, reactions = self.generate_enviformer_dataset()
print(ds.X().head())
print(ds.y().head())
def generate_rule_dataset(self):
"""Generate a RuleBasedDataset from test package data"""
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
return ds, reactions, applicable_rules
def generate_enviformer_dataset(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
ds = EnviFormerDataset.generate_dataset(reactions)
return ds, reactions

View File

@ -50,7 +50,7 @@ class EnviFormerTest(TestCase):
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model()
mod.evaluate_model(n_splits=2)
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -4,7 +4,7 @@ import numpy as np
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, Package
from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
class ModelTest(TestCase):
@ -17,7 +17,7 @@ class ModelTest(TestCase):
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self):
def test_mlrr(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
@ -36,23 +36,11 @@ class ModelTest(TestCase):
description="Created MLRelativeReasoning in Testcase",
)
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model()
mod.evaluate_model(n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
@ -73,3 +61,32 @@ class ModelTest(TestCase):
# from pprint import pprint
# pprint(mod.eval_results)
def test_rbrr(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = RuleBasedRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold=threshold,
min_count=5,
max_count=0,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model(n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -65,6 +65,10 @@ class Dataset(ABC):
"""Use the polars dataframe columns"""
return self.df.columns
@property
def shape(self):
return self.df.shape
@abstractmethod
def X(self, **kwargs):
pass
@ -123,13 +127,15 @@ class Dataset(ABC):
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c])
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_features: int = len(self.columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
@ -142,21 +148,25 @@ class RuleBasedDataset(Dataset):
return self._observed
def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.df.item(index, "structure_id")
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered[0]: self._triggered[1]]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, len(self.columns) - self.num_labels:]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
@ -164,7 +174,7 @@ class RuleBasedDataset(Dataset):
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
_structures = set()
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
@ -254,8 +264,12 @@ class RuleBasedDataset(Dataset):
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
if isinstance(structures[0], str):
struct_smiles = structures[0]
else:
struct_smiles = structures[0].smiles
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures:
@ -306,43 +320,38 @@ class RuleBasedDataset(Dataset):
class EnviFormerDataset(Dataset):
def __init__(self, educts, products):
super().__init__()
assert len(educts) == len(products), "Can't have unequal length educts and products"
def __init__(self, columns=None, data=None):
super().__init__(columns, data)
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
educts = []
products = []
# Standardise reactions for the training data
stereo = kwargs.get("stereo", False)
rows = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all()
]
)
educts.append(e)
products.append(p)
return EnviFormerDataset(educts, products)
def X(self):
pass
def y(self):
pass
def __getitem__(self, item):
pass
def __len__(self):
pass
rows.append([e, p])
ds = EnviFormerDataset(["educts", "products"], rows)
return ds
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -524,7 +533,7 @@ class EnsembleClassifierChain:
self.classifiers = []
if self.num_labels is None:
self.num_labels = len(Y[0])
self.num_labels = Y.shape[1]
for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
@ -555,7 +564,7 @@ class RelativeReasoning:
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = len(Y[0])
n_attributes = Y.shape[1]
for i in range(n_attributes):
for j in range(n_attributes):
@ -567,8 +576,8 @@ class RelativeReasoning:
countboth = 0
for k in range(n_instances):
vi = Y[k][i]
vj = Y[k][j]
vi = Y[k, i]
vj = Y[k, j]
if vi is None or vj is None:
continue