diff --git a/epdb/logic.py b/epdb/logic.py index 19f03ae2..0aaebf32 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -1542,9 +1542,7 @@ class SPathway(object): if sub.app_domain_assessment is None: if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess( - sub.smiles - )[0] + app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles) if self.persist is not None: n = self.snode_persist_lookup[sub] @@ -1576,11 +1574,7 @@ class SPathway(object): app_domain_assessment = None if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = ( - self.prediction_setting.model.app_domain.assess(c)[ - 0 - ] - ) + app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c)) self.smiles_to_node[c] = SNode( c, sub.depth + 1, app_domain_assessment diff --git a/epdb/models.py b/epdb/models.py index 324fe301..3db8ce0f 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -28,7 +28,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score from sklearn.model_selection import ShuffleSplit from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils -from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning +from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \ + EnviFormerDataset, Dataset logger = logging.getLogger(__name__) @@ -2175,7 +2176,7 @@ class PackageBasedModel(EPModel): applicable_rules = self.applicable_rules reactions = list(self._get_reactions()) - ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True) + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True) end = datetime.now() logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") @@ -2184,7 +2185,7 @@ class PackageBasedModel(EPModel): ds.save(f) return ds - def load_dataset(self) -> "Dataset": + def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset": ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") return Dataset.load(ds_path) @@ -2225,7 +2226,7 @@ class PackageBasedModel(EPModel): self.model_status = self.BUILT_NOT_EVALUATED self.save() - def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None): + def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs): if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") @@ -2343,37 +2344,37 @@ class PackageBasedModel(EPModel): eval_reactions = list( Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() ) - ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) + ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) if isinstance(self, RuleBasedRelativeReasoning): - X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() else: - X = np.array(ds.X(na_replacement=np.nan)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(na_replacement=np.nan).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: ds = self.load_dataset() if isinstance(self, RuleBasedRelativeReasoning): - X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() else: - X = np.array(ds.X(na_replacement=np.nan)) - y = np.array(ds.y(na_replacement=np.nan)) + X = ds.X(na_replacement=np.nan).to_numpy() + y = ds.y(na_replacement=np.nan).to_numpy() - n_splits = 20 + n_splits = kwargs.get("n_splits", 20) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) splits = list(shuff.split(X)) from joblib import Parallel, delayed - models = Parallel(n_jobs=10)( + models = Parallel(n_jobs=min(10, len(splits)))( delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits ) - evaluations = Parallel(n_jobs=10)( + evaluations = Parallel(n_jobs=min(10, len(splits)))( delayed(evaluate_sg)(model, X, y, test_index, self.threshold) for model, (_, test_index) in zip(models, splits) ) @@ -2585,11 +2586,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel): return rbrr - def _fit_model(self, ds: Dataset): + def _fit_model(self, ds: RuleBasedDataset): X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None) model = RelativeReasoning( start_index=ds.triggered()[0], - end_index=ds.triggered()[1], + end_index=ds.triggered()[-1], ) model.fit(X, y) return model @@ -2599,7 +2600,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel): return { "clz": "RuleBaseRelativeReasoning", "start_index": ds.triggered()[0], - "end_index": ds.triggered()[1], + "end_index": ds.triggered()[-1], } def _save_model(self, model): @@ -2687,11 +2688,11 @@ class MLRelativeReasoning(PackageBasedModel): return mlrr - def _fit_model(self, ds: Dataset): + def _fit_model(self, ds: RuleBasedDataset): X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) - model.fit(X, y) + model.fit(X.to_numpy(), y.to_numpy()) return model def _model_args(self): @@ -2714,7 +2715,7 @@ class MLRelativeReasoning(PackageBasedModel): start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) - pred = self.model.predict_proba(classify_ds.X()) + pred = self.model.predict_proba(classify_ds.X().to_numpy()) res = MLRelativeReasoning.combine_products_and_probs( self.applicable_rules, pred[0], classify_prods[0] @@ -2759,7 +2760,9 @@ class ApplicabilityDomain(EnviPathModel): @cached_property def training_set_probs(self): - return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")) + ds = self.model.load_dataset() + col_ids = ds.block_indices("prob") + return ds[:, col_ids] def build(self): ds = self.model.load_dataset() @@ -2767,9 +2770,9 @@ class ApplicabilityDomain(EnviPathModel): start = datetime.now() # Get Trainingset probs and dump them as they're required when using the app domain - probs = self.model.model.predict_proba(ds.X()) - f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl") - joblib.dump(probs, f) + probs = self.model.model.predict_proba(ds.X().to_numpy()) + ds.add_probs(probs) + ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl")) ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours) ad.build(ds) @@ -2792,16 +2795,19 @@ class ApplicabilityDomain(EnviPathModel): joblib.dump(ad, f) def assess(self, structure: Union[str, "CompoundStructure"]): + return self.assess_batch([structure])[0] + + def assess_batch(self, structures: List["CompoundStructure | str"]): ds = self.model.load_dataset() - if isinstance(structure, CompoundStructure): - smiles = structure.smiles - else: - smiles = structure + smiles = [] + for struct in structures: + if isinstance(struct, CompoundStructure): + smiles.append(structures.smiles) + else: + smiles.append(structures) - assessment_ds, assessment_prods = ds.classification_dataset( - [structure], self.model.applicable_rules - ) + assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules) # qualified_neighbours_per_rule is a nested dictionary structured as: # { @@ -2814,82 +2820,47 @@ class ApplicabilityDomain(EnviPathModel): # it identifies all training structures that have the same trigger reaction activated (i.e., value 1). # This is used to find "qualified neighbours" — training examples that share the same triggered feature # with a given assessment structure under a particular rule. - qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict( - lambda: defaultdict(list) - ) + qualified_neighbours_per_rule: Dict = {} - for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())): - feature = ds.columns[feature_index] - if feature.startswith("trig_"): - # TODO unroll loop - for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)): - if int(cx[feature_index]) == 1: - for j, tx in enumerate(ds.X(exclude_id_col=False)): - if int(tx[feature_index]) == 1: - qualified_neighbours_per_rule[i][rule_idx].append(j) - - probs = self.training_set_probs - # preds = self.model.model.predict_proba(assessment_ds.X()) + import polars as pl + # Select only the triggered columns + for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)): + # Find the rules the structure triggers. For each rule, filter the training dataset to rows that also + # trigger that rule. + train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1)) + for trig_uuid, value in row.items() if value == 1} + qualified_neighbours_per_rule[i] = train_trig + rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)} preds = self.model.combine_products_and_probs( self.model.applicable_rules, - self.model.model.predict_proba(assessment_ds.X())[0], + self.model.model.predict_proba(assessment_ds.X().to_numpy())[0], assessment_prods[0], ) assessments = list() - # loop through our assessment dataset - for i, instance in enumerate(assessment_ds): + for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]): rule_reliabilities = dict() local_compatibilities = dict() neighbours_per_rule = dict() neighbor_probs_per_rule = dict() # loop through rule indices together with the collected neighbours indices from train dataset - for rule_idx, vals in qualified_neighbours_per_rule[i].items(): - # collect the train dataset instances and store it along with the index (a.k.a. row number) of the - # train dataset - train_instances = [] - for v in vals: - train_instances.append((v, ds.at(v))) - - # sf is a tuple with start/end index of the features - sf = ds.struct_features() - - # compute tanimoto distance for all neighbours - # result ist a list of tuples with train index and computed distance - dists = self._compute_distances( - instance.X()[0][sf[0] : sf[1]], - [ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances], - ) - - dists_with_index = list() - for ti, dist in zip(train_instances, dists): - dists_with_index.append((ti[0], dist[1])) + for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items(): + # compute tanimoto distance for all neighbours and add to dataset + dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0], + train_instances[:, train_instances.struct_features()].to_numpy()) + train_instances = train_instances.with_columns(dist=pl.Series(dists)) # sort them in a descending way and take at most `self.num_neighbours` - dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True) - dists_with_index = dists_with_index[: self.num_neighbours] - + # TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending) + train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours] # compute average distance - rule_reliabilities[rule_idx] = ( - sum([d[1] for d in dists_with_index]) / len(dists_with_index) - if len(dists_with_index) > 0 - else 0.0 - ) - + rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item() # for local_compatibility we'll need the datasets for the indices having the highest similarity - neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index] - local_compatibilities[rule_idx] = self._compute_compatibility( - rule_idx, probs, neighbour_datasets - ) - neighbours_per_rule[rule_idx] = [ - CompoundStructure.objects.get(uuid=ds[1].structure_id()) - for ds in neighbour_datasets - ] - neighbor_probs_per_rule[rule_idx] = [ - probs[d[0]][rule_idx] for d in dists_with_index - ] + local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances) + neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"])) + neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list() ad_res = { "ad_params": { @@ -2900,23 +2871,21 @@ class ApplicabilityDomain(EnviPathModel): "local_compatibility_threshold": self.local_compatibilty_threshold, }, "assessment": { - "smiles": smiles, - "inside_app_domain": self.pca.is_applicable(instance)[0], + "smiles": smiles[i], + "inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0], }, } transformations = list() - for rule_idx in rule_reliabilities.keys(): - rule = Rule.objects.get( - uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "") - ) + for rule_uuid in rule_reliabilities.keys(): + rule = Rule.objects.get(uuid=rule_uuid) rule_data = rule.simple_json() rule_data["image"] = f"{rule.url}?image=svg" neighbors = [] for n, n_prob in zip( - neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx] + neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid] ): neighbor = n.simple_json() neighbor["image"] = f"{n.url}?image=svg" @@ -2933,14 +2902,14 @@ class ApplicabilityDomain(EnviPathModel): transformation = { "rule": rule_data, - "reliability": rule_reliabilities[rule_idx], + "reliability": rule_reliabilities[rule_uuid], # We're setting it here to False, as we don't know whether "assess" is called during pathway # prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime "is_predicted": False, - "local_compatibility": local_compatibilities[rule_idx], - "probability": preds[rule_idx].probability, + "local_compatibility": local_compatibilities[rule_uuid], + "probability": preds[rule_to_i[rule_uuid]].probability, "transformation_products": [ - x.product_set for x in preds[rule_idx].product_sets + x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets ], "times_triggered": ds.times_triggered(str(rule.uuid)), "neighbors": neighbors, @@ -2958,32 +2927,21 @@ class ApplicabilityDomain(EnviPathModel): def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]): from utilities.ml import tanimoto_distance - distances = [ - (i, tanimoto_distance(classify_instance, train)) - for i, train in enumerate(train_instances) - ] + distances = [tanimoto_distance(classify_instance, train) for train in train_instances] return distances - @staticmethod - def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]): - tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0 + def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"): accuracy = 0.0 - - for n in neighbours: - obs = n[1].y()[0][rule_idx] - pred = preds[n[0]][rule_idx] - if obs and pred: - tp += 1 - elif not obs and pred: - fp += 1 - elif obs and not pred: - fn += 1 - else: - tn += 1 - # Jaccard Index - if tp + tn > 0.0: - accuracy = (tp + tn) / (tp + tn + fp + fn) - + import polars as pl + obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean), + pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold) + # Compute tp, tn, fp, fn using polars expressions + tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height + tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height + fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height + fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height + if tp + tn > 0.0: + accuracy = (tp + tn) / (tp + tn + fp + fn) return accuracy @@ -3084,44 +3042,24 @@ class EnviFormer(PackageBasedModel): self.save() start = datetime.now() - # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently - co2 = {"C(=O)=O", "O=C=O"} - ds = [] - for reaction in self._get_reactions(): - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - if products not in co2: - ds.append(f"{educts}>>{products}") + ds = EnviFormerDataset.generate_dataset(self._get_reactions()) end = datetime.now() logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") - with open(f, "w") as d_file: - json.dump(ds, d_file) + ds.save(f) return ds - def load_dataset(self) -> "Dataset": + def load_dataset(self): ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") - with open(ds_path) as d_file: - ds = json.load(d_file) - return ds + return EnviFormerDataset.load(ds_path) def _fit_model(self, ds): # Call to enviFormer's fine_tune function and return the model from enviformer.finetune import fine_tune start = datetime.now() - model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) + model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) end = datetime.now() logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") return model @@ -3137,7 +3075,7 @@ class EnviFormer(PackageBasedModel): args = {"clz": "EnviFormer"} return args - def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None): + def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs): if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") @@ -3152,21 +3090,20 @@ class EnviFormer(PackageBasedModel): self.model_status = self.EVALUATING self.save() - def evaluate_sg(test_reactions, predictions, model_thresh): + def evaluate_sg(test_ds, predictions, model_thresh): # Group the true products of reactions with the same reactant together + assert len(test_ds) == len(predictions) true_dict = {} - for r in test_reactions: - reactant, true_product_set = r.split(">>") + for r in test_ds: + reactant, true_product_set = r true_product_set = {p for p in true_product_set.split(".")} true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] - assert len(test_reactions) == len(predictions) - assert sum(len(v) for v in true_dict.values()) == len(test_reactions) # Group the predicted products of reactions with the same reactant together pred_dict = {} for k, pred in enumerate(predictions): pred_smiles, pred_proba = zip(*pred.items()) - reactant, true_product = test_reactions[k].split(">>") + reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"] pred_dict.setdefault(reactant, {"predict": [], "scores": []}) for smiles, proba in zip(pred_smiles, pred_proba): smiles = set(smiles.split(".")) @@ -3201,7 +3138,7 @@ class EnviFormer(PackageBasedModel): break # Recall is TP (correct) / TP + FN (len(test_reactions)) - rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()} + rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()} # Precision is TP (correct) / TP + FP (predicted) prec = { f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items() @@ -3280,47 +3217,32 @@ class EnviFormer(PackageBasedModel): # If there are eval packages perform single generation evaluation on them instead of random splits if self.eval_packages.count() > 0: - ds = [] - for reaction in Reaction.objects.filter( - package__in=self.eval_packages.all() - ).distinct(): - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - ds.append(f"{educts}>>{products}") - test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds]) + ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter( + package__in=self.eval_packages.all()).distinct()) + test_result = self.model.predict_batch(ds.X()) single_gen_result = evaluate_sg(ds, test_result, self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: from enviformer.finetune import fine_tune ds = self.load_dataset() - n_splits = 20 + n_splits = kwargs.get("n_splits", 20) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42) # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models # this helps reduce the memory footprint. single_gen_results = [] for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): - train = [ds[i] for i in train_index] - test = [ds[i] for i in test_index] + train = ds[train_index] + test = ds[test_index] start = datetime.now() - model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) + model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) end = datetime.now() logger.debug( f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" ) model.to(s.ENVIFORMER_DEVICE) - test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) + test_result = model.predict_batch(test.X()) single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) self.eval_results = self.compute_averages(single_gen_results) @@ -3391,31 +3313,15 @@ class EnviFormer(PackageBasedModel): for pathway in train_pathways: for reaction in pathway.edges: reaction = reaction.edge_label - if any( - [ - educt in test_educts - for educt in reaction_to_educts[str(reaction.uuid)] - ] - ): + if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): overlap += 1 continue - educts = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.educts.all() - ] - ) - products = ".".join( - [ - FormatConverter.standardize(smile.smiles, remove_stereo=True) - for smile in reaction.products.all() - ] - ) - train_reactions.append(f"{educts}>>{products}") + train_reactions.append(reaction) + train_ds = EnviFormerDataset.generate_dataset(train_reactions) logging.debug( f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" ) - model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") + model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}") multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) self.eval_results.update( diff --git a/epdb/views.py b/epdb/views.py index b6ef865c..1a2ce23c 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -892,7 +892,7 @@ def package_model(request, package_uuid, model_uuid): return JsonResponse(res, safe=False) else: - app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0] + app_domain_assessment = current_model.app_domain.assess(stand_smiles) return JsonResponse(app_domain_assessment, safe=False) context = get_base_context(request) diff --git a/pyproject.toml b/pyproject.toml index 1fba9371..26371296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,10 +27,11 @@ dependencies = [ "scikit-learn>=1.6.1", "sentry-sdk[django]>=2.32.0", "setuptools>=80.8.0", + "polars==1.35.1", ] [tool.uv.sources] -enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" } +enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" } envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" } envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"} envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" } diff --git a/tests/test_dataset.py b/tests/test_dataset.py index eb5a7924..300ab1ed 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,8 +1,10 @@ +import os.path +from tempfile import TemporaryDirectory from django.test import TestCase - from epdb.logic import PackageManager -from epdb.models import Reaction, Compound, User, Rule -from utilities.ml import Dataset +from epdb.models import Reaction, Compound, User, Rule, Package +from utilities.chem import FormatConverter +from utilities.ml import RuleBasedDataset, EnviFormerDataset class DatasetTest(TestCase): @@ -41,12 +43,108 @@ class DatasetTest(TestCase): super(DatasetTest, cls).setUpClass() cls.user = User.objects.get(username="anonymous") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") + cls.BBD_SUBSET = Package.objects.get(name="Fixtures") - def test_smoke(self): + def test_generate_dataset(self): + """Test generating dataset does not crash""" + self.generate_rule_dataset() + + def test_indexing(self): + """Test indexing a few different ways to check for crashes""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds[5]) + print(ds[2, 5]) + print(ds[3:6, 2:8]) + print(ds[:2, "structure_id"]) + + def test_add_rows(self): + """Test adding one row and adding multiple rows""" + ds, reactions, rules = self.generate_rule_dataset() + ds.add_row(list(ds.df.row(1))) + ds.add_rows([list(ds.df.row(i)) for i in range(5)]) + + def test_times_triggered(self): + """Check getting times triggered for a rule id""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.times_triggered(rules[0].uuid)) + + def test_block_indices(self): + """Test the usages of _block_indices""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.struct_features()) + print(ds.triggered()) + print(ds.observed()) + + def test_structure_id(self): + """Check getting a structure id from row index""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.structure_id(0)) + + def test_x(self): + """Test getting X portion of the dataframe""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.X().df.head()) + + def test_trig(self): + """Test getting the triggered portion of the dataframe""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.trig().df.head()) + + def test_y(self): + """Test getting the Y portion of the dataframe""" + ds, reactions, rules = self.generate_rule_dataset() + print(ds.y().df.head()) + + def test_classification_dataset(self): + """Test making the classification dataset""" + ds, reactions, rules = self.generate_rule_dataset() + compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)] + class_ds, products = ds.classification_dataset(compounds, rules) + print(class_ds.df.head(5)) + print(products[:5]) + + def test_extra_features(self): + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan]) + print(ds.shape) + + def test_to_arff(self): + """Test exporting the arff version of the dataset""" + ds, reactions, rules = self.generate_rule_dataset() + ds.to_arff("dataset_arff_test.arff") + + def test_save_load(self): + """Test saving and loading dataset""" + with TemporaryDirectory() as tmpdir: + ds, reactions, rules = self.generate_rule_dataset() + ds.save(os.path.join(tmpdir, "save_dataset.pkl")) + ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl")) + self.assertTrue(ds.df.equals(ds_loaded.df)) + + def test_dataset_example(self): + """Test with a concrete example checking dataset size""" reactions = [r for r in Reaction.objects.filter(package=self.package)] applicable_rules = [self.rule1] - ds = Dataset.generate_dataset(reactions, applicable_rules) + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) self.assertEqual(len(ds.y()), 1) - self.assertEqual(sum(ds.y()[0]), 1) + self.assertEqual(ds.y().df.item(), 1) + + def test_enviformer_dataset(self): + ds, reactions = self.generate_enviformer_dataset() + print(ds.X().head()) + print(ds.y().head()) + + def generate_rule_dataset(self): + """Generate a RuleBasedDataset from test package data""" + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules) + return ds, reactions, applicable_rules + + def generate_enviformer_dataset(self): + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + ds = EnviFormerDataset.generate_dataset(reactions) + return ds, reactions diff --git a/tests/test_enviformer.py b/tests/test_enviformer.py index 647433fc..b7994d4a 100644 --- a/tests/test_enviformer.py +++ b/tests/test_enviformer.py @@ -42,13 +42,11 @@ class EnviFormerTest(TestCase): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] - mod = EnviFormer.create( - self.package, data_package_objs, eval_packages_objs, threshold=threshold - ) + mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold) mod.build_dataset() mod.build_model() - mod.evaluate_model(True, eval_packages_objs) + mod.evaluate_model(True, eval_packages_objs, n_splits=2) mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") @@ -57,12 +55,9 @@ class EnviFormerTest(TestCase): with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] - eval_packages_objs = [self.BBD_SUBSET] mods = [] for _ in range(4): - mod = EnviFormer.create( - self.package, data_package_objs, eval_packages_objs, threshold=threshold - ) + mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold) mod.build_dataset() mod.build_model() mods.append(mod) @@ -73,15 +68,11 @@ class EnviFormerTest(TestCase): # Test pathway prediction times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)] - print( - f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}" - ) + print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}") # Test eviction by performing three prediction with every model, twice. times = defaultdict(list) - for _ in range( - 2 - ): # Eviction should cause the second iteration here to have to reload the models + for _ in range(2): # Eviction should cause the second iteration here to have to reload the models for mod in mods: for _ in range(3): times[mod.pk].append(measure_predict(mod)) diff --git a/tests/test_model.py b/tests/test_model.py index f0355be9..50dfee19 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,7 +4,7 @@ import numpy as np from django.test import TestCase from epdb.logic import PackageManager -from epdb.models import User, MLRelativeReasoning, Package +from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning class ModelTest(TestCase): @@ -17,7 +17,7 @@ class ModelTest(TestCase): cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.BBD_SUBSET = Package.objects.get(name="Fixtures") - def test_smoke(self): + def test_mlrr(self): with TemporaryDirectory() as tmpdir: with self.settings(MODEL_DIR=tmpdir): threshold = float(0.5) @@ -35,21 +35,9 @@ class ModelTest(TestCase): description="Created MLRelativeReasoning in Testcase", ) - # mod = RuleBasedRelativeReasoning.create( - # self.package, - # rule_package_objs, - # data_package_objs, - # eval_packages_objs, - # threshold=threshold, - # min_count=5, - # max_count=0, - # name='ECC - BBD - 0.5', - # description='Created MLRelativeReasoning in Testcase', - # ) - mod.build_dataset() mod.build_model() - mod.evaluate_model(True, eval_packages_objs) + mod.evaluate_model(True, eval_packages_objs, n_splits=2) results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") @@ -70,3 +58,57 @@ class ModelTest(TestCase): # from pprint import pprint # pprint(mod.eval_results) + + def test_applicability(self): + with TemporaryDirectory() as tmpdir: + with self.settings(MODEL_DIR=tmpdir): + threshold = float(0.5) + + rule_package_objs = [self.BBD_SUBSET] + data_package_objs = [self.BBD_SUBSET] + eval_packages_objs = [self.BBD_SUBSET] + + mod = MLRelativeReasoning.create( + self.package, + rule_package_objs, + data_package_objs, + threshold=threshold, + name="ECC - BBD - 0.5", + description="Created MLRelativeReasoning in Testcase", + build_app_domain=True, # To test the applicability domain this must be True + app_domain_num_neighbours=5, + app_domain_local_compatibility_threshold=0.5, + app_domain_reliability_threshold=0.5, + ) + + mod.build_dataset() + mod.build_model() + mod.evaluate_model(True, eval_packages_objs, n_splits=2) + + results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") + + def test_rbrr(self): + with TemporaryDirectory() as tmpdir: + with self.settings(MODEL_DIR=tmpdir): + threshold = float(0.5) + + rule_package_objs = [self.BBD_SUBSET] + data_package_objs = [self.BBD_SUBSET] + eval_packages_objs = [self.BBD_SUBSET] + + mod = RuleBasedRelativeReasoning.create( + self.package, + rule_package_objs, + data_package_objs, + threshold=threshold, + min_count=5, + max_count=0, + name='ECC - BBD - 0.5', + description='Created MLRelativeReasoning in Testcase', + ) + + mod.build_dataset() + mod.build_model() + mod.evaluate_model(True, eval_packages_objs, n_splits=2) + + results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") diff --git a/utilities/chem.py b/utilities/chem.py index 250ccfb6..d7a68d75 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING from indigo import Indigo, IndigoException, IndigoObject from indigo.renderer import IndigoRenderer from rdkit import Chem, rdBase -from rdkit.Chem import MACCSkeys, Descriptors +from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator from rdkit.Chem import rdChemReactions from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.MolStandardize import rdMolStandardize @@ -107,6 +107,13 @@ class FormatConverter(object): bitvec = MACCSkeys.GenMACCSKeys(mol) return bitvec.ToList() + @staticmethod + def morgan(smiles, radius=3, fpSize=2048): + finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize) + mol = Chem.MolFromSmiles(smiles) + fp = finger_gen.GetFingerprint(mol) + return fp.ToList() + @staticmethod def get_functional_groups(smiles: str) -> List[str]: res = list() diff --git a/utilities/ml.py b/utilities/ml.py index a93fafd9..5df5dce8 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -5,11 +5,14 @@ import logging from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import List, Dict, Set, Tuple, TYPE_CHECKING +from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable +from abc import ABC, abstractmethod import networkx as nx import numpy as np +from envipy_plugins import Descriptor from numpy.random import default_rng +import polars as pl from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA from sklearn.dummy import DummyClassifier @@ -26,70 +29,281 @@ if TYPE_CHECKING: from epdb.models import Rule, CompoundStructure, Reaction -class Dataset: - def __init__( - self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None - ): - self.columns: List[str] = columns - self.num_labels: int = num_labels - - if data is None: - self.data: List[List[str | int | float]] = list() +class Dataset(ABC): + def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None): + if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__ + self.df = data else: - self.data = data + # Build either an empty dataframe with columns or fill it with list of list data + if data is not None and len(columns) != len(data[0]): + raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns") + if columns is None: + raise ValueError("Columns can't be None if data is not already a DataFrame") + self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None) - self.num_features: int = len(columns) - self.num_labels - self._struct_features: Tuple[int, int] = self._block_indices("feature_") - self._triggered: Tuple[int, int] = self._block_indices("trig_") - self._observed: Tuple[int, int] = self._block_indices("obs_") + def add_rows(self, rows: List[List[str | int | float]]): + """Add rows to the dataset. Extends the polars dataframe stored in self""" + if len(self.columns) != len(rows[0]): + raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns") + new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None) + self.df.extend(new_rows) - def _block_indices(self, prefix) -> Tuple[int, int]: + def add_row(self, row: List[str | int | float]): + """See add_rows""" + self.add_rows([row]) + + def block_indices(self, prefix) -> List[int]: + """Find the indexes in column labels that has the prefix""" indices: List[int] = [] for i, feature in enumerate(self.columns): if feature.startswith(prefix): indices.append(i) + return indices - return min(indices), max(indices) + @property + def columns(self) -> List[str]: + """Use the polars dataframe columns""" + return self.df.columns - def structure_id(self): - return self.data[0][0] + @property + def shape(self): + return self.df.shape - def add_row(self, row: List[str | int | float]): - if len(self.columns) != len(row): - raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") - self.data.append(row) + @abstractmethod + def X(self, **kwargs): + pass - def times_triggered(self, rule_uuid) -> int: - idx = self.columns.index(f"trig_{rule_uuid}") + @abstractmethod + def y(self, **kwargs): + pass - times_triggered = 0 - for row in self.data: - if row[idx] == 1: - times_triggered += 1 - - return times_triggered - - def struct_features(self) -> Tuple[int, int]: - return self._struct_features - - def triggered(self) -> Tuple[int, int]: - return self._triggered - - def observed(self) -> Tuple[int, int]: - return self._observed - - def at(self, position: int) -> Dataset: - return Dataset(self.columns, self.num_labels, [self.data[position]]) - - def limit(self, limit: int) -> Dataset: - return Dataset(self.columns, self.num_labels, self.data[:limit]) + @staticmethod + @abstractmethod + def generate_dataset(reactions, *args, **kwargs): + pass def __iter__(self): - return (self.at(i) for i, _ in enumerate(self.data)) + """Use polars iter_rows for iterating over the dataset""" + return self.df.iter_rows() + + def __getitem__(self, item): + """Item is passed to polars allowing for advanced indexing. + See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__""" + res = self.df[item] + if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe + return self.__class__(data=res) + else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item + return res + + def save(self, path: "Path | str"): + import pickle + + with open(path, "wb") as fh: + pickle.dump(self, fh) + + @staticmethod + def load(path: "str | Path") -> "Dataset": + import pickle + + return pickle.load(open(path, "rb")) + + def to_numpy(self): + return self.df.to_numpy() + + def __repr__(self): + return ( + f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>" + ) + + def __len__(self): + return len(self.df) + + def iter_rows(self, named=False): + return self.df.iter_rows(named=named) + + def filter(self, *predicates, **constraints): + return self.__class__(data=self.df.filter(*predicates, **constraints)) + + def select(self, *exprs, **named_exprs): + return self.__class__(data=self.df.select(*exprs, **named_exprs)) + + def with_columns(self, *exprs, **name_exprs): + return self.__class__(data=self.df.with_columns(*exprs, **name_exprs)) + + def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False): + return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last, + multithreaded=multithreaded, maintain_order=maintain_order)) + + def item(self, row=None, column=None): + return self.df.item(row, column) + + def fill_nan(self, value): + return self.__class__(data=self.df.fill_nan(value)) + + @property + def height(self): + return self.df.height + + +class RuleBasedDataset(Dataset): + def __init__(self, num_labels=None, columns=None, data=None): + super().__init__(columns, data) + # Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init. + self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c]) + # Pre-calculate the ids of columns for features/labels, useful later in X and y + self._struct_features: List[int] = self.block_indices("feature_") + self._triggered: List[int] = self.block_indices("trig_") + self._observed: List[int] = self.block_indices("obs_") + self.feature_cols: List[int] = self._struct_features + self._triggered + self.num_features: int = len(self.feature_cols) + self.has_probs = False + + def times_triggered(self, rule_uuid) -> int: + """Count how many times a rule is triggered by the number of rows with one in the rules trig column""" + return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height + + def struct_features(self) -> List[int]: + return self._struct_features + + def triggered(self) -> List[int]: + return self._triggered + + def observed(self) -> List[int]: + return self._observed + + def structure_id(self, index: int): + """Get the UUID of a compound""" + return self.item(index, "structure_id") + + def X(self, exclude_id_col=True, na_replacement=0): + """Get all the feature and trig columns""" + _col_ids = self.feature_cols + if not exclude_id_col: + _col_ids = [0] + _col_ids + res = self[:, _col_ids] + if na_replacement is not None: + res.df = res.df.fill_null(na_replacement) + return res + + def trig(self, na_replacement=0): + """Get all the trig columns""" + res = self[:, self._triggered] + if na_replacement is not None: + res.df = res.df.fill_null(na_replacement) + return res + + def y(self, na_replacement=0): + """Get all the obs columns""" + res = self[:, self._observed] + if na_replacement is not None: + res.df = res.df.fill_null(na_replacement) + return res + + @staticmethod + def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None): + if feat_funcs is None: + feat_funcs = [FormatConverter.maccs] + _structures = set() # Get all the structures + for r in reactions: + _structures.update(r.educts.all()) + if not educts_only: + _structures.update(r.products.all()) + + compounds = sorted(_structures, key=lambda x: x.url) + triggered: Dict[str, Set[str]] = defaultdict(set) + observed: Set[str] = set() + + # Apply rules on collected compounds and store tps + for i, comp in enumerate(compounds): + logger.debug(f"{i + 1}/{len(compounds)}...") + + for rule in applicable_rules: + product_sets = rule.apply(comp.smiles) + if len(product_sets) == 0: + continue + + key = f"{rule.uuid} + {comp.uuid}" + if key in triggered: + logger.info(f"{key} already present. Duplicate reaction?") + + for prod_set in product_sets: + for smi in prod_set: + try: + smi = FormatConverter.standardize(smi, remove_stereo=True) + except Exception: + logger.debug(f"Standardizing SMILES failed for {smi}") + triggered[key].add(smi) + + for i, r in enumerate(reactions): + logger.debug(f"{i + 1}/{len(reactions)}...") + + if len(r.educts.all()) != 1: + logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") + continue + + for comp in r.educts.all(): + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + if key not in triggered: + continue + + # standardize products from reactions for comparison + standardized_products = [] + for cs in r.products.all(): + smi = cs.smiles + try: + smi = FormatConverter.standardize(smi, remove_stereo=True) + except Exception as e: + logger.debug(f"Standardizing SMILES failed for {smi}") + standardized_products.append(smi) + if len(set(standardized_products).difference(triggered[key])) == 0: + observed.add(key) + feat_columns = [] + for feat_func in feat_funcs: + if isinstance(feat_func, Descriptor): + feats = feat_func.get_molecule_descriptors(compounds[0].smiles) + else: + feats = feat_func(compounds[0].smiles) + start_i = len(feat_columns) + feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)]) + ds_columns = (["structure_id"] + + feat_columns + + [f"trig_{r.uuid}" for r in applicable_rules] + + [f"obs_{r.uuid}" for r in applicable_rules]) + rows = [] + + for i, comp in enumerate(compounds): + # Features + feats = [] + for feat_func in feat_funcs: + if isinstance(feat_func, Descriptor): + feat = feat_func.get_molecule_descriptors(comp.smiles) + else: + feat = feat_func(comp.smiles) + feats.extend(feat) + trig = [] + obs = [] + for rule in applicable_rules: + key = f"{rule.uuid} + {comp.uuid}" + # Check triggered + if key in triggered: + trig.append(1) + else: + trig.append(0) + # Check obs + if key in observed: + obs.append(1) + elif key not in triggered: + obs.append(None) + else: + obs.append(0) + rows.append([str(comp.uuid)] + feats + trig + obs) + ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows) + return ds def classification_dataset( self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] - ) -> Tuple[Dataset, List[List[PredictionResult]]]: + ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: @@ -113,186 +327,18 @@ class Dataset: else: trig.append(0) prods.append([]) - - classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) + new_row = [struct_id] + features + trig + ([-1] * len(trig)) + if self.has_probs: + new_row += [-1] * len(trig) + classify_data.append(new_row) classify_products.append(prods) + ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data) + return ds, classify_products - return Dataset( - columns=self.columns, num_labels=self.num_labels, data=classify_data - ), classify_products - - @staticmethod - def generate_dataset( - reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True - ) -> Dataset: - _structures = set() - - for r in reactions: - for e in r.educts.all(): - _structures.add(e) - - if not educts_only: - for e in r.products: - _structures.add(e) - - compounds = sorted(_structures, key=lambda x: x.url) - - triggered: Dict[str, Set[str]] = defaultdict(set) - observed: Set[str] = set() - - # Apply rules on collected compounds and store tps - for i, comp in enumerate(compounds): - logger.debug(f"{i + 1}/{len(compounds)}...") - - for rule in applicable_rules: - product_sets = rule.apply(comp.smiles) - - if len(product_sets) == 0: - continue - - key = f"{rule.uuid} + {comp.uuid}" - - if key in triggered: - logger.info(f"{key} already present. Duplicate reaction?") - - for prod_set in product_sets: - for smi in prod_set: - try: - smi = FormatConverter.standardize(smi, remove_stereo=True) - except Exception: - # :shrug: - logger.debug(f"Standardizing SMILES failed for {smi}") - pass - - triggered[key].add(smi) - - for i, r in enumerate(reactions): - logger.debug(f"{i + 1}/{len(reactions)}...") - - if len(r.educts.all()) != 1: - logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!") - continue - - for comp in r.educts.all(): - for rule in applicable_rules: - key = f"{rule.uuid} + {comp.uuid}" - - if key not in triggered: - continue - - # standardize products from reactions for comparison - standardized_products = [] - for cs in r.products.all(): - smi = cs.smiles - - try: - smi = FormatConverter.standardize(smi, remove_stereo=True) - except Exception as e: - # :shrug: - logger.debug(f"Standardizing SMILES failed for {smi}") - pass - - standardized_products.append(smi) - - if len(set(standardized_products).difference(triggered[key])) == 0: - observed.add(key) - else: - pass - - ds = None - - for i, comp in enumerate(compounds): - # Features - feat = FormatConverter.maccs(comp.smiles) - trig = [] - obs = [] - - for rule in applicable_rules: - key = f"{rule.uuid} + {comp.uuid}" - - # Check triggered - if key in triggered: - trig.append(1) - else: - trig.append(0) - - # Check obs - if key in observed: - obs.append(1) - elif key not in triggered: - obs.append(None) - else: - obs.append(0) - - if ds is None: - header = ( - ["structure_id"] - + [f"feature_{i}" for i, _ in enumerate(feat)] - + [f"trig_{r.uuid}" for r in applicable_rules] - + [f"obs_{r.uuid}" for r in applicable_rules] - ) - ds = Dataset(header, len(applicable_rules)) - - ds.add_row([str(comp.uuid)] + feat + trig + obs) - - return ds - - def X(self, exclude_id_col=True, na_replacement=0): - res = self.__getitem__( - (slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)) - ) - if na_replacement is not None: - res = [[x if x is not None else na_replacement for x in row] for row in res] - return res - - def trig(self, na_replacement=0): - res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1]))) - if na_replacement is not None: - res = [[x if x is not None else na_replacement for x in row] for row in res] - return res - - def y(self, na_replacement=0): - res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) - if na_replacement is not None: - res = [[x if x is not None else na_replacement for x in row] for row in res] - return res - - def __getitem__(self, key): - if not isinstance(key, tuple): - raise TypeError("Dataset must be indexed with dataset[rows, columns]") - - row_key, col_key = key - - # Normalize rows - if isinstance(row_key, int): - rows = [self.data[row_key]] - else: - rows = self.data[row_key] - - # Normalize columns - if isinstance(col_key, int): - res = [row[col_key] for row in rows] - else: - res = [ - [row[i] for i in range(*col_key.indices(len(row)))] - if isinstance(col_key, slice) - else [row[i] for i in col_key] - for row in rows - ] - - return res - - def save(self, path: "Path"): - import pickle - - with open(path, "wb") as fh: - pickle.dump(self, fh) - - @staticmethod - def load(path: "Path") -> "Dataset": - import pickle - - return pickle.load(open(path, "rb")) + def add_probs(self, probs): + col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed] + self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)]) + self.has_probs = True def to_arff(self, path: "Path"): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" @@ -304,7 +350,7 @@ class Dataset: arff += f"@attribute {c} {{0,1}}\n" arff += "\n@data\n" - for d in self.data: + for d in self: ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]]) xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]]) arff += f"{ys},{xs}\n" @@ -313,10 +359,40 @@ class Dataset: fh.write(arff) fh.flush() - def __repr__(self): - return ( - f"" - ) + +class EnviFormerDataset(Dataset): + def __init__(self, columns=None, data=None): + super().__init__(columns, data) + + def X(self): + """Return the educts""" + return self["educts"] + + def y(self): + """Return the products""" + return self["products"] + + @staticmethod + def generate_dataset(reactions, *args, **kwargs): + # Standardise reactions for the training data + stereo = kwargs.get("stereo", False) + rows = [] + for reaction in reactions: + e = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) + for smile in reaction.educts.all() + ] + ) + p = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=not stereo) + for smile in reaction.products.all() + ] + ) + rows.append([e, p]) + ds = EnviFormerDataset(["educts", "products"], rows) + return ds class SparseLabelECC(BaseEstimator, ClassifierMixin): @@ -498,7 +574,7 @@ class EnsembleClassifierChain: self.classifiers = [] if self.num_labels is None: - self.num_labels = len(Y[0]) + self.num_labels = Y.shape[1] for p in range(self.num_chains): logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") @@ -529,7 +605,7 @@ class RelativeReasoning: def fit(self, X, Y): n_instances = len(Y) - n_attributes = len(Y[0]) + n_attributes = Y.shape[1] for i in range(n_attributes): for j in range(n_attributes): @@ -541,8 +617,8 @@ class RelativeReasoning: countboth = 0 for k in range(n_instances): - vi = Y[k][i] - vj = Y[k][j] + vi = Y[k, i] + vj = Y[k, j] if vi is None or vj is None: continue @@ -598,7 +674,7 @@ class ApplicabilityDomainPCA(PCA): self.min_vals = None self.max_vals = None - def build(self, train_dataset: "Dataset"): + def build(self, train_dataset: "RuleBasedDataset"): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca @@ -612,7 +688,7 @@ class ApplicabilityDomainPCA(PCA): instances_pca = self.transform(instances_scaled) return instances_pca - def is_applicable(self, classify_instances: "Dataset"): + def is_applicable(self, classify_instances: "RuleBasedDataset"): instances_pca = self.__transform(classify_instances.X()) is_applicable = [] diff --git a/uv.lock b/uv.lock index f3e6d123..0cd84139 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,10 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" +resolution-markers = [ + "sys_platform == 'linux' or sys_platform == 'win32'", + "sys_platform != 'linux' and sys_platform != 'win32'", +] [[package]] name = "aiohappyeyeballs" @@ -176,6 +180,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" }, ] +[[package]] +name = "celery-stubs" +version = "0.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/14/b853ada8706a3a301396566b6dd405d1cbb24bff756236a12a01dbe766a4/celery-stubs-0.1.3.tar.gz", hash = "sha256:0fb5345820f8a2bd14e6ffcbef2d10181e12e40f8369f551d7acc99d8d514919", size = 46583, upload-time = "2023-02-10T02:20:11.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/7a/4ab2347d13f1f59d10a7337feb9beb002664119f286036785284c6bec150/celery_stubs-0.1.3-py3-none-any.whl", hash = "sha256:dfb9ad27614a8af028b2055bb4a4ae99ca5e9a8d871428a506646d62153218d7", size = 89085, upload-time = "2023-02-10T02:20:09.409Z" }, +] + [[package]] name = "certifi" version = "2025.10.5" @@ -525,13 +542,14 @@ wheels = [ [[package]] name = "enviformer" version = "0.1.0" -source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2#3f28f60cfa1df814cf7559303b5130933efa40ae" } +source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4#7094be5767748fd63d4a84a5d71f06cf02ba07f3" } dependencies = [ { name = "joblib" }, { name = "lightning" }, { name = "pytorch-lightning" }, { name = "scikit-learn" }, - { name = "torch" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] [[package]] @@ -546,7 +564,6 @@ dependencies = [ { name = "django-ninja" }, { name = "django-oauth-toolkit" }, { name = "django-polymorphic" }, - { name = "django-stubs" }, { name = "enviformer" }, { name = "envipy-additional-information" }, { name = "envipy-ambit" }, @@ -554,6 +571,7 @@ dependencies = [ { name = "epam-indigo" }, { name = "gunicorn" }, { name = "networkx" }, + { name = "polars" }, { name = "psycopg2-binary" }, { name = "python-dotenv" }, { name = "rdkit" }, @@ -566,6 +584,8 @@ dependencies = [ [package.optional-dependencies] dev = [ + { name = "celery-stubs" }, + { name = "django-stubs" }, { name = "poethepoet" }, { name = "pre-commit" }, { name = "ruff" }, @@ -577,15 +597,16 @@ ms-login = [ [package.metadata] requires-dist = [ { name = "celery", specifier = ">=5.5.2" }, + { name = "celery-stubs", marker = "extra == 'dev'", specifier = "==0.1.3" }, { name = "django", specifier = ">=5.2.1" }, { name = "django-extensions", specifier = ">=4.1" }, { name = "django-model-utils", specifier = ">=5.0.0" }, { name = "django-ninja", specifier = ">=1.4.1" }, { name = "django-oauth-toolkit", specifier = ">=3.0.1" }, { name = "django-polymorphic", specifier = ">=4.1.0" }, - { name = "django-stubs", specifier = ">=5.2.4" }, - { name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2" }, - { name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4" }, + { name = "django-stubs", marker = "extra == 'dev'", specifier = ">=5.2.4" }, + { name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4" }, + { name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7" }, { name = "envipy-ambit", git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }, { name = "envipy-plugins", git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git?rev=v0.1.0" }, { name = "epam-indigo", specifier = ">=1.30.1" }, @@ -593,6 +614,7 @@ requires-dist = [ { name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" }, { name = "networkx", specifier = ">=3.4.2" }, { name = "poethepoet", marker = "extra == 'dev'", specifier = ">=0.37.0" }, + { name = "polars", specifier = "==1.35.1" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.3.0" }, { name = "psycopg2-binary", specifier = ">=2.9.10" }, { name = "python-dotenv", specifier = ">=1.1.0" }, @@ -608,8 +630,8 @@ provides-extras = ["ms-login", "dev"] [[package]] name = "envipy-additional-information" -version = "0.1.0" -source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4#4da604090bf7cf1f3f552d69485472dbc623030a" } +version = "0.1.7" +source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7#d02a5d5e6a931e6565ea86127813acf7e4b33a30" } dependencies = [ { name = "pydantic" }, ] @@ -865,7 +887,8 @@ dependencies = [ { name = "packaging" }, { name = "pytorch-lightning" }, { name = "pyyaml" }, - { name = "torch" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "torchmetrics" }, { name = "tqdm" }, { name = "typing-extensions" }, @@ -1074,6 +1097,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" }, ] +[[package]] +name = "mypy" +version = "1.18.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "pathspec" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/77/8f0d0001ffad290cef2f7f216f96c814866248a0b92a722365ed54648e7e/mypy-1.18.2.tar.gz", hash = "sha256:06a398102a5f203d7477b2923dda3634c36727fa5c237d8f859ef90c42a9924b", size = 3448846, upload-time = "2025-09-19T00:11:10.519Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/06/dfdd2bc60c66611dd8335f463818514733bc763e4760dee289dcc33df709/mypy-1.18.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:33eca32dd124b29400c31d7cf784e795b050ace0e1f91b8dc035672725617e34", size = 12908273, upload-time = "2025-09-19T00:10:58.321Z" }, + { url = "https://files.pythonhosted.org/packages/81/14/6a9de6d13a122d5608e1a04130724caf9170333ac5a924e10f670687d3eb/mypy-1.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3c47adf30d65e89b2dcd2fa32f3aeb5e94ca970d2c15fcb25e297871c8e4764", size = 11920910, upload-time = "2025-09-19T00:10:20.043Z" }, + { url = "https://files.pythonhosted.org/packages/5f/a9/b29de53e42f18e8cc547e38daa9dfa132ffdc64f7250e353f5c8cdd44bee/mypy-1.18.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d6c838e831a062f5f29d11c9057c6009f60cb294fea33a98422688181fe2893", size = 12465585, upload-time = "2025-09-19T00:10:33.005Z" }, + { url = "https://files.pythonhosted.org/packages/77/ae/6c3d2c7c61ff21f2bee938c917616c92ebf852f015fb55917fd6e2811db2/mypy-1.18.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01199871b6110a2ce984bde85acd481232d17413868c9807e95c1b0739a58914", size = 13348562, upload-time = "2025-09-19T00:10:11.51Z" }, + { url = "https://files.pythonhosted.org/packages/4d/31/aec68ab3b4aebdf8f36d191b0685d99faa899ab990753ca0fee60fb99511/mypy-1.18.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a2afc0fa0b0e91b4599ddfe0f91e2c26c2b5a5ab263737e998d6817874c5f7c8", size = 13533296, upload-time = "2025-09-19T00:10:06.568Z" }, + { url = "https://files.pythonhosted.org/packages/9f/83/abcb3ad9478fca3ebeb6a5358bb0b22c95ea42b43b7789c7fb1297ca44f4/mypy-1.18.2-cp312-cp312-win_amd64.whl", hash = "sha256:d8068d0afe682c7c4897c0f7ce84ea77f6de953262b12d07038f4d296d547074", size = 9828828, upload-time = "2025-09-19T00:10:28.203Z" }, + { url = "https://files.pythonhosted.org/packages/5f/04/7f462e6fbba87a72bc8097b93f6842499c428a6ff0c81dd46948d175afe8/mypy-1.18.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:07b8b0f580ca6d289e69209ec9d3911b4a26e5abfde32228a288eb79df129fcc", size = 12898728, upload-time = "2025-09-19T00:10:01.33Z" }, + { url = "https://files.pythonhosted.org/packages/99/5b/61ed4efb64f1871b41fd0b82d29a64640f3516078f6c7905b68ab1ad8b13/mypy-1.18.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed4482847168439651d3feee5833ccedbf6657e964572706a2adb1f7fa4dfe2e", size = 11910758, upload-time = "2025-09-19T00:10:42.607Z" }, + { url = "https://files.pythonhosted.org/packages/3c/46/d297d4b683cc89a6e4108c4250a6a6b717f5fa96e1a30a7944a6da44da35/mypy-1.18.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ad2afadd1e9fea5cf99a45a822346971ede8685cc581ed9cd4d42eaf940986", size = 12475342, upload-time = "2025-09-19T00:11:00.371Z" }, + { url = "https://files.pythonhosted.org/packages/83/45/4798f4d00df13eae3bfdf726c9244bcb495ab5bd588c0eed93a2f2dd67f3/mypy-1.18.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a431a6f1ef14cf8c144c6b14793a23ec4eae3db28277c358136e79d7d062f62d", size = 13338709, upload-time = "2025-09-19T00:11:03.358Z" }, + { url = "https://files.pythonhosted.org/packages/d7/09/479f7358d9625172521a87a9271ddd2441e1dab16a09708f056e97007207/mypy-1.18.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ab28cc197f1dd77a67e1c6f35cd1f8e8b73ed2217e4fc005f9e6a504e46e7ba", size = 13529806, upload-time = "2025-09-19T00:10:26.073Z" }, + { url = "https://files.pythonhosted.org/packages/71/cf/ac0f2c7e9d0ea3c75cd99dff7aec1c9df4a1376537cb90e4c882267ee7e9/mypy-1.18.2-cp313-cp313-win_amd64.whl", hash = "sha256:0e2785a84b34a72ba55fb5daf079a1003a34c05b22238da94fcae2bbe46f3544", size = 9833262, upload-time = "2025-09-19T00:10:40.035Z" }, + { url = "https://files.pythonhosted.org/packages/5a/0c/7d5300883da16f0063ae53996358758b2a2df2a09c72a5061fa79a1f5006/mypy-1.18.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:62f0e1e988ad41c2a110edde6c398383a889d95b36b3e60bcf155f5164c4fdce", size = 12893775, upload-time = "2025-09-19T00:10:03.814Z" }, + { url = "https://files.pythonhosted.org/packages/50/df/2cffbf25737bdb236f60c973edf62e3e7b4ee1c25b6878629e88e2cde967/mypy-1.18.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8795a039bab805ff0c1dfdb8cd3344642c2b99b8e439d057aba30850b8d3423d", size = 11936852, upload-time = "2025-09-19T00:10:51.631Z" }, + { url = "https://files.pythonhosted.org/packages/be/50/34059de13dd269227fb4a03be1faee6e2a4b04a2051c82ac0a0b5a773c9a/mypy-1.18.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ca1e64b24a700ab5ce10133f7ccd956a04715463d30498e64ea8715236f9c9c", size = 12480242, upload-time = "2025-09-19T00:11:07.955Z" }, + { url = "https://files.pythonhosted.org/packages/5b/11/040983fad5132d85914c874a2836252bbc57832065548885b5bb5b0d4359/mypy-1.18.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d924eef3795cc89fecf6bedc6ed32b33ac13e8321344f6ddbf8ee89f706c05cb", size = 13326683, upload-time = "2025-09-19T00:09:55.572Z" }, + { url = "https://files.pythonhosted.org/packages/e9/ba/89b2901dd77414dd7a8c8729985832a5735053be15b744c18e4586e506ef/mypy-1.18.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20c02215a080e3a2be3aa50506c67242df1c151eaba0dcbc1e4e557922a26075", size = 13514749, upload-time = "2025-09-19T00:10:44.827Z" }, + { url = "https://files.pythonhosted.org/packages/25/bc/cc98767cffd6b2928ba680f3e5bc969c4152bf7c2d83f92f5a504b92b0eb/mypy-1.18.2-cp314-cp314-win_amd64.whl", hash = "sha256:749b5f83198f1ca64345603118a6f01a4e99ad4bf9d103ddc5a3200cc4614adf", size = 9982959, upload-time = "2025-09-19T00:10:37.344Z" }, + { url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" }, +] + +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "networkx" version = "3.5" @@ -1192,7 +1256,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -1203,7 +1267,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -1230,9 +1294,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -1243,7 +1307,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -1308,6 +1372,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/18/a8444036c6dd65ba3624c63b734d3ba95ba63ace513078e1580590075d21/pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364", size = 5955, upload-time = "2020-09-16T19:21:11.409Z" }, ] +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, +] + [[package]] name = "pillow" version = "11.3.0" @@ -1396,6 +1469,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/1b/5337af1a6a478d25a3e3c56b9b4b42b0a160314e02f4a0498d5322c8dac4/poethepoet-0.37.0-py3-none-any.whl", hash = "sha256:861790276315abcc8df1b4bd60e28c3d48a06db273edd3092f3c94e1a46e5e22", size = 90062, upload-time = "2025-08-11T18:00:27.595Z" }, ] +[[package]] +name = "polars" +version = "1.35.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/5b/3caad788d93304026cbf0ab4c37f8402058b64a2f153b9c62f8b30f5d2ee/polars-1.35.1.tar.gz", hash = "sha256:06548e6d554580151d6ca7452d74bceeec4640b5b9261836889b8e68cfd7a62e", size = 694881, upload-time = "2025-10-30T12:12:52.294Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/4c/21a227b722534404241c2a76beceb7463469d50c775a227fc5c209eb8adc/polars-1.35.1-py3-none-any.whl", hash = "sha256:c29a933f28aa330d96a633adbd79aa5e6a6247a802a720eead9933f4613bdbf4", size = 783598, upload-time = "2025-10-30T12:11:54.668Z" }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.35.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/3e/19c252e8eb4096300c1a36ec3e50a27e5fa9a1ccaf32d3927793c16abaee/polars_runtime_32-1.35.1.tar.gz", hash = "sha256:f6b4ec9cd58b31c87af1b8c110c9c986d82345f1d50d7f7595b5d447a19dc365", size = 2696218, upload-time = "2025-10-30T12:12:53.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/2c/da339459805a26105e9d9c2f07e43ca5b8baeee55acd5457e6881487a79a/polars_runtime_32-1.35.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6f051a42f6ae2f26e3bc2cf1f170f2120602976e2a3ffb6cfba742eecc7cc620", size = 40525100, upload-time = "2025-10-30T12:11:58.098Z" }, + { url = "https://files.pythonhosted.org/packages/27/70/a0733568b3533481924d2ce68b279ab3d7334e5fa6ed259f671f650b7c5e/polars_runtime_32-1.35.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:c2232f9cf05ba59efc72d940b86c033d41fd2d70bf2742e8115ed7112a766aa9", size = 36701908, upload-time = "2025-10-30T12:12:02.166Z" }, + { url = "https://files.pythonhosted.org/packages/46/54/6c09137bef9da72fd891ba58c2962cc7c6c5cad4649c0e668d6b344a9d7b/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42f9837348557fd674477ea40a6ac8a7e839674f6dd0a199df24be91b026024c", size = 41317692, upload-time = "2025-10-30T12:12:04.928Z" }, + { url = "https://files.pythonhosted.org/packages/22/55/81c5b266a947c339edd7fbaa9e1d9614012d02418453f48b76cc177d3dd9/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:c873aeb36fed182d5ebc35ca17c7eb193fe83ae2ea551ee8523ec34776731390", size = 37853058, upload-time = "2025-10-30T12:12:08.342Z" }, + { url = "https://files.pythonhosted.org/packages/6c/58/be8b034d559eac515f52408fd6537be9bea095bc0388946a4e38910d3d50/polars_runtime_32-1.35.1-cp39-abi3-win_amd64.whl", hash = "sha256:35cde9453ca7032933f0e58e9ed4388f5a1e415dd0db2dd1e442c81d815e630c", size = 41289554, upload-time = "2025-10-30T12:12:11.104Z" }, + { url = "https://files.pythonhosted.org/packages/f4/7f/e0111b9e2a1169ea82cde3ded9c92683e93c26dfccd72aee727996a1ac5b/polars_runtime_32-1.35.1-cp39-abi3-win_arm64.whl", hash = "sha256:fd77757a6c9eb9865c4bfb7b07e22225207c6b7da382bd0b9bd47732f637105d", size = 36958878, upload-time = "2025-10-30T12:12:15.206Z" }, +] + [[package]] name = "pre-commit" version = "4.3.0" @@ -1670,7 +1769,8 @@ dependencies = [ { name = "lightning-utilities" }, { name = "packaging" }, { name = "pyyaml" }, - { name = "torch" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "torchmetrics" }, { name = "tqdm" }, { name = "typing-extensions" }, @@ -1754,11 +1854,11 @@ wheels = [ [[package]] name = "redis" -version = "6.4.0" +version = "7.0.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" } +sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" }, + { url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" }, ] [[package]] @@ -1963,15 +2063,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, ] +[[package]] +name = "torch" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "sys_platform != 'linux' and sys_platform != 'win32'", +] +dependencies = [ + { name = "filelock", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "fsspec", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "jinja2", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "networkx", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "setuptools", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "sympy", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" }, +] + [[package]] name = "torch" version = "2.8.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } +resolution-markers = [ + "sys_platform == 'linux' or sys_platform == 'win32'", +] dependencies = [ - { name = "filelock" }, - { name = "fsspec" }, - { name = "jinja2" }, - { name = "networkx" }, + { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -1986,10 +2111,10 @@ dependencies = [ { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setuptools" }, - { name = "sympy" }, + { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions" }, + { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" }, @@ -2008,7 +2133,8 @@ dependencies = [ { name = "lightning-utilities" }, { name = "numpy" }, { name = "packaging" }, - { name = "torch" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/85/2e/48a887a59ecc4a10ce9e8b35b3e3c5cef29d902c4eac143378526e7485cb/torchmetrics-1.8.2.tar.gz", hash = "sha256:cf64a901036bf107f17a524009eea7781c9c5315d130713aeca5747a686fe7a5", size = 580679, upload-time = "2025-09-03T14:00:54.077Z" } wheels = [ @@ -2032,7 +2158,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools" }, + { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" },