forked from enviPath/enviPy
Merge remote-tracking branch 'origin/develop' into fix/xss
# Conflicts: # pyproject.toml # uv.lock
This commit is contained in:
308
epdb/models.py
308
epdb/models.py
@ -29,7 +29,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
|
||||
EnviFormerDataset, Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2178,7 +2179,7 @@ class PackageBasedModel(EPModel):
|
||||
|
||||
applicable_rules = self.applicable_rules
|
||||
reactions = list(self._get_reactions())
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
@ -2187,7 +2188,7 @@ class PackageBasedModel(EPModel):
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "Dataset":
|
||||
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
@ -2228,7 +2229,7 @@ class PackageBasedModel(EPModel):
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
@ -2346,37 +2347,37 @@ class PackageBasedModel(EPModel):
|
||||
eval_reactions = list(
|
||||
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
)
|
||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
ds = self.load_dataset()
|
||||
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
|
||||
n_splits = 20
|
||||
n_splits = kwargs.get("n_splits", 20)
|
||||
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
splits = list(shuff.split(X))
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
models = Parallel(n_jobs=10)(
|
||||
models = Parallel(n_jobs=min(10, len(splits)))(
|
||||
delayed(train_func)(X, y, train_index, self._model_args())
|
||||
for train_index, _ in splits
|
||||
)
|
||||
evaluations = Parallel(n_jobs=10)(
|
||||
evaluations = Parallel(n_jobs=min(10, len(splits)))(
|
||||
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits)
|
||||
)
|
||||
@ -2588,11 +2589,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return rbrr
|
||||
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||
model = RelativeReasoning(
|
||||
start_index=ds.triggered()[0],
|
||||
end_index=ds.triggered()[1],
|
||||
end_index=ds.triggered()[-1],
|
||||
)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
@ -2602,7 +2603,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
return {
|
||||
"clz": "RuleBaseRelativeReasoning",
|
||||
"start_index": ds.triggered()[0],
|
||||
"end_index": ds.triggered()[1],
|
||||
"end_index": ds.triggered()[-1],
|
||||
}
|
||||
|
||||
def _save_model(self, model):
|
||||
@ -2690,11 +2691,11 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return mlrr
|
||||
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||
|
||||
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
||||
model.fit(X, y)
|
||||
model.fit(X.to_numpy(), y.to_numpy())
|
||||
return model
|
||||
|
||||
def _model_args(self):
|
||||
@ -2717,7 +2718,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
pred = self.model.predict_proba(classify_ds.X())
|
||||
pred = self.model.predict_proba(classify_ds.X().to_numpy())
|
||||
|
||||
res = MLRelativeReasoning.combine_products_and_probs(
|
||||
self.applicable_rules, pred[0], classify_prods[0]
|
||||
@ -2762,7 +2763,9 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
|
||||
@cached_property
|
||||
def training_set_probs(self):
|
||||
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
||||
ds = self.model.load_dataset()
|
||||
col_ids = ds.block_indices("prob")
|
||||
return ds[:, col_ids]
|
||||
|
||||
def build(self):
|
||||
ds = self.model.load_dataset()
|
||||
@ -2770,9 +2773,9 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
start = datetime.now()
|
||||
|
||||
# Get Trainingset probs and dump them as they're required when using the app domain
|
||||
probs = self.model.model.predict_proba(ds.X())
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
||||
joblib.dump(probs, f)
|
||||
probs = self.model.model.predict_proba(ds.X().to_numpy())
|
||||
ds.add_probs(probs)
|
||||
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
|
||||
|
||||
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
||||
ad.build(ds)
|
||||
@ -2795,16 +2798,19 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
joblib.dump(ad, f)
|
||||
|
||||
def assess(self, structure: Union[str, "CompoundStructure"]):
|
||||
return self.assess_batch([structure])[0]
|
||||
|
||||
def assess_batch(self, structures: List["CompoundStructure | str"]):
|
||||
ds = self.model.load_dataset()
|
||||
|
||||
if isinstance(structure, CompoundStructure):
|
||||
smiles = structure.smiles
|
||||
else:
|
||||
smiles = structure
|
||||
smiles = []
|
||||
for struct in structures:
|
||||
if isinstance(struct, CompoundStructure):
|
||||
smiles.append(structures.smiles)
|
||||
else:
|
||||
smiles.append(structures)
|
||||
|
||||
assessment_ds, assessment_prods = ds.classification_dataset(
|
||||
[structure], self.model.applicable_rules
|
||||
)
|
||||
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
|
||||
|
||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||
# {
|
||||
@ -2817,82 +2823,47 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
||||
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
||||
# with a given assessment structure under a particular rule.
|
||||
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
qualified_neighbours_per_rule: Dict = {}
|
||||
|
||||
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
||||
feature = ds.columns[feature_index]
|
||||
if feature.startswith("trig_"):
|
||||
# TODO unroll loop
|
||||
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
||||
if int(cx[feature_index]) == 1:
|
||||
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
||||
if int(tx[feature_index]) == 1:
|
||||
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
||||
|
||||
probs = self.training_set_probs
|
||||
# preds = self.model.model.predict_proba(assessment_ds.X())
|
||||
import polars as pl
|
||||
# Select only the triggered columns
|
||||
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
|
||||
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
|
||||
# trigger that rule.
|
||||
train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
|
||||
for trig_uuid, value in row.items() if value == 1}
|
||||
qualified_neighbours_per_rule[i] = train_trig
|
||||
rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
|
||||
preds = self.model.combine_products_and_probs(
|
||||
self.model.applicable_rules,
|
||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
||||
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
|
||||
assessment_prods[0],
|
||||
)
|
||||
|
||||
assessments = list()
|
||||
|
||||
# loop through our assessment dataset
|
||||
for i, instance in enumerate(assessment_ds):
|
||||
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
|
||||
rule_reliabilities = dict()
|
||||
local_compatibilities = dict()
|
||||
neighbours_per_rule = dict()
|
||||
neighbor_probs_per_rule = dict()
|
||||
|
||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
||||
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
||||
# train dataset
|
||||
train_instances = []
|
||||
for v in vals:
|
||||
train_instances.append((v, ds.at(v)))
|
||||
|
||||
# sf is a tuple with start/end index of the features
|
||||
sf = ds.struct_features()
|
||||
|
||||
# compute tanimoto distance for all neighbours
|
||||
# result ist a list of tuples with train index and computed distance
|
||||
dists = self._compute_distances(
|
||||
instance.X()[0][sf[0] : sf[1]],
|
||||
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
|
||||
)
|
||||
|
||||
dists_with_index = list()
|
||||
for ti, dist in zip(train_instances, dists):
|
||||
dists_with_index.append((ti[0], dist[1]))
|
||||
for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
|
||||
# compute tanimoto distance for all neighbours and add to dataset
|
||||
dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
|
||||
train_instances[:, train_instances.struct_features()].to_numpy())
|
||||
train_instances = train_instances.with_columns(dist=pl.Series(dists))
|
||||
|
||||
# sort them in a descending way and take at most `self.num_neighbours`
|
||||
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
||||
dists_with_index = dists_with_index[: self.num_neighbours]
|
||||
|
||||
# TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending)
|
||||
train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours]
|
||||
# compute average distance
|
||||
rule_reliabilities[rule_idx] = (
|
||||
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
|
||||
if len(dists_with_index) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
|
||||
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
||||
local_compatibilities[rule_idx] = self._compute_compatibility(
|
||||
rule_idx, probs, neighbour_datasets
|
||||
)
|
||||
neighbours_per_rule[rule_idx] = [
|
||||
CompoundStructure.objects.get(uuid=ds[1].structure_id())
|
||||
for ds in neighbour_datasets
|
||||
]
|
||||
neighbor_probs_per_rule[rule_idx] = [
|
||||
probs[d[0]][rule_idx] for d in dists_with_index
|
||||
]
|
||||
local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances)
|
||||
neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]))
|
||||
neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
|
||||
|
||||
ad_res = {
|
||||
"ad_params": {
|
||||
@ -2903,23 +2874,21 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
||||
},
|
||||
"assessment": {
|
||||
"smiles": smiles,
|
||||
"inside_app_domain": self.pca.is_applicable(instance)[0],
|
||||
"smiles": smiles[i],
|
||||
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
|
||||
},
|
||||
}
|
||||
|
||||
transformations = list()
|
||||
for rule_idx in rule_reliabilities.keys():
|
||||
rule = Rule.objects.get(
|
||||
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
|
||||
)
|
||||
for rule_uuid in rule_reliabilities.keys():
|
||||
rule = Rule.objects.get(uuid=rule_uuid)
|
||||
|
||||
rule_data = rule.simple_json()
|
||||
rule_data["image"] = f"{rule.url}?image=svg"
|
||||
|
||||
neighbors = []
|
||||
for n, n_prob in zip(
|
||||
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]
|
||||
neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
|
||||
):
|
||||
neighbor = n.simple_json()
|
||||
neighbor["image"] = f"{n.url}?image=svg"
|
||||
@ -2936,14 +2905,14 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
|
||||
transformation = {
|
||||
"rule": rule_data,
|
||||
"reliability": rule_reliabilities[rule_idx],
|
||||
"reliability": rule_reliabilities[rule_uuid],
|
||||
# We're setting it here to False, as we don't know whether "assess" is called during pathway
|
||||
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
|
||||
"is_predicted": False,
|
||||
"local_compatibility": local_compatibilities[rule_idx],
|
||||
"probability": preds[rule_idx].probability,
|
||||
"local_compatibility": local_compatibilities[rule_uuid],
|
||||
"probability": preds[rule_to_i[rule_uuid]].probability,
|
||||
"transformation_products": [
|
||||
x.product_set for x in preds[rule_idx].product_sets
|
||||
x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
|
||||
],
|
||||
"times_triggered": ds.times_triggered(str(rule.uuid)),
|
||||
"neighbors": neighbors,
|
||||
@ -2961,32 +2930,21 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||
from utilities.ml import tanimoto_distance
|
||||
|
||||
distances = [
|
||||
(i, tanimoto_distance(classify_instance, train))
|
||||
for i, train in enumerate(train_instances)
|
||||
]
|
||||
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
|
||||
return distances
|
||||
|
||||
@staticmethod
|
||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
|
||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
||||
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
|
||||
accuracy = 0.0
|
||||
|
||||
for n in neighbours:
|
||||
obs = n[1].y()[0][rule_idx]
|
||||
pred = preds[n[0]][rule_idx]
|
||||
if obs and pred:
|
||||
tp += 1
|
||||
elif not obs and pred:
|
||||
fp += 1
|
||||
elif obs and not pred:
|
||||
fn += 1
|
||||
else:
|
||||
tn += 1
|
||||
# Jaccard Index
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||
|
||||
import polars as pl
|
||||
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
|
||||
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
|
||||
# Compute tp, tn, fp, fn using polars expressions
|
||||
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
|
||||
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
|
||||
fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
|
||||
fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||
return accuracy
|
||||
|
||||
|
||||
@ -3087,44 +3045,24 @@ class EnviFormer(PackageBasedModel):
|
||||
self.save()
|
||||
|
||||
start = datetime.now()
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
co2 = {"C(=O)=O", "O=C=O"}
|
||||
ds = []
|
||||
for reaction in self._get_reactions():
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
if products not in co2:
|
||||
ds.append(f"{educts}>>{products}")
|
||||
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(f, "w") as d_file:
|
||||
json.dump(ds, d_file)
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "Dataset":
|
||||
def load_dataset(self):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(ds_path) as d_file:
|
||||
ds = json.load(d_file)
|
||||
return ds
|
||||
return EnviFormerDataset.load(ds_path)
|
||||
|
||||
def _fit_model(self, ds):
|
||||
# Call to enviFormer's fine_tune function and return the model
|
||||
from enviformer.finetune import fine_tune
|
||||
|
||||
start = datetime.now()
|
||||
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
return model
|
||||
@ -3140,7 +3078,7 @@ class EnviFormer(PackageBasedModel):
|
||||
args = {"clz": "EnviFormer"}
|
||||
return args
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
@ -3155,21 +3093,20 @@ class EnviFormer(PackageBasedModel):
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
||||
def evaluate_sg(test_ds, predictions, model_thresh):
|
||||
# Group the true products of reactions with the same reactant together
|
||||
assert len(test_ds) == len(predictions)
|
||||
true_dict = {}
|
||||
for r in test_reactions:
|
||||
reactant, true_product_set = r.split(">>")
|
||||
for r in test_ds:
|
||||
reactant, true_product_set = r
|
||||
true_product_set = {p for p in true_product_set.split(".")}
|
||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||
assert len(test_reactions) == len(predictions)
|
||||
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
|
||||
|
||||
# Group the predicted products of reactions with the same reactant together
|
||||
pred_dict = {}
|
||||
for k, pred in enumerate(predictions):
|
||||
pred_smiles, pred_proba = zip(*pred.items())
|
||||
reactant, true_product = test_reactions[k].split(">>")
|
||||
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
|
||||
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
||||
for smiles, proba in zip(pred_smiles, pred_proba):
|
||||
smiles = set(smiles.split("."))
|
||||
@ -3204,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
|
||||
break
|
||||
|
||||
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
||||
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
||||
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
|
||||
# Precision is TP (correct) / TP + FP (predicted)
|
||||
prec = {
|
||||
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
||||
@ -3283,47 +3220,32 @@ class EnviFormer(PackageBasedModel):
|
||||
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
ds = []
|
||||
for reaction in Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()
|
||||
).distinct():
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
ds.append(f"{educts}>>{products}")
|
||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()).distinct())
|
||||
test_result = self.model.predict_batch(ds.X())
|
||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
from enviformer.finetune import fine_tune
|
||||
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
n_splits = kwargs.get("n_splits", 20)
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
||||
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
single_gen_results = []
|
||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||
train = [ds[i] for i in train_index]
|
||||
test = [ds[i] for i in test_index]
|
||||
train = ds[train_index]
|
||||
test = ds[test_index]
|
||||
start = datetime.now()
|
||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(
|
||||
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
||||
)
|
||||
model.to(s.ENVIFORMER_DEVICE)
|
||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||
test_result = model.predict_batch(test.X())
|
||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
@ -3394,31 +3316,15 @@ class EnviFormer(PackageBasedModel):
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
reaction = reaction.edge_label
|
||||
if any(
|
||||
[
|
||||
educt in test_educts
|
||||
for educt in reaction_to_educts[str(reaction.uuid)]
|
||||
]
|
||||
):
|
||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||
overlap += 1
|
||||
continue
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
train_reactions.append(f"{educts}>>{products}")
|
||||
train_reactions.append(reaction)
|
||||
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
||||
)
|
||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
|
||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||
|
||||
self.eval_results.update(
|
||||
|
||||
Reference in New Issue
Block a user