[Enhancement] Refactor Dataset (#184)

# Summary
I have introduced a new base `class Dataset` in `ml.py` which all datasets should subclass. It stores the dataset as a polars DataFrame with the column names and number of columns determined by the subclass. It implements generic methods such as `add_row`, `at`, `limit` and dataset saving. It also details abstract methods required by the subclasses. These include `X`, `y` and `generate_dataset`.

There are two subclasses that currently exist. `RuleBasedDataset` for the MLRR models and `EnviFormerDataset` for the enviFormer models.

# Old Dataset to New RuleBasedDataset Functionality Translation

- [x] \_\_init\_\_
    - self.columns and self.num_labels moved to base Dataset class
    - self.data moved to base class with name self.df along with initialising from list or from another DataFrame
    - struct_features, triggered and observed remain the same
- [x] \_block\_indices
    - function moved to base Dataset class
- [x] structure_id
    - stays in RuleBasedDataset, now requires an index for the row of interest
- [x] add_row
    - moved to base Dataset class, now calls add_rows so one or more rows can be added at a time
- [x] times_triggered
    - stays in RuleBasedDataset, now does a look up using polars df.filter
- [x] struct_features (see init)
- [x] triggered (see init)
- [x] observed (see init)
- [x] at
    - removed in favour of indexing with getitem
- [x] limit
    - removed in favour of indexing with getitem
- [x] classification_dataset
    - stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] generate_dataset
    - stays in RuleBasedDataset, largely the same just with new dataset construction using add_rows
- [x] X
    - moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] trig
    - stays in RuleBasedDataset, functionally the same but uses polars
- [x] y
    - moved to base Dataset as @abstract_method, RuleBasedDataset implementation functionally the same but uses polars
- [x] \_\_get_item\_\_
    - moved to base dataset, now passes item to the dataframe for polars to handle
- [x] to_arff
    - stays in RuleBasedDataset, functionally the same but uses polars
- [x] \_\_repr\_\_
    - moved to base dataset
- [x] \_\_iter\_\_
    - moved to base Dataset, now uses polars iter_rows

# Base Dataset class Features
The following functions are available in the base Dataset class

- init - Create the dataset from a list of columns and data in format list of list. Or can create a dataset from a polars Dataframe, this is essential for recreating itself during indexing. Can create an empty dataset by just passing column names.
- add_rows - Add rows to the Dataset, we check that the new data length is the same but it is presumed that the column order matches the existing dataframe
- add_row - Add one row, see add_rows
- block_indices - Returns the column indices that start with the given prefix
- columns - Property, returns dataframe.columns
- shape - Property, returns dataframe.shape
- X - Abstract method to be implemented by the subclasses, it should represent the input to a ML model
- y - Abstract method to be implemented by the subclasses, it should represent the target for a ML model
- generate_dataset - Abstract and static method to be implemented by the subclasses, should return an initialised subclass of Dataset
- iter - returns the iterable from dataframe.iter_rows()
- getitem - passes the item argument to the dataframe. If the result of indexing the dataframe is another dataframe, the new dataframe is  packaged into a new Dataset of the same subclass. If the result of indexing is something else (int, float, polar Series) return the result.
- save - Pickle and save the dataframe to the given path
- load - Static method to load the dataset from the given path
- to_numpy - returns the dataframe as a numpy array. Required for compatibility with training of the ECC model
- repr - return a representation of the dataset
- len - return the length of the dataframe
- iter_rows - Return dataframe.iterrows with arguments passed through. Mainly used to get the named iterable which returns rows of the dataframe as dict of column names: column values instead of tuple of column values.
- filter - pass to dataframe.filter and recreates self with the result
- select - pass to dataframe.select and recreates self with the result
- with_columns - pass to dataframe.with_columns and recreates self with the result
- sort - pass to dataframe.sort and recreates self with the result
- item - pass to dataframe.item
- fill_nan - fill the dataframe nan's with value
- height - Property, returns the height (number of rows) of the dataframe

