2 Commits

Author SHA1 Message Date
04f9c9252a ... 2025-11-04 10:15:06 +01:00
c2d45917ce clean up d3_json 2025-10-31 09:38:41 +01:00
15 changed files with 570 additions and 764 deletions

View File

@ -16,5 +16,3 @@ POSTGRES_PORT=
# MAIL
EMAIL_HOST_USER=
EMAIL_HOST_PASSWORD=
# MATOMO
MATOMO_SITE_ID

View File

@ -357,6 +357,3 @@ if MS_ENTRA_ENABLED:
MS_ENTRA_AUTHORITY = f"https://login.microsoftonline.com/{MS_ENTRA_TENANT_ID}"
MS_ENTRA_REDIRECT_URI = os.environ["MS_REDIRECT_URI"]
MS_ENTRA_SCOPES = os.environ.get("MS_SCOPES", "").split(",")
# Site ID 10 -> beta.envipath.org
MATOMO_SITE_ID = os.environ.get("MATOMO_SITE_ID", "10")

View File

@ -1542,7 +1542,9 @@ class SPathway(object):
if sub.app_domain_assessment is None:
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
sub.smiles
)[0]
if self.persist is not None:
n = self.snode_persist_lookup[sub]
@ -1574,7 +1576,11 @@ class SPathway(object):
app_domain_assessment = None
if self.prediction_setting.model:
if self.prediction_setting.model.app_domain:
app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
app_domain_assessment = (
self.prediction_setting.model.app_domain.assess(c)[
0
]
)
self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment

View File

