starting on app domain with new dataset #120

This commit is contained in:
Liam Brydon
2025-11-04 16:33:56 +13:00
parent ac5d370b18
commit 13af49488e
3 changed files with 98 additions and 58 deletions

View File

@ -51,14 +51,13 @@ class Dataset(ABC):
"""See add_rows"""
self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]:
def block_indices(self, prefix) -> List[int]:
"""Find the start and end indexes in column labels that has the prefix"""
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices, default=None), max(indices, default=None)
return indices
@property
def columns(self) -> List[str]:
@ -99,11 +98,11 @@ class Dataset(ABC):
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=self.df[item])
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path"):
def save(self, path: "Path | str"):
import pickle
with open(path, "wb") as fh:
@ -115,6 +114,9 @@ class Dataset(ABC):
return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
@ -123,28 +125,34 @@ class Dataset(ABC):
def __len__(self):
return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_features: int = len(self.columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
def struct_features(self) -> List[int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
def triggered(self) -> List[int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
def observed(self) -> List[int]:
return self._observed
def structure_id(self, index: int):
@ -153,21 +161,24 @@ class RuleBasedDataset(Dataset):
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
_col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered[0]: self._triggered[1]]
res = self[:, self._triggered]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, len(self.columns) - self.num_labels:]
res = self[:, self._observed]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@ -264,14 +275,6 @@ class RuleBasedDataset(Dataset):
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
if isinstance(structures[0], str):
struct_smiles = structures[0]
else:
struct_smiles = structures[0].smiles
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures:
if isinstance(struct, str):
struct_id = None
@ -293,12 +296,19 @@ class RuleBasedDataset(Dataset):
else:
trig.append(0)
prods.append([])
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
new_row = [struct_id] + features + trig + ([-1] * len(trig))
if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data)
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"