forked from enviPath/enviPy
new RuleBasedDataset and EnviFormer dataset working for respective models #120
This commit is contained in:
@ -2225,7 +2225,7 @@ class PackageBasedModel(EPModel):
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def evaluate_model(self):
|
||||
def evaluate_model(self, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
@ -2354,18 +2354,18 @@ class PackageBasedModel(EPModel):
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
|
||||
n_splits = 20
|
||||
n_splits = kwargs.get("n_splits", 20)
|
||||
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
splits = list(shuff.split(X))
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
models = Parallel(n_jobs=10)(
|
||||
models = Parallel(n_jobs=min(10, len(splits)))(
|
||||
delayed(train_func)(X, y, train_index, self._model_args())
|
||||
for train_index, _ in splits
|
||||
)
|
||||
evaluations = Parallel(n_jobs=10)(
|
||||
evaluations = Parallel(n_jobs=min(10, len(splits)))(
|
||||
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits)
|
||||
)
|
||||
@ -2716,7 +2716,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
pred = self.model.predict_proba(classify_ds.X())
|
||||
pred = self.model.predict_proba(np.array(classify_ds.X()))
|
||||
|
||||
res = MLRelativeReasoning.combine_products_and_probs(
|
||||
self.applicable_rules, pred[0], classify_prods[0]
|
||||
@ -3096,7 +3096,7 @@ class EnviFormer(PackageBasedModel):
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "RuleBasedDataset":
|
||||
def load_dataset(self):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
return EnviFormerDataset.load(ds_path)
|
||||
|
||||
@ -3105,7 +3105,7 @@ class EnviFormer(PackageBasedModel):
|
||||
from enviformer.finetune import fine_tune
|
||||
|
||||
start = datetime.now()
|
||||
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
return model
|
||||
@ -3121,19 +3121,19 @@ class EnviFormer(PackageBasedModel):
|
||||
args = {"clz": "EnviFormer"}
|
||||
return args
|
||||
|
||||
def evaluate_model(self):
|
||||
def evaluate_model(self, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
||||
def evaluate_sg(test_ds, predictions, model_thresh):
|
||||
# Group the true products of reactions with the same reactant together
|
||||
assert len(test_reactions) == len(predictions)
|
||||
assert len(test_ds) == len(predictions)
|
||||
true_dict = {}
|
||||
for r in test_reactions:
|
||||
reactant, true_product_set = r.split(">>")
|
||||
for r in test_ds:
|
||||
reactant, true_product_set = r
|
||||
true_product_set = {p for p in true_product_set.split(".")}
|
||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||
|
||||
@ -3141,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
|
||||
pred_dict = {}
|
||||
for k, pred in enumerate(predictions):
|
||||
pred_smiles, pred_proba = zip(*pred.items())
|
||||
reactant, true_product = test_reactions[k].split(">>")
|
||||
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
|
||||
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
||||
for smiles, proba in zip(pred_smiles, pred_proba):
|
||||
smiles = set(smiles.split("."))
|
||||
@ -3176,7 +3176,7 @@ class EnviFormer(PackageBasedModel):
|
||||
break
|
||||
|
||||
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
||||
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
||||
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
|
||||
# Precision is TP (correct) / TP + FP (predicted)
|
||||
prec = {
|
||||
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
||||
@ -3257,30 +3257,30 @@ class EnviFormer(PackageBasedModel):
|
||||
if self.eval_packages.count() > 0:
|
||||
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()).distinct())
|
||||
test_result = self.model.predict_batch(ds)
|
||||
test_result = self.model.predict_batch(ds.X())
|
||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
from enviformer.finetune import fine_tune
|
||||
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
n_splits = kwargs.get("n_splits", 20)
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
single_gen_results = []
|
||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||
train = [ds[i] for i in train_index]
|
||||
test = [ds[i] for i in test_index]
|
||||
train = ds[train_index]
|
||||
test = ds[test_index]
|
||||
start = datetime.now()
|
||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(
|
||||
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
||||
)
|
||||
model.to(s.ENVIFORMER_DEVICE)
|
||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||
test_result = model.predict_batch(test.X())
|
||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
@ -3351,31 +3351,15 @@ class EnviFormer(PackageBasedModel):
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
reaction = reaction.edge_label
|
||||
if any(
|
||||
[
|
||||
educt in test_educts
|
||||
for educt in reaction_to_educts[str(reaction.uuid)]
|
||||
]
|
||||
):
|
||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||
overlap += 1
|
||||
continue
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
train_reactions.append(f"{educts}>>{products}")
|
||||
train_reactions.append(reaction)
|
||||
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
||||
)
|
||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
|
||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||
|
||||
self.eval_results.update(
|
||||
|
||||
Reference in New Issue
Block a user