work towards #120

This commit is contained in:
Liam Brydon
2025-11-03 15:24:28 +13:00
parent 8166df6f39
commit ff51e48f90
5 changed files with 263 additions and 274 deletions

View File

@ -30,42 +30,47 @@ if TYPE_CHECKING:
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame):
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
self.df = data
else:
# Build either an empty dataframe with columns or fill it with list of list data
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns)
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
def add_rows(self, rows: List[List[str | int | float]]):
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
new_rows = pl.DataFrame(data=rows, schema=self.columns)
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
self.df.extend(new_rows)
def add_row(self, row: List[str | int | float]):
"""See add_rows"""
self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]:
"""Find the start and end indexes in column labels that has the prefix"""
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
return min(indices, default=None), max(indices, default=None)
@property
def columns(self) -> List[str]:
"""Use the polars dataframe columns"""
return self.df.columns
@abstractmethod
def X(self):
def X(self, **kwargs):
pass
@abstractmethod
def y(self):
def y(self, **kwargs):
pass
@staticmethod
@ -73,11 +78,26 @@ class Dataset(ABC):
def generate_dataset(reactions, *args, **kwargs):
pass
def at(self, position: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
def at(self, position: int) -> Dataset:
"""See __getitem__"""
return self[position]
def limit(self, limit: int) -> Dataset:
"""See __getitem__"""
return self[:limit]
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
"""Use polars iter_rows for iterating over the dataset"""
return self.df.iter_rows()
def __getitem__(self, item):
"""Item is passed to polars allowing for advanced indexing.
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=self.df[item])
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path"):
import pickle
@ -86,35 +106,61 @@ class Dataset(ABC):
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "RuleBasedDataset":
def load(path: "str | Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
)
class NewRuleBasedDataset(Dataset):
def __init__(self, num_labels, columns=None, data=None):
def __len__(self):
return len(self.df)
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
self.num_labels: int = num_labels
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c])
self.num_features: int = len(self.columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def times_triggered(self, rule_uuid) -> int:
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
return self._block_indices("feature_")
return self._struct_features
def triggered(self) -> Tuple[int, int]:
return self._block_indices("trig_")
return self._triggered
def observed(self) -> Tuple[int, int]:
return self._block_indices("obs_")
return self._observed
def X(self):
pass
def structure_id(self, index: int):
return self.df.item(index, "structure_id")
def y(self):
pass
def X(self, exclude_id_col=True, na_replacement=0):
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
res = self[:, self._triggered[0]: self._triggered[1]]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
res = self[:, len(self.columns) - self.num_labels:]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
@ -178,7 +224,6 @@ class NewRuleBasedDataset(Dataset):
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
rows = []
for i, comp in enumerate(compounds):
@ -201,77 +246,18 @@ class NewRuleBasedDataset(Dataset):
else:
obs.append(0)
rows.append([str(comp.uuid)] + feat + trig + obs)
ds.add_rows(rows)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
return ds
def __getitem__(self, item):
pass
class RuleBasedDataset(Dataset):
def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns
self.num_labels: int = num_labels
if data is None:
self.data: List[List[str | int | float]] = list()
else:
self.data = data
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
def structure_id(self):
return self.data[0][0]
def add_row(self, row: List[str | int | float]):
if len(self.columns) != len(row):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f"trig_{rule_uuid}")
times_triggered = 0
for row in self.data:
if row[idx] == 1:
times_triggered += 1
return times_triggered
def struct_features(self) -> Tuple[int, int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
return self._observed
def at(self, position: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, [self.data[position]])
def limit(self, limit: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures:
if isinstance(struct, str):
struct_id = None
@ -296,171 +282,8 @@ class RuleBasedDataset(Dataset):
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return RuleBasedDataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> RuleBasedDataset:
_structures = set()
for r in reactions:
for e in r.educts.all():
_structures.add(e)
if not educts_only:
for e in r.products:
_structures.add(e)
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
else:
pass
ds = None
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
if ds is None:
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = RuleBasedDataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
row_key, col_key = key
# Normalize rows
if isinstance(row_key, int):
rows = [self.data[row_key]]
else:
rows = self.data[row_key]
# Normalize columns
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data)
return ds, classify_products
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
@ -472,7 +295,7 @@ class RuleBasedDataset(Dataset):
arff += f"@attribute {c} {{0,1}}\n"
arff += "\n@data\n"
for d in self.data:
for d in self:
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
@ -481,14 +304,10 @@ class RuleBasedDataset(Dataset):
fh.write(arff)
fh.flush()
def __repr__(self):
return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class EnviFormerDataset(Dataset):
def __init__(self, educts, products):
super().__init__()
assert len(educts) == len(products), "Can't have unequal length educts and products"
@staticmethod