forked from enviPath/enviPy
[Feature] Enviformer fine tuning and evaluation
## Changes - I have finished the backend integration of EnviFormer (#19), this includes, dataset building, model finetuning, model evaluation and model prediction with the finetuned model. - `PackageBasedModel` has been adjusted to be more abstract, this includes making the `_save_model` method and making `compute_averages` a static class function. - I had to bump the python-version in `pyproject.toml` to >=3.12 from >=3.11 otherwise uv failed to install EnviFormer. - The default EnviFormer loading during `settings.py` has been removed. ## Future Fix I noticed you have a little bit of code in `PackageBasedModel` -> `evaluate_model` for using the `eval_packages` during evaluation instead of train/test splits on `data_packages`. It doesn't seem finished, I presume we want this for all models, so I will take care of that in a new branch/pullrequest after this request is merged. Also, I haven't done anything for a POST request to finetune the model, I'm not sure if that is something we want now. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#141 Reviewed-by: jebus <lorsbach@envipath.com> Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
413
epdb/models.py
413
epdb/models.py
@ -9,7 +9,7 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Union, List, Optional, Dict, Tuple, Set, Any
|
||||
from uuid import uuid4
|
||||
|
||||
import math
|
||||
import joblib
|
||||
import numpy as np
|
||||
from django.conf import settings as s
|
||||
@ -2002,6 +2002,10 @@ class PackageBasedModel(EPModel):
|
||||
def _model_args(self) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _save_model(self, model):
|
||||
pass
|
||||
|
||||
def build_model(self):
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
@ -2010,8 +2014,7 @@ class PackageBasedModel(EPModel):
|
||||
|
||||
mod = self._fit_model(ds)
|
||||
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(mod, f)
|
||||
self._save_model(mod)
|
||||
|
||||
if self.app_domain is not None:
|
||||
logger.debug("Building applicability domain...")
|
||||
@ -2116,7 +2119,7 @@ class PackageBasedModel(EPModel):
|
||||
|
||||
mg_acc = 0.0
|
||||
for t in thresholds:
|
||||
for true, pred in zip(test_pathways, pred_pathways):
|
||||
for true, pred in zip(pathways, pred_pathways):
|
||||
acc, pre, rec = multigen_eval(true, pred, t)
|
||||
if abs(t - threshold) < 0.01:
|
||||
mg_acc = acc
|
||||
@ -2146,29 +2149,7 @@ class PackageBasedModel(EPModel):
|
||||
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits))
|
||||
|
||||
def compute_averages(data):
|
||||
num_items = len(data)
|
||||
avg_first_item = sum(item[0] for item in data) / num_items
|
||||
|
||||
sum_dict2 = defaultdict(float)
|
||||
sum_dict3 = defaultdict(float)
|
||||
|
||||
for _, dict2, dict3 in data:
|
||||
for key in dict2:
|
||||
sum_dict2[key] += dict2[key]
|
||||
for key in dict3:
|
||||
sum_dict3[key] += dict3[key]
|
||||
|
||||
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
||||
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
||||
|
||||
return {
|
||||
"average_accuracy": float(avg_first_item),
|
||||
"average_precision_per_threshold": avg_dict2,
|
||||
"average_recall_per_threshold": avg_dict3
|
||||
}
|
||||
|
||||
self.eval_results = compute_averages(evaluations)
|
||||
self.eval_results = self.compute_averages(evaluations)
|
||||
|
||||
if self.multigen_eval:
|
||||
|
||||
@ -2209,7 +2190,7 @@ class PackageBasedModel(EPModel):
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
# build lookup to avoid recalculation of features, labels
|
||||
id_to_index = {uuid: i for i, uuid in enumerate(ds[:, 0])}
|
||||
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
||||
|
||||
# Compute splits of the collected pathway
|
||||
splits = []
|
||||
@ -2233,10 +2214,8 @@ class PackageBasedModel(EPModel):
|
||||
# Ensure compounds in the training set do not appear in the test set
|
||||
if educt not in test_educts:
|
||||
if educt in id_to_index:
|
||||
split_ids.append(id_to_index[str(educt)])
|
||||
try:
|
||||
split_ids.append(id_to_index[str(educt)])
|
||||
except KeyError:
|
||||
split_ids.append(id_to_index[educt])
|
||||
else:
|
||||
logger.debug(f"Couldn't find features in X for compound {educt}")
|
||||
else:
|
||||
overlap += 1
|
||||
@ -2260,12 +2239,34 @@ class PackageBasedModel(EPModel):
|
||||
zip(trained_models, splits)
|
||||
)
|
||||
|
||||
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in compute_averages(multi_ret_vals).items()})
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
@staticmethod
|
||||
def compute_averages(data):
|
||||
num_items = len(data)
|
||||
avg_first_item = sum(item[0] for item in data) / num_items
|
||||
|
||||
sum_dict2 = defaultdict(float)
|
||||
sum_dict3 = defaultdict(float)
|
||||
|
||||
for _, dict2, dict3 in data:
|
||||
for key in dict2:
|
||||
sum_dict2[key] += dict2[key]
|
||||
for key in dict3:
|
||||
sum_dict3[key] += dict3[key]
|
||||
|
||||
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
||||
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
||||
|
||||
return {
|
||||
"average_accuracy": float(avg_first_item),
|
||||
"average_precision_per_threshold": avg_dict2,
|
||||
"average_recall_per_threshold": avg_dict3
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
||||
res = []
|
||||
@ -2354,6 +2355,10 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
'end_index': ds.triggered()[1],
|
||||
}
|
||||
|
||||
def _save_model(self, model):
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(model, f)
|
||||
|
||||
@cached_property
|
||||
def model(self) -> 'RelativeReasoning':
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
@ -2443,6 +2448,10 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
**s.DEFAULT_MODEL_PARAMS,
|
||||
}
|
||||
|
||||
def _save_model(self, model):
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
||||
joblib.dump(model, f)
|
||||
|
||||
@cached_property
|
||||
def model(self) -> 'EnsembleClassifierChain':
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
||||
@ -2696,57 +2705,335 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
return accuracy
|
||||
|
||||
|
||||
|
||||
class EnviFormer(EPModel):
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
class EnviFormer(PackageBasedModel):
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(package, name, description, threshold):
|
||||
def create(package: 'Package', data_packages: List['Package'], eval_packages: List['Package'],
|
||||
threshold: float = 0.5, name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
||||
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
||||
app_domain_local_compatibility_threshold: float = None):
|
||||
mod = EnviFormer()
|
||||
mod.package = package
|
||||
|
||||
if name is None or name.strip() == '':
|
||||
name = f"EnviFormer {EnviFormer.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mod.name = name
|
||||
mod.description = description
|
||||
|
||||
if description is not None and description.strip() != '':
|
||||
mod.description = description
|
||||
|
||||
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
||||
raise ValueError("Threshold must be a float between 0 and 1.")
|
||||
|
||||
mod.threshold = threshold
|
||||
|
||||
if len(data_packages) == 0:
|
||||
raise ValueError("At least one data package must be provided.")
|
||||
|
||||
mod.save()
|
||||
|
||||
for p in data_packages:
|
||||
mod.data_packages.add(p)
|
||||
|
||||
if eval_packages:
|
||||
for p in eval_packages:
|
||||
mod.eval_packages.add(p)
|
||||
|
||||
# if build_app_domain:
|
||||
# ad = ApplicabilityDomain.create(mod, app_domain_num_neighbours, app_domain_reliability_threshold,
|
||||
# app_domain_local_compatibility_threshold)
|
||||
# mod.app_domain = ad
|
||||
|
||||
mod.save()
|
||||
return mod
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
mod = getattr(s, 'ENVIFORMER_INSTANCE', None)
|
||||
logger.info(f"Model from settings {hash(mod)}")
|
||||
return mod
|
||||
from enviformer import load
|
||||
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
||||
return load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
||||
|
||||
def predict(self, smiles) -> List['PredictionResult']:
|
||||
# example = {
|
||||
# 'C#N': 0.46326889595136767,
|
||||
# 'C#C': 0.04981685951409509,
|
||||
# }
|
||||
from rdkit import Chem
|
||||
m = Chem.MolFromSmiles(smiles)
|
||||
Chem.Kekulize(m)
|
||||
kek = Chem.MolToSmiles(m, kekuleSmiles=True)
|
||||
logger.info(f"Submitting {kek} to {hash(self.model)}")
|
||||
products = self.model.predict(kek)
|
||||
logger.info(f"Got results {products}")
|
||||
return self.predict_batch([smiles])[0]
|
||||
|
||||
res = []
|
||||
for smi, prob in products.items():
|
||||
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
||||
def predict_batch(self, smiles_list):
|
||||
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
||||
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
|
||||
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}")
|
||||
products_list = self.model.predict_batch(canon_smiles)
|
||||
logger.info(f"Got results {products_list}")
|
||||
|
||||
return res
|
||||
results = []
|
||||
for products in products_list:
|
||||
res = []
|
||||
for smi, prob in products.items():
|
||||
try:
|
||||
smi = ".".join([FormatConverter.standardize(smile, remove_stereo=True) for smile in smi.split(".")])
|
||||
except ValueError: # This occurs when the predicted string is an invalid SMILES
|
||||
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
|
||||
continue
|
||||
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
||||
results.append(res)
|
||||
|
||||
return results
|
||||
|
||||
def build_dataset(self):
|
||||
self.model_status = self.INITIALIZING
|
||||
self.save()
|
||||
|
||||
start = datetime.now()
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
ds = []
|
||||
for reaction in self._get_reactions():
|
||||
educts = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
|
||||
ds.append(f"{educts}>>{products}")
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(f, "w") as d_file:
|
||||
json.dump(ds, d_file)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> 'Dataset':
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(ds_path) as d_file:
|
||||
ds = json.load(d_file)
|
||||
return ds
|
||||
|
||||
def _fit_model(self, ds):
|
||||
# Call to enviFormer's fine_tune function and return the model
|
||||
from enviformer.finetune import fine_tune
|
||||
start = datetime.now()
|
||||
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
return model
|
||||
|
||||
def _save_model(self, model):
|
||||
from enviformer.utils import save_model
|
||||
save_model(model, os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt"))
|
||||
|
||||
def _model_args(self) -> Dict[str, Any]:
|
||||
args = {"clz": "EnviFormer"}
|
||||
return args
|
||||
|
||||
def evaluate_model(self):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
||||
# Group the true products of reactions with the same reactant together
|
||||
true_dict = {}
|
||||
for r in test_reactions:
|
||||
reactant, true_product_set = r.split(">>")
|
||||
true_product_set = {p for p in true_product_set.split(".")}
|
||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||
assert len(test_reactions) == len(predictions)
|
||||
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
|
||||
|
||||
# Group the predicted products of reactions with the same reactant together
|
||||
pred_dict = {}
|
||||
for k, pred in enumerate(predictions):
|
||||
pred_smiles, pred_proba = zip(*pred.items())
|
||||
reactant, true_product = test_reactions[k].split(">>")
|
||||
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
||||
for smiles, proba in zip(pred_smiles, pred_proba):
|
||||
smiles = set(smiles.split("."))
|
||||
if smiles not in pred_dict[reactant]["predict"]:
|
||||
pred_dict[reactant]["predict"].append(smiles)
|
||||
pred_dict[reactant]["scores"].append(proba)
|
||||
|
||||
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
|
||||
thresholds = set()
|
||||
thresholds.update({i / 5 for i in range(-75, -10, 15)})
|
||||
thresholds.update({i / 50 for i in range(-100, -10, 10)})
|
||||
thresholds = {math.exp(t) for t in thresholds}
|
||||
thresholds.add(model_thresh)
|
||||
thresholds = sorted(thresholds)
|
||||
|
||||
# Calculate the number correct and predicted for each threshold and at each top-k
|
||||
correct = {t: 0 for t in thresholds}
|
||||
predicted = {t: 0 for t in thresholds}
|
||||
for reactant, product_sets in true_dict.items():
|
||||
pred_smiles = pred_dict[reactant]["predict"]
|
||||
pred_scores = pred_dict[reactant]["scores"]
|
||||
|
||||
for true_set in product_sets:
|
||||
for threshold in correct:
|
||||
pred_s = [s for i, s in enumerate(pred_smiles) if pred_scores[i] > threshold]
|
||||
predicted[threshold] += len(pred_s)
|
||||
for pred_set in pred_s:
|
||||
if len(true_set - pred_set) == 0:
|
||||
correct[threshold] += 1
|
||||
break
|
||||
|
||||
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
||||
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
||||
# Precision is TP (correct) / TP + FP (predicted)
|
||||
prec = {f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()}
|
||||
# Accuracy for EnviFormer is just recall
|
||||
return rec[f"{model_thresh:.2f}"], prec, rec
|
||||
|
||||
def evaluate_mg(model, pathways, threshold):
|
||||
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
|
||||
thresholds = set()
|
||||
thresholds.update({i / 5 for i in range(-75, -10, 15)})
|
||||
thresholds.update({i / 50 for i in range(-100, -10, 10)})
|
||||
thresholds = {math.exp(t) for t in thresholds}
|
||||
thresholds.add(threshold)
|
||||
thresholds = sorted(thresholds)
|
||||
|
||||
precision = {f"{t:.2f}": [] for t in thresholds}
|
||||
recall = {f"{t:.2f}": [] for t in thresholds}
|
||||
|
||||
# Note: only one root compound supported at this time
|
||||
root_compounds = []
|
||||
for p in pathways:
|
||||
root_node = p.root_nodes
|
||||
if len(root_node) > 1:
|
||||
logging.warning(f"Pathway {p.name} has more than one root compound, only {root_node[0]} will be used")
|
||||
root_node = ".".join([FormatConverter.standardize(smile) for smile in root_node[0].default_node_label.smiles.split(".")])
|
||||
root_compounds.append(root_node)
|
||||
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
||||
# pass it to the setting used in prediction
|
||||
mod = EnviFormer.objects.get(pk=self.pk)
|
||||
mod.model = model
|
||||
|
||||
s = Setting()
|
||||
s.model = mod
|
||||
s.model_threshold = min(thresholds)
|
||||
s.max_depth = 10
|
||||
s.max_nodes = 50
|
||||
|
||||
from epdb.logic import SPathway
|
||||
from utilities.ml import multigen_eval
|
||||
|
||||
# Predict pathways from each root compound
|
||||
pred_pathways = []
|
||||
for i, root in enumerate(root_compounds):
|
||||
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
||||
|
||||
spw = SPathway(root_nodes=root, prediction_setting=s)
|
||||
level = 0
|
||||
|
||||
while not spw.done:
|
||||
spw.predict_step(from_depth=level)
|
||||
level += 1
|
||||
|
||||
pred_pathways.append(spw)
|
||||
|
||||
mg_acc = 0.0
|
||||
for t in thresholds:
|
||||
for true, pred in zip(pathways, pred_pathways):
|
||||
# Calculate multigen statistics
|
||||
acc, pre, rec = multigen_eval(true, pred, t)
|
||||
if t == threshold:
|
||||
mg_acc = acc
|
||||
precision[f"{t:.2f}"].append(pre)
|
||||
recall[f"{t:.2f}"].append(rec)
|
||||
|
||||
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
||||
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
||||
return mg_acc, precision, recall
|
||||
|
||||
from enviformer.finetune import fine_tune
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
single_gen_results = []
|
||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||
train = [ds[i] for i in train_index]
|
||||
test = [ds[i] for i in test_index]
|
||||
start = datetime.now()
|
||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
model.to(s.ENVIFORMER_DEVICE)
|
||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
|
||||
if self.multigen_eval:
|
||||
pathway_qs = Pathway.objects.prefetch_related(
|
||||
'node_set',
|
||||
'node_set__out_edges',
|
||||
'node_set__default_node_label',
|
||||
'node_set__scenarios',
|
||||
'edge_set',
|
||||
'edge_set__start_nodes',
|
||||
'edge_set__end_nodes',
|
||||
'edge_set__edge_label',
|
||||
'edge_set__scenarios'
|
||||
).filter(package__in=self.data_packages.all()).distinct()
|
||||
|
||||
pathways = []
|
||||
for pathway in pathway_qs:
|
||||
# There is one pathway with no root compounds, so this check is required
|
||||
if len(pathway.root_nodes) > 0:
|
||||
pathways.append(pathway)
|
||||
else:
|
||||
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
|
||||
|
||||
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
||||
reaction_to_educts = defaultdict(set)
|
||||
for pathway in pathways:
|
||||
for reaction in pathway.edges:
|
||||
for e in reaction.edge_label.educts.all():
|
||||
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
||||
|
||||
multi_gen_results = []
|
||||
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
||||
# iteration instead of storing all trained models.
|
||||
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
# Collect structures from test pathways
|
||||
test_educts = set()
|
||||
for pathway in test_pathways:
|
||||
for reaction in pathway.edges:
|
||||
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
||||
|
||||
train_reactions = []
|
||||
overlap = 0
|
||||
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
||||
# the test pathways
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
reaction = reaction.edge_label
|
||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||
overlap += 1
|
||||
continue
|
||||
educts = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
|
||||
products = ".".join(
|
||||
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
|
||||
train_reactions.append(f"{educts}>>{products}")
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
|
||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||
|
||||
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
@cached_property
|
||||
def applicable_rules(self):
|
||||
return []
|
||||
|
||||
def status(self):
|
||||
return "Model is built and can be used for predictions, Model is not evaluated yet."
|
||||
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class PluginModel(EPModel):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user