diff --git a/epdb/logic.py b/epdb/logic.py index 82324100..d8415977 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -1552,9 +1552,7 @@ class SPathway(object): if sub.app_domain_assessment is None: if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess( - sub.smiles - )[0] + app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles) if self.persist is not None: n = self.snode_persist_lookup[sub] @@ -1586,11 +1584,7 @@ class SPathway(object): app_domain_assessment = None if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = ( - self.prediction_setting.model.app_domain.assess(c)[ - 0 - ] - ) + app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c)) self.smiles_to_node[c] = SNode( c, sub.depth + 1, app_domain_assessment diff --git a/epdb/models.py b/epdb/models.py index ceabd6d6..e3ab0476 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -29,7 +29,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning +from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \ + EnviFormerDataset, Dataset logger = logging.getLogger(__name__) @@ -2178,7 +2179,7 @@ class PackageBasedModel(EPModel): applicable_rules = self.applicable_rules reactions = list(self._get_reactions()) - ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True) + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True) end = datetime.now() logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") @@ -2187,7 +2188,7 @@ class PackageBasedModel(EPModel): ds.save(f) return ds - def load_dataset(self) -> "Dataset": + def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset": ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") return Dataset.load(ds_path) @@ -2228,7 +2229,7 @@ class PackageBasedModel(EPModel): self.model_status = self.BUILT_NOT_EVALUATED self.save() - def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None): + def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs): if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") @@ -2346,37 +2347,37 @@ class PackageBasedModel(EPModel): eval_reactions = list( Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() ) - ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) + ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) if isinstance(self, RuleBasedRelativeReasoning): - X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() else: - X = np.array(ds.X(na_replacement=np.nan)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(na_replacement=np.nan).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: ds = self.load_dataset() if isinstance(self, RuleBasedRelativeReasoning): - X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() else: - X = np.array(ds.X(na_replacement=np.nan)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(na_replacement=np.nan).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() - n_splits = 20 + n_splits = kwargs.get("n_splits", 20) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) splits = list(shuff.split(X)) from joblib import Parallel, delayed - models = Parallel(n_jobs=10)( + models = Parallel(n_jobs=min(10, len(splits)))( delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits ) - evaluations = Parallel(n_jobs=10)( + evaluations = Parallel(n_jobs=min(10, len(splits)))( delayed(evaluate_sg)(model, X, y, test_index, self.threshold) for model, (_, test_index) in zip(models, splits) ) @@ -2588,11 +2589,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel): return rbrr - def _fit_model(self, ds: Dataset): + def _fit_model(self, ds: RuleBasedDataset): X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None) model = RelativeReasoning( start_index=ds.triggered()[0], - end_index=ds.triggered()[1], + end_index=ds.triggered()[-1], ) model.fit(X, y) return model @@ -2602,7 +2603,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel): return { "clz": "RuleBaseRelativeReasoning", "start_index": ds.triggered()[0], - "end_index": ds.triggered()[1], + "end_index": ds.triggered()[-1], } def _save_model(self, model): @@ -2690,11 +2691,11 @@ class MLRelativeReasoning(PackageBasedModel): return mlrr - def _fit_model(self, ds: Dataset): + def _fit_model(self, ds: RuleBasedDataset): X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) - model.fit(X, y) + model.fit(X.to_numpy(), y.to_numpy()) return model def _model_args(self): @@ -2717,7 +2718,7 @@ class MLRelativeReasoning(PackageBasedModel): start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) - pred = self.model.predict_proba(classify_ds.X()) + pred = self.model.predict_proba(classify_ds.X().to_numpy()) res = MLRelativeReasoning.combine_products_and_probs( self.applicable_rules, pred[0], classify_prods[0] @@ -2762,7 +2763,9 @@ class ApplicabilityDomain(EnviPathModel): @cached_property def training_set_probs(self): - return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")) + ds = self.model.load_dataset() + col_ids = ds.block_indices("prob") + return ds[:, col_ids] def build(self): ds = self.model.load_dataset() @@ -2770,9 +2773,9 @@ class ApplicabilityDomain(EnviPathModel): start = datetime.now() # Get Trainingset probs and dump them as they're required when using the app domain - probs = self.model.model.predict_proba(ds.X()) - f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl") - joblib.dump(probs, f) + probs = self.model.model.predict_proba(ds.X().to_numpy()) + ds.add_probs(probs) + ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl")) ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours) ad.build(ds) @@ -2795,16 +2798,19 @@ class ApplicabilityDomain(EnviPathModel): joblib.dump(ad, f) def assess(self, structure: Union[str, "CompoundStructure"]): + return self.assess_batch([structure])[0] + + def assess_batch(self, structures: List["CompoundStructure | str"]): ds = self.model.load_dataset() - if isinstance(structure, CompoundStructure): - smiles = structure.smiles - else: - smiles = structure + smiles = [] + for struct in structures: + if isinstance(struct, CompoundStructure): + smiles.append(structures.smiles) + else: + smiles.append(structures) - assessment_ds, assessment_prods = ds.classification_dataset( - [structure], self.model.applicable_rules - ) + assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules) # qualified_neighbours_per_rule is a nested dictionary structured as: # { @@ -2817,82 +2823,47 @@ class ApplicabilityDomain(EnviPathModel): # it identifies all training structures that have the same trigger reaction activated (i.e., value 1). # This is used to find "qualified neighbours" — training examples that share the same triggered feature # with a given assessment structure under a particular rule. - qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict( - lambda: defaultdict(list) - ) + qualified_neighbours_per_rule: Dict = {} - for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())): - feature = ds.columns[feature_index] - if feature.startswith("trig_"): - # TODO unroll loop - for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)): - if int(cx[feature_index]) == 1: - for j, tx in enumerate(ds.X(exclude_id_col=False)): - if int(tx[feature_index]) == 1: - qualified_neighbours_per_rule[i][rule_idx].append(j) - - probs = self.training_set_probs - # preds = self.model.model.predict_proba(assessment_ds.X()) + import polars as pl + # Select only the triggered columns + for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)): + # Find the rules the structure triggers. For each rule, filter the training dataset to rows that also + # trigger that rule. + train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1)) + for trig_uuid, value in row.items() if value == 1} + qualified_neighbours_per_rule[i] = train_trig + rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)} preds = self.model.combine_products_and_probs( self.model.applicable_rules, - self.model.model.predict_proba(assessment_ds.X())[0], + self.model.model.predict_proba(assessment_ds.X().to_numpy())[0], assessment_prods[0], ) assessments = list() - # loop through our assessment dataset - for i, instance in enumerate(assessment_ds): + for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]): rule_reliabilities = dict() local_compatibilities = dict() neighbours_per_rule = dict() neighbor_probs_per_rule = dict() # loop through rule indices together with the collected neighbours indices from train dataset - for rule_idx, vals in qualified_neighbours_per_rule[i].items(): - # collect the train dataset instances and store it along with the index (a.k.a. row number) of the - # train dataset - train_instances = [] - for v in vals: - train_instances.append((v, ds.at(v))) - - # sf is a tuple with start/end index of the features - sf = ds.struct_features() - - # compute tanimoto distance for all neighbours - # result ist a list of tuples with train index and computed distance - dists = self._compute_distances( - instance.X()[0][sf[0] : sf[1]], - [ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances], - ) - - dists_with_index = list() - for ti, dist in zip(train_instances, dists): - dists_with_index.append((ti[0], dist[1])) + for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items(): + # compute tanimoto distance for all neighbours and add to dataset + dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0], + train_instances[:, train_instances.struct_features()].to_numpy()) + train_instances = train_instances.with_columns(dist=pl.Series(dists)) # sort them in a descending way and take at most `self.num_neighbours` - dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True) - dists_with_index = dists_with_index[: self.num_neighbours] - + # TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending) + train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours] # compute average distance - rule_reliabilities[rule_idx] = ( - sum([d[1] for d in dists_with_index]) / len(dists_with_index) - if len(dists_with_index) > 0 - else 0.0 - ) - + rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item() # for local_compatibility we'll need the datasets for the indices having the highest similarity - neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index] - local_compatibilities[rule_idx] = self._compute_compatibility( - rule_idx, probs, neighbour_datasets - ) - neighbours_per_rule[rule_idx] = [ - CompoundStructure.objects.get(uuid=ds[1].structure_id()) - for ds in neighbour_datasets - ] - neighbor_probs_per_rule[rule_idx] = [ - probs[d[0]][rule_idx] for d in dists_with_index - ] + local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances) + neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"])) + neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list() ad_res = { "ad_params": { @@ -2903,23 +2874,21 @@ class ApplicabilityDomain(EnviPathModel): "local_compatibility_threshold": self.local_compatibilty_threshold, }, "assessment": { - "smiles": smiles, - "inside_app_domain": self.pca.is_applicable(instance)[0], + "smiles": smiles[i], + "inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0], }, } transformations = list() - for rule_idx in rule_reliabilities.keys(): - rule = Rule.objects.get( - uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "") - ) + for rule_uuid in rule_reliabilities.keys(): + rule = Rule.objects.get(uuid=rule_uuid) rule_data = rule.simple_json() rule_data["image"] = f"{rule.url}?image=svg" neighbors = [] for n, n_prob in zip( - neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx] + neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid] ): neighbor = n.simple_json() neighbor["image"] = f"{n.url}?image=svg" @@ -2936,14 +2905,14 @@ class ApplicabilityDomain(EnviPathModel): transformation = { "rule": rule_data, - "reliability": rule_reliabilities[rule_idx], + "reliability": rule_reliabilities[rule_uuid], # We're setting it here to False, as we don't know whether "assess" is called during pathway # prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime "is_predicted": False, - "local_compatibility": local_compatibilities[rule_idx], - "probability": preds[rule_idx].probability, + "local_compatibility": local_compatibilities[rule_uuid], + "probability": preds[rule_to_i[rule_uuid]].probability, "transformation_products": [ - x.product_set for x in preds[rule_idx].product_sets + x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets ], "times_triggered": ds.times_triggered(str(rule.uuid)), "neighbors": neighbors, @@ -2961,32 +2930,21 @@ class ApplicabilityDomain(EnviPathModel): def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]): from utilities.ml import tanimoto_distance - distances = [ - (i, tanimoto_distance(classify_instance, train)) - for i, train in enumerate(train_instances) - ] + distances = [tanimoto_distance(classify_instance, train) for train in train_instances] return distances - @staticmethod - def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]): - tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0 + def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"): accuracy = 0.0 - - for n in neighbours: - obs = n[1].y()[0][rule_idx] - pred = preds[n[0]][rule_idx] - if obs and pred: - tp += 1 - elif not obs and pred: - fp += 1 - elif obs and not pred: - fn += 1 - else: - tn += 1 - # Jaccard Index - if tp + tn > 0.0: - accuracy = (tp + tn) / (tp + tn + fp + fn) - + import polars as pl + obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean), + pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold) + # Compute tp, tn, fp, fn using polars expressions + tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height + tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height + fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height + fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height + if tp + tn > 0.0: + accuracy = (tp + tn) / (tp + tn + fp + fn) return accuracy @@ -3087,44 +3045,24 @@ class EnviFormer(PackageBasedModel): self.save() start = datetime.now() - # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently - co2 = {"C(=O)=O", "O=C=O"} - ds = [] - for reaction in self._get_reactions(): - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - if products not in co2: - ds.append(f"{educts}>>{products}") + ds = EnviFormerDataset.generate_dataset(self._get_reactions()) end = datetime.now() logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") - with open(f, "w") as d_file: - json.dump(ds, d_file) + ds.save(f) return ds - def load_dataset(self) -> "Dataset": + def load_dataset(self): ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") - with open(ds_path) as d_file: - ds = json.load(d_file) - return ds + return EnviFormerDataset.load(ds_path) def _fit_model(self, ds): # Call to enviFormer's fine_tune function and return the model from enviformer.finetune import fine_tune start = datetime.now() - model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) + model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) end = datetime.now() logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") return model @@ -3140,7 +3078,7 @@ class EnviFormer(PackageBasedModel): args = {"clz": "EnviFormer"} return args - def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None): + def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs): if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") @@ -3155,21 +3093,20 @@ class EnviFormer(PackageBasedModel): self.model_status = self.EVALUATING self.save() - def evaluate_sg(test_reactions, predictions, model_thresh): + def evaluate_sg(test_ds, predictions, model_thresh): # Group the true products of reactions with the same reactant together + assert len(test_ds) == len(predictions) true_dict = {} - for r in test_reactions: - reactant, true_product_set = r.split(">>") + for r in test_ds: + reactant, true_product_set = r true_product_set = {p for p in true_product_set.split(".")} true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] - assert len(test_reactions) == len(predictions) - assert sum(len(v) for v in true_dict.values()) == len(test_reactions) # Group the predicted products of reactions with the same reactant together pred_dict = {} for k, pred in enumerate(predictions): pred_smiles, pred_proba = zip(*pred.items()) - reactant, true_product = test_reactions[k].split(">>") + reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"] pred_dict.setdefault(reactant, {"predict": [], "scores": []}) for smiles, proba in zip(pred_smiles, pred_proba): smiles = set(smiles.split(".")) @@ -3204,7 +3141,7 @@ class EnviFormer(PackageBasedModel): break # Recall is TP (correct) / TP + FN (len(test_reactions)) - rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()} + rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()} # Precision is TP (correct) / TP + FP (predicted) prec = { f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items() @@ -3283,47 +3220,32 @@ class EnviFormer(PackageBasedModel): # If there are eval packages perform single generation evaluation on them instead of random splits if self.eval_packages.count() > 0: - ds = [] - for reaction in Reaction.objects.filter( - package__in=self.eval_packages.all() - ).distinct(): - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - ds.append(f"{educts}>>{products}") - test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds]) + ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter( + package__in=self.eval_packages.all()).distinct()) + test_result = self.model.predict_batch(ds.X()) single_gen_result = evaluate_sg(ds, test_result, self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: from enviformer.finetune import fine_tune ds = self.load_dataset() - n_splits = 20 + n_splits = kwargs.get("n_splits", 20) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42) # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models # this helps reduce the memory footprint. single_gen_results = [] for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): - train = [ds[i] for i in train_index] - test = [ds[i] for i in test_index] + train = ds[train_index] + test = ds[test_index] start = datetime.now() - model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) + model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) end = datetime.now() logger.debug( f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" ) model.to(s.ENVIFORMER_DEVICE) - test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) + test_result = model.predict_batch(test.X()) single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) self.eval_results = self.compute_averages(single_gen_results) @@ -3394,31 +3316,15 @@ class EnviFormer(PackageBasedModel): for pathway in train_pathways: for reaction in pathway.edges: reaction = reaction.edge_label - if any( - [ - educt in test_educts - for educt in reaction_to_educts[str(reaction.uuid)] - ] - ): + if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): overlap += 1 continue - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - train_reactions.append(f"{educts}>>{products}") + train_reactions.append(reaction) + train_ds = EnviFormerDataset.generate_dataset(train_reactions) logging.debug( f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" ) - model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") + model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}") multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) self.eval_results.update( diff --git a/epdb/views.py b/epdb/views.py index 10a8027b..4a3a131a 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -894,7 +894,7 @@ def package_model(request, package_uuid, model_uuid): return JsonResponse(res, safe=False) else: - app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0] + app_domain_assessment = current_model.app_domain.assess(stand_smiles) return JsonResponse(app_domain_assessment, safe=False) context = get_base_context(request) diff --git a/pyproject.toml b/pyproject.toml index 0dfbe118..347f1e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,12 @@ dependencies = [ "scikit-learn>=1.6.1", "sentry-sdk[django]>=2.32.0", "setuptools>=80.8.0", - "nh3==0.3.2" + "nh3==0.3.2", + "polars==1.35.1", ] [tool.uv.sources] -enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" } +enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" } envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" } envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"} envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" } diff --git a/tests/test_dataset.py b/tests/test_dataset.py index eb5a7924..300ab1ed 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,8 +1,10 @@ +import os.path +from tempfile import TemporaryDirectory from django.test import TestCase - from epdb.logic import PackageManager -from epdb.models import Reaction, Compound, User, Rule -from utilities.ml import Dataset +from epdb.models import Reaction, Compound, User, Rule, Package +from utilities.chem import FormatConverter +from utilities.ml import RuleBasedDataset, EnviFormerDataset class DatasetTest(TestCase): @@ -41,12 +43,108 @@ class DatasetTest(TestCase): super(DatasetTest, cls).setUpClass() cls.user = User.objects.get(username="anonymous") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") + cls.BBD_SUBSET = Package.objects.get(name="Fixtures") - def test_smoke(self): + def test_generate_dataset(self): + """Test generating dataset does not crash""" + self.generate_rule_dataset() + + def test_indexing(self): + """Test indexing a few different ways to check for crashes""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds[5]) + print(ds[2, 5]) + print(ds[3:6, 2:8]) + print(ds[:2, "structure_id"]) + + def test_add_rows(self): + """Test adding one row and adding multiple rows""" + ds, reactions, rules = self.generate_rule_dataset() + ds.add_row(list(ds.df.row(1))) + ds.add_rows([list(ds.df.row(i)) for i in range(5)]) + + def test_times_triggered(self): + """Check getting times triggered for a rule id""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.times_triggered(rules[0].uuid)) + + def test_block_indices(self): + """Test the usages of _block_indices""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.struct_features()) + print(ds.triggered()) + print(ds.observed()) + + def test_structure_id(self): + """Check getting a structure id from row index""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.structure_id(0)) + + def test_x(self): + """Test getting X portion of the dataframe""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.X().df.head()) + + def test_trig(self): + """Test getting the triggered portion of the dataframe""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.trig().df.head()) + + def test_y(self): + """Test getting the Y portion of the dataframe""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.y().df.head()) + + def test_classification_dataset(self): + """Test making the classification dataset""" + ds, reactions, rules = self.generate_rule_dataset() + compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)] + class_ds, products = ds.classification_dataset(compounds, rules) + print(class_ds.df.head(5)) + print(products[:5]) + + def test_extra_features(self): + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan]) + print(ds.shape) + + def test_to_arff(self): + """Test exporting the arff version of the dataset""" + ds, reactions, rules = self.generate_rule_dataset() + ds.to_arff("dataset_arff_test.arff") + + def test_save_load(self): + """Test saving and loading dataset""" + with TemporaryDirectory() as tmpdir: + ds, reactions, rules = self.generate_rule_dataset() + ds.save(os.path.join(tmpdir, "save_dataset.pkl")) + ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl")) + self.assertTrue(ds.df.equals(ds_loaded.df)) + + def test_dataset_example(self): + """Test with a concrete example checking dataset size""" reactions = [r for r in Reaction.objects.filter(package=self.package)] applicable_rules = [self.rule1] - ds = Dataset.generate_dataset(reactions, applicable_rules) + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) self.assertEqual(len(ds.y()), 1) - self.assertEqual(sum(ds.y()[0]), 1) + self.assertEqual(ds.y().df.item(), 1) + + def test_enviformer_dataset(self): + ds, reactions = self.generate_enviformer_dataset() + print(ds.X().head()) + print(ds.y().head()) + + def generate_rule_dataset(self): + """Generate a RuleBasedDataset from test package data""" + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) + return ds, reactions, applicable_rules + + def generate_enviformer_dataset(self): + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + ds = EnviFormerDataset.generate_dataset(reactions) + return ds, reactions diff --git a/tests/test_enviformer.py b/tests/test_enviformer.py index 647433fc..b7994d4a 100644 --- a/tests/test_enviformer.py +++ b/tests/test_enviformer.py @@ -42,13 +42,11 @@ class EnviFormerTest(TestCase): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] - mod = EnviFormer.create( - self.package, data_package_objs, eval_packages_objs, threshold=threshold - ) + mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold) mod.build_dataset() mod.build_model() - mod.evaluate_model(True, eval_packages_objs) + mod.evaluate_model(True, eval_packages_objs, n_splits=2) mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") @@ -57,12 +55,9 @@ class EnviFormerTest(TestCase): with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] - eval_packages_objs = [self.BBD_SUBSET] mods = [] for _ in range(4): - mod = EnviFormer.create( - self.package, data_package_objs, eval_packages_objs, threshold=threshold - ) + mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold) mod.build_dataset() mod.build_model() mods.append(mod) @@ -73,15 +68,11 @@ class EnviFormerTest(TestCase): # Test pathway prediction times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)] - print( - f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}" - ) + print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}") # Test eviction by performing three prediction with every model, twice. times = defaultdict(list) - for _ in range( - 2 - ): # Eviction should cause the second iteration here to have to reload the models + for _ in range(2): # Eviction should cause the second iteration here to have to reload the models for mod in mods: for _ in range(3): times[mod.pk].append(measure_predict(mod)) diff --git a/tests/test_model.py b/tests/test_model.py index f0355be9..50dfee19 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,7 +4,7 @@ import numpy as np from django.test import TestCase from epdb.logic import PackageManager -from epdb.models import User, MLRelativeReasoning, Package +from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning class ModelTest(TestCase): @@ -17,7 +17,7 @@ class ModelTest(TestCase): cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.BBD_SUBSET = Package.objects.get(name="Fixtures") - def test_smoke(self): + def test_mlrr(self): with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) @@ -35,21 +35,9 @@ class ModelTest(TestCase): description="Created MLRelativeReasoning in Testcase", ) - # mod = RuleBasedRelativeReasoning.create( - # self.package, - # rule_package_objs, - # data_package_objs, - # eval_packages_objs, - # threshold=threshold, - # min_count=5, - # max_count=0, - # name='ECC - BBD - 0.5', - # description='Created MLRelativeReasoning in Testcase', - # ) - mod.build_dataset() mod.build_model() - mod.evaluate_model(True, eval_packages_objs) + mod.evaluate_model(True, eval_packages_objs, n_splits=2) results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") @@ -70,3 +58,57 @@ class ModelTest(TestCase): # from pprint import pprint # pprint(mod.eval_results) + + def test_applicability(self): + with TemporaryDirectory() as tmpdir: + with self.settings(MODEL_DIR=tmpdir): + threshold = float(0.5) + + rule_package_objs = [self.BBD_SUBSET] + data_package_objs = [self.BBD_SUBSET] + eval_packages_objs = [self.BBD_SUBSET] + + mod = MLRelativeReasoning.create( + self.package, + rule_package_objs, + data_package_objs, + threshold=threshold, + name="ECC - BBD - 0.5", + description="Created MLRelativeReasoning in Testcase", + build_app_domain=True, # To test the applicability domain this must be True + app_domain_num_neighbours=5, + app_domain_local_compatibility_threshold=0.5, + app_domain_reliability_threshold=0.5, + ) + + mod.build_dataset() + mod.build_model() + mod.evaluate_model(True, eval_packages_objs, n_splits=2) + + results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") + + def test_rbrr(self): + with TemporaryDirectory() as tmpdir: + with self.settings(MODEL_DIR=tmpdir): + threshold = float(0.5) + + rule_package_objs = [self.BBD_SUBSET] + data_package_objs = [self.BBD_SUBSET] + eval_packages_objs = [self.BBD_SUBSET] + + mod = RuleBasedRelativeReasoning.create( + self.package, + rule_package_objs, + data_package_objs, + threshold=threshold, + min_count=5, + max_count=0, + name='ECC - BBD - 0.5', + description='Created MLRelativeReasoning in Testcase', + ) + + mod.build_dataset() + mod.build_model() + mod.evaluate_model(True, eval_packages_objs, n_splits=2) + + results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") diff --git a/utilities/chem.py b/utilities/chem.py index 250ccfb6..d7a68d75 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING from indigo import Indigo, IndigoException, IndigoObject from indigo.renderer import IndigoRenderer from rdkit import Chem, rdBase -from rdkit.Chem import MACCSkeys, Descriptors +from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator from rdkit.Chem import rdChemReactions from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.MolStandardize import rdMolStandardize @@ -107,6 +107,13 @@ class FormatConverter(object): bitvec = MACCSkeys.GenMACCSKeys(mol) return bitvec.ToList() + @staticmethod + def morgan(smiles, radius=3, fpSize=2048): + finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize) + mol = Chem.MolFromSmiles(smiles) + fp = finger_gen.GetFingerprint(mol) + return fp.ToList() + @staticmethod def get_functional_groups(smiles: str) -> List[str]: res = list() diff --git a/utilities/ml.py b/utilities/ml.py index a93fafd9..5df5dce8 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -5,11 +5,14 @@ import logging from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import List, Dict, Set, Tuple, TYPE_CHECKING +from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable +from abc import ABC, abstractmethod import networkx as nx import numpy as np +from envipy_plugins import Descriptor from numpy.random import default_rng +import polars as pl from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.dummy import DummyClassifier @@ -26,70 +29,281 @@ if TYPE_CHECKING: from epdb.models import Rule, CompoundStructure, Reaction -class Dataset: - def __init__( - self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None - ): - self.columns: List[str] = columns - self.num_labels: int = num_labels - - if data is None: - self.data: List[List[str | int | float]] = list() +class Dataset(ABC): + def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None): + if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__ + self.df = data else: - self.data = data + # Build either an empty dataframe with columns or fill it with list of list data + if data is not None and len(columns) != len(data[0]): + raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns") + if columns is None: + raise ValueError("Columns can't be None if data is not already a DataFrame") + self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None) - self.num_features: int = len(columns) - self.num_labels - self._struct_features: Tuple[int, int] = self._block_indices("feature_") - self._triggered: Tuple[int, int] = self._block_indices("trig_") - self._observed: Tuple[int, int] = self._block_indices("obs_") + def add_rows(self, rows: List[List[str | int | float]]): + """Add rows to the dataset. Extends the polars dataframe stored in self""" + if len(self.columns) != len(rows[0]): + raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns") + new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None) + self.df.extend(new_rows) - def _block_indices(self, prefix) -> Tuple[int, int]: + def add_row(self, row: List[str | int | float]): + """See add_rows""" + self.add_rows([row]) + + def block_indices(self, prefix) -> List[int]: + """Find the indexes in column labels that has the prefix""" indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) + return indices - return min(indices), max(indices) + @property + def columns(self) -> List[str]: + """Use the polars dataframe columns""" + return self.df.columns - def structure_id(self): - return self.data[0][0] + @property + def shape(self): + return self.df.shape - def add_row(self, row: List[str | int | float]): - if len(self.columns) != len(row): - raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") - self.data.append(row) + @abstractmethod + def X(self, **kwargs): + pass - def times_triggered(self, rule_uuid) -> int: - idx = self.columns.index(f"trig_{rule_uuid}") + @abstractmethod + def y(self, **kwargs): + pass - times_triggered = 0 - for row in self.data: - if row[idx] == 1: - times_triggered += 1 - - return times_triggered - - def struct_features(self) -> Tuple[int, int]: - return self._struct_features - - def triggered(self) -> Tuple[int, int]: - return self._triggered - - def observed(self) -> Tuple[int, int]: - return self._observed - - def at(self, position: int) -> Dataset: - return Dataset(self.columns, self.num_labels, [self.data[position]]) - - def limit(self, limit: int) -> Dataset: - return Dataset(self.columns, self.num_labels, self.data[:limit]) + @staticmethod + @abstractmethod + def generate_dataset(reactions, *args, **kwargs): + pass def __iter__(self): - return (self.at(i) for i, _ in enumerate(self.data)) + """Use polars iter_rows for iterating over the dataset""" + return self.df.iter_rows() + + def __getitem__(self, item): + """Item is passed to polars allowing for advanced indexing. + See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__""" + res = self.df[item] + if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe + return self.__class__(data=res) + else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item + return res + + def save(self, path: "Path | str"): + import pickle + + with open(path, "wb") as fh: + pickle.dump(self, fh) + + @staticmethod + def load(path: "str | Path") -> "Dataset": + import pickle + + return pickle.load(open(path, "rb")) + + def to_numpy(self): + return self.df.to_numpy() + + def __repr__(self): + return ( + f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>" + ) + + def __len__(self): + return len(self.df) + + def iter_rows(self, named=False): + return self.df.iter_rows(named=named) + + def filter(self, *predicates, **constraints): + return self.__class__(data=self.df.filter(*predicates, **constraints)) + + def select(self, *exprs, **named_exprs): + return self.__class__(data=self.df.select(*exprs, **named_exprs)) + + def with_columns(self, *exprs, **name_exprs): + return self.__class__(data=self.df.with_columns(*exprs, **name_exprs)) + + def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False): + return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last, + multithreaded=multithreaded, maintain_order=maintain_order)) + + def item(self, row=None, column=None): + return self.df.item(row, column) + + def fill_nan(self, value): + return self.__class__(data=self.df.fill_nan(value)) + + @property + def height(self): + return self.df.height + + +class RuleBasedDataset(Dataset): + def __init__(self, num_labels=None, columns=None, data=None): + super().__init__(columns, data) + # Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init. + self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c]) + # Pre-calculate the ids of columns for features/labels, useful later in X and y + self._struct_features: List[int] = self.block_indices("feature_") + self._triggered: List[int] = self.block_indices("trig_") + self._observed: List[int] = self.block_indices("obs_") + self.feature_cols: List[int] = self._struct_features + self._triggered + self.num_features: int = len(self.feature_cols) + self.has_probs = False + + def times_triggered(self, rule_uuid) -> int: + """Count how many times a rule is triggered by the number of rows with one in the rules trig column""" + return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height + + def struct_features(self) -> List[int]: + return self._struct_features + + def triggered(self) -> List[int]: + return self._triggered + + def observed(self) -> List[int]: + return self._observed + + def structure_id(self, index: int): + """Get the UUID of a compound""" + return self.item(index, "structure_id") + + def X(self, exclude_id_col=True, na_replacement=0): + """Get all the feature and trig columns""" + _col_ids = self.feature_cols + if not exclude_id_col: + _col_ids = [0] + _col_ids + res = self[:, _col_ids] + if na_replacement is not None: + res.df = res.df.fill_null(na_replacement) + return res + + def trig(self, na_replacement=0): + """Get all the trig columns""" + res = self[:, self._triggered] + if na_replacement is not None: + res.df = res.df.fill_null(na_replacement) + return res + + def y(self, na_replacement=0): + """Get all the obs columns""" + res = self[:, self._observed] + if na_replacement is not None: + res.df = res.df.fill_null(na_replacement) + return res + + @staticmethod + def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None): + if feat_funcs is None: + feat_funcs = [FormatConverter.maccs] + _structures = set() # Get all the structures + for r in reactions: + _structures.update(r.educts.all()) + if not educts_only: + _structures.update(r.products.all()) + + compounds = sorted(_structures, key=lambda x: x.url) + triggered: Dict[str, Set[str]] = defaultdict(set) + observed: Set[str] = set() + + # Apply rules on collected compounds and store tps + for i, comp in enumerate(compounds): + logger.debug(f"{i + 1}/{len(compounds)}...") + + for rule in applicable_rules: + product_sets = rule.apply(comp.smiles) + if len(product_sets) == 0: + continue + + key = f"{rule.uuid} + {comp.uuid}" + if key in triggered: + logger.info(f"{key} already present. Duplicate reaction?") + + for prod_set in product_sets: + for smi in prod_set: + try: + smi = FormatConverter.standardize(smi, remove_stereo=True) + except Exception: + logger.debug(f"Standardizing SMILES failed for {smi}") + triggered[key].add(smi) + + for i, r in enumerate(reactions): + logger.debug(f"{i + 1}/{len(reactions)}...") + + if len(r.educts.all()) != 1: + logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") + continue + + for comp in r.educts.all(): + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + if key not in triggered: + continue + + # standardize products from reactions for comparison + standardized_products = [] + for cs in r.products.all(): + smi = cs.smiles + try: + smi = FormatConverter.standardize(smi, remove_stereo=True) + except Exception as e: + logger.debug(f"Standardizing SMILES failed for {smi}") + standardized_products.append(smi) + if len(set(standardized_products).difference(triggered[key])) == 0: + observed.add(key) + feat_columns = [] + for feat_func in feat_funcs: + if isinstance(feat_func, Descriptor): + feats = feat_func.get_molecule_descriptors(compounds[0].smiles) + else: + feats = feat_func(compounds[0].smiles) + start_i = len(feat_columns) + feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)]) + ds_columns = (["structure_id"] + + feat_columns + + [f"trig_{r.uuid}" for r in applicable_rules] + + [f"obs_{r.uuid}" for r in applicable_rules]) + rows = [] + + for i, comp in enumerate(compounds): + # Features + feats = [] + for feat_func in feat_funcs: + if isinstance(feat_func, Descriptor): + feat = feat_func.get_molecule_descriptors(comp.smiles) + else: + feat = feat_func(comp.smiles) + feats.extend(feat) + trig = [] + obs = [] + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + # Check triggered + if key in triggered: + trig.append(1) + else: + trig.append(0) + # Check obs + if key in observed: + obs.append(1) + elif key not in triggered: + obs.append(None) + else: + obs.append(0) + rows.append([str(comp.uuid)] + feats + trig + obs) + ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows) + return ds def classification_dataset( self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] - ) -> Tuple[Dataset, List[List[PredictionResult]]]: + ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: @@ -113,186 +327,18 @@ class Dataset: else: trig.append(0) prods.append([]) - - classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) + new_row = [struct_id] + features + trig + ([-1] * len(trig)) + if self.has_probs: + new_row += [-1] * len(trig) + classify_data.append(new_row) classify_products.append(prods) + ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data) + return ds, classify_products - return Dataset( - columns=self.columns, num_labels=self.num_labels, data=classify_data - ), classify_products - - @staticmethod - def generate_dataset( - reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True - ) -> Dataset: - _structures = set() - - for r in reactions: - for e in r.educts.all(): - _structures.add(e) - - if not educts_only: - for e in r.products: - _structures.add(e) - - compounds = sorted(_structures, key=lambda x: x.url) - - triggered: Dict[str, Set[str]] = defaultdict(set) - observed: Set[str] = set() - - # Apply rules on collected compounds and store tps - for i, comp in enumerate(compounds): - logger.debug(f"{i + 1}/{len(compounds)}...") - - for rule in applicable_rules: - product_sets = rule.apply(comp.smiles) - - if len(product_sets) == 0: - continue - - key = f"{rule.uuid} + {comp.uuid}" - - if key in triggered: - logger.info(f"{key} already present. Duplicate reaction?") - - for prod_set in product_sets: - for smi in prod_set: - try: - smi = FormatConverter.standardize(smi, remove_stereo=True) - except Exception: - # :shrug: - logger.debug(f"Standardizing SMILES failed for {smi}") - pass - - triggered[key].add(smi) - - for i, r in enumerate(reactions): - logger.debug(f"{i + 1}/{len(reactions)}...") - - if len(r.educts.all()) != 1: - logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") - continue - - for comp in r.educts.all(): - for rule in applicable_rules: - key = f"{rule.uuid} + {comp.uuid}" - - if key not in triggered: - continue - - # standardize products from reactions for comparison - standardized_products = [] - for cs in r.products.all(): - smi = cs.smiles - - try: - smi = FormatConverter.standardize(smi, remove_stereo=True) - except Exception as e: - # :shrug: - logger.debug(f"Standardizing SMILES failed for {smi}") - pass - - standardized_products.append(smi) - - if len(set(standardized_products).difference(triggered[key])) == 0: - observed.add(key) - else: - pass - - ds = None - - for i, comp in enumerate(compounds): - # Features - feat = FormatConverter.maccs(comp.smiles) - trig = [] - obs = [] - - for rule in applicable_rules: - key = f"{rule.uuid} + {comp.uuid}" - - # Check triggered - if key in triggered: - trig.append(1) - else: - trig.append(0) - - # Check obs - if key in observed: - obs.append(1) - elif key not in triggered: - obs.append(None) - else: - obs.append(0) - - if ds is None: - header = ( - ["structure_id"] - + [f"feature_{i}" for i, _ in enumerate(feat)] - + [f"trig_{r.uuid}" for r in applicable_rules] - + [f"obs_{r.uuid}" for r in applicable_rules] - ) - ds = Dataset(header, len(applicable_rules)) - - ds.add_row([str(comp.uuid)] + feat + trig + obs) - - return ds - - def X(self, exclude_id_col=True, na_replacement=0): - res = self.__getitem__( - (slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)) - ) - if na_replacement is not None: - res = [[x if x is not None else na_replacement for x in row] for row in res] - return res - - def trig(self, na_replacement=0): - res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1]))) - if na_replacement is not None: - res = [[x if x is not None else na_replacement for x in row] for row in res] - return res - - def y(self, na_replacement=0): - res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) - if na_replacement is not None: - res = [[x if x is not None else na_replacement for x in row] for row in res] - return res - - def __getitem__(self, key): - if not isinstance(key, tuple): - raise TypeError("Dataset must be indexed with dataset[rows, columns]") - - row_key, col_key = key - - # Normalize rows - if isinstance(row_key, int): - rows = [self.data[row_key]] - else: - rows = self.data[row_key] - - # Normalize columns - if isinstance(col_key, int): - res = [row[col_key] for row in rows] - else: - res = [ - [row[i] for i in range(*col_key.indices(len(row)))] - if isinstance(col_key, slice) - else [row[i] for i in col_key] - for row in rows - ] - - return res - - def save(self, path: "Path"): - import pickle - - with open(path, "wb") as fh: - pickle.dump(self, fh) - - @staticmethod - def load(path: "Path") -> "Dataset": - import pickle - - return pickle.load(open(path, "rb")) + def add_probs(self, probs): + col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed] + self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)]) + self.has_probs = True def to_arff(self, path: "Path"): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" @@ -304,7 +350,7 @@ class Dataset: arff += f"@attribute {c} {{0,1}}\n" arff += "\n@data\n" - for d in self.data: + for d in self: ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]]) xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]]) arff += f"{ys},{xs}\n" @@ -313,10 +359,40 @@ class Dataset: fh.write(arff) fh.flush() - def __repr__(self): - return ( - f"" - ) + +class EnviFormerDataset(Dataset): + def __init__(self, columns=None, data=None): + super().__init__(columns, data) + + def X(self): + """Return the educts""" + return self["educts"] + + def y(self): + """Return the products""" + return self["products"] + + @staticmethod + def generate_dataset(reactions, *args, **kwargs): + # Standardise reactions for the training data + stereo = kwargs.get("stereo", False) + rows = [] + for reaction in reactions: + e = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) + for smile in reaction.educts.all() + ] + ) + p = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) + for smile in reaction.products.all() + ] + ) + rows.append([e, p]) + ds = EnviFormerDataset(["educts", "products"], rows) + return ds class SparseLabelECC(BaseEstimator, ClassifierMixin): @@ -498,7 +574,7 @@ class EnsembleClassifierChain: self.classifiers = [] if self.num_labels is None: - self.num_labels = len(Y[0]) + self.num_labels = Y.shape[1] for p in range(self.num_chains): logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") @@ -529,7 +605,7 @@ class RelativeReasoning: def fit(self, X, Y): n_instances = len(Y) - n_attributes = len(Y[0]) + n_attributes = Y.shape[1] for i in range(n_attributes): for j in range(n_attributes): @@ -541,8 +617,8 @@ class RelativeReasoning: countboth = 0 for k in range(n_instances): - vi = Y[k][i] - vj = Y[k][j] + vi = Y[k, i] + vj = Y[k, j] if vi is None or vj is None: continue @@ -598,7 +674,7 @@ class ApplicabilityDomainPCA(PCA): self.min_vals = None self.max_vals = None - def build(self, train_dataset: "Dataset"): + def build(self, train_dataset: "RuleBasedDataset"): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca @@ -612,7 +688,7 @@ class ApplicabilityDomainPCA(PCA): instances_pca = self.transform(instances_scaled) return instances_pca - def is_applicable(self, classify_instances: "Dataset"): + def is_applicable(self, classify_instances: "RuleBasedDataset"): instances_pca = self.__transform(classify_instances.X()) is_applicable = []