new RuleBasedDataset and EnviFormer dataset working for respective models #120

This commit is contained in:
Liam Brydon
2025-11-04 10:58:16 +13:00
parent ff51e48f90
commit ac5d370b18
5 changed files with 126 additions and 101 deletions

View File

@ -2225,7 +2225,7 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self):
def evaluate_model(self, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -2354,18 +2354,18 @@ class PackageBasedModel(EPModel):
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
n_splits = 20
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X))
from joblib import Parallel, delayed
models = Parallel(n_jobs=10)(
models = Parallel(n_jobs=min(10, len(splits)))(
delayed(train_func)(X, y, train_index, self._model_args())
for train_index, _ in splits
)
evaluations = Parallel(n_jobs=10)(
evaluations = Parallel(n_jobs=min(10, len(splits)))(
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits)
)
@ -2716,7 +2716,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X())
pred = self.model.predict_proba(np.array(classify_ds.X()))
res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0]
@ -3096,7 +3096,7 @@ class EnviFormer(PackageBasedModel):
ds.save(f)
return ds
def load_dataset(self) -> "RuleBasedDataset":
def load_dataset(self):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
return EnviFormerDataset.load(ds_path)
@ -3105,7 +3105,7 @@ class EnviFormer(PackageBasedModel):
from enviformer.finetune import fine_tune
start = datetime.now()
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model
@ -3121,19 +3121,19 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"}
return args
def evaluate_model(self):
def evaluate_model(self, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING
self.save()
def evaluate_sg(test_reactions, predictions, model_thresh):
def evaluate_sg(test_ds, predictions, model_thresh):
# Group the true products of reactions with the same reactant together
assert len(test_reactions) == len(predictions)
assert len(test_ds) == len(predictions)
true_dict = {}
for r in test_reactions:
reactant, true_product_set = r.split(">>")
for r in test_ds:
reactant, true_product_set = r
true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
@ -3141,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
pred_dict = {}
for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_reactions[k].split(">>")
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split("."))
@ -3176,7 +3176,7 @@ class EnviFormer(PackageBasedModel):
break
# Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted)
prec = {
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
@ -3257,30 +3257,30 @@ class EnviFormer(PackageBasedModel):
if self.eval_packages.count() > 0:
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
package__in=self.eval_packages.all()).distinct())
test_result = self.model.predict_batch(ds)
test_result = self.model.predict_batch(ds.X())
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
from enviformer.finetune import fine_tune
ds = self.load_dataset()
n_splits = 20
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint.
single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index]
test = [ds[i] for i in test_index]
train = ds[train_index]
test = ds[test_index]
start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
)
model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
test_result = model.predict_batch(test.X())
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results)
@ -3351,31 +3351,15 @@ class EnviFormer(PackageBasedModel):
for pathway in train_pathways:
for reaction in pathway.edges:
reaction = reaction.edge_label
if any(
[
educt in test_educts
for educt in reaction_to_educts[str(reaction.uuid)]
]
):
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
overlap += 1
continue
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
train_reactions.append(f"{educts}>>{products}")
train_reactions.append(reaction)
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
)
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update(