@ -28,8 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
EnviFormerDataset, Dataset
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
logger = logging.getLogger(__name__)
@ -1574,14 +1573,12 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
while len(queue):
current = queue.pop()
processed.add(current)
nodes.append(current.d3_json())
for e in self.edges:
if current in e.start_nodes.all():
for prod in e.end_nodes.all():
if prod not in queue and prod not in processed:
queue.append(prod)
for e in self.edges.filter(start_nodes=current).distinct():
for prod in e.end_nodes.all():
if prod not in queue and prod not in processed:
queue.append(prod)
# We shouldn't lose or make up nodes...
assert len(nodes) == len(self.nodes)
@ -2176,7 +2173,7 @@ class PackageBasedModel(EPModel):
applicable_rules = self.applicable_rules
reactions = list(self._get_reactions())
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
@ -2185,7 +2182,7 @@ class PackageBasedModel(EPModel):
ds.save(f)
return ds
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
def load_dataset(self) -> "Dataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return Dataset.load(ds_path)
@ -2226,7 +2223,7 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -2344,37 +2341,37 @@ class PackageBasedModel(EPModel):
eval_reactions = list(
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
)
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning):
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
else:
X = ds.X(na_replacement=np.nan).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning):
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
else:
X = ds.X(na_replacement=np.nan).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
n_splits = kwargs.get("n_splits", 20)
n_splits = 20
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X))
from joblib import Parallel, delayed
models = Parallel(n_jobs=min(10, len(splits)))(
models = Parallel(n_jobs=10)(
delayed(train_func)(X, y, train_index, self._model_args())
for train_index, _ in splits
)
evaluations = Parallel(n_jobs=min(10, len(splits)))(
evaluations = Parallel(n_jobs=10)(
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits)
)
@ -2586,11 +2583,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return rbrr
def _fit_model(self, ds: RuleBasedDataset):
def _fit_model(self, ds: Dataset):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning(
start_index=ds.triggered()[0],
end_index=ds.triggered()[-1],
end_index=ds.triggered()[1],
)
model.fit(X, y)
return model
@ -2600,7 +2597,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return {
"clz": "RuleBaseRelativeReasoning",
"start_index": ds.triggered()[0],
"end_index": ds.triggered()[-1],
"end_index": ds.triggered()[1],
}
def _save_model(self, model):
@ -2688,11 +2685,11 @@ class MLRelativeReasoning(PackageBasedModel):
return mlrr
def _fit_model(self, ds: RuleBasedDataset):
def _fit_model(self, ds: Dataset):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
model.fit(X.to_numpy(), y.to_numpy())
model.fit(X, y)
return model
def _model_args(self):
@ -2715,7 +2712,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X().to_numpy())
pred = self.model.predict_proba(classify_ds.X())
res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0]
@ -2760,9 +2757,7 @@ class ApplicabilityDomain(EnviPathModel):
@cached_property
def training_set_probs(self):
ds = self.model.load_dataset()
col_ids = ds.block_indices("prob")
return ds[:, col_ids]
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
def build(self):
ds = self.model.load_dataset()
@ -2770,9 +2765,9 @@ class ApplicabilityDomain(EnviPathModel):
start = datetime.now()
# Get Trainingset probs and dump them as they're required when using the app domain
probs = self.model.model.predict_proba(ds.X().to_numpy())
ds.add_probs(probs)
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
probs = self.model.model.predict_proba(ds.X())
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
joblib.dump(probs, f)
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
ad.build(ds)
@ -2795,19 +2790,16 @@ class ApplicabilityDomain(EnviPathModel):
joblib.dump(ad, f)
def assess(self, structure: Union[str, "CompoundStructure"]):
return self.assess_batch([structure])[0]
def assess_batch(self, structures: List["CompoundStructure | str"]):
ds = self.model.load_dataset()
smiles = []
for struct in structures:
if isinstance(struct, CompoundStructure):
smiles.append(structures.smiles)
else:
smiles.append(structures)
if isinstance(structure, CompoundStructure):
smiles = structure.smiles
else:
smiles = structure
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
assessment_ds, assessment_prods = ds.classification_dataset(
[structure], self.model.applicable_rules
)
# qualified_neighbours_per_rule is a nested dictionary structured as:
# {
@ -2820,46 +2812,82 @@ class ApplicabilityDomain(EnviPathModel):
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
# with a given assessment structure under a particular rule.
qualified_neighbours_per_rule: Dict = {}
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
lambda: defaultdict(list)
)
import polars as pl
# Select only the triggered columns
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
# trigger that rule.
train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
for trig_uuid, value in row.items() if value == 1}
qualified_neighbours_per_rule[i] = train_trig
rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
feature = ds.columns[feature_index]
if feature.startswith("trig_"):
# TODO unroll loop
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
if int(cx[feature_index]) == 1:
for j, tx in enumerate(ds.X(exclude_id_col=False)):
if int(tx[feature_index]) == 1:
qualified_neighbours_per_rule[i][rule_idx].append(j)
probs = self.training_set_probs
# preds = self.model.model.predict_proba(assessment_ds.X())
preds = self.model.combine_products_and_probs(
self.model.applicable_rules,
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
self.model.model.predict_proba(assessment_ds.X())[0],
assessment_prods[0],
)
assessments = list()
# loop through our assessment dataset
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
for i, instance in enumerate(assessment_ds):
rule_reliabilities = dict()
local_compatibilities = dict()
neighbours_per_rule = dict()
neighbor_probs_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset
for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
# compute tanimoto distance for all neighbours and add to dataset
dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
train_instances[:, train_instances.struct_features()].to_numpy())
train_instances = train_instances.with_columns(dist=pl.Series(dists))
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
# train dataset
train_instances = []
for v in vals:
train_instances.append((v, ds.at(v)))
# sf is a tuple with start/end index of the features
sf = ds.struct_features()
# compute tanimoto distance for all neighbours
# result ist a list of tuples with train index and computed distance
dists = self._compute_distances(
instance.X()[0][sf[0] : sf[1]],
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
)
dists_with_index = list()
for ti, dist in zip(train_instances, dists):
dists_with_index.append((ti[0], dist[1]))
# sort them in a descending way and take at most `self.num_neighbours`
train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours]
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
dists_with_index = dists_with_index[: self.num_neighbours]
# compute average distance
rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
rule_reliabilities[rule_idx] = (
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
if len(dists_with_index) > 0
else 0.0
)
# for local_compatibility we'll need the datasets for the indices having the highest similarity
local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances)
neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]))
neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
local_compatibilities[rule_idx] = self._compute_compatibility(
rule_idx, probs, neighbour_datasets
)
neighbours_per_rule[rule_idx] = [
CompoundStructure.objects.get(uuid=ds[1].structure_id())
for ds in neighbour_datasets
]
neighbor_probs_per_rule[rule_idx] = [
probs[d[0]][rule_idx] for d in dists_with_index
]
ad_res = {
"ad_params": {
@ -2870,21 +2898,23 @@ class ApplicabilityDomain(EnviPathModel):
"local_compatibility_threshold": self.local_compatibilty_threshold,
},
"assessment": {
"smiles": smiles[i],
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
"smiles": smiles,
"inside_app_domain": self.pca.is_applicable(instance)[0],
},
}
transformations = list()
for rule_uuid in rule_reliabilities.keys():
rule = Rule.objects.get(uuid=rule_uuid)
for rule_idx in rule_reliabilities.keys():
rule = Rule.objects.get(
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
)
rule_data = rule.simple_json()
rule_data["image"] = f"{rule.url}?image=svg"
neighbors = []
for n, n_prob in zip(
neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]
):
neighbor = n.simple_json()
neighbor["image"] = f"{n.url}?image=svg"
@ -2901,14 +2931,14 @@ class ApplicabilityDomain(EnviPathModel):
transformation = {
"rule": rule_data,
"reliability": rule_reliabilities[rule_uuid],
"reliability": rule_reliabilities[rule_idx],
# We're setting it here to False, as we don't know whether "assess" is called during pathway
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
"is_predicted": False,
"local_compatibility": local_compatibilities[rule_uuid],
"probability": preds[rule_to_i[rule_uuid]].probability,
"local_compatibility": local_compatibilities[rule_idx],
"probability": preds[rule_idx].probability,
"transformation_products": [
x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
x.product_set for x in preds[rule_idx].product_sets
],
"times_triggered": ds.times_triggered(str(rule.uuid)),
"neighbors": neighbors,
@ -2926,21 +2956,32 @@ class ApplicabilityDomain(EnviPathModel):
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
from utilities.ml import tanimoto_distance
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
distances = [
(i, tanimoto_distance(classify_instance, train))
for i, train in enumerate(train_instances)
]
return distances
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
@staticmethod
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
accuracy = 0.0
import polars as pl
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
# Compute tp, tn, fp, fn using polars expressions
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn)
for n in neighbours:
obs = n[1].y()[0][rule_idx]
pred = preds[n[0]][rule_idx]
if obs and pred:
tp += 1
elif not obs and pred:
fp += 1
elif obs and not pred:
fn += 1
else:
tn += 1
# Jaccard Index
if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn)
return accuracy
@ -3041,24 +3082,44 @@ class EnviFormer(PackageBasedModel):
self.save()
start = datetime.now()
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
co2 = {"C(=O)=O", "O=C=O"}
ds = []
for reaction in self._get_reactions():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
if products not in co2:
ds.append(f"{educts}>>{products}")
end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
ds.save(f)
with open(f, "w") as d_file:
json.dump(ds, d_file)
return ds
def load_dataset(self):
def load_dataset(self) -> "Dataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
return EnviFormerDataset.load(ds_path)
with open(ds_path) as d_file:
ds = json.load(d_file)
return ds
def _fit_model(self, ds):
# Call to enviFormer's fine_tune function and return the model
from enviformer.finetune import fine_tune
start = datetime.now()
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model
@ -3074,7 +3135,7 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"}
return args
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -3089,20 +3150,21 @@ class EnviFormer(PackageBasedModel):
self.model_status = self.EVALUATING
self.save()
def evaluate_sg(test_ds, predictions, model_thresh):
def evaluate_sg(test_reactions, predictions, model_thresh):
# Group the true products of reactions with the same reactant together
assert len(test_ds) == len(predictions)
true_dict = {}
for r in test_ds:
reactant, true_product_set = r
for r in test_reactions:
reactant, true_product_set = r.split(">>")
true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
assert len(test_reactions) == len(predictions)
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
# Group the predicted products of reactions with the same reactant together
pred_dict = {}
for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
reactant, true_product = test_reactions[k].split(">>")
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split("."))
@ -3137,7 +3199,7 @@ class EnviFormer(PackageBasedModel):
break
# Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted)
prec = {
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
@ -3216,32 +3278,47 @@ class EnviFormer(PackageBasedModel):
# If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0:
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
package__in=self.eval_packages.all()).distinct())
test_result = self.model.predict_batch(ds.X())
ds = []
for reaction in Reaction.objects.filter(
package__in=self.eval_packages.all()
).distinct():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
from enviformer.finetune import fine_tune
ds = self.load_dataset()
n_splits = kwargs.get("n_splits", 20)
n_splits = 20
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint.
single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = ds[train_index]
test = ds[test_index]
train = [ds[i] for i in train_index]
test = [ds[i] for i in test_index]
start = datetime.now()
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
)
model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch(test.X())
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results)
@ -3312,15 +3389,31 @@ class EnviFormer(PackageBasedModel):
for pathway in train_pathways:
for reaction in pathway.edges:
reaction = reaction.edge_label
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
if any(
[
educt in test_educts
for educt in reaction_to_educts[str(reaction.uuid)]
]
):
overlap += 1
continue
train_reactions.append(reaction)
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
train_reactions.append(f"{educts}>>{products}")
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
)
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update(

View File

@ -172,6 +172,7 @@ def predict(
except Exception as e:
pw.kv.update({"status": "failed"})
pw.kv.update(**{"error": str(e)})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():

View File

@ -237,7 +237,6 @@ def get_base_context(request, for_user=None) -> Dict[str, Any]:
"enabled_features": s.FLAGS,
"debug": s.DEBUG,
"external_databases": ExternalDatabase.get_databases(),
"site_id": s.MATOMO_SITE_ID,
},
}
@ -892,7 +891,7 @@ def package_model(request, package_uuid, model_uuid):
return JsonResponse(res, safe=False)
else:
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request)

