forked from enviPath/enviPy
Compare commits
14 Commits
fix/missin
...
enhancemen
| Author | SHA1 | Date | |
|---|---|---|---|
| 9ec5e433ea | |||
| dddea79daf | |||
| cfd8d7440b | |||
| 6a5413b492 | |||
| 8282855975 | |||
| 09ddd46d69 | |||
| 9f0e396437 | |||
| 5dc4c822c4 | |||
| f1f7ce344c | |||
| 13af49488e | |||
| ac5d370b18 | |||
| ff51e48f90 | |||
| 8166df6f39 | |||
| 2980a75daa |
@ -1542,9 +1542,7 @@ class SPathway(object):
|
||||
if sub.app_domain_assessment is None:
|
||||
if self.prediction_setting.model:
|
||||
if self.prediction_setting.model.app_domain:
|
||||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
||||
sub.smiles
|
||||
)[0]
|
||||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
|
||||
|
||||
if self.persist is not None:
|
||||
n = self.snode_persist_lookup[sub]
|
||||
@ -1576,11 +1574,7 @@ class SPathway(object):
|
||||
app_domain_assessment = None
|
||||
if self.prediction_setting.model:
|
||||
if self.prediction_setting.model.app_domain:
|
||||
app_domain_assessment = (
|
||||
self.prediction_setting.model.app_domain.assess(c)[
|
||||
0
|
||||
]
|
||||
)
|
||||
app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
|
||||
|
||||
self.smiles_to_node[c] = SNode(
|
||||
c, sub.depth + 1, app_domain_assessment
|
||||
|
||||
307
epdb/models.py
307
epdb/models.py
@ -28,7 +28,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
|
||||
EnviFormerDataset, Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2175,7 +2176,7 @@ class PackageBasedModel(EPModel):
|
||||
|
||||
applicable_rules = self.applicable_rules
|
||||
reactions = list(self._get_reactions())
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
@ -2184,7 +2185,7 @@ class PackageBasedModel(EPModel):
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "Dataset":
|
||||
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
@ -2225,7 +2226,7 @@ class PackageBasedModel(EPModel):
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
@ -2343,37 +2344,37 @@ class PackageBasedModel(EPModel):
|
||||
eval_reactions = list(
|
||||
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
)
|
||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
ds = self.load_dataset()
|
||||
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
else:
|
||||
X = np.array(ds.X(na_replacement=np.nan))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||
|
||||
n_splits = 20
|
||||
n_splits = kwargs.get("n_splits", 20)
|
||||
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
splits = list(shuff.split(X))
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
models = Parallel(n_jobs=10)(
|
||||
models = Parallel(n_jobs=min(10, len(splits)))(
|
||||
delayed(train_func)(X, y, train_index, self._model_args())
|
||||
for train_index, _ in splits
|
||||
)
|
||||
evaluations = Parallel(n_jobs=10)(
|
||||
evaluations = Parallel(n_jobs=min(10, len(splits)))(
|
||||
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||
for model, (_, test_index) in zip(models, splits)
|
||||
)
|
||||
@ -2585,11 +2586,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return rbrr
|
||||
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||
model = RelativeReasoning(
|
||||
start_index=ds.triggered()[0],
|
||||
end_index=ds.triggered()[1],
|
||||
end_index=ds.triggered()[-1],
|
||||
)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
@ -2599,7 +2600,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
return {
|
||||
"clz": "RuleBaseRelativeReasoning",
|
||||
"start_index": ds.triggered()[0],
|
||||
"end_index": ds.triggered()[1],
|
||||
"end_index": ds.triggered()[-1],
|
||||
}
|
||||
|
||||
def _save_model(self, model):
|
||||
@ -2687,11 +2688,11 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return mlrr
|
||||
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||
|
||||
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
||||
model.fit(X, y)
|
||||
model.fit(X.to_numpy(), y.to_numpy())
|
||||
return model
|
||||
|
||||
def _model_args(self):
|
||||
@ -2714,7 +2715,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
pred = self.model.predict_proba(classify_ds.X())
|
||||
pred = self.model.predict_proba(classify_ds.X().to_numpy())
|
||||
|
||||
res = MLRelativeReasoning.combine_products_and_probs(
|
||||
self.applicable_rules, pred[0], classify_prods[0]
|
||||
@ -2759,7 +2760,9 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
|
||||
@cached_property
|
||||
def training_set_probs(self):
|
||||
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
||||
ds = self.model.load_dataset()
|
||||
col_ids = ds.block_indices("prob")
|
||||
return ds[:, col_ids]
|
||||
|
||||
def build(self):
|
||||
ds = self.model.load_dataset()
|
||||
@ -2767,9 +2770,9 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
start = datetime.now()
|
||||
|
||||
# Get Trainingset probs and dump them as they're required when using the app domain
|
||||
probs = self.model.model.predict_proba(ds.X())
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
||||
joblib.dump(probs, f)
|
||||
probs = self.model.model.predict_proba(ds.X().to_numpy())
|
||||
ds.add_probs(probs)
|
||||
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
|
||||
|
||||
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
||||
ad.build(ds)
|
||||
@ -2792,16 +2795,19 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
joblib.dump(ad, f)
|
||||
|
||||
def assess(self, structure: Union[str, "CompoundStructure"]):
|
||||
return self.assess_batch([structure])[0]
|
||||
|
||||
def assess_batch(self, structures: List["CompoundStructure | str"]):
|
||||
ds = self.model.load_dataset()
|
||||
|
||||
if isinstance(structure, CompoundStructure):
|
||||
smiles = structure.smiles
|
||||
else:
|
||||
smiles = structure
|
||||
smiles = []
|
||||
for struct in structures:
|
||||
if isinstance(struct, CompoundStructure):
|
||||
smiles.append(structures.smiles)
|
||||
else:
|
||||
smiles.append(structures)
|
||||
|
||||
assessment_ds, assessment_prods = ds.classification_dataset(
|
||||
[structure], self.model.applicable_rules
|
||||
)
|
||||
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
|
||||
|
||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||
# {
|
||||
@ -2814,82 +2820,46 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
||||
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
||||
# with a given assessment structure under a particular rule.
|
||||
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
qualified_neighbours_per_rule: Dict = {}
|
||||
|
||||
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
||||
feature = ds.columns[feature_index]
|
||||
if feature.startswith("trig_"):
|
||||
# TODO unroll loop
|
||||
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
||||
if int(cx[feature_index]) == 1:
|
||||
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
||||
if int(tx[feature_index]) == 1:
|
||||
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
||||
|
||||
probs = self.training_set_probs
|
||||
# preds = self.model.model.predict_proba(assessment_ds.X())
|
||||
import polars as pl
|
||||
# Select only the triggered columns
|
||||
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
|
||||
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
|
||||
# trigger that rule.
|
||||
train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
|
||||
for trig_uuid, value in row.items() if value == 1}
|
||||
qualified_neighbours_per_rule[i] = train_trig
|
||||
rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
|
||||
preds = self.model.combine_products_and_probs(
|
||||
self.model.applicable_rules,
|
||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
||||
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
|
||||
assessment_prods[0],
|
||||
)
|
||||
|
||||
assessments = list()
|
||||
|
||||
# loop through our assessment dataset
|
||||
for i, instance in enumerate(assessment_ds):
|
||||
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
|
||||
rule_reliabilities = dict()
|
||||
local_compatibilities = dict()
|
||||
neighbours_per_rule = dict()
|
||||
neighbor_probs_per_rule = dict()
|
||||
|
||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
||||
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
||||
# train dataset
|
||||
train_instances = []
|
||||
for v in vals:
|
||||
train_instances.append((v, ds.at(v)))
|
||||
|
||||
# sf is a tuple with start/end index of the features
|
||||
sf = ds.struct_features()
|
||||
|
||||
# compute tanimoto distance for all neighbours
|
||||
# result ist a list of tuples with train index and computed distance
|
||||
dists = self._compute_distances(
|
||||
instance.X()[0][sf[0] : sf[1]],
|
||||
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
|
||||
)
|
||||
|
||||
dists_with_index = list()
|
||||
for ti, dist in zip(train_instances, dists):
|
||||
dists_with_index.append((ti[0], dist[1]))
|
||||
for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
|
||||
# compute tanimoto distance for all neighbours and add to dataset
|
||||
dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
|
||||
train_instances[:, train_instances.struct_features()].to_numpy())
|
||||
train_instances = train_instances.with_columns(dist=pl.Series(dists))
|
||||
|
||||
# sort them in a descending way and take at most `self.num_neighbours`
|
||||
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
||||
dists_with_index = dists_with_index[: self.num_neighbours]
|
||||
|
||||
train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours]
|
||||
# compute average distance
|
||||
rule_reliabilities[rule_idx] = (
|
||||
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
|
||||
if len(dists_with_index) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
|
||||
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
||||
local_compatibilities[rule_idx] = self._compute_compatibility(
|
||||
rule_idx, probs, neighbour_datasets
|
||||
)
|
||||
neighbours_per_rule[rule_idx] = [
|
||||
CompoundStructure.objects.get(uuid=ds[1].structure_id())
|
||||
for ds in neighbour_datasets
|
||||
]
|
||||
neighbor_probs_per_rule[rule_idx] = [
|
||||
probs[d[0]][rule_idx] for d in dists_with_index
|
||||
]
|
||||
local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances)
|
||||
neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]))
|
||||
neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
|
||||
|
||||
ad_res = {
|
||||
"ad_params": {
|
||||
@ -2900,23 +2870,21 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
||||
},
|
||||
"assessment": {
|
||||
"smiles": smiles,
|
||||
"inside_app_domain": self.pca.is_applicable(instance)[0],
|
||||
"smiles": smiles[i],
|
||||
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
|
||||
},
|
||||
}
|
||||
|
||||
transformations = list()
|
||||
for rule_idx in rule_reliabilities.keys():
|
||||
rule = Rule.objects.get(
|
||||
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
|
||||
)
|
||||
for rule_uuid in rule_reliabilities.keys():
|
||||
rule = Rule.objects.get(uuid=rule_uuid)
|
||||
|
||||
rule_data = rule.simple_json()
|
||||
rule_data["image"] = f"{rule.url}?image=svg"
|
||||
|
||||
neighbors = []
|
||||
for n, n_prob in zip(
|
||||
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]
|
||||
neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
|
||||
):
|
||||
neighbor = n.simple_json()
|
||||
neighbor["image"] = f"{n.url}?image=svg"
|
||||
@ -2933,14 +2901,14 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
|
||||
transformation = {
|
||||
"rule": rule_data,
|
||||
"reliability": rule_reliabilities[rule_idx],
|
||||
"reliability": rule_reliabilities[rule_uuid],
|
||||
# We're setting it here to False, as we don't know whether "assess" is called during pathway
|
||||
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
|
||||
"is_predicted": False,
|
||||
"local_compatibility": local_compatibilities[rule_idx],
|
||||
"probability": preds[rule_idx].probability,
|
||||
"local_compatibility": local_compatibilities[rule_uuid],
|
||||
"probability": preds[rule_to_i[rule_uuid]].probability,
|
||||
"transformation_products": [
|
||||
x.product_set for x in preds[rule_idx].product_sets
|
||||
x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
|
||||
],
|
||||
"times_triggered": ds.times_triggered(str(rule.uuid)),
|
||||
"neighbors": neighbors,
|
||||
@ -2958,32 +2926,21 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||
from utilities.ml import tanimoto_distance
|
||||
|
||||
distances = [
|
||||
(i, tanimoto_distance(classify_instance, train))
|
||||
for i, train in enumerate(train_instances)
|
||||
]
|
||||
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
|
||||
return distances
|
||||
|
||||
@staticmethod
|
||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
|
||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
||||
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
|
||||
accuracy = 0.0
|
||||
|
||||
for n in neighbours:
|
||||
obs = n[1].y()[0][rule_idx]
|
||||
pred = preds[n[0]][rule_idx]
|
||||
if obs and pred:
|
||||
tp += 1
|
||||
elif not obs and pred:
|
||||
fp += 1
|
||||
elif obs and not pred:
|
||||
fn += 1
|
||||
else:
|
||||
tn += 1
|
||||
# Jaccard Index
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||
|
||||
import polars as pl
|
||||
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
|
||||
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
|
||||
# Compute tp, tn, fp, fn using polars expressions
|
||||
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
|
||||
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
|
||||
fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
|
||||
fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
|
||||
if tp + tn > 0.0:
|
||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||
return accuracy
|
||||
|
||||
|
||||
@ -3084,44 +3041,24 @@ class EnviFormer(PackageBasedModel):
|
||||
self.save()
|
||||
|
||||
start = datetime.now()
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
co2 = {"C(=O)=O", "O=C=O"}
|
||||
ds = []
|
||||
for reaction in self._get_reactions():
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
if products not in co2:
|
||||
ds.append(f"{educts}>>{products}")
|
||||
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(f, "w") as d_file:
|
||||
json.dump(ds, d_file)
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "Dataset":
|
||||
def load_dataset(self):
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(ds_path) as d_file:
|
||||
ds = json.load(d_file)
|
||||
return ds
|
||||
return EnviFormerDataset.load(ds_path)
|
||||
|
||||
def _fit_model(self, ds):
|
||||
# Call to enviFormer's fine_tune function and return the model
|
||||
from enviformer.finetune import fine_tune
|
||||
|
||||
start = datetime.now()
|
||||
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||
return model
|
||||
@ -3137,7 +3074,7 @@ class EnviFormer(PackageBasedModel):
|
||||
args = {"clz": "EnviFormer"}
|
||||
return args
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
@ -3152,21 +3089,20 @@ class EnviFormer(PackageBasedModel):
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
||||
def evaluate_sg(test_ds, predictions, model_thresh):
|
||||
# Group the true products of reactions with the same reactant together
|
||||
assert len(test_ds) == len(predictions)
|
||||
true_dict = {}
|
||||
for r in test_reactions:
|
||||
reactant, true_product_set = r.split(">>")
|
||||
for r in test_ds:
|
||||
reactant, true_product_set = r
|
||||
true_product_set = {p for p in true_product_set.split(".")}
|
||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||
assert len(test_reactions) == len(predictions)
|
||||
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
|
||||
|
||||
# Group the predicted products of reactions with the same reactant together
|
||||
pred_dict = {}
|
||||
for k, pred in enumerate(predictions):
|
||||
pred_smiles, pred_proba = zip(*pred.items())
|
||||
reactant, true_product = test_reactions[k].split(">>")
|
||||
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
|
||||
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
||||
for smiles, proba in zip(pred_smiles, pred_proba):
|
||||
smiles = set(smiles.split("."))
|
||||
@ -3201,7 +3137,7 @@ class EnviFormer(PackageBasedModel):
|
||||
break
|
||||
|
||||
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
||||
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
||||
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
|
||||
# Precision is TP (correct) / TP + FP (predicted)
|
||||
prec = {
|
||||
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
||||
@ -3280,47 +3216,32 @@ class EnviFormer(PackageBasedModel):
|
||||
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
ds = []
|
||||
for reaction in Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()
|
||||
).distinct():
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
ds.append(f"{educts}>>{products}")
|
||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()).distinct())
|
||||
test_result = self.model.predict_batch(ds.X())
|
||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
from enviformer.finetune import fine_tune
|
||||
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
n_splits = kwargs.get("n_splits", 20)
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
||||
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
single_gen_results = []
|
||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||
train = [ds[i] for i in train_index]
|
||||
test = [ds[i] for i in test_index]
|
||||
train = ds[train_index]
|
||||
test = ds[test_index]
|
||||
start = datetime.now()
|
||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||
end = datetime.now()
|
||||
logger.debug(
|
||||
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
||||
)
|
||||
model.to(s.ENVIFORMER_DEVICE)
|
||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
||||
test_result = model.predict_batch(test.X())
|
||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||
|
||||
self.eval_results = self.compute_averages(single_gen_results)
|
||||
@ -3391,31 +3312,15 @@ class EnviFormer(PackageBasedModel):
|
||||
for pathway in train_pathways:
|
||||
for reaction in pathway.edges:
|
||||
reaction = reaction.edge_label
|
||||
if any(
|
||||
[
|
||||
educt in test_educts
|
||||
for educt in reaction_to_educts[str(reaction.uuid)]
|
||||
]
|
||||
):
|
||||
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||
overlap += 1
|
||||
continue
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
train_reactions.append(f"{educts}>>{products}")
|
||||
train_reactions.append(reaction)
|
||||
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
|
||||
logging.debug(
|
||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
||||
)
|
||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
||||
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
|
||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||
|
||||
self.eval_results.update(
|
||||
|
||||
@ -892,7 +892,7 @@ def package_model(request, package_uuid, model_uuid):
|
||||
return JsonResponse(res, safe=False)
|
||||
|
||||
else:
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
|
||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
|
||||
return JsonResponse(app_domain_assessment, safe=False)
|
||||
|
||||
context = get_base_context(request)
|
||||
|
||||
@ -27,10 +27,11 @@ dependencies = [
|
||||
"scikit-learn>=1.6.1",
|
||||
"sentry-sdk[django]>=2.32.0",
|
||||
"setuptools>=80.8.0",
|
||||
"polars==1.35.1",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" }
|
||||
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" }
|
||||
envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" }
|
||||
envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"}
|
||||
envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import os.path
|
||||
from tempfile import TemporaryDirectory
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import Reaction, Compound, User, Rule
|
||||
from utilities.ml import Dataset
|
||||
from epdb.models import Reaction, Compound, User, Rule, Package
|
||||
from utilities.chem import FormatConverter
|
||||
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
||||
|
||||
|
||||
class DatasetTest(TestCase):
|
||||
@ -41,12 +43,108 @@ class DatasetTest(TestCase):
|
||||
super(DatasetTest, cls).setUpClass()
|
||||
cls.user = User.objects.get(username="anonymous")
|
||||
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
||||
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
||||
|
||||
def test_smoke(self):
|
||||
def test_generate_dataset(self):
|
||||
"""Test generating dataset does not crash"""
|
||||
self.generate_rule_dataset()
|
||||
|
||||
def test_indexing(self):
|
||||
"""Test indexing a few different ways to check for crashes"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds[5])
|
||||
print(ds[2, 5])
|
||||
print(ds[3:6, 2:8])
|
||||
print(ds[:2, "structure_id"])
|
||||
|
||||
def test_add_rows(self):
|
||||
"""Test adding one row and adding multiple rows"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
ds.add_row(list(ds.df.row(1)))
|
||||
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
|
||||
|
||||
def test_times_triggered(self):
|
||||
"""Check getting times triggered for a rule id"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.times_triggered(rules[0].uuid))
|
||||
|
||||
def test_block_indices(self):
|
||||
"""Test the usages of _block_indices"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.struct_features())
|
||||
print(ds.triggered())
|
||||
print(ds.observed())
|
||||
|
||||
def test_structure_id(self):
|
||||
"""Check getting a structure id from row index"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.structure_id(0))
|
||||
|
||||
def test_x(self):
|
||||
"""Test getting X portion of the dataframe"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.X().df.head())
|
||||
|
||||
def test_trig(self):
|
||||
"""Test getting the triggered portion of the dataframe"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.trig().df.head())
|
||||
|
||||
def test_y(self):
|
||||
"""Test getting the Y portion of the dataframe"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
print(ds.y().df.head())
|
||||
|
||||
def test_classification_dataset(self):
|
||||
"""Test making the classification dataset"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
|
||||
class_ds, products = ds.classification_dataset(compounds, rules)
|
||||
print(class_ds.df.head(5))
|
||||
print(products[:5])
|
||||
|
||||
def test_extra_features(self):
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
|
||||
print(ds.shape)
|
||||
|
||||
def test_to_arff(self):
|
||||
"""Test exporting the arff version of the dataset"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
ds.to_arff("dataset_arff_test.arff")
|
||||
|
||||
def test_save_load(self):
|
||||
"""Test saving and loading dataset"""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||
self.assertTrue(ds.df.equals(ds_loaded.df))
|
||||
|
||||
def test_dataset_example(self):
|
||||
"""Test with a concrete example checking dataset size"""
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
||||
applicable_rules = [self.rule1]
|
||||
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules)
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||
|
||||
self.assertEqual(len(ds.y()), 1)
|
||||
self.assertEqual(sum(ds.y()[0]), 1)
|
||||
self.assertEqual(ds.y().df.item(), 1)
|
||||
|
||||
def test_enviformer_dataset(self):
|
||||
ds, reactions = self.generate_enviformer_dataset()
|
||||
print(ds.X().head())
|
||||
print(ds.y().head())
|
||||
|
||||
def generate_rule_dataset(self):
|
||||
"""Generate a RuleBasedDataset from test package data"""
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||
return ds, reactions, applicable_rules
|
||||
|
||||
def generate_enviformer_dataset(self):
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||
ds = EnviFormerDataset.generate_dataset(reactions)
|
||||
return ds, reactions
|
||||
|
||||
@ -42,13 +42,11 @@ class EnviFormerTest(TestCase):
|
||||
threshold = float(0.5)
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = [self.BBD_SUBSET]
|
||||
mod = EnviFormer.create(
|
||||
self.package, data_package_objs, eval_packages_objs, threshold=threshold
|
||||
)
|
||||
mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.evaluate_model(True, eval_packages_objs)
|
||||
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||
|
||||
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||
|
||||
@ -57,12 +55,9 @@ class EnviFormerTest(TestCase):
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = [self.BBD_SUBSET]
|
||||
mods = []
|
||||
for _ in range(4):
|
||||
mod = EnviFormer.create(
|
||||
self.package, data_package_objs, eval_packages_objs, threshold=threshold
|
||||
)
|
||||
mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mods.append(mod)
|
||||
@ -73,15 +68,11 @@ class EnviFormerTest(TestCase):
|
||||
|
||||
# Test pathway prediction
|
||||
times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)]
|
||||
print(
|
||||
f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}"
|
||||
)
|
||||
print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}")
|
||||
|
||||
# Test eviction by performing three prediction with every model, twice.
|
||||
times = defaultdict(list)
|
||||
for _ in range(
|
||||
2
|
||||
): # Eviction should cause the second iteration here to have to reload the models
|
||||
for _ in range(2): # Eviction should cause the second iteration here to have to reload the models
|
||||
for mod in mods:
|
||||
for _ in range(3):
|
||||
times[mod.pk].append(measure_predict(mod))
|
||||
|
||||
@ -4,7 +4,7 @@ import numpy as np
|
||||
from django.test import TestCase
|
||||
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import User, MLRelativeReasoning, Package
|
||||
from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
|
||||
|
||||
|
||||
class ModelTest(TestCase):
|
||||
@ -17,7 +17,7 @@ class ModelTest(TestCase):
|
||||
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
||||
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
||||
|
||||
def test_smoke(self):
|
||||
def test_mlrr(self):
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
@ -35,21 +35,9 @@ class ModelTest(TestCase):
|
||||
description="Created MLRelativeReasoning in Testcase",
|
||||
)
|
||||
|
||||
# mod = RuleBasedRelativeReasoning.create(
|
||||
# self.package,
|
||||
# rule_package_objs,
|
||||
# data_package_objs,
|
||||
# eval_packages_objs,
|
||||
# threshold=threshold,
|
||||
# min_count=5,
|
||||
# max_count=0,
|
||||
# name='ECC - BBD - 0.5',
|
||||
# description='Created MLRelativeReasoning in Testcase',
|
||||
# )
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.evaluate_model(True, eval_packages_objs)
|
||||
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||
|
||||
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||
|
||||
@ -70,3 +58,57 @@ class ModelTest(TestCase):
|
||||
|
||||
# from pprint import pprint
|
||||
# pprint(mod.eval_results)
|
||||
|
||||
def test_applicability(self):
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = [self.BBD_SUBSET]
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
threshold=threshold,
|
||||
name="ECC - BBD - 0.5",
|
||||
description="Created MLRelativeReasoning in Testcase",
|
||||
build_app_domain=True, # To test the applicability domain this must be True
|
||||
app_domain_num_neighbours=5,
|
||||
app_domain_local_compatibility_threshold=0.5,
|
||||
app_domain_reliability_threshold=0.5,
|
||||
)
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||
|
||||
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||
|
||||
def test_rbrr(self):
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
with self.settings(MODEL_DIR=tmpdir):
|
||||
threshold = float(0.5)
|
||||
|
||||
rule_package_objs = [self.BBD_SUBSET]
|
||||
data_package_objs = [self.BBD_SUBSET]
|
||||
eval_packages_objs = [self.BBD_SUBSET]
|
||||
|
||||
mod = RuleBasedRelativeReasoning.create(
|
||||
self.package,
|
||||
rule_package_objs,
|
||||
data_package_objs,
|
||||
threshold=threshold,
|
||||
min_count=5,
|
||||
max_count=0,
|
||||
name='ECC - BBD - 0.5',
|
||||
description='Created MLRelativeReasoning in Testcase',
|
||||
)
|
||||
|
||||
mod.build_dataset()
|
||||
mod.build_model()
|
||||
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||
|
||||
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING
|
||||
from indigo import Indigo, IndigoException, IndigoObject
|
||||
from indigo.renderer import IndigoRenderer
|
||||
from rdkit import Chem, rdBase
|
||||
from rdkit.Chem import MACCSkeys, Descriptors
|
||||
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
|
||||
from rdkit.Chem import rdChemReactions
|
||||
from rdkit.Chem.Draw import rdMolDraw2D
|
||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
@ -107,6 +107,13 @@ class FormatConverter(object):
|
||||
bitvec = MACCSkeys.GenMACCSKeys(mol)
|
||||
return bitvec.ToList()
|
||||
|
||||
@staticmethod
|
||||
def morgan(smiles, radius=3, fpSize=2048):
|
||||
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
fp = finger_gen.GetFingerprint(mol)
|
||||
return fp.ToList()
|
||||
|
||||
@staticmethod
|
||||
def get_functional_groups(smiles: str) -> List[str]:
|
||||
res = list()
|
||||
|
||||
550
utilities/ml.py
550
utilities/ml.py
@ -5,11 +5,14 @@ import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from envipy_plugins import Descriptor
|
||||
from numpy.random import default_rng
|
||||
import polars as pl
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.dummy import DummyClassifier
|
||||
@ -26,70 +29,281 @@ if TYPE_CHECKING:
|
||||
from epdb.models import Rule, CompoundStructure, Reaction
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(
|
||||
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
||||
):
|
||||
self.columns: List[str] = columns
|
||||
self.num_labels: int = num_labels
|
||||
|
||||
if data is None:
|
||||
self.data: List[List[str | int | float]] = list()
|
||||
class Dataset(ABC):
|
||||
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
||||
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
|
||||
self.df = data
|
||||
else:
|
||||
self.data = data
|
||||
# Build either an empty dataframe with columns or fill it with list of list data
|
||||
if data is not None and len(columns) != len(data[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
|
||||
if columns is None:
|
||||
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
||||
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
|
||||
|
||||
self.num_features: int = len(columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||
def add_rows(self, rows: List[List[str | int | float]]):
|
||||
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
|
||||
if len(self.columns) != len(rows[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
|
||||
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
|
||||
self.df.extend(new_rows)
|
||||
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
"""See add_rows"""
|
||||
self.add_rows([row])
|
||||
|
||||
def block_indices(self, prefix) -> List[int]:
|
||||
"""Find the indexes in column labels that has the prefix"""
|
||||
indices: List[int] = []
|
||||
for i, feature in enumerate(self.columns):
|
||||
if feature.startswith(prefix):
|
||||
indices.append(i)
|
||||
return indices
|
||||
|
||||
return min(indices), max(indices)
|
||||
@property
|
||||
def columns(self) -> List[str]:
|
||||
"""Use the polars dataframe columns"""
|
||||
return self.df.columns
|
||||
|
||||
def structure_id(self):
|
||||
return self.data[0][0]
|
||||
@property
|
||||
def shape(self):
|
||||
return self.df.shape
|
||||
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
if len(self.columns) != len(row):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
|
||||
self.data.append(row)
|
||||
@abstractmethod
|
||||
def X(self, **kwargs):
|
||||
pass
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
idx = self.columns.index(f"trig_{rule_uuid}")
|
||||
@abstractmethod
|
||||
def y(self, **kwargs):
|
||||
pass
|
||||
|
||||
times_triggered = 0
|
||||
for row in self.data:
|
||||
if row[idx] == 1:
|
||||
times_triggered += 1
|
||||
|
||||
return times_triggered
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
return self._struct_features
|
||||
|
||||
def triggered(self) -> Tuple[int, int]:
|
||||
return self._triggered
|
||||
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
return self._observed
|
||||
|
||||
def at(self, position: int) -> Dataset:
|
||||
return Dataset(self.columns, self.num_labels, [self.data[position]])
|
||||
|
||||
def limit(self, limit: int) -> Dataset:
|
||||
return Dataset(self.columns, self.num_labels, self.data[:limit])
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def generate_dataset(reactions, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
"""Use polars iter_rows for iterating over the dataset"""
|
||||
return self.df.iter_rows()
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Item is passed to polars allowing for advanced indexing.
|
||||
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
||||
res = self.df[item]
|
||||
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
||||
return self.__class__(data=res)
|
||||
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
||||
return res
|
||||
|
||||
def save(self, path: "Path | str"):
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: "str | Path") -> "Dataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
def to_numpy(self):
|
||||
return self.df.to_numpy()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def iter_rows(self, named=False):
|
||||
return self.df.iter_rows(named=named)
|
||||
|
||||
def filter(self, *predicates, **constraints):
|
||||
return self.__class__(data=self.df.filter(*predicates, **constraints))
|
||||
|
||||
def select(self, *exprs, **named_exprs):
|
||||
return self.__class__(data=self.df.select(*exprs, **named_exprs))
|
||||
|
||||
def with_columns(self, *exprs, **name_exprs):
|
||||
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
|
||||
|
||||
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
|
||||
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
|
||||
multithreaded=multithreaded, maintain_order=maintain_order))
|
||||
|
||||
def item(self, row=None, column=None):
|
||||
return self.df.item(row, column)
|
||||
|
||||
def fill_nan(self, value):
|
||||
return self.__class__(data=self.df.fill_nan(value))
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.df.height
|
||||
|
||||
|
||||
class RuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels=None, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||
# Pre-calculate the ids of columns for features/labels, useful later in X and y
|
||||
self._struct_features: List[int] = self.block_indices("feature_")
|
||||
self._triggered: List[int] = self.block_indices("trig_")
|
||||
self._observed: List[int] = self.block_indices("obs_")
|
||||
self.feature_cols: List[int] = self._struct_features + self._triggered
|
||||
self.num_features: int = len(self.feature_cols)
|
||||
self.has_probs = False
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||
|
||||
def struct_features(self) -> List[int]:
|
||||
return self._struct_features
|
||||
|
||||
def triggered(self) -> List[int]:
|
||||
return self._triggered
|
||||
|
||||
def observed(self) -> List[int]:
|
||||
return self._observed
|
||||
|
||||
def structure_id(self, index: int):
|
||||
"""Get the UUID of a compound"""
|
||||
return self.item(index, "structure_id")
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
"""Get all the feature and trig columns"""
|
||||
_col_ids = self.feature_cols
|
||||
if not exclude_id_col:
|
||||
_col_ids = [0] + _col_ids
|
||||
res = self[:, _col_ids]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
"""Get all the trig columns"""
|
||||
res = self[:, self._triggered]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
"""Get all the obs columns"""
|
||||
res = self[:, self._observed]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
|
||||
if feat_funcs is None:
|
||||
feat_funcs = [FormatConverter.maccs]
|
||||
_structures = set() # Get all the structures
|
||||
for r in reactions:
|
||||
_structures.update(r.educts.all())
|
||||
if not educts_only:
|
||||
_structures.update(r.products.all())
|
||||
|
||||
compounds = sorted(_structures, key=lambda x: x.url)
|
||||
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||
observed: Set[str] = set()
|
||||
|
||||
# Apply rules on collected compounds and store tps
|
||||
for i, comp in enumerate(compounds):
|
||||
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||
|
||||
for rule in applicable_rules:
|
||||
product_sets = rule.apply(comp.smiles)
|
||||
if len(product_sets) == 0:
|
||||
continue
|
||||
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
if key in triggered:
|
||||
logger.info(f"{key} already present. Duplicate reaction?")
|
||||
|
||||
for prod_set in product_sets:
|
||||
for smi in prod_set:
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
triggered[key].add(smi)
|
||||
|
||||
for i, r in enumerate(reactions):
|
||||
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||
|
||||
if len(r.educts.all()) != 1:
|
||||
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||
continue
|
||||
|
||||
for comp in r.educts.all():
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
if key not in triggered:
|
||||
continue
|
||||
|
||||
# standardize products from reactions for comparison
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception as e:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
standardized_products.append(smi)
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
observed.add(key)
|
||||
feat_columns = []
|
||||
for feat_func in feat_funcs:
|
||||
if isinstance(feat_func, Descriptor):
|
||||
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
|
||||
else:
|
||||
feats = feat_func(compounds[0].smiles)
|
||||
start_i = len(feat_columns)
|
||||
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
|
||||
ds_columns = (["structure_id"] +
|
||||
feat_columns +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
rows = []
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feats = []
|
||||
for feat_func in feat_funcs:
|
||||
if isinstance(feat_func, Descriptor):
|
||||
feat = feat_func.get_molecule_descriptors(comp.smiles)
|
||||
else:
|
||||
feat = feat_func(comp.smiles)
|
||||
feats.extend(feat)
|
||||
trig = []
|
||||
obs = []
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
elif key not in triggered:
|
||||
obs.append(None)
|
||||
else:
|
||||
obs.append(0)
|
||||
rows.append([str(comp.uuid)] + feats + trig + obs)
|
||||
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
|
||||
return ds
|
||||
|
||||
def classification_dataset(
|
||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||
) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
for struct in structures:
|
||||
@ -113,186 +327,18 @@ class Dataset:
|
||||
else:
|
||||
trig.append(0)
|
||||
prods.append([])
|
||||
|
||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||
new_row = [struct_id] + features + trig + ([-1] * len(trig))
|
||||
if self.has_probs:
|
||||
new_row += [-1] * len(trig)
|
||||
classify_data.append(new_row)
|
||||
classify_products.append(prods)
|
||||
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
|
||||
return ds, classify_products
|
||||
|
||||
return Dataset(
|
||||
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
||||
), classify_products
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(
|
||||
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
||||
) -> Dataset:
|
||||
_structures = set()
|
||||
|
||||
for r in reactions:
|
||||
for e in r.educts.all():
|
||||
_structures.add(e)
|
||||
|
||||
if not educts_only:
|
||||
for e in r.products:
|
||||
_structures.add(e)
|
||||
|
||||
compounds = sorted(_structures, key=lambda x: x.url)
|
||||
|
||||
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||
observed: Set[str] = set()
|
||||
|
||||
# Apply rules on collected compounds and store tps
|
||||
for i, comp in enumerate(compounds):
|
||||
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||
|
||||
for rule in applicable_rules:
|
||||
product_sets = rule.apply(comp.smiles)
|
||||
|
||||
if len(product_sets) == 0:
|
||||
continue
|
||||
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
if key in triggered:
|
||||
logger.info(f"{key} already present. Duplicate reaction?")
|
||||
|
||||
for prod_set in product_sets:
|
||||
for smi in prod_set:
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception:
|
||||
# :shrug:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
pass
|
||||
|
||||
triggered[key].add(smi)
|
||||
|
||||
for i, r in enumerate(reactions):
|
||||
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||
|
||||
if len(r.educts.all()) != 1:
|
||||
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||
continue
|
||||
|
||||
for comp in r.educts.all():
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
if key not in triggered:
|
||||
continue
|
||||
|
||||
# standardize products from reactions for comparison
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
pass
|
||||
|
||||
standardized_products.append(smi)
|
||||
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
observed.add(key)
|
||||
else:
|
||||
pass
|
||||
|
||||
ds = None
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feat = FormatConverter.maccs(comp.smiles)
|
||||
trig = []
|
||||
obs = []
|
||||
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
elif key not in triggered:
|
||||
obs.append(None)
|
||||
else:
|
||||
obs.append(0)
|
||||
|
||||
if ds is None:
|
||||
header = (
|
||||
["structure_id"]
|
||||
+ [f"feature_{i}" for i, _ in enumerate(feat)]
|
||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
||||
)
|
||||
ds = Dataset(header, len(applicable_rules))
|
||||
|
||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
||||
|
||||
return ds
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
res = self.__getitem__(
|
||||
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
|
||||
)
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
|
||||
|
||||
row_key, col_key = key
|
||||
|
||||
# Normalize rows
|
||||
if isinstance(row_key, int):
|
||||
rows = [self.data[row_key]]
|
||||
else:
|
||||
rows = self.data[row_key]
|
||||
|
||||
# Normalize columns
|
||||
if isinstance(col_key, int):
|
||||
res = [row[col_key] for row in rows]
|
||||
else:
|
||||
res = [
|
||||
[row[i] for i in range(*col_key.indices(len(row)))]
|
||||
if isinstance(col_key, slice)
|
||||
else [row[i] for i in col_key]
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return res
|
||||
|
||||
def save(self, path: "Path"):
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: "Path") -> "Dataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
def add_probs(self, probs):
|
||||
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
|
||||
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
|
||||
self.has_probs = True
|
||||
|
||||
def to_arff(self, path: "Path"):
|
||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||
@ -304,7 +350,7 @@ class Dataset:
|
||||
arff += f"@attribute {c} {{0,1}}\n"
|
||||
|
||||
arff += "\n@data\n"
|
||||
for d in self.data:
|
||||
for d in self:
|
||||
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
|
||||
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
|
||||
arff += f"{ys},{xs}\n"
|
||||
@ -313,10 +359,40 @@ class Dataset:
|
||||
fh.write(arff)
|
||||
fh.flush()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
|
||||
)
|
||||
|
||||
class EnviFormerDataset(Dataset):
|
||||
def __init__(self, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
|
||||
def X(self):
|
||||
"""Return the educts"""
|
||||
return self["educts"]
|
||||
|
||||
def y(self):
|
||||
"""Return the products"""
|
||||
return self["products"]
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, *args, **kwargs):
|
||||
# Standardise reactions for the training data
|
||||
stereo = kwargs.get("stereo", False)
|
||||
rows = []
|
||||
for reaction in reactions:
|
||||
e = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
p = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
rows.append([e, p])
|
||||
ds = EnviFormerDataset(["educts", "products"], rows)
|
||||
return ds
|
||||
|
||||
|
||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
@ -498,7 +574,7 @@ class EnsembleClassifierChain:
|
||||
self.classifiers = []
|
||||
|
||||
if self.num_labels is None:
|
||||
self.num_labels = len(Y[0])
|
||||
self.num_labels = Y.shape[1]
|
||||
|
||||
for p in range(self.num_chains):
|
||||
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
@ -529,7 +605,7 @@ class RelativeReasoning:
|
||||
|
||||
def fit(self, X, Y):
|
||||
n_instances = len(Y)
|
||||
n_attributes = len(Y[0])
|
||||
n_attributes = Y.shape[1]
|
||||
|
||||
for i in range(n_attributes):
|
||||
for j in range(n_attributes):
|
||||
@ -541,8 +617,8 @@ class RelativeReasoning:
|
||||
countboth = 0
|
||||
|
||||
for k in range(n_instances):
|
||||
vi = Y[k][i]
|
||||
vj = Y[k][j]
|
||||
vi = Y[k, i]
|
||||
vj = Y[k, j]
|
||||
|
||||
if vi is None or vj is None:
|
||||
continue
|
||||
@ -598,7 +674,7 @@ class ApplicabilityDomainPCA(PCA):
|
||||
self.min_vals = None
|
||||
self.max_vals = None
|
||||
|
||||
def build(self, train_dataset: "Dataset"):
|
||||
def build(self, train_dataset: "RuleBasedDataset"):
|
||||
# transform
|
||||
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
||||
# fit pca
|
||||
@ -612,7 +688,7 @@ class ApplicabilityDomainPCA(PCA):
|
||||
instances_pca = self.transform(instances_scaled)
|
||||
return instances_pca
|
||||
|
||||
def is_applicable(self, classify_instances: "Dataset"):
|
||||
def is_applicable(self, classify_instances: "RuleBasedDataset"):
|
||||
instances_pca = self.__transform(classify_instances.X())
|
||||
|
||||
is_applicable = []
|
||||
|
||||
184
uv.lock
generated
184
uv.lock
generated
@ -1,6 +1,10 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux' or sys_platform == 'win32'",
|
||||
"sys_platform != 'linux' and sys_platform != 'win32'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@ -176,6 +180,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "celery-stubs"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mypy" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/14/b853ada8706a3a301396566b6dd405d1cbb24bff756236a12a01dbe766a4/celery-stubs-0.1.3.tar.gz", hash = "sha256:0fb5345820f8a2bd14e6ffcbef2d10181e12e40f8369f551d7acc99d8d514919", size = 46583, upload-time = "2023-02-10T02:20:11.837Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/7a/4ab2347d13f1f59d10a7337feb9beb002664119f286036785284c6bec150/celery_stubs-0.1.3-py3-none-any.whl", hash = "sha256:dfb9ad27614a8af028b2055bb4a4ae99ca5e9a8d871428a506646d62153218d7", size = 89085, upload-time = "2023-02-10T02:20:09.409Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.10.5"
|
||||
@ -525,13 +542,14 @@ wheels = [
|
||||
[[package]]
|
||||
name = "enviformer"
|
||||
version = "0.1.0"
|
||||
source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2#3f28f60cfa1df814cf7559303b5130933efa40ae" }
|
||||
source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4#7094be5767748fd63d4a84a5d71f06cf02ba07f3" }
|
||||
dependencies = [
|
||||
{ name = "joblib" },
|
||||
{ name = "lightning" },
|
||||
{ name = "pytorch-lightning" },
|
||||
{ name = "scikit-learn" },
|
||||
{ name = "torch" },
|
||||
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -546,7 +564,6 @@ dependencies = [
|
||||
{ name = "django-ninja" },
|
||||
{ name = "django-oauth-toolkit" },
|
||||
{ name = "django-polymorphic" },
|
||||
{ name = "django-stubs" },
|
||||
{ name = "enviformer" },
|
||||
{ name = "envipy-additional-information" },
|
||||
{ name = "envipy-ambit" },
|
||||
@ -554,6 +571,7 @@ dependencies = [
|
||||
{ name = "epam-indigo" },
|
||||
{ name = "gunicorn" },
|
||||
{ name = "networkx" },
|
||||
{ name = "polars" },
|
||||
{ name = "psycopg2-binary" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "rdkit" },
|
||||
@ -566,6 +584,8 @@ dependencies = [
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "celery-stubs" },
|
||||
{ name = "django-stubs" },
|
||||
{ name = "poethepoet" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "ruff" },
|
||||
@ -577,15 +597,16 @@ ms-login = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "celery", specifier = ">=5.5.2" },
|
||||
{ name = "celery-stubs", marker = "extra == 'dev'", specifier = "==0.1.3" },
|
||||
{ name = "django", specifier = ">=5.2.1" },
|
||||
{ name = "django-extensions", specifier = ">=4.1" },
|
||||
{ name = "django-model-utils", specifier = ">=5.0.0" },
|
||||
{ name = "django-ninja", specifier = ">=1.4.1" },
|
||||
{ name = "django-oauth-toolkit", specifier = ">=3.0.1" },
|
||||
{ name = "django-polymorphic", specifier = ">=4.1.0" },
|
||||
{ name = "django-stubs", specifier = ">=5.2.4" },
|
||||
{ name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2" },
|
||||
{ name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4" },
|
||||
{ name = "django-stubs", marker = "extra == 'dev'", specifier = ">=5.2.4" },
|
||||
{ name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4" },
|
||||
{ name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7" },
|
||||
{ name = "envipy-ambit", git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" },
|
||||
{ name = "envipy-plugins", git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git?rev=v0.1.0" },
|
||||
{ name = "epam-indigo", specifier = ">=1.30.1" },
|
||||
@ -593,6 +614,7 @@ requires-dist = [
|
||||
{ name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" },
|
||||
{ name = "networkx", specifier = ">=3.4.2" },
|
||||
{ name = "poethepoet", marker = "extra == 'dev'", specifier = ">=0.37.0" },
|
||||
{ name = "polars", specifier = "==1.35.1" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.3.0" },
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.10" },
|
||||
{ name = "python-dotenv", specifier = ">=1.1.0" },
|
||||
@ -608,8 +630,8 @@ provides-extras = ["ms-login", "dev"]
|
||||
|
||||
[[package]]
|
||||
name = "envipy-additional-information"
|
||||
version = "0.1.0"
|
||||
source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4#4da604090bf7cf1f3f552d69485472dbc623030a" }
|
||||
version = "0.1.7"
|
||||
source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7#d02a5d5e6a931e6565ea86127813acf7e4b33a30" }
|
||||
dependencies = [
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
@ -865,7 +887,8 @@ dependencies = [
|
||||
{ name = "packaging" },
|
||||
{ name = "pytorch-lightning" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "torch" },
|
||||
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "torchmetrics" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
@ -1074,6 +1097,47 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "1.18.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mypy-extensions" },
|
||||
{ name = "pathspec" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/77/8f0d0001ffad290cef2f7f216f96c814866248a0b92a722365ed54648e7e/mypy-1.18.2.tar.gz", hash = "sha256:06a398102a5f203d7477b2923dda3634c36727fa5c237d8f859ef90c42a9924b", size = 3448846, upload-time = "2025-09-19T00:11:10.519Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/06/dfdd2bc60c66611dd8335f463818514733bc763e4760dee289dcc33df709/mypy-1.18.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:33eca32dd124b29400c31d7cf784e795b050ace0e1f91b8dc035672725617e34", size = 12908273, upload-time = "2025-09-19T00:10:58.321Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/14/6a9de6d13a122d5608e1a04130724caf9170333ac5a924e10f670687d3eb/mypy-1.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3c47adf30d65e89b2dcd2fa32f3aeb5e94ca970d2c15fcb25e297871c8e4764", size = 11920910, upload-time = "2025-09-19T00:10:20.043Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/a9/b29de53e42f18e8cc547e38daa9dfa132ffdc64f7250e353f5c8cdd44bee/mypy-1.18.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d6c838e831a062f5f29d11c9057c6009f60cb294fea33a98422688181fe2893", size = 12465585, upload-time = "2025-09-19T00:10:33.005Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/ae/6c3d2c7c61ff21f2bee938c917616c92ebf852f015fb55917fd6e2811db2/mypy-1.18.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01199871b6110a2ce984bde85acd481232d17413868c9807e95c1b0739a58914", size = 13348562, upload-time = "2025-09-19T00:10:11.51Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/31/aec68ab3b4aebdf8f36d191b0685d99faa899ab990753ca0fee60fb99511/mypy-1.18.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a2afc0fa0b0e91b4599ddfe0f91e2c26c2b5a5ab263737e998d6817874c5f7c8", size = 13533296, upload-time = "2025-09-19T00:10:06.568Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/83/abcb3ad9478fca3ebeb6a5358bb0b22c95ea42b43b7789c7fb1297ca44f4/mypy-1.18.2-cp312-cp312-win_amd64.whl", hash = "sha256:d8068d0afe682c7c4897c0f7ce84ea77f6de953262b12d07038f4d296d547074", size = 9828828, upload-time = "2025-09-19T00:10:28.203Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/04/7f462e6fbba87a72bc8097b93f6842499c428a6ff0c81dd46948d175afe8/mypy-1.18.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:07b8b0f580ca6d289e69209ec9d3911b4a26e5abfde32228a288eb79df129fcc", size = 12898728, upload-time = "2025-09-19T00:10:01.33Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/5b/61ed4efb64f1871b41fd0b82d29a64640f3516078f6c7905b68ab1ad8b13/mypy-1.18.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed4482847168439651d3feee5833ccedbf6657e964572706a2adb1f7fa4dfe2e", size = 11910758, upload-time = "2025-09-19T00:10:42.607Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/46/d297d4b683cc89a6e4108c4250a6a6b717f5fa96e1a30a7944a6da44da35/mypy-1.18.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ad2afadd1e9fea5cf99a45a822346971ede8685cc581ed9cd4d42eaf940986", size = 12475342, upload-time = "2025-09-19T00:11:00.371Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/45/4798f4d00df13eae3bfdf726c9244bcb495ab5bd588c0eed93a2f2dd67f3/mypy-1.18.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a431a6f1ef14cf8c144c6b14793a23ec4eae3db28277c358136e79d7d062f62d", size = 13338709, upload-time = "2025-09-19T00:11:03.358Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/09/479f7358d9625172521a87a9271ddd2441e1dab16a09708f056e97007207/mypy-1.18.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ab28cc197f1dd77a67e1c6f35cd1f8e8b73ed2217e4fc005f9e6a504e46e7ba", size = 13529806, upload-time = "2025-09-19T00:10:26.073Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/cf/ac0f2c7e9d0ea3c75cd99dff7aec1c9df4a1376537cb90e4c882267ee7e9/mypy-1.18.2-cp313-cp313-win_amd64.whl", hash = "sha256:0e2785a84b34a72ba55fb5daf079a1003a34c05b22238da94fcae2bbe46f3544", size = 9833262, upload-time = "2025-09-19T00:10:40.035Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/0c/7d5300883da16f0063ae53996358758b2a2df2a09c72a5061fa79a1f5006/mypy-1.18.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:62f0e1e988ad41c2a110edde6c398383a889d95b36b3e60bcf155f5164c4fdce", size = 12893775, upload-time = "2025-09-19T00:10:03.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/df/2cffbf25737bdb236f60c973edf62e3e7b4ee1c25b6878629e88e2cde967/mypy-1.18.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8795a039bab805ff0c1dfdb8cd3344642c2b99b8e439d057aba30850b8d3423d", size = 11936852, upload-time = "2025-09-19T00:10:51.631Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/50/34059de13dd269227fb4a03be1faee6e2a4b04a2051c82ac0a0b5a773c9a/mypy-1.18.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ca1e64b24a700ab5ce10133f7ccd956a04715463d30498e64ea8715236f9c9c", size = 12480242, upload-time = "2025-09-19T00:11:07.955Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/11/040983fad5132d85914c874a2836252bbc57832065548885b5bb5b0d4359/mypy-1.18.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d924eef3795cc89fecf6bedc6ed32b33ac13e8321344f6ddbf8ee89f706c05cb", size = 13326683, upload-time = "2025-09-19T00:09:55.572Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/ba/89b2901dd77414dd7a8c8729985832a5735053be15b744c18e4586e506ef/mypy-1.18.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20c02215a080e3a2be3aa50506c67242df1c151eaba0dcbc1e4e557922a26075", size = 13514749, upload-time = "2025-09-19T00:10:44.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/bc/cc98767cffd6b2928ba680f3e5bc969c4152bf7c2d83f92f5a504b92b0eb/mypy-1.18.2-cp314-cp314-win_amd64.whl", hash = "sha256:749b5f83198f1ca64345603118a6f01a4e99ad4bf9d103ddc5a3200cc4614adf", size = 9982959, upload-time = "2025-09-19T00:10:37.344Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-extensions"
|
||||
version = "1.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "networkx"
|
||||
version = "3.5"
|
||||
@ -1192,7 +1256,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.10.2.21"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||
@ -1203,7 +1267,7 @@ name = "nvidia-cufft-cu12"
|
||||
version = "11.3.3.83"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||
@ -1230,9 +1294,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.7.3.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
{ name = "nvidia-cusparse-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||
@ -1243,7 +1307,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.5.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||
@ -1308,6 +1372,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/18/a8444036c6dd65ba3624c63b734d3ba95ba63ace513078e1580590075d21/pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364", size = 5955, upload-time = "2020-09-16T19:21:11.409Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathspec"
|
||||
version = "0.12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "11.3.0"
|
||||
@ -1396,6 +1469,32 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/92/1b/5337af1a6a478d25a3e3c56b9b4b42b0a160314e02f4a0498d5322c8dac4/poethepoet-0.37.0-py3-none-any.whl", hash = "sha256:861790276315abcc8df1b4bd60e28c3d48a06db273edd3092f3c94e1a46e5e22", size = 90062, upload-time = "2025-08-11T18:00:27.595Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polars"
|
||||
version = "1.35.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "polars-runtime-32" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9b/5b/3caad788d93304026cbf0ab4c37f8402058b64a2f153b9c62f8b30f5d2ee/polars-1.35.1.tar.gz", hash = "sha256:06548e6d554580151d6ca7452d74bceeec4640b5b9261836889b8e68cfd7a62e", size = 694881, upload-time = "2025-10-30T12:12:52.294Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/4c/21a227b722534404241c2a76beceb7463469d50c775a227fc5c209eb8adc/polars-1.35.1-py3-none-any.whl", hash = "sha256:c29a933f28aa330d96a633adbd79aa5e6a6247a802a720eead9933f4613bdbf4", size = 783598, upload-time = "2025-10-30T12:11:54.668Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polars-runtime-32"
|
||||
version = "1.35.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/df/3e/19c252e8eb4096300c1a36ec3e50a27e5fa9a1ccaf32d3927793c16abaee/polars_runtime_32-1.35.1.tar.gz", hash = "sha256:f6b4ec9cd58b31c87af1b8c110c9c986d82345f1d50d7f7595b5d447a19dc365", size = 2696218, upload-time = "2025-10-30T12:12:53.479Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/2c/da339459805a26105e9d9c2f07e43ca5b8baeee55acd5457e6881487a79a/polars_runtime_32-1.35.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6f051a42f6ae2f26e3bc2cf1f170f2120602976e2a3ffb6cfba742eecc7cc620", size = 40525100, upload-time = "2025-10-30T12:11:58.098Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/70/a0733568b3533481924d2ce68b279ab3d7334e5fa6ed259f671f650b7c5e/polars_runtime_32-1.35.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:c2232f9cf05ba59efc72d940b86c033d41fd2d70bf2742e8115ed7112a766aa9", size = 36701908, upload-time = "2025-10-30T12:12:02.166Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/54/6c09137bef9da72fd891ba58c2962cc7c6c5cad4649c0e668d6b344a9d7b/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42f9837348557fd674477ea40a6ac8a7e839674f6dd0a199df24be91b026024c", size = 41317692, upload-time = "2025-10-30T12:12:04.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/55/81c5b266a947c339edd7fbaa9e1d9614012d02418453f48b76cc177d3dd9/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:c873aeb36fed182d5ebc35ca17c7eb193fe83ae2ea551ee8523ec34776731390", size = 37853058, upload-time = "2025-10-30T12:12:08.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/58/be8b034d559eac515f52408fd6537be9bea095bc0388946a4e38910d3d50/polars_runtime_32-1.35.1-cp39-abi3-win_amd64.whl", hash = "sha256:35cde9453ca7032933f0e58e9ed4388f5a1e415dd0db2dd1e442c81d815e630c", size = 41289554, upload-time = "2025-10-30T12:12:11.104Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/7f/e0111b9e2a1169ea82cde3ded9c92683e93c26dfccd72aee727996a1ac5b/polars_runtime_32-1.35.1-cp39-abi3-win_arm64.whl", hash = "sha256:fd77757a6c9eb9865c4bfb7b07e22225207c6b7da382bd0b9bd47732f637105d", size = 36958878, upload-time = "2025-10-30T12:12:15.206Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "4.3.0"
|
||||
@ -1670,7 +1769,8 @@ dependencies = [
|
||||
{ name = "lightning-utilities" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "torch" },
|
||||
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "torchmetrics" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
@ -1754,11 +1854,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.4.0"
|
||||
version = "7.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1963,15 +2063,40 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"sys_platform != 'linux' and sys_platform != 'win32'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "fsspec", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "jinja2", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "networkx", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "setuptools", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "sympy", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.8.0+cu128"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu128" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux' or sys_platform == 'win32'",
|
||||
]
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx" },
|
||||
{ name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
@ -1986,10 +2111,10 @@ dependencies = [
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "sympy" },
|
||||
{ name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" },
|
||||
@ -2008,7 +2133,8 @@ dependencies = [
|
||||
{ name = "lightning-utilities" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "torch" },
|
||||
{ name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/85/2e/48a887a59ecc4a10ce9e8b35b3e3c5cef29d902c4eac143378526e7485cb/torchmetrics-1.8.2.tar.gz", hash = "sha256:cf64a901036bf107f17a524009eea7781c9c5315d130713aeca5747a686fe7a5", size = 580679, upload-time = "2025-09-03T14:00:54.077Z" }
|
||||
wheels = [
|
||||
@ -2032,7 +2158,7 @@ name = "triton"
|
||||
version = "3.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "setuptools" },
|
||||
{ name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" },
|
||||
|
||||
Reference in New Issue
Block a user