work towards #120

This commit is contained in:
Liam Brydon
2025-10-24 14:40:26 +13:00
parent 2980a75daa
commit 8166df6f39
2 changed files with 203 additions and 59 deletions

View File

@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
import networkx as nx
import numpy as np
from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
@ -28,6 +29,37 @@ if TYPE_CHECKING:
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame):
self.df = data
else:
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns)
def add_rows(self, rows: List[List[str | int | float]]):
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
new_rows = pl.DataFrame(data=rows, schema=self.columns)
self.df.extend(new_rows)
def add_row(self, row: List[str | int | float]):
self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
@property
def columns(self) -> List[str]:
return self.df.columns
@abstractmethod
def X(self):
pass
@ -36,7 +68,143 @@ class Dataset(ABC):
def y(self):
pass
@staticmethod
@abstractmethod
def generate_dataset(reactions, *args, **kwargs):
pass
def at(self, position: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "RuleBasedDataset":
import pickle
return pickle.load(open(path, "rb"))
class NewRuleBasedDataset(Dataset):
def __init__(self, num_labels, columns=None, data=None):
super().__init__(columns, data)
self.num_labels: int = num_labels
self.num_features: int = len(self.columns) - self.num_labels
def times_triggered(self, rule_uuid) -> int:
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
return self._block_indices("feature_")
def triggered(self) -> Tuple[int, int]:
return self._block_indices("trig_")
def observed(self) -> Tuple[int, int]:
return self._block_indices("obs_")
def X(self):
pass
def y(self):
pass
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
_structures = set()
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
rows = []
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feat + trig + obs)
ds.add_rows(rows)
return ds
def __getitem__(self, item):
pass
@ -99,9 +267,6 @@ class RuleBasedDataset(Dataset):
def limit(self, limit: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
@ -297,18 +462,6 @@ class RuleBasedDataset(Dataset):
return res
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "Path") -> "RuleBasedDataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
@ -335,8 +488,30 @@ class RuleBasedDataset(Dataset):
class EnviFormerDataset(Dataset):
def __init__(self):
pass
def __init__(self, educts, products):
assert len(educts) == len(products), "Can't have unequal length educts and products"
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
educts = []
products = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
educts.append(e)
products.append(p)
return EnviFormerDataset(educts, products)
def X(self):
pass
@ -347,6 +522,9 @@ class EnviFormerDataset(Dataset):
def __getitem__(self, item):
pass
def __len__(self):
pass
class SparseLabelECC(BaseEstimator, ClassifierMixin):
"""