starting on app domain with new dataset #120

This commit is contained in:
Liam Brydon
2025-11-04 16:33:56 +13:00
parent ac5d370b18
commit 13af49488e
3 changed files with 98 additions and 58 deletions

View File

@ -28,7 +28,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
EnviFormerDataset, Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -2184,9 +2185,9 @@ class PackageBasedModel(EPModel):
ds.save(f) ds.save(f)
return ds return ds
def load_dataset(self) -> "RuleBasedDataset": def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return RuleBasedDataset.load(ds_path) return Dataset.load(ds_path)
def retrain(self): def retrain(self):
self.build_dataset() self.build_dataset()
@ -2196,7 +2197,7 @@ class PackageBasedModel(EPModel):
self.build_model() self.build_model()
@abstractmethod @abstractmethod
def _fit_model(self, ds: RuleBasedDataset): def _fit_model(self, ds: Dataset):
pass pass
@abstractmethod @abstractmethod
@ -2337,22 +2338,22 @@ class PackageBasedModel(EPModel):
) )
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning): if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
else: else:
X = np.array(ds.X(na_replacement=np.nan)) X = ds.X(na_replacement=np.nan).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold) single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result]) self.eval_results = self.compute_averages([single_gen_result])
else: else:
ds = self.load_dataset() ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning): if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
else: else:
X = np.array(ds.X(na_replacement=np.nan)) X = ds.X(na_replacement=np.nan).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
n_splits = kwargs.get("n_splits", 20) n_splits = kwargs.get("n_splits", 20)
@ -2586,7 +2587,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None) X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning( model = RelativeReasoning(
start_index=ds.triggered()[0], start_index=ds.triggered()[0],
end_index=ds.triggered()[1], end_index=ds.triggered()[-1],
) )
model.fit(X, y) model.fit(X, y)
return model return model
@ -2596,7 +2597,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return { return {
"clz": "RuleBaseRelativeReasoning", "clz": "RuleBaseRelativeReasoning",
"start_index": ds.triggered()[0], "start_index": ds.triggered()[0],
"end_index": ds.triggered()[1], "end_index": ds.triggered()[-1],
} }
def _save_model(self, model): def _save_model(self, model):
@ -2716,7 +2717,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now() start = datetime.now()
ds = self.load_dataset() ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(np.array(classify_ds.X())) pred = self.model.predict_proba(classify_ds.X().to_numpy())
res = MLRelativeReasoning.combine_products_and_probs( res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0] self.applicable_rules, pred[0], classify_prods[0]
@ -2761,7 +2762,9 @@ class ApplicabilityDomain(EnviPathModel):
@cached_property @cached_property
def training_set_probs(self): def training_set_probs(self):
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")) ds = self.model.load_dataset()
col_ids = ds.block_indices("prob")
return ds[ds.columns[col_ids[0]: col_ids[1]]]
def build(self): def build(self):
ds = self.model.load_dataset() ds = self.model.load_dataset()
@ -2769,9 +2772,9 @@ class ApplicabilityDomain(EnviPathModel):
start = datetime.now() start = datetime.now()
# Get Trainingset probs and dump them as they're required when using the app domain # Get Trainingset probs and dump them as they're required when using the app domain
probs = self.model.model.predict_proba(ds.X()) probs = self.model.model.predict_proba(ds.X().to_numpy())
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl") ds.add_probs(probs)
joblib.dump(probs, f) ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours) ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
ad.build(ds) ad.build(ds)
@ -2816,25 +2819,21 @@ class ApplicabilityDomain(EnviPathModel):
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1). # it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
# This is used to find "qualified neighbours" — training examples that share the same triggered feature # This is used to find "qualified neighbours" — training examples that share the same triggered feature
# with a given assessment structure under a particular rule. # with a given assessment structure under a particular rule.
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
lambda: defaultdict(list)
)
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())): import polars as pl
feature = ds.columns[feature_index] qualified_neighbours_per_rule: Dict = {}
if feature.startswith("trig_"): # Select only the triggered columns
# TODO unroll loop for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)): # Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
if int(cx[feature_index]) == 1: # trigger that rule. Select the structure_id of the compounds in those filtered rows
for j, tx in enumerate(ds.X(exclude_id_col=False)): train_trig = {col_name: ds.df.filter(pl.col(col_name).eq(1)).select("structure_id") for col_name, value in row.items() if value == 1}
if int(tx[feature_index]) == 1: qualified_neighbours_per_rule[i] = train_trig
qualified_neighbours_per_rule[i][rule_idx].append(j)
probs = self.training_set_probs probs = self.training_set_probs
# preds = self.model.model.predict_proba(assessment_ds.X()) # preds = self.model.model.predict_proba(assessment_ds.X())
preds = self.model.combine_products_and_probs( preds = self.model.combine_products_and_probs(
self.model.applicable_rules, self.model.applicable_rules,
self.model.model.predict_proba(assessment_ds.X())[0], self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
assessment_prods[0], assessment_prods[0],
) )