View File

@ -27,11 +27,10 @@ dependencies = [
"scikit-learn>=1.6.1",
"sentry-sdk[django]>=2.32.0",
"setuptools>=80.8.0",
"polars==1.35.1",
]
[tool.uv.sources]
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" }
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" }
envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" }
envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"}
envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }

View File

@ -56,7 +56,7 @@
(function () {
var u = "//matomo.envipath.com/";
_paq.push(['setTrackerUrl', u + 'matomo.php']);
_paq.push(['setSiteId', '{{ meta.site_id }}']);
_paq.push(['setSiteId', '10']);
var d = document, g = d.createElement('script'), s = d.getElementsByTagName('script')[0];
g.async = true;
g.src = u + 'matomo.js';

View File

@ -1,10 +1,8 @@
import os.path
from tempfile import TemporaryDirectory
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.chem import FormatConverter
from utilities.ml import RuleBasedDataset, EnviFormerDataset
from epdb.models import Reaction, Compound, User, Rule
from utilities.ml import Dataset
class DatasetTest(TestCase):
@ -43,108 +41,12 @@ class DatasetTest(TestCase):
super(DatasetTest, cls).setUpClass()
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_generate_dataset(self):
"""Test generating dataset does not crash"""
self.generate_rule_dataset()
def test_indexing(self):
"""Test indexing a few different ways to check for crashes"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds[5])
print(ds[2, 5])
print(ds[3:6, 2:8])
print(ds[:2, "structure_id"])
def test_add_rows(self):
"""Test adding one row and adding multiple rows"""
ds, reactions, rules = self.generate_rule_dataset()
ds.add_row(list(ds.df.row(1)))
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
def test_times_triggered(self):
"""Check getting times triggered for a rule id"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.times_triggered(rules[0].uuid))
def test_block_indices(self):
"""Test the usages of _block_indices"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.struct_features())
print(ds.triggered())
print(ds.observed())
def test_structure_id(self):
"""Check getting a structure id from row index"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.structure_id(0))
def test_x(self):
"""Test getting X portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.X().df.head())
def test_trig(self):
"""Test getting the triggered portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.trig().df.head())
def test_y(self):
"""Test getting the Y portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.y().df.head())
def test_classification_dataset(self):
"""Test making the classification dataset"""
ds, reactions, rules = self.generate_rule_dataset()
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
class_ds, products = ds.classification_dataset(compounds, rules)
print(class_ds.df.head(5))
print(products[:5])
def test_extra_features(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
print(ds.shape)
def test_to_arff(self):
"""Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_rule_dataset()
ds.to_arff("dataset_arff_test.arff")
def test_save_load(self):
"""Test saving and loading dataset"""
with TemporaryDirectory() as tmpdir:
ds, reactions, rules = self.generate_rule_dataset()
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
self.assertTrue(ds.df.equals(ds_loaded.df))
def test_dataset_example(self):
"""Test with a concrete example checking dataset size"""
def test_smoke(self):
reactions = [r for r in Reaction.objects.filter(package=self.package)]
applicable_rules = [self.rule1]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
ds = Dataset.generate_dataset(reactions, applicable_rules)
self.assertEqual(len(ds.y()), 1)
self.assertEqual(ds.y().df.item(), 1)
def test_enviformer_dataset(self):
ds, reactions = self.generate_enviformer_dataset()
print(ds.X().head())
print(ds.y().head())
def generate_rule_dataset(self):
"""Generate a RuleBasedDataset from test package data"""
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
return ds, reactions, applicable_rules
def generate_enviformer_dataset(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
ds = EnviFormerDataset.generate_dataset(reactions)
return ds, reactions
self.assertEqual(sum(ds.y()[0]), 1)

View File

@ -42,11 +42,13 @@ class EnviFormerTest(TestCase):
threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
mod = EnviFormer.create(
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
mod.evaluate_model(True, eval_packages_objs)
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
@ -55,9 +57,12 @@ class EnviFormerTest(TestCase):
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mods = []
for _ in range(4):
mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
mod = EnviFormer.create(
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset()
mod.build_model()
mods.append(mod)
@ -68,11 +73,15 @@ class EnviFormerTest(TestCase):
# Test pathway prediction
times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)]
print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}")
print(
f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}"
)
# Test eviction by performing three prediction with every model, twice.
times = defaultdict(list)
for _ in range(2): # Eviction should cause the second iteration here to have to reload the models
for _ in range(
2
): # Eviction should cause the second iteration here to have to reload the models
for mod in mods:
for _ in range(3):
times[mod.pk].append(measure_predict(mod))

