forked from enviPath/enviPy
starting on app domain with new dataset #120
This commit is contained in:
@ -28,7 +28,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|||||||
from sklearn.model_selection import ShuffleSplit
|
from sklearn.model_selection import ShuffleSplit
|
||||||
|
|
||||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
|
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
|
||||||
|
EnviFormerDataset, Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -2184,9 +2185,9 @@ class PackageBasedModel(EPModel):
|
|||||||
ds.save(f)
|
ds.save(f)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "RuleBasedDataset":
|
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||||
return RuleBasedDataset.load(ds_path)
|
return Dataset.load(ds_path)
|
||||||
|
|
||||||
def retrain(self):
|
def retrain(self):
|
||||||
self.build_dataset()
|
self.build_dataset()
|
||||||
@ -2196,7 +2197,7 @@ class PackageBasedModel(EPModel):
|
|||||||
self.build_model()
|
self.build_model()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _fit_model(self, ds: RuleBasedDataset):
|
def _fit_model(self, ds: Dataset):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -2337,22 +2338,22 @@ class PackageBasedModel(EPModel):
|
|||||||
)
|
)
|
||||||
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
else:
|
else:
|
||||||
X = np.array(ds.X(na_replacement=np.nan))
|
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||||
self.eval_results = self.compute_averages([single_gen_result])
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
else:
|
else:
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
|
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
else:
|
else:
|
||||||
X = np.array(ds.X(na_replacement=np.nan))
|
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
|
|
||||||
n_splits = kwargs.get("n_splits", 20)
|
n_splits = kwargs.get("n_splits", 20)
|
||||||
|
|
||||||
@ -2586,7 +2587,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||||
model = RelativeReasoning(
|
model = RelativeReasoning(
|
||||||
start_index=ds.triggered()[0],
|
start_index=ds.triggered()[0],
|
||||||
end_index=ds.triggered()[1],
|
end_index=ds.triggered()[-1],
|
||||||
)
|
)
|
||||||
model.fit(X, y)
|
model.fit(X, y)
|
||||||
return model
|
return model
|
||||||
@ -2596,7 +2597,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
return {
|
return {
|
||||||
"clz": "RuleBaseRelativeReasoning",
|
"clz": "RuleBaseRelativeReasoning",
|
||||||
"start_index": ds.triggered()[0],
|
"start_index": ds.triggered()[0],
|
||||||
"end_index": ds.triggered()[1],
|
"end_index": ds.triggered()[-1],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_model(self, model):
|
def _save_model(self, model):
|
||||||
@ -2716,7 +2717,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||||
pred = self.model.predict_proba(np.array(classify_ds.X()))
|
pred = self.model.predict_proba(classify_ds.X().to_numpy())
|
||||||
|
|
||||||
res = MLRelativeReasoning.combine_products_and_probs(
|
res = MLRelativeReasoning.combine_products_and_probs(
|
||||||
self.applicable_rules, pred[0], classify_prods[0]
|
self.applicable_rules, pred[0], classify_prods[0]
|
||||||
@ -2761,7 +2762,9 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def training_set_probs(self):
|
def training_set_probs(self):
|
||||||
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
ds = self.model.load_dataset()
|
||||||
|
col_ids = ds.block_indices("prob")
|
||||||
|
return ds[ds.columns[col_ids[0]: col_ids[1]]]
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
ds = self.model.load_dataset()
|
ds = self.model.load_dataset()
|
||||||
@ -2769,9 +2772,9 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
|
|
||||||
# Get Trainingset probs and dump them as they're required when using the app domain
|
# Get Trainingset probs and dump them as they're required when using the app domain
|
||||||
probs = self.model.model.predict_proba(ds.X())
|
probs = self.model.model.predict_proba(ds.X().to_numpy())
|
||||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
ds.add_probs(probs)
|
||||||
joblib.dump(probs, f)
|
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
|
||||||
|
|
||||||
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
||||||
ad.build(ds)
|
ad.build(ds)
|
||||||
@ -2816,25 +2819,21 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
||||||
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
||||||
# with a given assessment structure under a particular rule.
|
# with a given assessment structure under a particular rule.
|
||||||
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
|
|
||||||
lambda: defaultdict(list)
|
|
||||||
)
|
|
||||||
|
|
||||||
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
import polars as pl
|
||||||
feature = ds.columns[feature_index]
|
qualified_neighbours_per_rule: Dict = {}
|
||||||
if feature.startswith("trig_"):
|
# Select only the triggered columns
|
||||||
# TODO unroll loop
|
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
|
||||||
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
|
||||||
if int(cx[feature_index]) == 1:
|
# trigger that rule. Select the structure_id of the compounds in those filtered rows
|
||||||
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
train_trig = {col_name: ds.df.filter(pl.col(col_name).eq(1)).select("structure_id") for col_name, value in row.items() if value == 1}
|
||||||
if int(tx[feature_index]) == 1:
|
qualified_neighbours_per_rule[i] = train_trig
|
||||||
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
|
||||||
|
|
||||||
probs = self.training_set_probs
|
probs = self.training_set_probs
|
||||||
# preds = self.model.model.predict_proba(assessment_ds.X())
|
# preds = self.model.model.predict_proba(assessment_ds.X())
|
||||||
preds = self.model.combine_products_and_probs(
|
preds = self.model.combine_products_and_probs(
|
||||||
self.model.applicable_rules,
|
self.model.applicable_rules,
|
||||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
|
||||||
assessment_prods[0],
|
assessment_prods[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -62,6 +62,37 @@ class ModelTest(TestCase):
|
|||||||
# from pprint import pprint
|
# from pprint import pprint
|
||||||
# pprint(mod.eval_results)
|
# pprint(mod.eval_results)
|
||||||
|
|
||||||
|
def test_applicability(self):
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
|
threshold = float(0.5)
|
||||||
|
|
||||||
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
|
||||||
|
mod = MLRelativeReasoning.create(
|
||||||
|
self.package,
|
||||||
|
rule_package_objs,
|
||||||
|
data_package_objs,
|
||||||
|
eval_packages_objs,
|
||||||
|
threshold=threshold,
|
||||||
|
name="ECC - BBD - 0.5",
|
||||||
|
description="Created MLRelativeReasoning in Testcase",
|
||||||
|
build_app_domain=True, # To test the applicability domain this must be True
|
||||||
|
app_domain_num_neighbours=5,
|
||||||
|
app_domain_local_compatibility_threshold=0.5,
|
||||||
|
app_domain_reliability_threshold=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
mod.build_dataset()
|
||||||
|
mod.build_model()
|
||||||
|
mod.multigen_eval = True
|
||||||
|
mod.save()
|
||||||
|
mod.evaluate_model(n_splits=2)
|
||||||
|
|
||||||
|
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
def test_rbrr(self):
|
def test_rbrr(self):
|
||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
with self.settings(MODEL_DIR=tmpdir):
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
|
|||||||
@ -51,14 +51,13 @@ class Dataset(ABC):
|
|||||||
"""See add_rows"""
|
"""See add_rows"""
|
||||||
self.add_rows([row])
|
self.add_rows([row])
|
||||||
|
|
||||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
def block_indices(self, prefix) -> List[int]:
|
||||||
"""Find the start and end indexes in column labels that has the prefix"""
|
"""Find the start and end indexes in column labels that has the prefix"""
|
||||||
indices: List[int] = []
|
indices: List[int] = []
|
||||||
for i, feature in enumerate(self.columns):
|
for i, feature in enumerate(self.columns):
|
||||||
if feature.startswith(prefix):
|
if feature.startswith(prefix):
|
||||||
indices.append(i)
|
indices.append(i)
|
||||||
|
return indices
|
||||||
return min(indices, default=None), max(indices, default=None)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns(self) -> List[str]:
|
def columns(self) -> List[str]:
|
||||||
@ -99,11 +98,11 @@ class Dataset(ABC):
|
|||||||
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
||||||
res = self.df[item]
|
res = self.df[item]
|
||||||
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
||||||
return self.__class__(data=self.df[item])
|
return self.__class__(data=res)
|
||||||
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def save(self, path: "Path"):
|
def save(self, path: "Path | str"):
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
with open(path, "wb") as fh:
|
with open(path, "wb") as fh:
|
||||||
@ -115,6 +114,9 @@ class Dataset(ABC):
|
|||||||
|
|
||||||
return pickle.load(open(path, "rb"))
|
return pickle.load(open(path, "rb"))
|
||||||
|
|
||||||
|
def to_numpy(self):
|
||||||
|
return self.df.to_numpy()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||||
@ -123,28 +125,34 @@ class Dataset(ABC):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
|
def iter_rows(self, named=False):
|
||||||
|
return self.df.iter_rows(named=named)
|
||||||
|
|
||||||
|
|
||||||
class RuleBasedDataset(Dataset):
|
class RuleBasedDataset(Dataset):
|
||||||
def __init__(self, num_labels=None, columns=None, data=None):
|
def __init__(self, num_labels=None, columns=None, data=None):
|
||||||
super().__init__(columns, data)
|
super().__init__(columns, data)
|
||||||
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||||
self.num_features: int = len(self.columns) - self.num_labels
|
# Pre-calculate the ids of columns for features/labels, useful later in X and y
|
||||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
self._struct_features: List[int] = self.block_indices("feature_")
|
||||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
self._triggered: List[int] = self.block_indices("trig_")
|
||||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
self._observed: List[int] = self.block_indices("obs_")
|
||||||
|
self.feature_cols: List[int] = self._struct_features + self._triggered
|
||||||
|
self.num_features: int = len(self.feature_cols)
|
||||||
|
self.has_probs = False
|
||||||
|
|
||||||
def times_triggered(self, rule_uuid) -> int:
|
def times_triggered(self, rule_uuid) -> int:
|
||||||
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
||||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||||
|
|
||||||
def struct_features(self) -> Tuple[int, int]:
|
def struct_features(self) -> List[int]:
|
||||||
return self._struct_features
|
return self._struct_features
|
||||||
|
|
||||||
def triggered(self) -> Tuple[int, int]:
|
def triggered(self) -> List[int]:
|
||||||
return self._triggered
|
return self._triggered
|
||||||
|
|
||||||
def observed(self) -> Tuple[int, int]:
|
def observed(self) -> List[int]:
|
||||||
return self._observed
|
return self._observed
|
||||||
|
|
||||||
def structure_id(self, index: int):
|
def structure_id(self, index: int):
|
||||||
@ -153,21 +161,24 @@ class RuleBasedDataset(Dataset):
|
|||||||
|
|
||||||
def X(self, exclude_id_col=True, na_replacement=0):
|
def X(self, exclude_id_col=True, na_replacement=0):
|
||||||
"""Get all the feature and trig columns"""
|
"""Get all the feature and trig columns"""
|
||||||
res = self[:, 1 if exclude_id_col else 0: len(self.columns) - self.num_labels]
|
_col_ids = self.feature_cols
|
||||||
|
if not exclude_id_col:
|
||||||
|
_col_ids = [0] + _col_ids
|
||||||
|
res = self[:, _col_ids]
|
||||||
if na_replacement is not None:
|
if na_replacement is not None:
|
||||||
res.df = res.df.fill_null(na_replacement)
|
res.df = res.df.fill_null(na_replacement)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def trig(self, na_replacement=0):
|
def trig(self, na_replacement=0):
|
||||||
"""Get all the trig columns"""
|
"""Get all the trig columns"""
|
||||||
res = self[:, self._triggered[0]: self._triggered[1]]
|
res = self[:, self._triggered]
|
||||||
if na_replacement is not None:
|
if na_replacement is not None:
|
||||||
res.df = res.df.fill_null(na_replacement)
|
res.df = res.df.fill_null(na_replacement)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def y(self, na_replacement=0):
|
def y(self, na_replacement=0):
|
||||||
"""Get all the obs columns"""
|
"""Get all the obs columns"""
|
||||||
res = self[:, len(self.columns) - self.num_labels:]
|
res = self[:, self._observed]
|
||||||
if na_replacement is not None:
|
if na_replacement is not None:
|
||||||
res.df = res.df.fill_null(na_replacement)
|
res.df = res.df.fill_null(na_replacement)
|
||||||
return res
|
return res
|
||||||
@ -264,14 +275,6 @@ class RuleBasedDataset(Dataset):
|
|||||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||||
classify_data = []
|
classify_data = []
|
||||||
classify_products = []
|
classify_products = []
|
||||||
if isinstance(structures[0], str):
|
|
||||||
struct_smiles = structures[0]
|
|
||||||
else:
|
|
||||||
struct_smiles = structures[0].smiles
|
|
||||||
ds_columns = (["structure_id"] +
|
|
||||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(struct_smiles))] +
|
|
||||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
|
||||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
|
||||||
for struct in structures:
|
for struct in structures:
|
||||||
if isinstance(struct, str):
|
if isinstance(struct, str):
|
||||||
struct_id = None
|
struct_id = None
|
||||||
@ -293,12 +296,19 @@ class RuleBasedDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
trig.append(0)
|
trig.append(0)
|
||||||
prods.append([])
|
prods.append([])
|
||||||
|
new_row = [struct_id] + features + trig + ([-1] * len(trig))
|
||||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
if self.has_probs:
|
||||||
|
new_row += [-1] * len(trig)
|
||||||
|
classify_data.append(new_row)
|
||||||
classify_products.append(prods)
|
classify_products.append(prods)
|
||||||
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=classify_data)
|
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
|
||||||
return ds, classify_products
|
return ds, classify_products
|
||||||
|
|
||||||
|
def add_probs(self, probs):
|
||||||
|
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
|
||||||
|
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
|
||||||
|
self.has_probs = True
|
||||||
|
|
||||||
def to_arff(self, path: "Path"):
|
def to_arff(self, path: "Path"):
|
||||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||||
arff += "\n"
|
arff += "\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user