forked from enviPath/enviPy
new RuleBasedDataset and EnviFormer dataset working for respective models #120
This commit is contained in:
@ -65,6 +65,10 @@ class Dataset(ABC):
|
||||
"""Use the polars dataframe columns"""
|
||||
return self.df.columns
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.df.shape
|
||||
|
||||
@abstractmethod
|
||||
def X(self, **kwargs):
|
||||
pass
|
||||
@ -123,13 +127,15 @@ class Dataset(ABC):
|
||||
class RuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels=None, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c])
|
||||
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||
self.num_features: int = len(self.columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
@ -142,21 +148,25 @@ class RuleBasedDataset(Dataset):
|
||||
return self._observed
|
||||
|
||||
def structure_id(self, index: int):
|
||||
"""Get the UUID of a compound"""
|
||||
return self.df.item(index, "structure_id")
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
"""Get all the feature and trig columns"""
|
||||
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
"""Get all the trig columns"""
|
||||
res = self[:, self._triggered[0]: self._triggered[1]]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
"""Get all the obs columns"""
|
||||
res = self[:, len(self.columns) - self.num_labels:]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
@ -164,7 +174,7 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
||||
_structures = set()
|
||||
_structures = set() # Get all the structures
|
||||
for r in reactions:
|
||||
_structures.update(r.educts.all())
|
||||
if not educts_only:
|
||||
@ -254,8 +264,12 @@ class RuleBasedDataset(Dataset):
|
||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
if isinstance(structures[0], str):
|
||||
struct_smiles = structures[0]
|
||||
else:
|
||||
struct_smiles = structures[0].smiles
|
||||
ds_columns = (["structure_id"] +
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] +
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
for struct in structures:
|
||||
@ -306,43 +320,38 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
|
||||
class EnviFormerDataset(Dataset):
|
||||
def __init__(self, educts, products):
|
||||
super().__init__()
|
||||
assert len(educts) == len(products), "Can't have unequal length educts and products"
|
||||
def __init__(self, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
|
||||
def X(self):
|
||||
"""Return the educts"""
|
||||
return self["educts"]
|
||||
|
||||
def y(self):
|
||||
"""Return the products"""
|
||||
return self["products"]
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, *args, **kwargs):
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
educts = []
|
||||
products = []
|
||||
# Standardise reactions for the training data
|
||||
stereo = kwargs.get("stereo", False)
|
||||
rows = []
|
||||
for reaction in reactions:
|
||||
e = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
p = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
educts.append(e)
|
||||
products.append(p)
|
||||
return EnviFormerDataset(educts, products)
|
||||
|
||||
def X(self):
|
||||
pass
|
||||
|
||||
def y(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
pass
|
||||
rows.append([e, p])
|
||||
ds = EnviFormerDataset(["educts", "products"], rows)
|
||||
return ds
|
||||
|
||||
|
||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
@ -524,7 +533,7 @@ class EnsembleClassifierChain:
|
||||
self.classifiers = []
|
||||
|
||||
if self.num_labels is None:
|
||||
self.num_labels = len(Y[0])
|
||||
self.num_labels = Y.shape[1]
|
||||
|
||||
for p in range(self.num_chains):
|
||||
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||
@ -555,7 +564,7 @@ class RelativeReasoning:
|
||||
|
||||
def fit(self, X, Y):
|
||||
n_instances = len(Y)
|
||||
n_attributes = len(Y[0])
|
||||
n_attributes = Y.shape[1]
|
||||
|
||||
for i in range(n_attributes):
|
||||
for j in range(n_attributes):
|
||||
@ -567,8 +576,8 @@ class RelativeReasoning:
|
||||
countboth = 0
|
||||
|
||||
for k in range(n_instances):
|
||||
vi = Y[k][i]
|
||||
vj = Y[k][j]
|
||||
vi = Y[k, i]
|
||||
vj = Y[k, j]
|
||||
|
||||
if vi is None or vj is None:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user