View File

@ -4,7 +4,7 @@ import numpy as np
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
from epdb.models import User, MLRelativeReasoning, Package
class ModelTest(TestCase):
@ -17,7 +17,7 @@ class ModelTest(TestCase):
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_mlrr(self):
def test_smoke(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
@ -35,9 +35,21 @@ class ModelTest(TestCase):
description="Created MLRelativeReasoning in Testcase",
)
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
mod.evaluate_model(True, eval_packages_objs)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
@ -58,57 +70,3 @@ class ModelTest(TestCase):
# from pprint import pprint
# pprint(mod.eval_results)
def test_applicability(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
threshold=threshold,
name="ECC - BBD - 0.5",
description="Created MLRelativeReasoning in Testcase",
build_app_domain=True, # To test the applicability domain this must be True
app_domain_num_neighbours=5,
app_domain_local_compatibility_threshold=0.5,
app_domain_reliability_threshold=0.5,
)
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
def test_rbrr(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = RuleBasedRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
threshold=threshold,
min_count=5,
max_count=0,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
)
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -2,12 +2,13 @@ import logging
import re
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Dict, TYPE_CHECKING
from typing import List, Optional, Dict, TYPE_CHECKING, Union
from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer
from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.Chem import rdchem
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MolStandardize import rdMolStandardize
@ -90,8 +91,15 @@ class FormatConverter(object):
return Chem.MolToSmiles(mol, canonical=canonical)
@staticmethod
def InChIKey(smiles):
return Chem.MolToInchiKey(FormatConverter.from_smiles(smiles))
def InChIKey(mol_or_smiles: Union[rdchem.Mol | str]):
if isinstance(mol_or_smiles, str):
mol_or_smiles = mol_or_smiles.replace("~", "")
mol_or_smiles = FormatConverter.from_smiles(mol_or_smiles)
if mol_or_smiles is None:
return None
return Chem.MolToInchiKey(mol_or_smiles)
@staticmethod
def InChI(smiles):
@ -107,13 +115,6 @@ class FormatConverter(object):
bitvec = MACCSkeys.GenMACCSKeys(mol)
return bitvec.ToList()
@staticmethod
def morgan(smiles, radius=3, fpSize=2048):
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
mol = Chem.MolFromSmiles(smiles)
fp = finger_gen.GetFingerprint(mol)
return fp.ToList()
@staticmethod
def get_functional_groups(smiles: str) -> List[str]:
res = list()
@ -295,7 +296,8 @@ class FormatConverter(object):
product = GetMolFrags(product, asMols=True)
for p in product:
p = FormatConverter.standardize(
Chem.MolToSmiles(p), remove_stereo=remove_stereo
Chem.MolToSmiles(p).replace("~", ""),
remove_stereo=remove_stereo,
)
prods.append(p)

View File