- [x] App domain
- [x] MACCS alternatives

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#184
Reviewed-by: jebus <lorsbach@envipath.com>
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-11-07 08:46:17 +13:00
committed by jebus
parent 98d62e1d1f
commit e26d5a21e3
10 changed files with 754 additions and 513 deletions

View File

@ -1542,9 +1542,7 @@ class SPathway(object):
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)[0]
app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
if self.persist is not None:
n = self.snode_persist_lookup[sub]
@ -1576,11 +1574,7 @@ class SPathway(object):
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)[
0
]
)
app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment

View File

@ -28,7 +28,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
EnviFormerDataset, Dataset
logger = logging.getLogger(__name__)
@ -2175,7 +2176,7 @@ class PackageBasedModel(EPModel):
applicable_rules = self.applicable_rules
reactions = list(self._get_reactions())
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
@ -2184,7 +2185,7 @@ class PackageBasedModel(EPModel):
ds.save(f)
return ds
def load_dataset(self) -> "Dataset":
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return Dataset.load(ds_path)
@ -2225,7 +2226,7 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -2343,37 +2344,37 @@ class PackageBasedModel(EPModel):
eval_reactions = list(
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
)
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(na_replacement=np.nan).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(na_replacement=np.nan).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
n_splits = 20
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X))
from joblib import Parallel, delayed
models = Parallel(n_jobs=10)(
models = Parallel(n_jobs=min(10, len(splits)))(
delayed(train_func)(X, y, train_index, self._model_args())
for train_index, _ in splits
)
evaluations = Parallel(n_jobs=10)(
evaluations = Parallel(n_jobs=min(10, len(splits)))(
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits)
)
@ -2585,11 +2586,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return rbrr
def _fit_model(self, ds: Dataset):
def _fit_model(self, ds: RuleBasedDataset):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning(
start_index=ds.triggered()[0],
end_index=ds.triggered()[1],
end_index=ds.triggered()[-1],
)
model.fit(X, y)
return model
@ -2599,7 +2600,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return {
"clz": "RuleBaseRelativeReasoning",
"start_index": ds.triggered()[0],
"end_index": ds.triggered()[1],
"end_index": ds.triggered()[-1],
}
def _save_model(self, model):
@ -2687,11 +2688,11 @@ class MLRelativeReasoning(PackageBasedModel):
return mlrr
def _fit_model(self, ds: Dataset):
def _fit_model(self, ds: RuleBasedDataset):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
model.fit(X, y)
model.fit(X.to_numpy(), y.to_numpy())
return model
def _model_args(self):
@ -2714,7 +2715,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X())
pred = self.model.predict_proba(classify_ds.X().to_numpy())
res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0]
@ -2759,7 +2760,9 @@ class ApplicabilityDomain(EnviPathModel):
@cached_property
def training_set_probs(self):
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
ds = self.model.load_dataset()
col_ids = ds.block_indices("prob")
return ds[:, col_ids]
def build(self):
ds = self.model.load_dataset()
@ -2767,9 +2770,9 @@ class ApplicabilityDomain(EnviPathModel):
start = datetime.now()
# Get Trainingset probs and dump them as they're required when using the app domain
probs = self.model.model.predict_proba(ds.X())
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
joblib.dump(probs, f)
probs = self.model.model.predict_proba(ds.X().to_numpy())
ds.add_probs(probs)
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
ad.build(ds)
@ -2792,16 +2795,19 @@ class ApplicabilityDomain(EnviPathModel):
joblib.dump(ad, f)
def assess(self, structure: Union[str, "CompoundStructure"]):
return self.assess_batch([structure])[0]
def assess_batch(self, structures: List["CompoundStructure | str"]):
ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure):
smiles = structure.smiles
else:
smiles = structure
smiles = []
for struct in structures:
if isinstance(struct, CompoundStructure):
smiles.append(structures.smiles)
else:
smiles.append(structures)
assessment_ds, assessment_prods = ds.classification_dataset(
[structure], self.model.applicable_rules
)
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
# qualified_neighbours_per_rule is a nested dictionary structured as:
# {
@ -2814,82 +2820,47 @@ class ApplicabilityDomain(EnviPathModel):
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
# with a given assessment structure under a particular rule.
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
lambda: defaultdict(list)
)
qualified_neighbours_per_rule: Dict = {}
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
feature = ds.columns[feature_index]
if feature.startswith("trig_"):
# TODO unroll loop
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
if int(cx[feature_index]) == 1:
for j, tx in enumerate(ds.X(exclude_id_col=False)):
if int(tx[feature_index]) == 1:
qualified_neighbours_per_rule[i][rule_idx].append(j)
probs = self.training_set_probs
# preds = self.model.model.predict_proba(assessment_ds.X())
import polars as pl
# Select only the triggered columns
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
# trigger that rule.
train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
for trig_uuid, value in row.items() if value == 1}
qualified_neighbours_per_rule[i] = train_trig
rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
preds = self.model.combine_products_and_probs(
self.model.applicable_rules,
self.model.model.predict_proba(assessment_ds.X())[0],
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
assessment_prods[0],
)
assessments = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds):
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
rule_reliabilities = dict()
local_compatibilities = dict()
neighbours_per_rule = dict()
neighbor_probs_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
# train dataset
train_instances = []
for v in vals:
train_instances.append((v, ds.at(v)))
# sf is a tuple with start/end index of the features
sf = ds.struct_features()
# compute tanimoto distance for all neighbours
# result ist a list of tuples with train index and computed distance
dists = self._compute_distances(
instance.X()[0][sf[0] : sf[1]],
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
)
dists_with_index = list()
for ti, dist in zip(train_instances, dists):
dists_with_index.append((ti[0], dist[1]))
for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
# compute tanimoto distance for all neighbours and add to dataset
dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
train_instances[:, train_instances.struct_features()].to_numpy())
train_instances = train_instances.with_columns(dist=pl.Series(dists))
# sort them in a descending way and take at most `self.num_neighbours`
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
dists_with_index = dists_with_index[: self.num_neighbours]
# TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending)
train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours]
# compute average distance
rule_reliabilities[rule_idx] = (
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
if len(dists_with_index) > 0
else 0.0
)
rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
# for local_compatibility we'll need the datasets for the indices having the highest similarity
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
local_compatibilities[rule_idx] = self._compute_compatibility(
rule_idx, probs, neighbour_datasets
)
neighbours_per_rule[rule_idx] = [
CompoundStructure.objects.get(uuid=ds[1].structure_id())
for ds in neighbour_datasets
]
neighbor_probs_per_rule[rule_idx] = [
probs[d[0]][rule_idx] for d in dists_with_index
]
local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances)
neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]))
neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
ad_res = {
"ad_params": {
@ -2900,23 +2871,21 @@ class ApplicabilityDomain(EnviPathModel):
"local_compatibility_threshold": self.local_compatibilty_threshold,
},
"assessment": {
"smiles": smiles,
"inside_app_domain": self.pca.is_applicable(instance)[0],
"smiles": smiles[i],
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
},
}
transformations = list()
for rule_idx in rule_reliabilities.keys():
rule = Rule.objects.get(
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
)
for rule_uuid in rule_reliabilities.keys():
rule = Rule.objects.get(uuid=rule_uuid)
rule_data = rule.simple_json()
rule_data["image"] = f"{rule.url}?image=svg"
neighbors = []
for n, n_prob in zip(
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]
neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
):
neighbor = n.simple_json()
neighbor["image"] = f"{n.url}?image=svg"
@ -2933,14 +2902,14 @@ class ApplicabilityDomain(EnviPathModel):
transformation = {
"rule": rule_data,
"reliability": rule_reliabilities[rule_idx],
"reliability": rule_reliabilities[rule_uuid],
# We're setting it here to False, as we don't know whether "assess" is called during pathway
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
"is_predicted": False,
"local_compatibility": local_compatibilities[rule_idx],
"probability": preds[rule_idx].probability,
"local_compatibility": local_compatibilities[rule_uuid],
"probability": preds[rule_to_i[rule_uuid]].probability,
"transformation_products": [
x.product_set for x in preds[rule_idx].product_sets
x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
],
"times_triggered": ds.times_triggered(str(rule.uuid)),
"neighbors": neighbors,
@ -2958,32 +2927,21 @@ class ApplicabilityDomain(EnviPathModel):
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
from utilities.ml import tanimoto_distance
distances = [
(i, tanimoto_distance(classify_instance, train))
for i, train in enumerate(train_instances)
]
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
return distances
@staticmethod
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
accuracy = 0.0
for n in neighbours:
obs = n[1].y()[0][rule_idx]
pred = preds[n[0]][rule_idx]
if obs and pred:
tp += 1
elif not obs and pred:
fp += 1
elif obs and not pred:
fn += 1
else:
tn += 1
# Jaccard Index
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn)
import polars as pl
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
# Compute tp, tn, fp, fn using polars expressions
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn)
return accuracy
@ -3084,44 +3042,24 @@ class EnviFormer(PackageBasedModel):
self.save()
start = datetime.now()
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
co2 = {"C(=O)=O", "O=C=O"}
ds = []
for reaction in self._get_reactions():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
if products not in co2:
ds.append(f"{educts}>>{products}")
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(f, "w") as d_file:
json.dump(ds, d_file)
ds.save(f)
return ds
def load_dataset(self) -> "Dataset":
def load_dataset(self):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(ds_path) as d_file:
ds = json.load(d_file)
return ds
return EnviFormerDataset.load(ds_path)
def _fit_model(self, ds):
# Call to enviFormer's fine_tune function and return the model
from enviformer.finetune import fine_tune
start = datetime.now()
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model
@ -3137,7 +3075,7 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"}
return args
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -3152,21 +3090,20 @@ class EnviFormer(PackageBasedModel):
self.model_status = self.EVALUATING
self.save()
def evaluate_sg(test_reactions, predictions, model_thresh):
def evaluate_sg(test_ds, predictions, model_thresh):
# Group the true products of reactions with the same reactant together
assert len(test_ds) == len(predictions)
true_dict = {}
for r in test_reactions:
reactant, true_product_set = r.split(">>")
for r in test_ds:
reactant, true_product_set = r
true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
assert len(test_reactions) == len(predictions)
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
# Group the predicted products of reactions with the same reactant together
pred_dict = {}
for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_reactions[k].split(">>")
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split("."))
@ -3201,7 +3138,7 @@ class EnviFormer(PackageBasedModel):
break
# Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted)
prec = {
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
@ -3280,47 +3217,32 @@ class EnviFormer(PackageBasedModel):
# If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0:
ds = []
for reaction in Reaction.objects.filter(
package__in=self.eval_packages.all()
).distinct():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
package__in=self.eval_packages.all()).distinct())
test_result = self.model.predict_batch(ds.X())
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
from enviformer.finetune import fine_tune
ds = self.load_dataset()
n_splits = 20
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint.
single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index]
test = [ds[i] for i in test_index]
train = ds[train_index]
test = ds[test_index]
start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
)
model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
test_result = model.predict_batch(test.X())
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results)
@ -3391,31 +3313,15 @@ class EnviFormer(PackageBasedModel):
for pathway in train_pathways:
for reaction in pathway.edges:
reaction = reaction.edge_label
if any(
[
educt in test_educts
for educt in reaction_to_educts[str(reaction.uuid)]
]
):
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
overlap += 1
continue
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
train_reactions.append(f"{educts}>>{products}")
train_reactions.append(reaction)
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
)
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update(

View File

@ -892,7 +892,7 @@ def package_model(request, package_uuid, model_uuid):
return JsonResponse(res, safe=False)
else:
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request)