forked from enviPath/enviPy
start towards #120
This commit is contained in:
@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@ -26,7 +27,21 @@ if TYPE_CHECKING:
|
||||
from epdb.models import Rule, CompoundStructure, Reaction
|
||||
|
||||
|
||||
class Dataset:
|
||||
class Dataset(ABC):
|
||||
@abstractmethod
|
||||
def X(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def y(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
|
||||
class RuleBasedDataset(Dataset):
|
||||
def __init__(
|
||||
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
||||
):
|
||||
@ -78,18 +93,18 @@ class Dataset:
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
return self._observed
|
||||
|
||||
def at(self, position: int) -> Dataset:
|
||||
return Dataset(self.columns, self.num_labels, [self.data[position]])
|
||||
def at(self, position: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, [self.data[position]])
|
||||
|
||||
def limit(self, limit: int) -> Dataset:
|
||||
return Dataset(self.columns, self.num_labels, self.data[:limit])
|
||||
def limit(self, limit: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
|
||||
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
|
||||
def classification_dataset(
|
||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||
) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
for struct in structures:
|
||||
@ -117,14 +132,14 @@ class Dataset:
|
||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||
classify_products.append(prods)
|
||||
|
||||
return Dataset(
|
||||
return RuleBasedDataset(
|
||||
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
||||
), classify_products
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(
|
||||
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
||||
) -> Dataset:
|
||||
) -> RuleBasedDataset:
|
||||
_structures = set()
|
||||
|
||||
for r in reactions:
|
||||
@ -231,7 +246,7 @@ class Dataset:
|
||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
||||
)
|
||||
ds = Dataset(header, len(applicable_rules))
|
||||
ds = RuleBasedDataset(header, len(applicable_rules))
|
||||
|
||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
||||
|
||||
@ -289,7 +304,7 @@ class Dataset:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: "Path") -> "Dataset":
|
||||
def load(path: "Path") -> "RuleBasedDataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
@ -319,6 +334,20 @@ class Dataset:
|
||||
)
|
||||
|
||||
|
||||
class EnviFormerDataset(Dataset):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def X(self):
|
||||
pass
|
||||
|
||||
def y(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
|
||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
Ensemble of Classifier Chains with sparse label removal.
|
||||
@ -598,7 +627,7 @@ class ApplicabilityDomainPCA(PCA):
|
||||
self.min_vals = None
|
||||
self.max_vals = None
|
||||
|
||||
def build(self, train_dataset: "Dataset"):
|
||||
def build(self, train_dataset: "RuleBasedDataset"):
|
||||
# transform
|
||||
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
||||
# fit pca
|
||||
@ -612,7 +641,7 @@ class ApplicabilityDomainPCA(PCA):
|
||||
instances_pca = self.transform(instances_scaled)
|
||||
return instances_pca
|
||||
|
||||
def is_applicable(self, classify_instances: "Dataset"):
|
||||
def is_applicable(self, classify_instances: "RuleBasedDataset"):
|
||||
instances_pca = self.__transform(classify_instances.X())
|
||||
|
||||
is_applicable = []
|
||||
|
||||
Reference in New Issue
Block a user