@ -35,6 +35,7 @@ from epdb.models import (
RuleBasedRelativeReasoning,
Scenario,
SequentialRule,
Setting,
SimpleAmbitRule,
SimpleRDKitRule,
SimpleRule,
@ -1258,3 +1259,46 @@ class PathwayUtils:
res[edge.url] = rule_chain
return res
def _find_intermediates(self, data_pathway, pred_pathway):
pass
def engineer(self, setting: "Setting"):
from epdb.logic import SPathway
# get a fresh copy
pw = Pathway.objects.get(id=self.pathway.pk)
root_nodes = [n.default_node_label.smiles for n in pw.root_nodes]
if len(root_nodes) != 1:
logger.warning(f"Pathway {pw.name} has {len(root_nodes)} root nodes")
return
spw = SPathway(root_nodes[0], None, setting)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# Generate Node / SMILES mapping
node_mapping = {}
from utilities.chem import FormatConverter
for node in pw.nodes:
for snode in spw.smiles_to_node.values():
data_smiles = node.default_node_label.smiles
pred_smiles = snode.smiles
data_key = FormatConverter.InChIKey(data_smiles.replace("~", ""))
pred_key = FormatConverter.InChIKey(pred_smiles.replace("~", ""))
if data_key == pred_key:
node_mapping[snode] = node
print(node_mapping)
return spw
pass

View File

@ -5,14 +5,11 @@ import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
from abc import ABC, abstractmethod
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
import networkx as nx
import numpy as np
from envipy_plugins import Descriptor
from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
@ -29,281 +26,70 @@ if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
self.df = data
class Dataset:
def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns
self.num_labels: int = num_labels
if data is None:
self.data: List[List[str | int | float]] = list()
else:
# Build either an empty dataframe with columns or fill it with list of list data
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
self.data = data
def add_rows(self, rows: List[List[str | int | float]]):
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
self.df.extend(new_rows)
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def add_row(self, row: List[str | int | float]):
"""See add_rows"""
self.add_rows([row])
def block_indices(self, prefix) -> List[int]:
"""Find the indexes in column labels that has the prefix"""
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return indices
@property
def columns(self) -> List[str]:
"""Use the polars dataframe columns"""
return self.df.columns
return min(indices), max(indices)
@property
def shape(self):
return self.df.shape
def structure_id(self):
return self.data[0][0]
@abstractmethod
def X(self, **kwargs):
pass
@abstractmethod
def y(self, **kwargs):
pass
@staticmethod
@abstractmethod
def generate_dataset(reactions, *args, **kwargs):
pass
def __iter__(self):
"""Use polars iter_rows for iterating over the dataset"""
return self.df.iter_rows()
def __getitem__(self, item):
"""Item is passed to polars allowing for advanced indexing.
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path | str"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
)
def __len__(self):
return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
def filter(self, *predicates, **constraints):
return self.__class__(data=self.df.filter(*predicates, **constraints))
def select(self, *exprs, **named_exprs):
return self.__class__(data=self.df.select(*exprs, **named_exprs))
def with_columns(self, *exprs, **name_exprs):
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
multithreaded=multithreaded, maintain_order=maintain_order))
def item(self, row=None, column=None):
return self.df.item(row, column)
def fill_nan(self, value):
return self.__class__(data=self.df.fill_nan(value))
@property
def height(self):
return self.df.height
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def add_row(self, row: List[str | int | float]):
if len(self.columns) != len(row):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
idx = self.columns.index(f"trig_{rule_uuid}")
def struct_features(self) -> List[int]:
times_triggered = 0
for row in self.data:
if row[idx] == 1:
times_triggered += 1
return times_triggered
def struct_features(self) -> Tuple[int, int]:
return self._struct_features
def triggered(self) -> List[int]:
def triggered(self) -> Tuple[int, int]:
return self._triggered
def observed(self) -> List[int]:
def observed(self) -> Tuple[int, int]:
return self._observed
def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.item(index, "structure_id")
def at(self, position: int) -> Dataset:
return Dataset(self.columns, self.num_labels, [self.data[position]])
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
_col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def limit(self, limit: int) -> Dataset:
return Dataset(self.columns, self.num_labels, self.data[:limit])
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, self._observed]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
feat_columns = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
else:
feats = feat_func(compounds[0].smiles)
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
ds_columns = (["structure_id"] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
rows = []
for i, comp in enumerate(compounds):
# Features
feats = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feat = feat_func.get_molecule_descriptors(comp.smiles)
else:
feat = feat_func(comp.smiles)
feats.extend(feat)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feats + trig + obs)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
return ds
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
@ -327,18 +113,186 @@ class RuleBasedDataset(Dataset):
else:
trig.append(0)
prods.append([])
new_row = [struct_id] + features + trig + ([-1] * len(trig))
if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return Dataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set()
for r in reactions:
for e in r.educts.all():
_structures.add(e)
if not educts_only:
for e in r.products:
_structures.add(e)
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
else:
pass
ds = None
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
if ds is None:
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
row_key, col_key = key
# Normalize rows
if isinstance(row_key, int):
rows = [self.data[row_key]]
else:
rows = self.data[row_key]
# Normalize columns
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
@ -350,7 +304,7 @@ class RuleBasedDataset(Dataset):
arff += f"@attribute {c} {{0,1}}\n"
arff += "\n@data\n"
for d in self:
for d in self.data:
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
@ -359,40 +313,10 @@ class RuleBasedDataset(Dataset):
fh.write(arff)
fh.flush()
class EnviFormerDataset(Dataset):
def __init__(self, columns=None, data=None):
super().__init__(columns, data)
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data
stereo = kwargs.get("stereo", False)
rows = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all()
]
)
rows.append([e, p])
ds = EnviFormerDataset(["educts", "products"], rows)
return ds
def __repr__(self):
return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -574,7 +498,7 @@ class EnsembleClassifierChain:
self.classifiers = []
if self.num_labels is None:
self.num_labels = Y.shape[1]
self.num_labels = len(Y[0])
for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
@ -605,7 +529,7 @@ class RelativeReasoning:
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = Y.shape[1]
n_attributes = len(Y[0])
for i in range(n_attributes):
for j in range(n_attributes):
@ -617,8 +541,8 @@ class RelativeReasoning:
countboth = 0
for k in range(n_instances):
vi = Y[k, i]
vj = Y[k, j]
vi = Y[k][i]
vj = Y[k][j]
if vi is None or vj is None:
continue
@ -674,7 +598,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: "RuleBasedDataset"):
def build(self, train_dataset: "Dataset"):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
@ -688,7 +612,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: "RuleBasedDataset"):
def is_applicable(self, classify_instances: "Dataset"):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []

184
uv.lock generated
View File

@ -1,10 +1,6 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.12"
resolution-markers = [
"sys_platform == 'linux' or sys_platform == 'win32'",
"sys_platform != 'linux' and sys_platform != 'win32'",
]
[[package]]
name = "aiohappyeyeballs"
@ -180,19 +176,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" },
]
[[package]]
name = "celery-stubs"
version = "0.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/98/14/b853ada8706a3a301396566b6dd405d1cbb24bff756236a12a01dbe766a4/celery-stubs-0.1.3.tar.gz", hash = "sha256:0fb5345820f8a2bd14e6ffcbef2d10181e12e40f8369f551d7acc99d8d514919", size = 46583, upload-time = "2023-02-10T02:20:11.837Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/7a/4ab2347d13f1f59d10a7337feb9beb002664119f286036785284c6bec150/celery_stubs-0.1.3-py3-none-any.whl", hash = "sha256:dfb9ad27614a8af028b2055bb4a4ae99ca5e9a8d871428a506646d62153218d7", size = 89085, upload-time = "2023-02-10T02:20:09.409Z" },
]
[[package]]
name = "certifi"
version = "2025.10.5"
@ -542,14 +525,13 @@ wheels = [
[[package]]
name = "enviformer"
version = "0.1.0"
source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4#7094be5767748fd63d4a84a5d71f06cf02ba07f3" }
source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2#3f28f60cfa1df814cf7559303b5130933efa40ae" }
dependencies = [
{ name = "joblib" },
{ name = "lightning" },
{ name = "pytorch-lightning" },
{ name = "scikit-learn" },
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torch" },
]
[[package]]
@ -564,6 +546,7 @@ dependencies = [
{ name = "django-ninja" },
{ name = "django-oauth-toolkit" },
{ name = "django-polymorphic" },
{ name = "django-stubs" },
{ name = "enviformer" },
{ name = "envipy-additional-information" },
{ name = "envipy-ambit" },
@ -571,7 +554,6 @@ dependencies = [
{ name = "epam-indigo" },
{ name = "gunicorn" },
{ name = "networkx" },
{ name = "polars" },
{ name = "psycopg2-binary" },
{ name = "python-dotenv" },
{ name = "rdkit" },
@ -584,8 +566,6 @@ dependencies = [
[package.optional-dependencies]
dev = [
{ name = "celery-stubs" },
{ name = "django-stubs" },
{ name = "poethepoet" },
{ name = "pre-commit" },
{ name = "ruff" },
@ -597,16 +577,15 @@ ms-login = [
[package.metadata]
requires-dist = [
{ name = "celery", specifier = ">=5.5.2" },
{ name = "celery-stubs", marker = "extra == 'dev'", specifier = "==0.1.3" },
{ name = "django", specifier = ">=5.2.1" },
{ name = "django-extensions", specifier = ">=4.1" },
{ name = "django-model-utils", specifier = ">=5.0.0" },
{ name = "django-ninja", specifier = ">=1.4.1" },
{ name = "django-oauth-toolkit", specifier = ">=3.0.1" },
{ name = "django-polymorphic", specifier = ">=4.1.0" },
{ name = "django-stubs", marker = "extra == 'dev'", specifier = ">=5.2.4" },
{ name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4" },
{ name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7" },
{ name = "django-stubs", specifier = ">=5.2.4" },
{ name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2" },
{ name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4" },
{ name = "envipy-ambit", git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" },
{ name = "envipy-plugins", git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git?rev=v0.1.0" },
{ name = "epam-indigo", specifier = ">=1.30.1" },
@ -614,7 +593,6 @@ requires-dist = [
{ name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" },
{ name = "networkx", specifier = ">=3.4.2" },
{ name = "poethepoet", marker = "extra == 'dev'", specifier = ">=0.37.0" },
{ name = "polars", specifier = "==1.35.1" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.3.0" },
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
{ name = "python-dotenv", specifier = ">=1.1.0" },
@ -630,8 +608,8 @@ provides-extras = ["ms-login", "dev"]
[[package]]
name = "envipy-additional-information"
version = "0.1.7"
source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7#d02a5d5e6a931e6565ea86127813acf7e4b33a30" }
version = "0.1.0"
source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4#4da604090bf7cf1f3f552d69485472dbc623030a" }
dependencies = [
{ name = "pydantic" },
]
@ -887,8 +865,7 @@ dependencies = [
{ name = "packaging" },
{ name = "pytorch-lightning" },
{ name = "pyyaml" },
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torch" },
{ name = "torchmetrics" },
{ name = "tqdm" },
{ name = "typing-extensions" },
@ -1097,47 +1074,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
]
[[package]]
name = "mypy"
version = "1.18.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy-extensions" },
{ name = "pathspec" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/77/8f0d0001ffad290cef2f7f216f96c814866248a0b92a722365ed54648e7e/mypy-1.18.2.tar.gz", hash = "sha256:06a398102a5f203d7477b2923dda3634c36727fa5c237d8f859ef90c42a9924b", size = 3448846, upload-time = "2025-09-19T00:11:10.519Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/06/dfdd2bc60c66611dd8335f463818514733bc763e4760dee289dcc33df709/mypy-1.18.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:33eca32dd124b29400c31d7cf784e795b050ace0e1f91b8dc035672725617e34", size = 12908273, upload-time = "2025-09-19T00:10:58.321Z" },
{ url = "https://files.pythonhosted.org/packages/81/14/6a9de6d13a122d5608e1a04130724caf9170333ac5a924e10f670687d3eb/mypy-1.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3c47adf30d65e89b2dcd2fa32f3aeb5e94ca970d2c15fcb25e297871c8e4764", size = 11920910, upload-time = "2025-09-19T00:10:20.043Z" },
{ url = "https://files.pythonhosted.org/packages/5f/a9/b29de53e42f18e8cc547e38daa9dfa132ffdc64f7250e353f5c8cdd44bee/mypy-1.18.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d6c838e831a062f5f29d11c9057c6009f60cb294fea33a98422688181fe2893", size = 12465585, upload-time = "2025-09-19T00:10:33.005Z" },
{ url = "https://files.pythonhosted.org/packages/77/ae/6c3d2c7c61ff21f2bee938c917616c92ebf852f015fb55917fd6e2811db2/mypy-1.18.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01199871b6110a2ce984bde85acd481232d17413868c9807e95c1b0739a58914", size = 13348562, upload-time = "2025-09-19T00:10:11.51Z" },
{ url = "https://files.pythonhosted.org/packages/4d/31/aec68ab3b4aebdf8f36d191b0685d99faa899ab990753ca0fee60fb99511/mypy-1.18.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a2afc0fa0b0e91b4599ddfe0f91e2c26c2b5a5ab263737e998d6817874c5f7c8", size = 13533296, upload-time = "2025-09-19T00:10:06.568Z" },
{ url = "https://files.pythonhosted.org/packages/9f/83/abcb3ad9478fca3ebeb6a5358bb0b22c95ea42b43b7789c7fb1297ca44f4/mypy-1.18.2-cp312-cp312-win_amd64.whl", hash = "sha256:d8068d0afe682c7c4897c0f7ce84ea77f6de953262b12d07038f4d296d547074", size = 9828828, upload-time = "2025-09-19T00:10:28.203Z" },
{ url = "https://files.pythonhosted.org/packages/5f/04/7f462e6fbba87a72bc8097b93f6842499c428a6ff0c81dd46948d175afe8/mypy-1.18.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:07b8b0f580ca6d289e69209ec9d3911b4a26e5abfde32228a288eb79df129fcc", size = 12898728, upload-time = "2025-09-19T00:10:01.33Z" },
{ url = "https://files.pythonhosted.org/packages/99/5b/61ed4efb64f1871b41fd0b82d29a64640f3516078f6c7905b68ab1ad8b13/mypy-1.18.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed4482847168439651d3feee5833ccedbf6657e964572706a2adb1f7fa4dfe2e", size = 11910758, upload-time = "2025-09-19T00:10:42.607Z" },
{ url = "https://files.pythonhosted.org/packages/3c/46/d297d4b683cc89a6e4108c4250a6a6b717f5fa96e1a30a7944a6da44da35/mypy-1.18.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ad2afadd1e9fea5cf99a45a822346971ede8685cc581ed9cd4d42eaf940986", size = 12475342, upload-time = "2025-09-19T00:11:00.371Z" },
{ url = "https://files.pythonhosted.org/packages/83/45/4798f4d00df13eae3bfdf726c9244bcb495ab5bd588c0eed93a2f2dd67f3/mypy-1.18.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a431a6f1ef14cf8c144c6b14793a23ec4eae3db28277c358136e79d7d062f62d", size = 13338709, upload-time = "2025-09-19T00:11:03.358Z" },
{ url = "https://files.pythonhosted.org/packages/d7/09/479f7358d9625172521a87a9271ddd2441e1dab16a09708f056e97007207/mypy-1.18.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ab28cc197f1dd77a67e1c6f35cd1f8e8b73ed2217e4fc005f9e6a504e46e7ba", size = 13529806, upload-time = "2025-09-19T00:10:26.073Z" },
{ url = "https://files.pythonhosted.org/packages/71/cf/ac0f2c7e9d0ea3c75cd99dff7aec1c9df4a1376537cb90e4c882267ee7e9/mypy-1.18.2-cp313-cp313-win_amd64.whl", hash = "sha256:0e2785a84b34a72ba55fb5daf079a1003a34c05b22238da94fcae2bbe46f3544", size = 9833262, upload-time = "2025-09-19T00:10:40.035Z" },
{ url = "https://files.pythonhosted.org/packages/5a/0c/7d5300883da16f0063ae53996358758b2a2df2a09c72a5061fa79a1f5006/mypy-1.18.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:62f0e1e988ad41c2a110edde6c398383a889d95b36b3e60bcf155f5164c4fdce", size = 12893775, upload-time = "2025-09-19T00:10:03.814Z" },
{ url = "https://files.pythonhosted.org/packages/50/df/2cffbf25737bdb236f60c973edf62e3e7b4ee1c25b6878629e88e2cde967/mypy-1.18.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8795a039bab805ff0c1dfdb8cd3344642c2b99b8e439d057aba30850b8d3423d", size = 11936852, upload-time = "2025-09-19T00:10:51.631Z" },
{ url = "https://files.pythonhosted.org/packages/be/50/34059de13dd269227fb4a03be1faee6e2a4b04a2051c82ac0a0b5a773c9a/mypy-1.18.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ca1e64b24a700ab5ce10133f7ccd956a04715463d30498e64ea8715236f9c9c", size = 12480242, upload-time = "2025-09-19T00:11:07.955Z" },
{ url = "https://files.pythonhosted.org/packages/5b/11/040983fad5132d85914c874a2836252bbc57832065548885b5bb5b0d4359/mypy-1.18.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d924eef3795cc89fecf6bedc6ed32b33ac13e8321344f6ddbf8ee89f706c05cb", size = 13326683, upload-time = "2025-09-19T00:09:55.572Z" },
{ url = "https://files.pythonhosted.org/packages/e9/ba/89b2901dd77414dd7a8c8729985832a5735053be15b744c18e4586e506ef/mypy-1.18.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20c02215a080e3a2be3aa50506c67242df1c151eaba0dcbc1e4e557922a26075", size = 13514749, upload-time = "2025-09-19T00:10:44.827Z" },
{ url = "https://files.pythonhosted.org/packages/25/bc/cc98767cffd6b2928ba680f3e5bc969c4152bf7c2d83f92f5a504b92b0eb/mypy-1.18.2-cp314-cp314-win_amd64.whl", hash = "sha256:749b5f83198f1ca64345603118a6f01a4e99ad4bf9d103ddc5a3200cc4614adf", size = 9982959, upload-time = "2025-09-19T00:10:37.344Z" },
{ url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" },
]
[[package]]
name = "mypy-extensions"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
]
[[package]]
name = "networkx"
version = "3.5"
@ -1256,7 +1192,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-cublas-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
@ -1267,7 +1203,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-nvjitlink-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
@ -1294,9 +1230,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-cublas-cu12" },
{ name = "nvidia-cusparse-cu12" },
{ name = "nvidia-nvjitlink-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
@ -1307,7 +1243,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-nvjitlink-cu12" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
@ -1372,15 +1308,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/aa/18/a8444036c6dd65ba3624c63b734d3ba95ba63ace513078e1580590075d21/pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364", size = 5955, upload-time = "2020-09-16T19:21:11.409Z" },
]
[[package]]
name = "pathspec"
version = "0.12.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
]
[[package]]
name = "pillow"
version = "11.3.0"
@ -1469,32 +1396,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/92/1b/5337af1a6a478d25a3e3c56b9b4b42b0a160314e02f4a0498d5322c8dac4/poethepoet-0.37.0-py3-none-any.whl", hash = "sha256:861790276315abcc8df1b4bd60e28c3d48a06db273edd3092f3c94e1a46e5e22", size = 90062, upload-time = "2025-08-11T18:00:27.595Z" },
]
[[package]]
name = "polars"
version = "1.35.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "polars-runtime-32" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9b/5b/3caad788d93304026cbf0ab4c37f8402058b64a2f153b9c62f8b30f5d2ee/polars-1.35.1.tar.gz", hash = "sha256:06548e6d554580151d6ca7452d74bceeec4640b5b9261836889b8e68cfd7a62e", size = 694881, upload-time = "2025-10-30T12:12:52.294Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/4c/21a227b722534404241c2a76beceb7463469d50c775a227fc5c209eb8adc/polars-1.35.1-py3-none-any.whl", hash = "sha256:c29a933f28aa330d96a633adbd79aa5e6a6247a802a720eead9933f4613bdbf4", size = 783598, upload-time = "2025-10-30T12:11:54.668Z" },
]
[[package]]
name = "polars-runtime-32"
version = "1.35.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/df/3e/19c252e8eb4096300c1a36ec3e50a27e5fa9a1ccaf32d3927793c16abaee/polars_runtime_32-1.35.1.tar.gz", hash = "sha256:f6b4ec9cd58b31c87af1b8c110c9c986d82345f1d50d7f7595b5d447a19dc365", size = 2696218, upload-time = "2025-10-30T12:12:53.479Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/08/2c/da339459805a26105e9d9c2f07e43ca5b8baeee55acd5457e6881487a79a/polars_runtime_32-1.35.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6f051a42f6ae2f26e3bc2cf1f170f2120602976e2a3ffb6cfba742eecc7cc620", size = 40525100, upload-time = "2025-10-30T12:11:58.098Z" },
{ url = "https://files.pythonhosted.org/packages/27/70/a0733568b3533481924d2ce68b279ab3d7334e5fa6ed259f671f650b7c5e/polars_runtime_32-1.35.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:c2232f9cf05ba59efc72d940b86c033d41fd2d70bf2742e8115ed7112a766aa9", size = 36701908, upload-time = "2025-10-30T12:12:02.166Z" },
{ url = "https://files.pythonhosted.org/packages/46/54/6c09137bef9da72fd891ba58c2962cc7c6c5cad4649c0e668d6b344a9d7b/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42f9837348557fd674477ea40a6ac8a7e839674f6dd0a199df24be91b026024c", size = 41317692, upload-time = "2025-10-30T12:12:04.928Z" },
{ url = "https://files.pythonhosted.org/packages/22/55/81c5b266a947c339edd7fbaa9e1d9614012d02418453f48b76cc177d3dd9/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:c873aeb36fed182d5ebc35ca17c7eb193fe83ae2ea551ee8523ec34776731390", size = 37853058, upload-time = "2025-10-30T12:12:08.342Z" },
{ url = "https://files.pythonhosted.org/packages/6c/58/be8b034d559eac515f52408fd6537be9bea095bc0388946a4e38910d3d50/polars_runtime_32-1.35.1-cp39-abi3-win_amd64.whl", hash = "sha256:35cde9453ca7032933f0e58e9ed4388f5a1e415dd0db2dd1e442c81d815e630c", size = 41289554, upload-time = "2025-10-30T12:12:11.104Z" },
{ url = "https://files.pythonhosted.org/packages/f4/7f/e0111b9e2a1169ea82cde3ded9c92683e93c26dfccd72aee727996a1ac5b/polars_runtime_32-1.35.1-cp39-abi3-win_arm64.whl", hash = "sha256:fd77757a6c9eb9865c4bfb7b07e22225207c6b7da382bd0b9bd47732f637105d", size = 36958878, upload-time = "2025-10-30T12:12:15.206Z" },
]
[[package]]
name = "pre-commit"
version = "4.3.0"
@ -1769,8 +1670,7 @@ dependencies = [
{ name = "lightning-utilities" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torch" },
{ name = "torchmetrics" },
{ name = "tqdm" },
{ name = "typing-extensions" },
@ -1854,11 +1754,11 @@ wheels = [
[[package]]
name = "redis"
version = "7.0.1"
version = "6.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" }
sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" },
{ url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" },
]
[[package]]
@ -2063,40 +1963,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
]
[[package]]
name = "torch"
version = "2.8.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform != 'linux' and sys_platform != 'win32'",
]
dependencies = [
{ name = "filelock", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "fsspec", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "jinja2", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "networkx", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "setuptools", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "sympy", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" },
{ url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" },
{ url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" },
]
[[package]]
name = "torch"
version = "2.8.0+cu128"
source = { registry = "https://download.pytorch.org/whl/cu128" }
resolution-markers = [
"sys_platform == 'linux' or sys_platform == 'win32'",
]
dependencies = [
{ name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "filelock" },
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
@ -2111,10 +1986,10 @@ dependencies = [
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "setuptools" },
{ name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "typing-extensions" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" },
@ -2133,8 +2008,7 @@ dependencies = [
{ name = "lightning-utilities" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/85/2e/48a887a59ecc4a10ce9e8b35b3e3c5cef29d902c4eac143378526e7485cb/torchmetrics-1.8.2.tar.gz", hash = "sha256:cf64a901036bf107f17a524009eea7781c9c5315d130713aeca5747a686fe7a5", size = 580679, upload-time = "2025-09-03T14:00:54.077Z" }
wheels = [
@ -2158,7 +2032,7 @@ name = "triton"
version = "3.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "setuptools" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" },