forked from enviPath/enviPy
starting on app domain with new dataset #120
This commit is contained in:
@ -51,14 +51,13 @@ class Dataset(ABC):
|
||||
"""See add_rows"""
|
||||
self.add_rows([row])
|
||||
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
def block_indices(self, prefix) -> List[int]:
|
||||
"""Find the start and end indexes in column labels that has the prefix"""
|
||||
indices: List[int] = []
|
||||
for i, feature in enumerate(self.columns):
|
||||
if feature.startswith(prefix):
|
||||
indices.append(i)
|
||||
|
||||
return min(indices, default=None), max(indices, default=None)
|
||||
return indices
|
||||
|
||||
@property
|
||||
def columns(self) -> List[str]:
|
||||
@ -99,11 +98,11 @@ class Dataset(ABC):
|
||||
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
||||
res = self.df[item]
|
||||
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
||||
return self.__class__(data=self.df[item])
|
||||
return self.__class__(data=res)
|
||||
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
||||
return res
|
||||
|
||||
def save(self, path: "Path"):
|
||||
def save(self, path: "Path | str"):
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as fh:
|
||||
@ -115,6 +114,9 @@ class Dataset(ABC):
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
def to_numpy(self):
|
||||
return self.df.to_numpy()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||
@ -123,28 +125,34 @@ class Dataset(ABC):
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def iter_rows(self, named=False):
|
||||
return self.df.iter_rows(named=named)
|
||||
|
||||
|
||||
class RuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels=None, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||
self.num_features: int = len(self.columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||
# Pre-calculate the ids of columns for features/labels, useful later in X and y
|
||||
self._struct_features: List[int] = self.block_indices("feature_")
|
||||
self._triggered: List[int] = self.block_indices("trig_")
|
||||
self._observed: List[int] = self.block_indices("obs_")
|
||||
self.feature_cols: List[int] = self._struct_features + self._triggered
|
||||
self.num_features: int = len(self.feature_cols)
|
||||
self.has_probs = False
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
def struct_features(self) -> List[int]:
|
||||
return self._struct_features
|
||||
|
||||
def triggered(self) -> Tuple[int, int]:
|
||||
def triggered(self) -> List[int]:
|
||||
return self._triggered
|
||||
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
def observed(self) -> List[int]:
|
||||
return self._observed
|
||||
|
||||
def structure_id(self, index: int):
|
||||
@ -153,21 +161,24 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
"""Get all the feature and trig columns"""
|
||||
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
|
||||
_col_ids = self.feature_cols
|
||||
if not exclude_id_col:
|
||||
_col_ids = [0] + _col_ids
|
||||
res = self[:, _col_ids]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
"""Get all the trig columns"""
|
||||
res = self[:, self._triggered[0]: self._triggered[1]]
|
||||
res = self[:, self._triggered]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
"""Get all the obs columns"""
|
||||
res = self[:, len(self.columns) - self.num_labels:]
|
||||
res = self[:, self._observed]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
@ -264,14 +275,6 @@ class RuleBasedDataset(Dataset):
|
||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
if isinstance(structures[0], str):
|
||||
struct_smiles = structures[0]
|
||||
else:
|
||||
struct_smiles = structures[0].smiles
|
||||
ds_columns = (["structure_id"] +
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
for struct in structures:
|
||||
if isinstance(struct, str):
|
||||
struct_id = None
|
||||
@ -293,12 +296,19 @@ class RuleBasedDataset(Dataset):
|
||||
else:
|
||||
trig.append(0)
|
||||
prods.append([])
|
||||
|
||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||
new_row = [struct_id] + features + trig + ([-1] * len(trig))
|
||||
if self.has_probs:
|
||||
new_row += [-1] * len(trig)
|
||||
classify_data.append(new_row)
|
||||
classify_products.append(prods)
|
||||
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data)
|
||||
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
|
||||
return ds, classify_products
|
||||
|
||||
def add_probs(self, probs):
|
||||
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
|
||||
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
|
||||
self.has_probs = True
|
||||
|
||||
def to_arff(self, path: "Path"):
|
||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||
arff += "\n"
|
||||
|
||||
Reference in New Issue
Block a user