View File

@ -62,6 +62,37 @@ class ModelTest(TestCase):
# from pprint import pprint # from pprint import pprint
# pprint(mod.eval_results) # pprint(mod.eval_results)
def test_applicability(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
eval_packages_objs,
threshold=threshold,
name="ECC - BBD - 0.5",
description="Created MLRelativeReasoning in Testcase",
build_app_domain=True, # To test the applicability domain this must be True
app_domain_num_neighbours=5,
app_domain_local_compatibility_threshold=0.5,
app_domain_reliability_threshold=0.5,
)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model(n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
def test_rbrr(self): def test_rbrr(self):
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir): with self.settings(MODEL_DIR=tmpdir):

View File

@ -51,14 +51,13 @@ class Dataset(ABC):
"""See add_rows""" """See add_rows"""
self.add_rows([row]) self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]: def block_indices(self, prefix) -> List[int]:
"""Find the start and end indexes in column labels that has the prefix""" """Find the start and end indexes in column labels that has the prefix"""
indices: List[int] = [] indices: List[int] = []
for i, feature in enumerate(self.columns): for i, feature in enumerate(self.columns):
if feature.startswith(prefix): if feature.startswith(prefix):
indices.append(i) indices.append(i)
return indices
return min(indices, default=None), max(indices, default=None)
@property @property
def columns(self) -> List[str]: def columns(self) -> List[str]:
@ -99,11 +98,11 @@ class Dataset(ABC):
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__""" See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item] res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=self.df[item]) return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res return res
def save(self, path: "Path"): def save(self, path: "Path | str"):
import pickle import pickle
with open(path, "wb") as fh: with open(path, "wb") as fh:
@ -115,6 +114,9 @@ class Dataset(ABC):
return pickle.load(open(path, "rb")) return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self): def __repr__(self):
return ( return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>" f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
@ -123,28 +125,34 @@ class Dataset(ABC):
def __len__(self): def __len__(self):
return len(self.df) return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
class RuleBasedDataset(Dataset): class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None): def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data) super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init. # Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c]) self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_features: int = len(self.columns) - self.num_labels # Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: Tuple[int, int] = self._block_indices("feature_") self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_") self._triggered: List[int] = self.block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_") self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def times_triggered(self, rule_uuid) -> int: def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column""" """Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]: def struct_features(self) -> List[int]:
return self._struct_features return self._struct_features
def triggered(self) -> Tuple[int, int]: def triggered(self) -> List[int]:
return self._triggered return self._triggered
def observed(self) -> Tuple[int, int]: def observed(self) -> List[int]:
return self._observed return self._observed
def structure_id(self, index: int): def structure_id(self, index: int):
@ -153,21 +161,24 @@ class RuleBasedDataset(Dataset):
def X(self, exclude_id_col=True, na_replacement=0): def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns""" """Get all the feature and trig columns"""
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels] _col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None: if na_replacement is not None:
res.df = res.df.fill_null(na_replacement) res.df = res.df.fill_null(na_replacement)
return res return res
def trig(self, na_replacement=0): def trig(self, na_replacement=0):
"""Get all the trig columns""" """Get all the trig columns"""
res = self[:, self._triggered[0]: self._triggered[1]] res = self[:, self._triggered]
if na_replacement is not None: if na_replacement is not None:
res.df = res.df.fill_null(na_replacement) res.df = res.df.fill_null(na_replacement)
return res return res
def y(self, na_replacement=0): def y(self, na_replacement=0):
"""Get all the obs columns""" """Get all the obs columns"""
res = self[:, len(self.columns) - self.num_labels:] res = self[:, self._observed]
if na_replacement is not None: if na_replacement is not None:
res.df = res.df.fill_null(na_replacement) res.df = res.df.fill_null(na_replacement)
return res return res
@ -264,14 +275,6 @@ class RuleBasedDataset(Dataset):
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = [] classify_data = []
classify_products = [] classify_products = []
if isinstance(structures[0], str):
struct_smiles = structures[0]
else:
struct_smiles = structures[0].smiles
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
for struct in structures: for struct in structures:
if isinstance(struct, str): if isinstance(struct, str):
struct_id = None struct_id = None
@ -293,12 +296,19 @@ class RuleBasedDataset(Dataset):
else: else:
trig.append(0) trig.append(0)
prods.append([]) prods.append([])
new_row = [struct_id] + features + trig + ([-1] * len(trig))
classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods) classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data) ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products return ds, classify_products
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
def to_arff(self, path: "Path"): def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n" arff += "\n"