new RuleBasedDataset and EnviFormer dataset working for respective models #120

This commit is contained in:
Liam Brydon
2025-11-04 10:58:16 +13:00
parent ff51e48f90
commit ac5d370b18
5 changed files with 126 additions and 101 deletions

View File

@ -2225,7 +2225,7 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED self.model_status = self.BUILT_NOT_EVALUATED
self.save() self.save()
def evaluate_model(self): def evaluate_model(self, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED: if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!") raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -2354,18 +2354,18 @@ class PackageBasedModel(EPModel):
X = np.array(ds.X(na_replacement=np.nan)) X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan)) y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20 n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X)) splits = list(shuff.split(X))
from joblib import Parallel, delayed from joblib import Parallel, delayed
models = Parallel(n_jobs=10)( models = Parallel(n_jobs=min(10, len(splits)))(
delayed(train_func)(X, y, train_index, self._model_args()) delayed(train_func)(X, y, train_index, self._model_args())
for train_index, _ in splits for train_index, _ in splits
) )
evaluations = Parallel(n_jobs=10)( evaluations = Parallel(n_jobs=min(10, len(splits)))(
delayed(evaluate_sg)(model, X, y, test_index, self.threshold) delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits) for model, (_, test_index) in zip(models, splits)
) )
@ -2716,7 +2716,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now() start = datetime.now()
ds = self.load_dataset() ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X()) pred = self.model.predict_proba(np.array(classify_ds.X()))
res = MLRelativeReasoning.combine_products_and_probs( res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0] self.applicable_rules, pred[0], classify_prods[0]
@ -3096,7 +3096,7 @@ class EnviFormer(PackageBasedModel):
ds.save(f) ds.save(f)
return ds return ds
def load_dataset(self) -> "RuleBasedDataset": def load_dataset(self):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
return EnviFormerDataset.load(ds_path) return EnviFormerDataset.load(ds_path)
@ -3105,7 +3105,7 @@ class EnviFormer(PackageBasedModel):
from enviformer.finetune import fine_tune from enviformer.finetune import fine_tune
start = datetime.now() start = datetime.now()
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now() end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model return model
@ -3121,19 +3121,19 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"} args = {"clz": "EnviFormer"}
return args return args
def evaluate_model(self): def evaluate_model(self, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED: if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!") raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING self.model_status = self.EVALUATING
self.save() self.save()
def evaluate_sg(test_reactions, predictions, model_thresh): def evaluate_sg(test_ds, predictions, model_thresh):
# Group the true products of reactions with the same reactant together # Group the true products of reactions with the same reactant together
assert len(test_reactions) == len(predictions) assert len(test_ds) == len(predictions)
true_dict = {} true_dict = {}
for r in test_reactions: for r in test_ds:
reactant, true_product_set = r.split(">>") reactant, true_product_set = r
true_product_set = {p for p in true_product_set.split(".")} true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
@ -3141,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
pred_dict = {} pred_dict = {}
for k, pred in enumerate(predictions): for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items()) pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_reactions[k].split(">>") reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
pred_dict.setdefault(reactant, {"predict": [], "scores": []}) pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba): for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split(".")) smiles = set(smiles.split("."))
@ -3176,7 +3176,7 @@ class EnviFormer(PackageBasedModel):
break break
# Recall is TP (correct) / TP + FN (len(test_reactions)) # Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()} rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted) # Precision is TP (correct) / TP + FP (predicted)
prec = { prec = {
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items() f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
@ -3257,30 +3257,30 @@ class EnviFormer(PackageBasedModel):
if self.eval_packages.count() > 0: if self.eval_packages.count() > 0:
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter( ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
package__in=self.eval_packages.all()).distinct()) package__in=self.eval_packages.all()).distinct())
test_result = self.model.predict_batch(ds) test_result = self.model.predict_batch(ds.X())
single_gen_result = evaluate_sg(ds, test_result, self.threshold) single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result]) self.eval_results = self.compute_averages([single_gen_result])
else: else:
from enviformer.finetune import fine_tune from enviformer.finetune import fine_tune
ds = self.load_dataset() ds = self.load_dataset()
n_splits = 20 n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint. # this helps reduce the memory footprint.
single_gen_results = [] single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index] train = ds[train_index]
test = [ds[i] for i in test_index] test = ds[test_index]
start = datetime.now() start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now() end = datetime.now()
logger.debug( logger.debug(
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
) )
model.to(s.ENVIFORMER_DEVICE) model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) test_result = model.predict_batch(test.X())
single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results) self.eval_results = self.compute_averages(single_gen_results)
@ -3351,31 +3351,15 @@ class EnviFormer(PackageBasedModel):
for pathway in train_pathways: for pathway in train_pathways:
for reaction in pathway.edges: for reaction in pathway.edges:
reaction = reaction.edge_label reaction = reaction.edge_label
if any( if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
[
educt in test_educts
for educt in reaction_to_educts[str(reaction.uuid)]
]
):
overlap += 1 overlap += 1
continue continue
educts = ".".join( train_reactions.append(reaction)
[ train_ds = EnviFormerDataset.generate_dataset(train_reactions)
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
train_reactions.append(f"{educts}>>{products}")
logging.debug( logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
) )
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update( self.eval_results.update(

View File

@ -1,8 +1,9 @@
import os.path
from tempfile import TemporaryDirectory
from django.test import TestCase from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule, Package from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.ml import RuleBasedDataset from utilities.ml import RuleBasedDataset, EnviFormerDataset
class DatasetTest(TestCase): class DatasetTest(TestCase):
@ -45,11 +46,11 @@ class DatasetTest(TestCase):
def test_generate_dataset(self): def test_generate_dataset(self):
"""Test generating dataset does not crash""" """Test generating dataset does not crash"""
self.generate_dataset() self.generate_rule_dataset()
def test_indexing(self): def test_indexing(self):
"""Test indexing a few different ways to check for crashes""" """Test indexing a few different ways to check for crashes"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds[5]) print(ds[5])
print(ds[2, 5]) print(ds[2, 5])
print(ds[3:6, 2:8]) print(ds[3:6, 2:8])
@ -57,45 +58,45 @@ class DatasetTest(TestCase):
def test_add_rows(self): def test_add_rows(self):
"""Test adding one row and adding multiple rows""" """Test adding one row and adding multiple rows"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
ds.add_row(list(ds.df.row(1))) ds.add_row(list(ds.df.row(1)))
ds.add_rows([list(ds.df.row(i)) for i in range(5)]) ds.add_rows([list(ds.df.row(i)) for i in range(5)])
def test_times_triggered(self): def test_times_triggered(self):
"""Check getting times triggered for a rule id""" """Check getting times triggered for a rule id"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds.times_triggered(rules[0].uuid)) print(ds.times_triggered(rules[0].uuid))
def test_block_indices(self): def test_block_indices(self):
"""Test the usages of _block_indices""" """Test the usages of _block_indices"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds.struct_features()) print(ds.struct_features())
print(ds.triggered()) print(ds.triggered())
print(ds.observed()) print(ds.observed())
def test_structure_id(self): def test_structure_id(self):
"""Check getting a structure id from row index""" """Check getting a structure id from row index"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds.structure_id(0)) print(ds.structure_id(0))
def test_x(self): def test_x(self):
"""Test getting X portion of the dataframe""" """Test getting X portion of the dataframe"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds.X().df.head()) print(ds.X().df.head())
def test_trig(self): def test_trig(self):
"""Test getting the triggered portion of the dataframe""" """Test getting the triggered portion of the dataframe"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds.trig().df.head()) print(ds.trig().df.head())
def test_y(self): def test_y(self):
"""Test getting the Y portion of the dataframe""" """Test getting the Y portion of the dataframe"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
print(ds.y().df.head()) print(ds.y().df.head())
def test_classification_dataset(self): def test_classification_dataset(self):
"""Test making the classification dataset""" """Test making the classification dataset"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)] compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
class_ds, products = ds.classification_dataset(compounds, rules) class_ds, products = ds.classification_dataset(compounds, rules)
print(class_ds.df.head(5)) print(class_ds.df.head(5))
@ -103,12 +104,16 @@ class DatasetTest(TestCase):
def test_to_arff(self): def test_to_arff(self):
"""Test exporting the arff version of the dataset""" """Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_dataset() ds, reactions, rules = self.generate_rule_dataset()
ds.to_arff("dataset_arff_test.arff") ds.to_arff("dataset_arff_test.arff")
def test_save_load(self): def test_save_load(self):
"""Test saving and loading dataset""" """Test saving and loading dataset"""
ds, reactions, rules = self.generate_dataset() with TemporaryDirectory() as tmpdir:
ds, reactions, rules = self.generate_rule_dataset()
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
self.assertTrue(ds.df.equals(ds_loaded.df))
def test_dataset_example(self): def test_dataset_example(self):
"""Test with a concrete example checking dataset size""" """Test with a concrete example checking dataset size"""
@ -120,9 +125,19 @@ class DatasetTest(TestCase):
self.assertEqual(len(ds.y()), 1) self.assertEqual(len(ds.y()), 1)
self.assertEqual(ds.y().df.item(), 1) self.assertEqual(ds.y().df.item(), 1)
def generate_dataset(self): def test_enviformer_dataset(self):
ds, reactions = self.generate_enviformer_dataset()
print(ds.X().head())
print(ds.y().head())
def generate_rule_dataset(self):
"""Generate a RuleBasedDataset from test package data""" """Generate a RuleBasedDataset from test package data"""
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
return ds, reactions, applicable_rules return ds, reactions, applicable_rules
def generate_enviformer_dataset(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
ds = EnviFormerDataset.generate_dataset(reactions)
return ds, reactions

View File

@ -50,7 +50,7 @@ class EnviFormerTest(TestCase):
mod.build_model() mod.build_model()
mod.multigen_eval = True mod.multigen_eval = True
mod.save() mod.save()
mod.evaluate_model() mod.evaluate_model(n_splits=2)
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -4,7 +4,7 @@ import numpy as np
from django.test import TestCase from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, Package from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
class ModelTest(TestCase): class ModelTest(TestCase):
@ -17,7 +17,7 @@ class ModelTest(TestCase):
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures") cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self): def test_mlrr(self):
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir): with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5) threshold = float(0.5)
@ -36,23 +36,11 @@ class ModelTest(TestCase):
description="Created MLRelativeReasoning in Testcase", description="Created MLRelativeReasoning in Testcase",
) )
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
mod.build_dataset() mod.build_dataset()
mod.build_model() mod.build_model()
mod.multigen_eval = True mod.multigen_eval = True
mod.save() mod.save()
mod.evaluate_model() mod.evaluate_model(n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
@ -73,3 +61,32 @@ class ModelTest(TestCase):
# from pprint import pprint # from pprint import pprint
# pprint(mod.eval_results) # pprint(mod.eval_results)
def test_rbrr(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = RuleBasedRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold=threshold,
min_count=5,
max_count=0,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model(n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -65,6 +65,10 @@ class Dataset(ABC):
"""Use the polars dataframe columns""" """Use the polars dataframe columns"""
return self.df.columns return self.df.columns
@property
def shape(self):
return self.df.shape
@abstractmethod @abstractmethod
def X(self, **kwargs): def X(self, **kwargs):
pass pass
@ -123,13 +127,15 @@ class Dataset(ABC):
class RuleBasedDataset(Dataset): class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None): def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data) super().__init__(columns, data)
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c]) # Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_features: int = len(self.columns) - self.num_labels self.num_features: int = len(self.columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_") self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_") self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_") self._observed: Tuple[int, int] = self._block_indices("obs_")
def times_triggered(self, rule_uuid) -> int: def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]: def struct_features(self) -> Tuple[int, int]:
@ -142,21 +148,25 @@ class RuleBasedDataset(Dataset):
return self._observed return self._observed
def structure_id(self, index: int): def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.df.item(index, "structure_id") return self.df.item(index, "structure_id")
def X(self, exclude_id_col=True, na_replacement=0): def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels] res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
if na_replacement is not None: if na_replacement is not None:
res.df = res.df.fill_null(na_replacement) res.df = res.df.fill_null(na_replacement)
return res return res
def trig(self, na_replacement=0): def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered[0]: self._triggered[1]] res = self[:, self._triggered[0]: self._triggered[1]]
if na_replacement is not None: if na_replacement is not None:
res.df = res.df.fill_null(na_replacement) res.df = res.df.fill_null(na_replacement)
return res return res
def y(self, na_replacement=0): def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, len(self.columns) - self.num_labels:] res = self[:, len(self.columns) - self.num_labels:]
if na_replacement is not None: if na_replacement is not None:
res.df = res.df.fill_null(na_replacement) res.df = res.df.fill_null(na_replacement)
@ -164,7 +174,7 @@ class RuleBasedDataset(Dataset):
@staticmethod @staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True): def generate_dataset(reactions, applicable_rules, educts_only=True):
_structures = set() _structures = set() # Get all the structures
for r in reactions: for r in reactions:
_structures.update(r.educts.all()) _structures.update(r.educts.all())
if not educts_only: if not educts_only:
@ -254,8 +264,12 @@ class RuleBasedDataset(Dataset):
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = [] classify_data = []
classify_products = [] classify_products = []
if isinstance(structures[0], str):
struct_smiles = structures[0]
else:
struct_smiles = structures[0].smiles
ds_columns = (["structure_id"] + ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] + [f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] + [f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules]) [f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures: for struct in structures:
@ -306,43 +320,38 @@ class RuleBasedDataset(Dataset):
class EnviFormerDataset(Dataset): class EnviFormerDataset(Dataset):
def __init__(self, educts, products): def __init__(self, columns=None, data=None):
super().__init__() super().__init__(columns, data)
assert len(educts) == len(products), "Can't have unequal length educts and products"
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod @staticmethod
def generate_dataset(reactions, *args, **kwargs): def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently # Standardise reactions for the training data
educts = [] stereo = kwargs.get("stereo", False)
products = [] rows = []
for reaction in reactions: for reaction in reactions:
e = ".".join( e = ".".join(
[ [
FormatConverter.standardize(smile.smiles, remove_stereo=True) FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all() for smile in reaction.educts.all()
] ]
) )
p = ".".join( p = ".".join(
[ [
FormatConverter.standardize(smile.smiles, remove_stereo=True) FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all() for smile in reaction.products.all()
] ]
) )
educts.append(e) rows.append([e, p])
products.append(p) ds = EnviFormerDataset(["educts", "products"], rows)
return EnviFormerDataset(educts, products) return ds
def X(self):
pass
def y(self):
pass
def __getitem__(self, item):
pass
def __len__(self):
pass
class SparseLabelECC(BaseEstimator, ClassifierMixin): class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -524,7 +533,7 @@ class EnsembleClassifierChain:
self.classifiers = [] self.classifiers = []
if self.num_labels is None: if self.num_labels is None:
self.num_labels = len(Y[0]) self.num_labels = Y.shape[1]
for p in range(self.num_chains): for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
@ -555,7 +564,7 @@ class RelativeReasoning:
def fit(self, X, Y): def fit(self, X, Y):
n_instances = len(Y) n_instances = len(Y)
n_attributes = len(Y[0]) n_attributes = Y.shape[1]
for i in range(n_attributes): for i in range(n_attributes):
for j in range(n_attributes): for j in range(n_attributes):
@ -567,8 +576,8 @@ class RelativeReasoning:
countboth = 0 countboth = 0
for k in range(n_instances): for k in range(n_instances):
vi = Y[k][i] vi = Y[k, i]
vj = Y[k][j] vj = Y[k, j]
if vi is None or vj is None: if vi is None or vj is None:
continue continue