forked from enviPath/enviPy
new RuleBasedDataset and EnviFormer dataset working for respective models #120
This commit is contained in:
@ -2225,7 +2225,7 @@ class PackageBasedModel(EPModel):
|
|||||||
self.model_status = self.BUILT_NOT_EVALUATED
|
self.model_status = self.BUILT_NOT_EVALUATED
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def evaluate_model(self):
|
def evaluate_model(self, **kwargs):
|
||||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||||
|
|
||||||
@ -2354,18 +2354,18 @@ class PackageBasedModel(EPModel):
|
|||||||
X = np.array(ds.X(na_replacement=np.nan))
|
X = np.array(ds.X(na_replacement=np.nan))
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
|
|
||||||
n_splits = 20
|
n_splits = kwargs.get("n_splits", 20)
|
||||||
|
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||||
splits = list(shuff.split(X))
|
splits = list(shuff.split(X))
|
||||||
|
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
models = Parallel(n_jobs=10)(
|
models = Parallel(n_jobs=min(10, len(splits)))(
|
||||||
delayed(train_func)(X, y, train_index, self._model_args())
|
delayed(train_func)(X, y, train_index, self._model_args())
|
||||||
for train_index, _ in splits
|
for train_index, _ in splits
|
||||||
)
|
)
|
||||||
evaluations = Parallel(n_jobs=10)(
|
evaluations = Parallel(n_jobs=min(10, len(splits)))(
|
||||||
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||||
for model, (_, test_index) in zip(models, splits)
|
for model, (_, test_index) in zip(models, splits)
|
||||||
)
|
)
|
||||||
@ -2716,7 +2716,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||||
pred = self.model.predict_proba(classify_ds.X())
|
pred = self.model.predict_proba(np.array(classify_ds.X()))
|
||||||
|
|
||||||
res = MLRelativeReasoning.combine_products_and_probs(
|
res = MLRelativeReasoning.combine_products_and_probs(
|
||||||
self.applicable_rules, pred[0], classify_prods[0]
|
self.applicable_rules, pred[0], classify_prods[0]
|
||||||
@ -3096,7 +3096,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
ds.save(f)
|
ds.save(f)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "RuleBasedDataset":
|
def load_dataset(self):
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||||
return EnviFormerDataset.load(ds_path)
|
return EnviFormerDataset.load(ds_path)
|
||||||
|
|
||||||
@ -3105,7 +3105,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
from enviformer.finetune import fine_tune
|
from enviformer.finetune import fine_tune
|
||||||
|
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||||
return model
|
return model
|
||||||
@ -3121,19 +3121,19 @@ class EnviFormer(PackageBasedModel):
|
|||||||
args = {"clz": "EnviFormer"}
|
args = {"clz": "EnviFormer"}
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def evaluate_model(self):
|
def evaluate_model(self, **kwargs):
|
||||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||||
|
|
||||||
self.model_status = self.EVALUATING
|
self.model_status = self.EVALUATING
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
def evaluate_sg(test_ds, predictions, model_thresh):
|
||||||
# Group the true products of reactions with the same reactant together
|
# Group the true products of reactions with the same reactant together
|
||||||
assert len(test_reactions) == len(predictions)
|
assert len(test_ds) == len(predictions)
|
||||||
true_dict = {}
|
true_dict = {}
|
||||||
for r in test_reactions:
|
for r in test_ds:
|
||||||
reactant, true_product_set = r.split(">>")
|
reactant, true_product_set = r
|
||||||
true_product_set = {p for p in true_product_set.split(".")}
|
true_product_set = {p for p in true_product_set.split(".")}
|
||||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||||
|
|
||||||
@ -3141,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
pred_dict = {}
|
pred_dict = {}
|
||||||
for k, pred in enumerate(predictions):
|
for k, pred in enumerate(predictions):
|
||||||
pred_smiles, pred_proba = zip(*pred.items())
|
pred_smiles, pred_proba = zip(*pred.items())
|
||||||
reactant, true_product = test_reactions[k].split(">>")
|
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
|
||||||
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
||||||
for smiles, proba in zip(pred_smiles, pred_proba):
|
for smiles, proba in zip(pred_smiles, pred_proba):
|
||||||
smiles = set(smiles.split("."))
|
smiles = set(smiles.split("."))
|
||||||
@ -3176,7 +3176,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
||||||
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
|
||||||
# Precision is TP (correct) / TP + FP (predicted)
|
# Precision is TP (correct) / TP + FP (predicted)
|
||||||
prec = {
|
prec = {
|
||||||
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
||||||
@ -3257,30 +3257,30 @@ class EnviFormer(PackageBasedModel):
|
|||||||
if self.eval_packages.count() > 0:
|
if self.eval_packages.count() > 0:
|
||||||
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||||
package__in=self.eval_packages.all()).distinct())
|
package__in=self.eval_packages.all()).distinct())
|
||||||
test_result = self.model.predict_batch(ds)
|
test_result = self.model.predict_batch(ds.X())
|
||||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||||
self.eval_results = self.compute_averages([single_gen_result])
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
else:
|
else:
|
||||||
from enviformer.finetune import fine_tune
|
from enviformer.finetune import fine_tune
|
||||||
|
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
n_splits = 20
|
n_splits = kwargs.get("n_splits", 20)
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||||
|
|
||||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||||
# this helps reduce the memory footprint.
|
# this helps reduce the memory footprint.
|
||||||
single_gen_results = []
|
single_gen_results = []
|
||||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||||
train = [ds[i] for i in train_index]
|
train = ds[train_index]
|
||||||
test = [ds[i] for i in test_index]
|
test = ds[test_index]
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
||||||
)
|
)
|
||||||
model.to(s.ENVIFORMER_DEVICE)
|
model.to(s.ENVIFORMER_DEVICE)
|
||||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
test_result = model.predict_batch(test.X())
|
||||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||||
|
|
||||||
self.eval_results = self.compute_averages(single_gen_results)
|
self.eval_results = self.compute_averages(single_gen_results)
|
||||||
@ -3351,31 +3351,15 @@ class EnviFormer(PackageBasedModel):
|
|||||||
for pathway in train_pathways:
|
for pathway in train_pathways:
|
||||||
for reaction in pathway.edges:
|
for reaction in pathway.edges:
|
||||||
reaction = reaction.edge_label
|
reaction = reaction.edge_label
|
||||||
if any(
|
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||||
[
|
|
||||||
educt in test_educts
|
|
||||||
for educt in reaction_to_educts[str(reaction.uuid)]
|
|
||||||
]
|
|
||||||
):
|
|
||||||
overlap += 1
|
overlap += 1
|
||||||
continue
|
continue
|
||||||
educts = ".".join(
|
train_reactions.append(reaction)
|
||||||
[
|
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.educts.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
products = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.products.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
train_reactions.append(f"{educts}>>{products}")
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
||||||
)
|
)
|
||||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
|
||||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||||
|
|
||||||
self.eval_results.update(
|
self.eval_results.update(
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
|
import os.path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import Reaction, Compound, User, Rule, Package
|
from epdb.models import Reaction, Compound, User, Rule, Package
|
||||||
from utilities.ml import RuleBasedDataset
|
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
||||||
|
|
||||||
|
|
||||||
class DatasetTest(TestCase):
|
class DatasetTest(TestCase):
|
||||||
@ -45,11 +46,11 @@ class DatasetTest(TestCase):
|
|||||||
|
|
||||||
def test_generate_dataset(self):
|
def test_generate_dataset(self):
|
||||||
"""Test generating dataset does not crash"""
|
"""Test generating dataset does not crash"""
|
||||||
self.generate_dataset()
|
self.generate_rule_dataset()
|
||||||
|
|
||||||
def test_indexing(self):
|
def test_indexing(self):
|
||||||
"""Test indexing a few different ways to check for crashes"""
|
"""Test indexing a few different ways to check for crashes"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds[5])
|
print(ds[5])
|
||||||
print(ds[2, 5])
|
print(ds[2, 5])
|
||||||
print(ds[3:6, 2:8])
|
print(ds[3:6, 2:8])
|
||||||
@ -57,45 +58,45 @@ class DatasetTest(TestCase):
|
|||||||
|
|
||||||
def test_add_rows(self):
|
def test_add_rows(self):
|
||||||
"""Test adding one row and adding multiple rows"""
|
"""Test adding one row and adding multiple rows"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
ds.add_row(list(ds.df.row(1)))
|
ds.add_row(list(ds.df.row(1)))
|
||||||
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
|
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
|
||||||
|
|
||||||
def test_times_triggered(self):
|
def test_times_triggered(self):
|
||||||
"""Check getting times triggered for a rule id"""
|
"""Check getting times triggered for a rule id"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds.times_triggered(rules[0].uuid))
|
print(ds.times_triggered(rules[0].uuid))
|
||||||
|
|
||||||
def test_block_indices(self):
|
def test_block_indices(self):
|
||||||
"""Test the usages of _block_indices"""
|
"""Test the usages of _block_indices"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds.struct_features())
|
print(ds.struct_features())
|
||||||
print(ds.triggered())
|
print(ds.triggered())
|
||||||
print(ds.observed())
|
print(ds.observed())
|
||||||
|
|
||||||
def test_structure_id(self):
|
def test_structure_id(self):
|
||||||
"""Check getting a structure id from row index"""
|
"""Check getting a structure id from row index"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds.structure_id(0))
|
print(ds.structure_id(0))
|
||||||
|
|
||||||
def test_x(self):
|
def test_x(self):
|
||||||
"""Test getting X portion of the dataframe"""
|
"""Test getting X portion of the dataframe"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds.X().df.head())
|
print(ds.X().df.head())
|
||||||
|
|
||||||
def test_trig(self):
|
def test_trig(self):
|
||||||
"""Test getting the triggered portion of the dataframe"""
|
"""Test getting the triggered portion of the dataframe"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds.trig().df.head())
|
print(ds.trig().df.head())
|
||||||
|
|
||||||
def test_y(self):
|
def test_y(self):
|
||||||
"""Test getting the Y portion of the dataframe"""
|
"""Test getting the Y portion of the dataframe"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
print(ds.y().df.head())
|
print(ds.y().df.head())
|
||||||
|
|
||||||
def test_classification_dataset(self):
|
def test_classification_dataset(self):
|
||||||
"""Test making the classification dataset"""
|
"""Test making the classification dataset"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
|
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
|
||||||
class_ds, products = ds.classification_dataset(compounds, rules)
|
class_ds, products = ds.classification_dataset(compounds, rules)
|
||||||
print(class_ds.df.head(5))
|
print(class_ds.df.head(5))
|
||||||
@ -103,12 +104,16 @@ class DatasetTest(TestCase):
|
|||||||
|
|
||||||
def test_to_arff(self):
|
def test_to_arff(self):
|
||||||
"""Test exporting the arff version of the dataset"""
|
"""Test exporting the arff version of the dataset"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
ds.to_arff("dataset_arff_test.arff")
|
ds.to_arff("dataset_arff_test.arff")
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
"""Test saving and loading dataset"""
|
"""Test saving and loading dataset"""
|
||||||
ds, reactions, rules = self.generate_dataset()
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||||
|
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||||
|
self.assertTrue(ds.df.equals(ds_loaded.df))
|
||||||
|
|
||||||
def test_dataset_example(self):
|
def test_dataset_example(self):
|
||||||
"""Test with a concrete example checking dataset size"""
|
"""Test with a concrete example checking dataset size"""
|
||||||
@ -120,9 +125,19 @@ class DatasetTest(TestCase):
|
|||||||
self.assertEqual(len(ds.y()), 1)
|
self.assertEqual(len(ds.y()), 1)
|
||||||
self.assertEqual(ds.y().df.item(), 1)
|
self.assertEqual(ds.y().df.item(), 1)
|
||||||
|
|
||||||
def generate_dataset(self):
|
def test_enviformer_dataset(self):
|
||||||
|
ds, reactions = self.generate_enviformer_dataset()
|
||||||
|
print(ds.X().head())
|
||||||
|
print(ds.y().head())
|
||||||
|
|
||||||
|
def generate_rule_dataset(self):
|
||||||
"""Generate a RuleBasedDataset from test package data"""
|
"""Generate a RuleBasedDataset from test package data"""
|
||||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||||
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||||
return ds, reactions, applicable_rules
|
return ds, reactions, applicable_rules
|
||||||
|
|
||||||
|
def generate_enviformer_dataset(self):
|
||||||
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
ds = EnviFormerDataset.generate_dataset(reactions)
|
||||||
|
return ds, reactions
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class EnviFormerTest(TestCase):
|
|||||||
mod.build_model()
|
mod.build_model()
|
||||||
mod.multigen_eval = True
|
mod.multigen_eval = True
|
||||||
mod.save()
|
mod.save()
|
||||||
mod.evaluate_model()
|
mod.evaluate_model(n_splits=2)
|
||||||
|
|
||||||
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import User, MLRelativeReasoning, Package
|
from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
|
||||||
|
|
||||||
|
|
||||||
class ModelTest(TestCase):
|
class ModelTest(TestCase):
|
||||||
@ -17,7 +17,7 @@ class ModelTest(TestCase):
|
|||||||
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
||||||
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
||||||
|
|
||||||
def test_smoke(self):
|
def test_mlrr(self):
|
||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
with self.settings(MODEL_DIR=tmpdir):
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
@ -36,23 +36,11 @@ class ModelTest(TestCase):
|
|||||||
description="Created MLRelativeReasoning in Testcase",
|
description="Created MLRelativeReasoning in Testcase",
|
||||||
)
|
)
|
||||||
|
|
||||||
# mod = RuleBasedRelativeReasoning.create(
|
|
||||||
# self.package,
|
|
||||||
# rule_package_objs,
|
|
||||||
# data_package_objs,
|
|
||||||
# eval_packages_objs,
|
|
||||||
# threshold=threshold,
|
|
||||||
# min_count=5,
|
|
||||||
# max_count=0,
|
|
||||||
# name='ECC - BBD - 0.5',
|
|
||||||
# description='Created MLRelativeReasoning in Testcase',
|
|
||||||
# )
|
|
||||||
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
mod.build_model()
|
mod.build_model()
|
||||||
mod.multigen_eval = True
|
mod.multigen_eval = True
|
||||||
mod.save()
|
mod.save()
|
||||||
mod.evaluate_model()
|
mod.evaluate_model(n_splits=2)
|
||||||
|
|
||||||
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
@ -73,3 +61,32 @@ class ModelTest(TestCase):
|
|||||||
|
|
||||||
# from pprint import pprint
|
# from pprint import pprint
|
||||||
# pprint(mod.eval_results)
|
# pprint(mod.eval_results)
|
||||||
|
|
||||||
|
def test_rbrr(self):
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
|
threshold = float(0.5)
|
||||||
|
|
||||||
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
|
||||||
|
mod = RuleBasedRelativeReasoning.create(
|
||||||
|
self.package,
|
||||||
|
rule_package_objs,
|
||||||
|
data_package_objs,
|
||||||
|
eval_packages_objs,
|
||||||
|
threshold=threshold,
|
||||||
|
min_count=5,
|
||||||
|
max_count=0,
|
||||||
|
name='ECC - BBD - 0.5',
|
||||||
|
description='Created MLRelativeReasoning in Testcase',
|
||||||
|
)
|
||||||
|
|
||||||
|
mod.build_dataset()
|
||||||
|
mod.build_model()
|
||||||
|
mod.multigen_eval = True
|
||||||
|
mod.save()
|
||||||
|
mod.evaluate_model(n_splits=2)
|
||||||
|
|
||||||
|
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|||||||
@ -65,6 +65,10 @@ class Dataset(ABC):
|
|||||||
"""Use the polars dataframe columns"""
|
"""Use the polars dataframe columns"""
|
||||||
return self.df.columns
|
return self.df.columns
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.df.shape
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def X(self, **kwargs):
|
def X(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -123,13 +127,15 @@ class Dataset(ABC):
|
|||||||
class RuleBasedDataset(Dataset):
|
class RuleBasedDataset(Dataset):
|
||||||
def __init__(self, num_labels=None, columns=None, data=None):
|
def __init__(self, num_labels=None, columns=None, data=None):
|
||||||
super().__init__(columns, data)
|
super().__init__(columns, data)
|
||||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c])
|
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||||
|
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||||
self.num_features: int = len(self.columns) - self.num_labels
|
self.num_features: int = len(self.columns) - self.num_labels
|
||||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||||
|
|
||||||
def times_triggered(self, rule_uuid) -> int:
|
def times_triggered(self, rule_uuid) -> int:
|
||||||
|
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
||||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||||
|
|
||||||
def struct_features(self) -> Tuple[int, int]:
|
def struct_features(self) -> Tuple[int, int]:
|
||||||
@ -142,21 +148,25 @@ class RuleBasedDataset(Dataset):
|
|||||||
return self._observed
|
return self._observed
|
||||||
|
|
||||||
def structure_id(self, index: int):
|
def structure_id(self, index: int):
|
||||||
|
"""Get the UUID of a compound"""
|
||||||
return self.df.item(index, "structure_id")
|
return self.df.item(index, "structure_id")
|
||||||
|
|
||||||
def X(self, exclude_id_col=True, na_replacement=0):
|
def X(self, exclude_id_col=True, na_replacement=0):
|
||||||
|
"""Get all the feature and trig columns"""
|
||||||
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
|
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
|
||||||
if na_replacement is not None:
|
if na_replacement is not None:
|
||||||
res.df = res.df.fill_null(na_replacement)
|
res.df = res.df.fill_null(na_replacement)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def trig(self, na_replacement=0):
|
def trig(self, na_replacement=0):
|
||||||
|
"""Get all the trig columns"""
|
||||||
res = self[:, self._triggered[0]: self._triggered[1]]
|
res = self[:, self._triggered[0]: self._triggered[1]]
|
||||||
if na_replacement is not None:
|
if na_replacement is not None:
|
||||||
res.df = res.df.fill_null(na_replacement)
|
res.df = res.df.fill_null(na_replacement)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def y(self, na_replacement=0):
|
def y(self, na_replacement=0):
|
||||||
|
"""Get all the obs columns"""
|
||||||
res = self[:, len(self.columns) - self.num_labels:]
|
res = self[:, len(self.columns) - self.num_labels:]
|
||||||
if na_replacement is not None:
|
if na_replacement is not None:
|
||||||
res.df = res.df.fill_null(na_replacement)
|
res.df = res.df.fill_null(na_replacement)
|
||||||
@ -164,7 +174,7 @@ class RuleBasedDataset(Dataset):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
||||||
_structures = set()
|
_structures = set() # Get all the structures
|
||||||
for r in reactions:
|
for r in reactions:
|
||||||
_structures.update(r.educts.all())
|
_structures.update(r.educts.all())
|
||||||
if not educts_only:
|
if not educts_only:
|
||||||
@ -254,8 +264,12 @@ class RuleBasedDataset(Dataset):
|
|||||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||||
classify_data = []
|
classify_data = []
|
||||||
classify_products = []
|
classify_products = []
|
||||||
|
if isinstance(structures[0], str):
|
||||||
|
struct_smiles = structures[0]
|
||||||
|
else:
|
||||||
|
struct_smiles = structures[0].smiles
|
||||||
ds_columns = (["structure_id"] +
|
ds_columns = (["structure_id"] +
|
||||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] +
|
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
|
||||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||||
for struct in structures:
|
for struct in structures:
|
||||||
@ -306,43 +320,38 @@ class RuleBasedDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class EnviFormerDataset(Dataset):
|
class EnviFormerDataset(Dataset):
|
||||||
def __init__(self, educts, products):
|
def __init__(self, columns=None, data=None):
|
||||||
super().__init__()
|
super().__init__(columns, data)
|
||||||
assert len(educts) == len(products), "Can't have unequal length educts and products"
|
|
||||||
|
def X(self):
|
||||||
|
"""Return the educts"""
|
||||||
|
return self["educts"]
|
||||||
|
|
||||||
|
def y(self):
|
||||||
|
"""Return the products"""
|
||||||
|
return self["products"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_dataset(reactions, *args, **kwargs):
|
def generate_dataset(reactions, *args, **kwargs):
|
||||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
# Standardise reactions for the training data
|
||||||
educts = []
|
stereo = kwargs.get("stereo", False)
|
||||||
products = []
|
rows = []
|
||||||
for reaction in reactions:
|
for reaction in reactions:
|
||||||
e = ".".join(
|
e = ".".join(
|
||||||
[
|
[
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||||
for smile in reaction.educts.all()
|
for smile in reaction.educts.all()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
p = ".".join(
|
p = ".".join(
|
||||||
[
|
[
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||||
for smile in reaction.products.all()
|
for smile in reaction.products.all()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
educts.append(e)
|
rows.append([e, p])
|
||||||
products.append(p)
|
ds = EnviFormerDataset(["educts", "products"], rows)
|
||||||
return EnviFormerDataset(educts, products)
|
return ds
|
||||||
|
|
||||||
def X(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def y(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||||
@ -524,7 +533,7 @@ class EnsembleClassifierChain:
|
|||||||
self.classifiers = []
|
self.classifiers = []
|
||||||
|
|
||||||
if self.num_labels is None:
|
if self.num_labels is None:
|
||||||
self.num_labels = len(Y[0])
|
self.num_labels = Y.shape[1]
|
||||||
|
|
||||||
for p in range(self.num_chains):
|
for p in range(self.num_chains):
|
||||||
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||||
@ -555,7 +564,7 @@ class RelativeReasoning:
|
|||||||
|
|
||||||
def fit(self, X, Y):
|
def fit(self, X, Y):
|
||||||
n_instances = len(Y)
|
n_instances = len(Y)
|
||||||
n_attributes = len(Y[0])
|
n_attributes = Y.shape[1]
|
||||||
|
|
||||||
for i in range(n_attributes):
|
for i in range(n_attributes):
|
||||||
for j in range(n_attributes):
|
for j in range(n_attributes):
|
||||||
@ -567,8 +576,8 @@ class RelativeReasoning:
|
|||||||
countboth = 0
|
countboth = 0
|
||||||
|
|
||||||
for k in range(n_instances):
|
for k in range(n_instances):
|
||||||
vi = Y[k][i]
|
vi = Y[k, i]
|
||||||
vj = Y[k][j]
|
vj = Y[k, j]
|
||||||
|
|
||||||
if vi is None or vj is None:
|
if vi is None or vj is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user