starting on app domain with new dataset #120

This commit is contained in:
Liam Brydon
2025-11-04 16:33:56 +13:00
parent ac5d370b18
commit 13af49488e
3 changed files with 98 additions and 58 deletions

View File

@ -28,7 +28,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
EnviFormerDataset, Dataset
logger = logging.getLogger(__name__)
@ -2184,9 +2185,9 @@ class PackageBasedModel(EPModel):
ds.save(f)
return ds
def load_dataset(self) -> "RuleBasedDataset":
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return RuleBasedDataset.load(ds_path)
return Dataset.load(ds_path)
def retrain(self):
self.build_dataset()
@ -2196,7 +2197,7 @@ class PackageBasedModel(EPModel):
self.build_model()
@abstractmethod
def _fit_model(self, ds: RuleBasedDataset):
def _fit_model(self, ds: Dataset):
pass
@abstractmethod
@ -2337,22 +2338,22 @@ class PackageBasedModel(EPModel):
)
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(na_replacement=np.nan).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:
ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
else:
X = np.array(ds.X(na_replacement=np.nan))
y = np.array(ds.y(na_replacement=np.nan))
X = ds.X(na_replacement=np.nan).to_numpy()
y = ds.y(na_replacement=np.nan).to_numpy()
n_splits = kwargs.get("n_splits", 20)
@ -2586,7 +2587,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning(
start_index=ds.triggered()[0],
end_index=ds.triggered()[1],
end_index=ds.triggered()[-1],
)
model.fit(X, y)
return model
@ -2596,7 +2597,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return {
"clz": "RuleBaseRelativeReasoning",
"start_index": ds.triggered()[0],
"end_index": ds.triggered()[1],
"end_index": ds.triggered()[-1],
}
def _save_model(self, model):
@ -2716,7 +2717,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now()
ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(np.array(classify_ds.X()))
pred = self.model.predict_proba(classify_ds.X().to_numpy())
res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0]
@ -2761,7 +2762,9 @@ class ApplicabilityDomain(EnviPathModel):
@cached_property
def training_set_probs(self):
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
ds = self.model.load_dataset()
col_ids = ds.block_indices("prob")
return ds[ds.columns[col_ids[0]: col_ids[1]]]
def build(self):
ds = self.model.load_dataset()
@ -2769,9 +2772,9 @@ class ApplicabilityDomain(EnviPathModel):
start = datetime.now()
# Get Trainingset probs and dump them as they're required when using the app domain
probs = self.model.model.predict_proba(ds.X())
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
joblib.dump(probs, f)
probs = self.model.model.predict_proba(ds.X().to_numpy())
ds.add_probs(probs)
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
ad.build(ds)
@ -2816,25 +2819,21 @@ class ApplicabilityDomain(EnviPathModel):
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
# with a given assessment structure under a particular rule.
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
lambda: defaultdict(list)
)
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
feature = ds.columns[feature_index]
if feature.startswith("trig_"):
# TODO unroll loop
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
if int(cx[feature_index]) == 1:
for j, tx in enumerate(ds.X(exclude_id_col=False)):
if int(tx[feature_index]) == 1:
qualified_neighbours_per_rule[i][rule_idx].append(j)
import polars as pl
qualified_neighbours_per_rule: Dict = {}
# Select only the triggered columns
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
# trigger that rule. Select the structure_id of the compounds in those filtered rows
train_trig = {col_name: ds.df.filter(pl.col(col_name).eq(1)).select("structure_id") for col_name, value in row.items() if value == 1}
qualified_neighbours_per_rule[i] = train_trig
probs = self.training_set_probs
# preds = self.model.model.predict_proba(assessment_ds.X())
preds = self.model.combine_products_and_probs(
self.model.applicable_rules,
self.model.model.predict_proba(assessment_ds.X())[0],
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
assessment_prods[0],
)

View File

@ -62,6 +62,37 @@ class ModelTest(TestCase):
# from pprint import pprint
# pprint(mod.eval_results)
def test_applicability(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold=threshold,
name="ECC - BBD - 0.5",
description="Created MLRelativeReasoning in Testcase",
build_app_domain=True, # To test the applicability domain this must be True
app_domain_num_neighbours=5,
app_domain_local_compatibility_threshold=0.5,
app_domain_reliability_threshold=0.5,
)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model(n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
def test_rbrr(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):

View File

@ -51,14 +51,13 @@ class Dataset(ABC):
"""See add_rows"""
self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]:
def block_indices(self, prefix) -> List[int]:
"""Find the start and end indexes in column labels that has the prefix"""
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices, default=None), max(indices, default=None)
return indices
@property
def columns(self) -> List[str]:
@ -99,11 +98,11 @@ class Dataset(ABC):
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=self.df[item])
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path"):
def save(self, path: "Path | str"):
import pickle
with open(path, "wb") as fh:
@ -115,6 +114,9 @@ class Dataset(ABC):
return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
@ -123,28 +125,34 @@ class Dataset(ABC):
def __len__(self):
return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_features: int = len(self.columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
def struct_features(self) -> List[int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
def triggered(self) -> List[int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
def observed(self) -> List[int]:
return self._observed
def structure_id(self, index: int):
@ -153,21 +161,24 @@ class RuleBasedDataset(Dataset):
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
_col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered[0]: self._triggered[1]]
res = self[:, self._triggered]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, len(self.columns) - self.num_labels:]
res = self[:, self._observed]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@ -264,14 +275,6 @@ class RuleBasedDataset(Dataset):
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
if isinstance(structures[0], str):
struct_smiles = structures[0]
else:
struct_smiles = structures[0].smiles
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures:
if isinstance(struct, str):
struct_id = None
@ -293,12 +296,19 @@ class RuleBasedDataset(Dataset):
else:
trig.append(0)
prods.append([])
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
new_row = [struct_id] + features + trig + ([-1] * len(trig))
if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data)
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"