new RuleBasedDataset and EnviFormer dataset working for respective models #120

This commit is contained in:
Liam Brydon
2025-11-04 10:58:16 +13:00
parent ff51e48f90
commit ac5d370b18
5 changed files with 126 additions and 101 deletions

View File

@ -65,6 +65,10 @@ class Dataset(ABC):
"""Use the polars dataframe columns"""
return self.df.columns
@property
def shape(self):
return self.df.shape
@abstractmethod
def X(self, **kwargs):
pass
@ -123,13 +127,15 @@ class Dataset(ABC):
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c])
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_features: int = len(self.columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
@ -142,21 +148,25 @@ class RuleBasedDataset(Dataset):
return self._observed
def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.df.item(index, "structure_id")
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered[0]: self._triggered[1]]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, len(self.columns) - self.num_labels:]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
@ -164,7 +174,7 @@ class RuleBasedDataset(Dataset):
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
_structures = set()
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
@ -254,8 +264,12 @@ class RuleBasedDataset(Dataset):
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
if isinstance(structures[0], str):
struct_smiles = structures[0]
else:
struct_smiles = structures[0].smiles
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures:
@ -306,43 +320,38 @@ class RuleBasedDataset(Dataset):
class EnviFormerDataset(Dataset):
def __init__(self, educts, products):
super().__init__()
assert len(educts) == len(products), "Can't have unequal length educts and products"
def __init__(self, columns=None, data=None):
super().__init__(columns, data)
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
educts = []
products = []
# Standardise reactions for the training data
stereo = kwargs.get("stereo", False)
rows = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all()
]
)
educts.append(e)
products.append(p)
return EnviFormerDataset(educts, products)
def X(self):
pass
def y(self):
pass
def __getitem__(self, item):
pass
def __len__(self):
pass
rows.append([e, p])
ds = EnviFormerDataset(["educts", "products"], rows)
return ds
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -524,7 +533,7 @@ class EnsembleClassifierChain:
self.classifiers = []
if self.num_labels is None:
self.num_labels = len(Y[0])
self.num_labels = Y.shape[1]
for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
@ -555,7 +564,7 @@ class RelativeReasoning:
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = len(Y[0])
n_attributes = Y.shape[1]
for i in range(n_attributes):
for j in range(n_attributes):
@ -567,8 +576,8 @@ class RelativeReasoning:
countboth = 0
for k in range(n_instances):
vi = Y[k][i]
vj = Y[k][j]
vi = Y[k, i]
vj = Y[k, j]
if vi is None or vj is None:
continue