forked from enviPath/enviPy
work towards #120
This commit is contained in:
335
utilities/ml.py
335
utilities/ml.py
@ -30,42 +30,47 @@ if TYPE_CHECKING:
|
||||
|
||||
class Dataset(ABC):
|
||||
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
||||
if isinstance(data, pl.DataFrame):
|
||||
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
|
||||
self.df = data
|
||||
else:
|
||||
# Build either an empty dataframe with columns or fill it with list of list data
|
||||
if data is not None and len(columns) != len(data[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
|
||||
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
|
||||
if columns is None:
|
||||
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
||||
self.df = pl.DataFrame(data=data, schema=columns)
|
||||
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
|
||||
|
||||
def add_rows(self, rows: List[List[str | int | float]]):
|
||||
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
|
||||
if len(self.columns) != len(rows[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
|
||||
new_rows = pl.DataFrame(data=rows, schema=self.columns)
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
|
||||
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
|
||||
self.df.extend(new_rows)
|
||||
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
"""See add_rows"""
|
||||
self.add_rows([row])
|
||||
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
"""Find the start and end indexes in column labels that has the prefix"""
|
||||
indices: List[int] = []
|
||||
for i, feature in enumerate(self.columns):
|
||||
if feature.startswith(prefix):
|
||||
indices.append(i)
|
||||
|
||||
return min(indices), max(indices)
|
||||
return min(indices, default=None), max(indices, default=None)
|
||||
|
||||
@property
|
||||
def columns(self) -> List[str]:
|
||||
"""Use the polars dataframe columns"""
|
||||
return self.df.columns
|
||||
|
||||
@abstractmethod
|
||||
def X(self):
|
||||
def X(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def y(self):
|
||||
def y(self, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -73,11 +78,26 @@ class Dataset(ABC):
|
||||
def generate_dataset(reactions, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def at(self, position: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
|
||||
def at(self, position: int) -> Dataset:
|
||||
"""See __getitem__"""
|
||||
return self[position]
|
||||
|
||||
def limit(self, limit: int) -> Dataset:
|
||||
"""See __getitem__"""
|
||||
return self[:limit]
|
||||
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
"""Use polars iter_rows for iterating over the dataset"""
|
||||
return self.df.iter_rows()
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Item is passed to polars allowing for advanced indexing.
|
||||
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
||||
res = self.df[item]
|
||||
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
||||
return self.__class__(data=self.df[item])
|
||||
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
||||
return res
|
||||
|
||||
def save(self, path: "Path"):
|
||||
import pickle
|
||||
@ -86,35 +106,61 @@ class Dataset(ABC):
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: "str | Path") -> "RuleBasedDataset":
|
||||
def load(path: "str | Path") -> "Dataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||
)
|
||||
|
||||
class NewRuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels, columns=None, data=None):
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
|
||||
class RuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels=None, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
self.num_labels: int = num_labels
|
||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "trig_" in c])
|
||||
self.num_features: int = len(self.columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
return self._block_indices("feature_")
|
||||
return self._struct_features
|
||||
|
||||
def triggered(self) -> Tuple[int, int]:
|
||||
return self._block_indices("trig_")
|
||||
return self._triggered
|
||||
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
return self._block_indices("obs_")
|
||||
return self._observed
|
||||
|
||||
def X(self):
|
||||
pass
|
||||
def structure_id(self, index: int):
|
||||
return self.df.item(index, "structure_id")
|
||||
|
||||
def y(self):
|
||||
pass
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
res = self[:, self._triggered[0]: self._triggered[1]]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self[:, len(self.columns) - self.num_labels:]
|
||||
if na_replacement is not None:
|
||||
res.df = res.df.fill_null(na_replacement)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
||||
@ -178,7 +224,6 @@ class NewRuleBasedDataset(Dataset):
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
|
||||
rows = []
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
@ -201,77 +246,18 @@ class NewRuleBasedDataset(Dataset):
|
||||
else:
|
||||
obs.append(0)
|
||||
rows.append([str(comp.uuid)] + feat + trig + obs)
|
||||
ds.add_rows(rows)
|
||||
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
|
||||
return ds
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
|
||||
class RuleBasedDataset(Dataset):
|
||||
def __init__(
|
||||
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
||||
):
|
||||
self.columns: List[str] = columns
|
||||
self.num_labels: int = num_labels
|
||||
|
||||
if data is None:
|
||||
self.data: List[List[str | int | float]] = list()
|
||||
else:
|
||||
self.data = data
|
||||
|
||||
self.num_features: int = len(columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
indices: List[int] = []
|
||||
for i, feature in enumerate(self.columns):
|
||||
if feature.startswith(prefix):
|
||||
indices.append(i)
|
||||
|
||||
return min(indices), max(indices)
|
||||
|
||||
def structure_id(self):
|
||||
return self.data[0][0]
|
||||
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
if len(self.columns) != len(row):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
|
||||
self.data.append(row)
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
idx = self.columns.index(f"trig_{rule_uuid}")
|
||||
|
||||
times_triggered = 0
|
||||
for row in self.data:
|
||||
if row[idx] == 1:
|
||||
times_triggered += 1
|
||||
|
||||
return times_triggered
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
return self._struct_features
|
||||
|
||||
def triggered(self) -> Tuple[int, int]:
|
||||
return self._triggered
|
||||
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
return self._observed
|
||||
|
||||
def at(self, position: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, [self.data[position]])
|
||||
|
||||
def limit(self, limit: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
|
||||
|
||||
def classification_dataset(
|
||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
ds_columns = (["structure_id"] +
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(structures[0].smiles))] +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
for struct in structures:
|
||||
if isinstance(struct, str):
|
||||
struct_id = None
|
||||
@ -296,171 +282,8 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||
classify_products.append(prods)
|
||||
|
||||
return RuleBasedDataset(
|
||||
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
||||
), classify_products
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(
|
||||
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
||||
) -> RuleBasedDataset:
|
||||
_structures = set()
|
||||
|
||||
for r in reactions:
|
||||
for e in r.educts.all():
|
||||
_structures.add(e)
|
||||
|
||||
if not educts_only:
|
||||
for e in r.products:
|
||||
_structures.add(e)
|
||||
|
||||
compounds = sorted(_structures, key=lambda x: x.url)
|
||||
|
||||
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||
observed: Set[str] = set()
|
||||
|
||||
# Apply rules on collected compounds and store tps
|
||||
for i, comp in enumerate(compounds):
|
||||
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||
|
||||
for rule in applicable_rules:
|
||||
product_sets = rule.apply(comp.smiles)
|
||||
|
||||
if len(product_sets) == 0:
|
||||
continue
|
||||
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
if key in triggered:
|
||||
logger.info(f"{key} already present. Duplicate reaction?")
|
||||
|
||||
for prod_set in product_sets:
|
||||
for smi in prod_set:
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception:
|
||||
# :shrug:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
pass
|
||||
|
||||
triggered[key].add(smi)
|
||||
|
||||
for i, r in enumerate(reactions):
|
||||
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||
|
||||
if len(r.educts.all()) != 1:
|
||||
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||
continue
|
||||
|
||||
for comp in r.educts.all():
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
if key not in triggered:
|
||||
continue
|
||||
|
||||
# standardize products from reactions for comparison
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
pass
|
||||
|
||||
standardized_products.append(smi)
|
||||
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
observed.add(key)
|
||||
else:
|
||||
pass
|
||||
|
||||
ds = None
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feat = FormatConverter.maccs(comp.smiles)
|
||||
trig = []
|
||||
obs = []
|
||||
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
elif key not in triggered:
|
||||
obs.append(None)
|
||||
else:
|
||||
obs.append(0)
|
||||
|
||||
if ds is None:
|
||||
header = (
|
||||
["structure_id"]
|
||||
+ [f"feature_{i}" for i, _ in enumerate(feat)]
|
||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
||||
)
|
||||
ds = RuleBasedDataset(header, len(applicable_rules))
|
||||
|
||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
||||
|
||||
return ds
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
res = self.__getitem__(
|
||||
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
|
||||
)
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
|
||||
|
||||
row_key, col_key = key
|
||||
|
||||
# Normalize rows
|
||||
if isinstance(row_key, int):
|
||||
rows = [self.data[row_key]]
|
||||
else:
|
||||
rows = self.data[row_key]
|
||||
|
||||
# Normalize columns
|
||||
if isinstance(col_key, int):
|
||||
res = [row[col_key] for row in rows]
|
||||
else:
|
||||
res = [
|
||||
[row[i] for i in range(*col_key.indices(len(row)))]
|
||||
if isinstance(col_key, slice)
|
||||
else [row[i] for i in col_key]
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return res
|
||||
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data)
|
||||
return ds, classify_products
|
||||
|
||||
def to_arff(self, path: "Path"):
|
||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||
@ -472,7 +295,7 @@ class RuleBasedDataset(Dataset):
|
||||
arff += f"@attribute {c} {{0,1}}\n"
|
||||
|
||||
arff += "\n@data\n"
|
||||
for d in self.data:
|
||||
for d in self:
|
||||
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
|
||||
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
|
||||
arff += f"{ys},{xs}\n"
|
||||
@ -481,14 +304,10 @@ class RuleBasedDataset(Dataset):
|
||||
fh.write(arff)
|
||||
fh.flush()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
|
||||
)
|
||||
|
||||
|
||||
class EnviFormerDataset(Dataset):
|
||||
def __init__(self, educts, products):
|
||||
super().__init__()
|
||||
assert len(educts) == len(products), "Can't have unequal length educts and products"
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user