forked from enviPath/enviPy
Merge remote-tracking branch 'origin/develop' into fix/xss
# Conflicts: # pyproject.toml # uv.lock
This commit is contained in:
@ -1552,9 +1552,7 @@ class SPathway(object):
|
|||||||
if sub.app_domain_assessment is None:
|
if sub.app_domain_assessment is None:
|
||||||
if self.prediction_setting.model:
|
if self.prediction_setting.model:
|
||||||
if self.prediction_setting.model.app_domain:
|
if self.prediction_setting.model.app_domain:
|
||||||
app_domain_assessment = self.prediction_setting.model.app_domain.assess(
|
app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
|
||||||
sub.smiles
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
if self.persist is not None:
|
if self.persist is not None:
|
||||||
n = self.snode_persist_lookup[sub]
|
n = self.snode_persist_lookup[sub]
|
||||||
@ -1586,11 +1584,7 @@ class SPathway(object):
|
|||||||
app_domain_assessment = None
|
app_domain_assessment = None
|
||||||
if self.prediction_setting.model:
|
if self.prediction_setting.model:
|
||||||
if self.prediction_setting.model.app_domain:
|
if self.prediction_setting.model.app_domain:
|
||||||
app_domain_assessment = (
|
app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
|
||||||
self.prediction_setting.model.app_domain.assess(c)[
|
|
||||||
0
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.smiles_to_node[c] = SNode(
|
self.smiles_to_node[c] = SNode(
|
||||||
c, sub.depth + 1, app_domain_assessment
|
c, sub.depth + 1, app_domain_assessment
|
||||||
|
|||||||
308
epdb/models.py
308
epdb/models.py
@ -29,7 +29,8 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|||||||
from sklearn.model_selection import ShuffleSplit
|
from sklearn.model_selection import ShuffleSplit
|
||||||
|
|
||||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
|
||||||
|
EnviFormerDataset, Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -2178,7 +2179,7 @@ class PackageBasedModel(EPModel):
|
|||||||
|
|
||||||
applicable_rules = self.applicable_rules
|
applicable_rules = self.applicable_rules
|
||||||
reactions = list(self._get_reactions())
|
reactions = list(self._get_reactions())
|
||||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||||
|
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||||
@ -2187,7 +2188,7 @@ class PackageBasedModel(EPModel):
|
|||||||
ds.save(f)
|
ds.save(f)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "Dataset":
|
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||||
return Dataset.load(ds_path)
|
return Dataset.load(ds_path)
|
||||||
|
|
||||||
@ -2228,7 +2229,7 @@ class PackageBasedModel(EPModel):
|
|||||||
self.model_status = self.BUILT_NOT_EVALUATED
|
self.model_status = self.BUILT_NOT_EVALUATED
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
|
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||||
|
|
||||||
@ -2346,37 +2347,37 @@ class PackageBasedModel(EPModel):
|
|||||||
eval_reactions = list(
|
eval_reactions = list(
|
||||||
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
)
|
)
|
||||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
else:
|
else:
|
||||||
X = np.array(ds.X(na_replacement=np.nan))
|
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
||||||
self.eval_results = self.compute_averages([single_gen_result])
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
else:
|
else:
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
|
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
else:
|
else:
|
||||||
X = np.array(ds.X(na_replacement=np.nan))
|
X = ds.X(na_replacement=np.nan).to_numpy()
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
||||||
|
|
||||||
n_splits = 20
|
n_splits = kwargs.get("n_splits", 20)
|
||||||
|
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||||
splits = list(shuff.split(X))
|
splits = list(shuff.split(X))
|
||||||
|
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
models = Parallel(n_jobs=10)(
|
models = Parallel(n_jobs=min(10, len(splits)))(
|
||||||
delayed(train_func)(X, y, train_index, self._model_args())
|
delayed(train_func)(X, y, train_index, self._model_args())
|
||||||
for train_index, _ in splits
|
for train_index, _ in splits
|
||||||
)
|
)
|
||||||
evaluations = Parallel(n_jobs=10)(
|
evaluations = Parallel(n_jobs=min(10, len(splits)))(
|
||||||
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
||||||
for model, (_, test_index) in zip(models, splits)
|
for model, (_, test_index) in zip(models, splits)
|
||||||
)
|
)
|
||||||
@ -2588,11 +2589,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
|
|
||||||
return rbrr
|
return rbrr
|
||||||
|
|
||||||
def _fit_model(self, ds: Dataset):
|
def _fit_model(self, ds: RuleBasedDataset):
|
||||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||||
model = RelativeReasoning(
|
model = RelativeReasoning(
|
||||||
start_index=ds.triggered()[0],
|
start_index=ds.triggered()[0],
|
||||||
end_index=ds.triggered()[1],
|
end_index=ds.triggered()[-1],
|
||||||
)
|
)
|
||||||
model.fit(X, y)
|
model.fit(X, y)
|
||||||
return model
|
return model
|
||||||
@ -2602,7 +2603,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
return {
|
return {
|
||||||
"clz": "RuleBaseRelativeReasoning",
|
"clz": "RuleBaseRelativeReasoning",
|
||||||
"start_index": ds.triggered()[0],
|
"start_index": ds.triggered()[0],
|
||||||
"end_index": ds.triggered()[1],
|
"end_index": ds.triggered()[-1],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_model(self, model):
|
def _save_model(self, model):
|
||||||
@ -2690,11 +2691,11 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
|
|
||||||
return mlrr
|
return mlrr
|
||||||
|
|
||||||
def _fit_model(self, ds: Dataset):
|
def _fit_model(self, ds: RuleBasedDataset):
|
||||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||||
|
|
||||||
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
||||||
model.fit(X, y)
|
model.fit(X.to_numpy(), y.to_numpy())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _model_args(self):
|
def _model_args(self):
|
||||||
@ -2717,7 +2718,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||||
pred = self.model.predict_proba(classify_ds.X())
|
pred = self.model.predict_proba(classify_ds.X().to_numpy())
|
||||||
|
|
||||||
res = MLRelativeReasoning.combine_products_and_probs(
|
res = MLRelativeReasoning.combine_products_and_probs(
|
||||||
self.applicable_rules, pred[0], classify_prods[0]
|
self.applicable_rules, pred[0], classify_prods[0]
|
||||||
@ -2762,7 +2763,9 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def training_set_probs(self):
|
def training_set_probs(self):
|
||||||
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
ds = self.model.load_dataset()
|
||||||
|
col_ids = ds.block_indices("prob")
|
||||||
|
return ds[:, col_ids]
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
ds = self.model.load_dataset()
|
ds = self.model.load_dataset()
|
||||||
@ -2770,9 +2773,9 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
|
|
||||||
# Get Trainingset probs and dump them as they're required when using the app domain
|
# Get Trainingset probs and dump them as they're required when using the app domain
|
||||||
probs = self.model.model.predict_proba(ds.X())
|
probs = self.model.model.predict_proba(ds.X().to_numpy())
|
||||||
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
ds.add_probs(probs)
|
||||||
joblib.dump(probs, f)
|
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
|
||||||
|
|
||||||
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
||||||
ad.build(ds)
|
ad.build(ds)
|
||||||
@ -2795,16 +2798,19 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
joblib.dump(ad, f)
|
joblib.dump(ad, f)
|
||||||
|
|
||||||
def assess(self, structure: Union[str, "CompoundStructure"]):
|
def assess(self, structure: Union[str, "CompoundStructure"]):
|
||||||
|
return self.assess_batch([structure])[0]
|
||||||
|
|
||||||
|
def assess_batch(self, structures: List["CompoundStructure | str"]):
|
||||||
ds = self.model.load_dataset()
|
ds = self.model.load_dataset()
|
||||||
|
|
||||||
if isinstance(structure, CompoundStructure):
|
smiles = []
|
||||||
smiles = structure.smiles
|
for struct in structures:
|
||||||
else:
|
if isinstance(struct, CompoundStructure):
|
||||||
smiles = structure
|
smiles.append(structures.smiles)
|
||||||
|
else:
|
||||||
|
smiles.append(structures)
|
||||||
|
|
||||||
assessment_ds, assessment_prods = ds.classification_dataset(
|
assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
|
||||||
[structure], self.model.applicable_rules
|
|
||||||
)
|
|
||||||
|
|
||||||
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
||||||
# {
|
# {
|
||||||
@ -2817,82 +2823,47 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
||||||
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
||||||
# with a given assessment structure under a particular rule.
|
# with a given assessment structure under a particular rule.
|
||||||
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
|
qualified_neighbours_per_rule: Dict = {}
|
||||||
lambda: defaultdict(list)
|
|
||||||
)
|
|
||||||
|
|
||||||
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
import polars as pl
|
||||||
feature = ds.columns[feature_index]
|
# Select only the triggered columns
|
||||||
if feature.startswith("trig_"):
|
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
|
||||||
# TODO unroll loop
|
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
|
||||||
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
# trigger that rule.
|
||||||
if int(cx[feature_index]) == 1:
|
train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
|
||||||
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
for trig_uuid, value in row.items() if value == 1}
|
||||||
if int(tx[feature_index]) == 1:
|
qualified_neighbours_per_rule[i] = train_trig
|
||||||
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
|
||||||
|
|
||||||
probs = self.training_set_probs
|
|
||||||
# preds = self.model.model.predict_proba(assessment_ds.X())
|
|
||||||
preds = self.model.combine_products_and_probs(
|
preds = self.model.combine_products_and_probs(
|
||||||
self.model.applicable_rules,
|
self.model.applicable_rules,
|
||||||
self.model.model.predict_proba(assessment_ds.X())[0],
|
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
|
||||||
assessment_prods[0],
|
assessment_prods[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
assessments = list()
|
assessments = list()
|
||||||
|
|
||||||
# loop through our assessment dataset
|
# loop through our assessment dataset
|
||||||
for i, instance in enumerate(assessment_ds):
|
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
|
||||||
rule_reliabilities = dict()
|
rule_reliabilities = dict()
|
||||||
local_compatibilities = dict()
|
local_compatibilities = dict()
|
||||||
neighbours_per_rule = dict()
|
neighbours_per_rule = dict()
|
||||||
neighbor_probs_per_rule = dict()
|
neighbor_probs_per_rule = dict()
|
||||||
|
|
||||||
# loop through rule indices together with the collected neighbours indices from train dataset
|
# loop through rule indices together with the collected neighbours indices from train dataset
|
||||||
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
|
||||||
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
# compute tanimoto distance for all neighbours and add to dataset
|
||||||
# train dataset
|
dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
|
||||||
train_instances = []
|
train_instances[:, train_instances.struct_features()].to_numpy())
|
||||||
for v in vals:
|
train_instances = train_instances.with_columns(dist=pl.Series(dists))
|
||||||
train_instances.append((v, ds.at(v)))
|
|
||||||
|
|
||||||
# sf is a tuple with start/end index of the features
|
|
||||||
sf = ds.struct_features()
|
|
||||||
|
|
||||||
# compute tanimoto distance for all neighbours
|
|
||||||
# result ist a list of tuples with train index and computed distance
|
|
||||||
dists = self._compute_distances(
|
|
||||||
instance.X()[0][sf[0] : sf[1]],
|
|
||||||
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
|
|
||||||
)
|
|
||||||
|
|
||||||
dists_with_index = list()
|
|
||||||
for ti, dist in zip(train_instances, dists):
|
|
||||||
dists_with_index.append((ti[0], dist[1]))
|
|
||||||
|
|
||||||
# sort them in a descending way and take at most `self.num_neighbours`
|
# sort them in a descending way and take at most `self.num_neighbours`
|
||||||
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
# TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending)
|
||||||
dists_with_index = dists_with_index[: self.num_neighbours]
|
train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours]
|
||||||
|
|
||||||
# compute average distance
|
# compute average distance
|
||||||
rule_reliabilities[rule_idx] = (
|
rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
|
||||||
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
|
|
||||||
if len(dists_with_index) > 0
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
||||||
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances)
|
||||||
local_compatibilities[rule_idx] = self._compute_compatibility(
|
neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]))
|
||||||
rule_idx, probs, neighbour_datasets
|
neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
|
||||||
)
|
|
||||||
neighbours_per_rule[rule_idx] = [
|
|
||||||
CompoundStructure.objects.get(uuid=ds[1].structure_id())
|
|
||||||
for ds in neighbour_datasets
|
|
||||||
]
|
|
||||||
neighbor_probs_per_rule[rule_idx] = [
|
|
||||||
probs[d[0]][rule_idx] for d in dists_with_index
|
|
||||||
]
|
|
||||||
|
|
||||||
ad_res = {
|
ad_res = {
|
||||||
"ad_params": {
|
"ad_params": {
|
||||||
@ -2903,23 +2874,21 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
||||||
},
|
},
|
||||||
"assessment": {
|
"assessment": {
|
||||||
"smiles": smiles,
|
"smiles": smiles[i],
|
||||||
"inside_app_domain": self.pca.is_applicable(instance)[0],
|
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
transformations = list()
|
transformations = list()
|
||||||
for rule_idx in rule_reliabilities.keys():
|
for rule_uuid in rule_reliabilities.keys():
|
||||||
rule = Rule.objects.get(
|
rule = Rule.objects.get(uuid=rule_uuid)
|
||||||
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
rule_data = rule.simple_json()
|
rule_data = rule.simple_json()
|
||||||
rule_data["image"] = f"{rule.url}?image=svg"
|
rule_data["image"] = f"{rule.url}?image=svg"
|
||||||
|
|
||||||
neighbors = []
|
neighbors = []
|
||||||
for n, n_prob in zip(
|
for n, n_prob in zip(
|
||||||
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]
|
neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
|
||||||
):
|
):
|
||||||
neighbor = n.simple_json()
|
neighbor = n.simple_json()
|
||||||
neighbor["image"] = f"{n.url}?image=svg"
|
neighbor["image"] = f"{n.url}?image=svg"
|
||||||
@ -2936,14 +2905,14 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
|
|
||||||
transformation = {
|
transformation = {
|
||||||
"rule": rule_data,
|
"rule": rule_data,
|
||||||
"reliability": rule_reliabilities[rule_idx],
|
"reliability": rule_reliabilities[rule_uuid],
|
||||||
# We're setting it here to False, as we don't know whether "assess" is called during pathway
|
# We're setting it here to False, as we don't know whether "assess" is called during pathway
|
||||||
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
|
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
|
||||||
"is_predicted": False,
|
"is_predicted": False,
|
||||||
"local_compatibility": local_compatibilities[rule_idx],
|
"local_compatibility": local_compatibilities[rule_uuid],
|
||||||
"probability": preds[rule_idx].probability,
|
"probability": preds[rule_to_i[rule_uuid]].probability,
|
||||||
"transformation_products": [
|
"transformation_products": [
|
||||||
x.product_set for x in preds[rule_idx].product_sets
|
x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
|
||||||
],
|
],
|
||||||
"times_triggered": ds.times_triggered(str(rule.uuid)),
|
"times_triggered": ds.times_triggered(str(rule.uuid)),
|
||||||
"neighbors": neighbors,
|
"neighbors": neighbors,
|
||||||
@ -2961,32 +2930,21 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
||||||
from utilities.ml import tanimoto_distance
|
from utilities.ml import tanimoto_distance
|
||||||
|
|
||||||
distances = [
|
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
|
||||||
(i, tanimoto_distance(classify_instance, train))
|
|
||||||
for i, train in enumerate(train_instances)
|
|
||||||
]
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@staticmethod
|
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
|
||||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
|
|
||||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
|
||||||
accuracy = 0.0
|
accuracy = 0.0
|
||||||
|
import polars as pl
|
||||||
for n in neighbours:
|
obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
|
||||||
obs = n[1].y()[0][rule_idx]
|
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
|
||||||
pred = preds[n[0]][rule_idx]
|
# Compute tp, tn, fp, fn using polars expressions
|
||||||
if obs and pred:
|
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
|
||||||
tp += 1
|
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
|
||||||
elif not obs and pred:
|
fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
|
||||||
fp += 1
|
fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
|
||||||
elif obs and not pred:
|
if tp + tn > 0.0:
|
||||||
fn += 1
|
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||||
else:
|
|
||||||
tn += 1
|
|
||||||
# Jaccard Index
|
|
||||||
if tp + tn > 0.0:
|
|
||||||
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
|
||||||
|
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
@ -3087,44 +3045,24 @@ class EnviFormer(PackageBasedModel):
|
|||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
||||||
co2 = {"C(=O)=O", "O=C=O"}
|
|
||||||
ds = []
|
|
||||||
for reaction in self._get_reactions():
|
|
||||||
educts = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.educts.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
products = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.products.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if products not in co2:
|
|
||||||
ds.append(f"{educts}>>{products}")
|
|
||||||
|
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||||
with open(f, "w") as d_file:
|
ds.save(f)
|
||||||
json.dump(ds, d_file)
|
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "Dataset":
|
def load_dataset(self):
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||||
with open(ds_path) as d_file:
|
return EnviFormerDataset.load(ds_path)
|
||||||
ds = json.load(d_file)
|
|
||||||
return ds
|
|
||||||
|
|
||||||
def _fit_model(self, ds):
|
def _fit_model(self, ds):
|
||||||
# Call to enviFormer's fine_tune function and return the model
|
# Call to enviFormer's fine_tune function and return the model
|
||||||
from enviformer.finetune import fine_tune
|
from enviformer.finetune import fine_tune
|
||||||
|
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
||||||
return model
|
return model
|
||||||
@ -3140,7 +3078,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
args = {"clz": "EnviFormer"}
|
args = {"clz": "EnviFormer"}
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None):
|
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||||
|
|
||||||
@ -3155,21 +3093,20 @@ class EnviFormer(PackageBasedModel):
|
|||||||
self.model_status = self.EVALUATING
|
self.model_status = self.EVALUATING
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
def evaluate_sg(test_ds, predictions, model_thresh):
|
||||||
# Group the true products of reactions with the same reactant together
|
# Group the true products of reactions with the same reactant together
|
||||||
|
assert len(test_ds) == len(predictions)
|
||||||
true_dict = {}
|
true_dict = {}
|
||||||
for r in test_reactions:
|
for r in test_ds:
|
||||||
reactant, true_product_set = r.split(">>")
|
reactant, true_product_set = r
|
||||||
true_product_set = {p for p in true_product_set.split(".")}
|
true_product_set = {p for p in true_product_set.split(".")}
|
||||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||||
assert len(test_reactions) == len(predictions)
|
|
||||||
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
|
|
||||||
|
|
||||||
# Group the predicted products of reactions with the same reactant together
|
# Group the predicted products of reactions with the same reactant together
|
||||||
pred_dict = {}
|
pred_dict = {}
|
||||||
for k, pred in enumerate(predictions):
|
for k, pred in enumerate(predictions):
|
||||||
pred_smiles, pred_proba = zip(*pred.items())
|
pred_smiles, pred_proba = zip(*pred.items())
|
||||||
reactant, true_product = test_reactions[k].split(">>")
|
reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
|
||||||
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
||||||
for smiles, proba in zip(pred_smiles, pred_proba):
|
for smiles, proba in zip(pred_smiles, pred_proba):
|
||||||
smiles = set(smiles.split("."))
|
smiles = set(smiles.split("."))
|
||||||
@ -3204,7 +3141,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
||||||
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
|
||||||
# Precision is TP (correct) / TP + FP (predicted)
|
# Precision is TP (correct) / TP + FP (predicted)
|
||||||
prec = {
|
prec = {
|
||||||
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
||||||
@ -3283,47 +3220,32 @@ class EnviFormer(PackageBasedModel):
|
|||||||
|
|
||||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||||
if self.eval_packages.count() > 0:
|
if self.eval_packages.count() > 0:
|
||||||
ds = []
|
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||||
for reaction in Reaction.objects.filter(
|
package__in=self.eval_packages.all()).distinct())
|
||||||
package__in=self.eval_packages.all()
|
test_result = self.model.predict_batch(ds.X())
|
||||||
).distinct():
|
|
||||||
educts = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.educts.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
products = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.products.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
ds.append(f"{educts}>>{products}")
|
|
||||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
|
||||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||||
self.eval_results = self.compute_averages([single_gen_result])
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
else:
|
else:
|
||||||
from enviformer.finetune import fine_tune
|
from enviformer.finetune import fine_tune
|
||||||
|
|
||||||
ds = self.load_dataset()
|
ds = self.load_dataset()
|
||||||
n_splits = 20
|
n_splits = kwargs.get("n_splits", 20)
|
||||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
||||||
|
|
||||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||||
# this helps reduce the memory footprint.
|
# this helps reduce the memory footprint.
|
||||||
single_gen_results = []
|
single_gen_results = []
|
||||||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||||||
train = [ds[i] for i in train_index]
|
train = ds[train_index]
|
||||||
test = [ds[i] for i in test_index]
|
test = ds[test_index]
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
||||||
)
|
)
|
||||||
model.to(s.ENVIFORMER_DEVICE)
|
model.to(s.ENVIFORMER_DEVICE)
|
||||||
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
test_result = model.predict_batch(test.X())
|
||||||
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
||||||
|
|
||||||
self.eval_results = self.compute_averages(single_gen_results)
|
self.eval_results = self.compute_averages(single_gen_results)
|
||||||
@ -3394,31 +3316,15 @@ class EnviFormer(PackageBasedModel):
|
|||||||
for pathway in train_pathways:
|
for pathway in train_pathways:
|
||||||
for reaction in pathway.edges:
|
for reaction in pathway.edges:
|
||||||
reaction = reaction.edge_label
|
reaction = reaction.edge_label
|
||||||
if any(
|
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
|
||||||
[
|
|
||||||
educt in test_educts
|
|
||||||
for educt in reaction_to_educts[str(reaction.uuid)]
|
|
||||||
]
|
|
||||||
):
|
|
||||||
overlap += 1
|
overlap += 1
|
||||||
continue
|
continue
|
||||||
educts = ".".join(
|
train_reactions.append(reaction)
|
||||||
[
|
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.educts.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
products = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.products.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
train_reactions.append(f"{educts}>>{products}")
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
||||||
)
|
)
|
||||||
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
|
||||||
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
||||||
|
|
||||||
self.eval_results.update(
|
self.eval_results.update(
|
||||||
|
|||||||
@ -894,7 +894,7 @@ def package_model(request, package_uuid, model_uuid):
|
|||||||
return JsonResponse(res, safe=False)
|
return JsonResponse(res, safe=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0]
|
app_domain_assessment = current_model.app_domain.assess(stand_smiles)
|
||||||
return JsonResponse(app_domain_assessment, safe=False)
|
return JsonResponse(app_domain_assessment, safe=False)
|
||||||
|
|
||||||
context = get_base_context(request)
|
context = get_base_context(request)
|
||||||
|
|||||||
@ -27,11 +27,12 @@ dependencies = [
|
|||||||
"scikit-learn>=1.6.1",
|
"scikit-learn>=1.6.1",
|
||||||
"sentry-sdk[django]>=2.32.0",
|
"sentry-sdk[django]>=2.32.0",
|
||||||
"setuptools>=80.8.0",
|
"setuptools>=80.8.0",
|
||||||
"nh3==0.3.2"
|
"nh3==0.3.2",
|
||||||
|
"polars==1.35.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" }
|
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" }
|
||||||
envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" }
|
envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" }
|
||||||
envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"}
|
envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"}
|
||||||
envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }
|
envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
import os.path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import Reaction, Compound, User, Rule
|
from epdb.models import Reaction, Compound, User, Rule, Package
|
||||||
from utilities.ml import Dataset
|
from utilities.chem import FormatConverter
|
||||||
|
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
||||||
|
|
||||||
|
|
||||||
class DatasetTest(TestCase):
|
class DatasetTest(TestCase):
|
||||||
@ -41,12 +43,108 @@ class DatasetTest(TestCase):
|
|||||||
super(DatasetTest, cls).setUpClass()
|
super(DatasetTest, cls).setUpClass()
|
||||||
cls.user = User.objects.get(username="anonymous")
|
cls.user = User.objects.get(username="anonymous")
|
||||||
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
||||||
|
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
||||||
|
|
||||||
def test_smoke(self):
|
def test_generate_dataset(self):
|
||||||
|
"""Test generating dataset does not crash"""
|
||||||
|
self.generate_rule_dataset()
|
||||||
|
|
||||||
|
def test_indexing(self):
|
||||||
|
"""Test indexing a few different ways to check for crashes"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds[5])
|
||||||
|
print(ds[2, 5])
|
||||||
|
print(ds[3:6, 2:8])
|
||||||
|
print(ds[:2, "structure_id"])
|
||||||
|
|
||||||
|
def test_add_rows(self):
|
||||||
|
"""Test adding one row and adding multiple rows"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
ds.add_row(list(ds.df.row(1)))
|
||||||
|
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
|
||||||
|
|
||||||
|
def test_times_triggered(self):
|
||||||
|
"""Check getting times triggered for a rule id"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds.times_triggered(rules[0].uuid))
|
||||||
|
|
||||||
|
def test_block_indices(self):
|
||||||
|
"""Test the usages of _block_indices"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds.struct_features())
|
||||||
|
print(ds.triggered())
|
||||||
|
print(ds.observed())
|
||||||
|
|
||||||
|
def test_structure_id(self):
|
||||||
|
"""Check getting a structure id from row index"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds.structure_id(0))
|
||||||
|
|
||||||
|
def test_x(self):
|
||||||
|
"""Test getting X portion of the dataframe"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds.X().df.head())
|
||||||
|
|
||||||
|
def test_trig(self):
|
||||||
|
"""Test getting the triggered portion of the dataframe"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds.trig().df.head())
|
||||||
|
|
||||||
|
def test_y(self):
|
||||||
|
"""Test getting the Y portion of the dataframe"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
print(ds.y().df.head())
|
||||||
|
|
||||||
|
def test_classification_dataset(self):
|
||||||
|
"""Test making the classification dataset"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
class_ds, products = ds.classification_dataset(compounds, rules)
|
||||||
|
print(class_ds.df.head(5))
|
||||||
|
print(products[:5])
|
||||||
|
|
||||||
|
def test_extra_features(self):
|
||||||
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
|
||||||
|
print(ds.shape)
|
||||||
|
|
||||||
|
def test_to_arff(self):
|
||||||
|
"""Test exporting the arff version of the dataset"""
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
ds.to_arff("dataset_arff_test.arff")
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
"""Test saving and loading dataset"""
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||||
|
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
|
||||||
|
self.assertTrue(ds.df.equals(ds_loaded.df))
|
||||||
|
|
||||||
|
def test_dataset_example(self):
|
||||||
|
"""Test with a concrete example checking dataset size"""
|
||||||
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
||||||
applicable_rules = [self.rule1]
|
applicable_rules = [self.rule1]
|
||||||
|
|
||||||
ds = Dataset.generate_dataset(reactions, applicable_rules)
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||||
|
|
||||||
self.assertEqual(len(ds.y()), 1)
|
self.assertEqual(len(ds.y()), 1)
|
||||||
self.assertEqual(sum(ds.y()[0]), 1)
|
self.assertEqual(ds.y().df.item(), 1)
|
||||||
|
|
||||||
|
def test_enviformer_dataset(self):
|
||||||
|
ds, reactions = self.generate_enviformer_dataset()
|
||||||
|
print(ds.X().head())
|
||||||
|
print(ds.y().head())
|
||||||
|
|
||||||
|
def generate_rule_dataset(self):
|
||||||
|
"""Generate a RuleBasedDataset from test package data"""
|
||||||
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||||
|
return ds, reactions, applicable_rules
|
||||||
|
|
||||||
|
def generate_enviformer_dataset(self):
|
||||||
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
ds = EnviFormerDataset.generate_dataset(reactions)
|
||||||
|
return ds, reactions
|
||||||
|
|||||||
@ -42,13 +42,11 @@ class EnviFormerTest(TestCase):
|
|||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = [self.BBD_SUBSET]
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
mod = EnviFormer.create(
|
mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
|
||||||
self.package, data_package_objs, eval_packages_objs, threshold=threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
mod.build_model()
|
mod.build_model()
|
||||||
mod.evaluate_model(True, eval_packages_objs)
|
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||||
|
|
||||||
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
@ -57,12 +55,9 @@ class EnviFormerTest(TestCase):
|
|||||||
with self.settings(MODEL_DIR=tmpdir):
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
data_package_objs = [self.BBD_SUBSET]
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
eval_packages_objs = [self.BBD_SUBSET]
|
|
||||||
mods = []
|
mods = []
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
mod = EnviFormer.create(
|
mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
|
||||||
self.package, data_package_objs, eval_packages_objs, threshold=threshold
|
|
||||||
)
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
mod.build_model()
|
mod.build_model()
|
||||||
mods.append(mod)
|
mods.append(mod)
|
||||||
@ -73,15 +68,11 @@ class EnviFormerTest(TestCase):
|
|||||||
|
|
||||||
# Test pathway prediction
|
# Test pathway prediction
|
||||||
times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)]
|
times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)]
|
||||||
print(
|
print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}")
|
||||||
f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test eviction by performing three prediction with every model, twice.
|
# Test eviction by performing three prediction with every model, twice.
|
||||||
times = defaultdict(list)
|
times = defaultdict(list)
|
||||||
for _ in range(
|
for _ in range(2): # Eviction should cause the second iteration here to have to reload the models
|
||||||
2
|
|
||||||
): # Eviction should cause the second iteration here to have to reload the models
|
|
||||||
for mod in mods:
|
for mod in mods:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
times[mod.pk].append(measure_predict(mod))
|
times[mod.pk].append(measure_predict(mod))
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import User, MLRelativeReasoning, Package
|
from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
|
||||||
|
|
||||||
|
|
||||||
class ModelTest(TestCase):
|
class ModelTest(TestCase):
|
||||||
@ -17,7 +17,7 @@ class ModelTest(TestCase):
|
|||||||
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
|
||||||
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
|
||||||
|
|
||||||
def test_smoke(self):
|
def test_mlrr(self):
|
||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
with self.settings(MODEL_DIR=tmpdir):
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
threshold = float(0.5)
|
threshold = float(0.5)
|
||||||
@ -35,21 +35,9 @@ class ModelTest(TestCase):
|
|||||||
description="Created MLRelativeReasoning in Testcase",
|
description="Created MLRelativeReasoning in Testcase",
|
||||||
)
|
)
|
||||||
|
|
||||||
# mod = RuleBasedRelativeReasoning.create(
|
|
||||||
# self.package,
|
|
||||||
# rule_package_objs,
|
|
||||||
# data_package_objs,
|
|
||||||
# eval_packages_objs,
|
|
||||||
# threshold=threshold,
|
|
||||||
# min_count=5,
|
|
||||||
# max_count=0,
|
|
||||||
# name='ECC - BBD - 0.5',
|
|
||||||
# description='Created MLRelativeReasoning in Testcase',
|
|
||||||
# )
|
|
||||||
|
|
||||||
mod.build_dataset()
|
mod.build_dataset()
|
||||||
mod.build_model()
|
mod.build_model()
|
||||||
mod.evaluate_model(True, eval_packages_objs)
|
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||||
|
|
||||||
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
@ -70,3 +58,57 @@ class ModelTest(TestCase):
|
|||||||
|
|
||||||
# from pprint import pprint
|
# from pprint import pprint
|
||||||
# pprint(mod.eval_results)
|
# pprint(mod.eval_results)
|
||||||
|
|
||||||
|
def test_applicability(self):
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
|
threshold = float(0.5)
|
||||||
|
|
||||||
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
|
||||||
|
mod = MLRelativeReasoning.create(
|
||||||
|
self.package,
|
||||||
|
rule_package_objs,
|
||||||
|
data_package_objs,
|
||||||
|
threshold=threshold,
|
||||||
|
name="ECC - BBD - 0.5",
|
||||||
|
description="Created MLRelativeReasoning in Testcase",
|
||||||
|
build_app_domain=True, # To test the applicability domain this must be True
|
||||||
|
app_domain_num_neighbours=5,
|
||||||
|
app_domain_local_compatibility_threshold=0.5,
|
||||||
|
app_domain_reliability_threshold=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
mod.build_dataset()
|
||||||
|
mod.build_model()
|
||||||
|
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||||
|
|
||||||
|
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|
||||||
|
def test_rbrr(self):
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
with self.settings(MODEL_DIR=tmpdir):
|
||||||
|
threshold = float(0.5)
|
||||||
|
|
||||||
|
rule_package_objs = [self.BBD_SUBSET]
|
||||||
|
data_package_objs = [self.BBD_SUBSET]
|
||||||
|
eval_packages_objs = [self.BBD_SUBSET]
|
||||||
|
|
||||||
|
mod = RuleBasedRelativeReasoning.create(
|
||||||
|
self.package,
|
||||||
|
rule_package_objs,
|
||||||
|
data_package_objs,
|
||||||
|
threshold=threshold,
|
||||||
|
min_count=5,
|
||||||
|
max_count=0,
|
||||||
|
name='ECC - BBD - 0.5',
|
||||||
|
description='Created MLRelativeReasoning in Testcase',
|
||||||
|
)
|
||||||
|
|
||||||
|
mod.build_dataset()
|
||||||
|
mod.build_model()
|
||||||
|
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
|
||||||
|
|
||||||
|
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING
|
|||||||
from indigo import Indigo, IndigoException, IndigoObject
|
from indigo import Indigo, IndigoException, IndigoObject
|
||||||
from indigo.renderer import IndigoRenderer
|
from indigo.renderer import IndigoRenderer
|
||||||
from rdkit import Chem, rdBase
|
from rdkit import Chem, rdBase
|
||||||
from rdkit.Chem import MACCSkeys, Descriptors
|
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
|
||||||
from rdkit.Chem import rdChemReactions
|
from rdkit.Chem import rdChemReactions
|
||||||
from rdkit.Chem.Draw import rdMolDraw2D
|
from rdkit.Chem.Draw import rdMolDraw2D
|
||||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||||
@ -107,6 +107,13 @@ class FormatConverter(object):
|
|||||||
bitvec = MACCSkeys.GenMACCSKeys(mol)
|
bitvec = MACCSkeys.GenMACCSKeys(mol)
|
||||||
return bitvec.ToList()
|
return bitvec.ToList()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def morgan(smiles, radius=3, fpSize=2048):
|
||||||
|
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
fp = finger_gen.GetFingerprint(mol)
|
||||||
|
return fp.ToList()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_functional_groups(smiles: str) -> List[str]:
|
def get_functional_groups(smiles: str) -> List[str]:
|
||||||
res = list()
|
res = list()
|
||||||
|
|||||||
550
utilities/ml.py
550
utilities/ml.py
@ -5,11 +5,14 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from envipy_plugins import Descriptor
|
||||||
from numpy.random import default_rng
|
from numpy.random import default_rng
|
||||||
|
import polars as pl
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
from sklearn.dummy import DummyClassifier
|
from sklearn.dummy import DummyClassifier
|
||||||
@ -26,70 +29,281 @@ if TYPE_CHECKING:
|
|||||||
from epdb.models import Rule, CompoundStructure, Reaction
|
from epdb.models import Rule, CompoundStructure, Reaction
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset(ABC):
|
||||||
def __init__(
|
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
||||||
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
|
||||||
):
|
self.df = data
|
||||||
self.columns: List[str] = columns
|
|
||||||
self.num_labels: int = num_labels
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
self.data: List[List[str | int | float]] = list()
|
|
||||||
else:
|
else:
|
||||||
self.data = data
|
# Build either an empty dataframe with columns or fill it with list of list data
|
||||||
|
if data is not None and len(columns) != len(data[0]):
|
||||||
|
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
|
||||||
|
if columns is None:
|
||||||
|
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
||||||
|
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
|
||||||
|
|
||||||
self.num_features: int = len(columns) - self.num_labels
|
def add_rows(self, rows: List[List[str | int | float]]):
|
||||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
|
||||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
if len(self.columns) != len(rows[0]):
|
||||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
|
||||||
|
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
|
||||||
|
self.df.extend(new_rows)
|
||||||
|
|
||||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
def add_row(self, row: List[str | int | float]):
|
||||||
|
"""See add_rows"""
|
||||||
|
self.add_rows([row])
|
||||||
|
|
||||||
|
def block_indices(self, prefix) -> List[int]:
|
||||||
|
"""Find the indexes in column labels that has the prefix"""
|
||||||
indices: List[int] = []
|
indices: List[int] = []
|
||||||
for i, feature in enumerate(self.columns):
|
for i, feature in enumerate(self.columns):
|
||||||
if feature.startswith(prefix):
|
if feature.startswith(prefix):
|
||||||
indices.append(i)
|
indices.append(i)
|
||||||
|
return indices
|
||||||
|
|
||||||
return min(indices), max(indices)
|
@property
|
||||||
|
def columns(self) -> List[str]:
|
||||||
|
"""Use the polars dataframe columns"""
|
||||||
|
return self.df.columns
|
||||||
|
|
||||||
def structure_id(self):
|
@property
|
||||||
return self.data[0][0]
|
def shape(self):
|
||||||
|
return self.df.shape
|
||||||
|
|
||||||
def add_row(self, row: List[str | int | float]):
|
@abstractmethod
|
||||||
if len(self.columns) != len(row):
|
def X(self, **kwargs):
|
||||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}")
|
pass
|
||||||
self.data.append(row)
|
|
||||||
|
|
||||||
def times_triggered(self, rule_uuid) -> int:
|
@abstractmethod
|
||||||
idx = self.columns.index(f"trig_{rule_uuid}")
|
def y(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
times_triggered = 0
|
@staticmethod
|
||||||
for row in self.data:
|
@abstractmethod
|
||||||
if row[idx] == 1:
|
def generate_dataset(reactions, *args, **kwargs):
|
||||||
times_triggered += 1
|
pass
|
||||||
|
|
||||||
return times_triggered
|
|
||||||
|
|
||||||
def struct_features(self) -> Tuple[int, int]:
|
|
||||||
return self._struct_features
|
|
||||||
|
|
||||||
def triggered(self) -> Tuple[int, int]:
|
|
||||||
return self._triggered
|
|
||||||
|
|
||||||
def observed(self) -> Tuple[int, int]:
|
|
||||||
return self._observed
|
|
||||||
|
|
||||||
def at(self, position: int) -> Dataset:
|
|
||||||
return Dataset(self.columns, self.num_labels, [self.data[position]])
|
|
||||||
|
|
||||||
def limit(self, limit: int) -> Dataset:
|
|
||||||
return Dataset(self.columns, self.num_labels, self.data[:limit])
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return (self.at(i) for i, _ in enumerate(self.data))
|
"""Use polars iter_rows for iterating over the dataset"""
|
||||||
|
return self.df.iter_rows()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
"""Item is passed to polars allowing for advanced indexing.
|
||||||
|
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
||||||
|
res = self.df[item]
|
||||||
|
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
||||||
|
return self.__class__(data=res)
|
||||||
|
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
||||||
|
return res
|
||||||
|
|
||||||
|
def save(self, path: "Path | str"):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(path, "wb") as fh:
|
||||||
|
pickle.dump(self, fh)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: "str | Path") -> "Dataset":
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
return pickle.load(open(path, "rb"))
|
||||||
|
|
||||||
|
def to_numpy(self):
|
||||||
|
return self.df.to_numpy()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.df)
|
||||||
|
|
||||||
|
def iter_rows(self, named=False):
|
||||||
|
return self.df.iter_rows(named=named)
|
||||||
|
|
||||||
|
def filter(self, *predicates, **constraints):
|
||||||
|
return self.__class__(data=self.df.filter(*predicates, **constraints))
|
||||||
|
|
||||||
|
def select(self, *exprs, **named_exprs):
|
||||||
|
return self.__class__(data=self.df.select(*exprs, **named_exprs))
|
||||||
|
|
||||||
|
def with_columns(self, *exprs, **name_exprs):
|
||||||
|
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
|
||||||
|
|
||||||
|
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
|
||||||
|
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
|
||||||
|
multithreaded=multithreaded, maintain_order=maintain_order))
|
||||||
|
|
||||||
|
def item(self, row=None, column=None):
|
||||||
|
return self.df.item(row, column)
|
||||||
|
|
||||||
|
def fill_nan(self, value):
|
||||||
|
return self.__class__(data=self.df.fill_nan(value))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self.df.height
|
||||||
|
|
||||||
|
|
||||||
|
class RuleBasedDataset(Dataset):
|
||||||
|
def __init__(self, num_labels=None, columns=None, data=None):
|
||||||
|
super().__init__(columns, data)
|
||||||
|
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||||
|
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||||
|
# Pre-calculate the ids of columns for features/labels, useful later in X and y
|
||||||
|
self._struct_features: List[int] = self.block_indices("feature_")
|
||||||
|
self._triggered: List[int] = self.block_indices("trig_")
|
||||||
|
self._observed: List[int] = self.block_indices("obs_")
|
||||||
|
self.feature_cols: List[int] = self._struct_features + self._triggered
|
||||||
|
self.num_features: int = len(self.feature_cols)
|
||||||
|
self.has_probs = False
|
||||||
|
|
||||||
|
def times_triggered(self, rule_uuid) -> int:
|
||||||
|
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
|
||||||
|
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||||
|
|
||||||
|
def struct_features(self) -> List[int]:
|
||||||
|
return self._struct_features
|
||||||
|
|
||||||
|
def triggered(self) -> List[int]:
|
||||||
|
return self._triggered
|
||||||
|
|
||||||
|
def observed(self) -> List[int]:
|
||||||
|
return self._observed
|
||||||
|
|
||||||
|
def structure_id(self, index: int):
|
||||||
|
"""Get the UUID of a compound"""
|
||||||
|
return self.item(index, "structure_id")
|
||||||
|
|
||||||
|
def X(self, exclude_id_col=True, na_replacement=0):
|
||||||
|
"""Get all the feature and trig columns"""
|
||||||
|
_col_ids = self.feature_cols
|
||||||
|
if not exclude_id_col:
|
||||||
|
_col_ids = [0] + _col_ids
|
||||||
|
res = self[:, _col_ids]
|
||||||
|
if na_replacement is not None:
|
||||||
|
res.df = res.df.fill_null(na_replacement)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def trig(self, na_replacement=0):
|
||||||
|
"""Get all the trig columns"""
|
||||||
|
res = self[:, self._triggered]
|
||||||
|
if na_replacement is not None:
|
||||||
|
res.df = res.df.fill_null(na_replacement)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def y(self, na_replacement=0):
|
||||||
|
"""Get all the obs columns"""
|
||||||
|
res = self[:, self._observed]
|
||||||
|
if na_replacement is not None:
|
||||||
|
res.df = res.df.fill_null(na_replacement)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
|
||||||
|
if feat_funcs is None:
|
||||||
|
feat_funcs = [FormatConverter.maccs]
|
||||||
|
_structures = set() # Get all the structures
|
||||||
|
for r in reactions:
|
||||||
|
_structures.update(r.educts.all())
|
||||||
|
if not educts_only:
|
||||||
|
_structures.update(r.products.all())
|
||||||
|
|
||||||
|
compounds = sorted(_structures, key=lambda x: x.url)
|
||||||
|
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||||
|
observed: Set[str] = set()
|
||||||
|
|
||||||
|
# Apply rules on collected compounds and store tps
|
||||||
|
for i, comp in enumerate(compounds):
|
||||||
|
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||||
|
|
||||||
|
for rule in applicable_rules:
|
||||||
|
product_sets = rule.apply(comp.smiles)
|
||||||
|
if len(product_sets) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = f"{rule.uuid} + {comp.uuid}"
|
||||||
|
if key in triggered:
|
||||||
|
logger.info(f"{key} already present. Duplicate reaction?")
|
||||||
|
|
||||||
|
for prod_set in product_sets:
|
||||||
|
for smi in prod_set:
|
||||||
|
try:
|
||||||
|
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||||
|
triggered[key].add(smi)
|
||||||
|
|
||||||
|
for i, r in enumerate(reactions):
|
||||||
|
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||||
|
|
||||||
|
if len(r.educts.all()) != 1:
|
||||||
|
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for comp in r.educts.all():
|
||||||
|
for rule in applicable_rules:
|
||||||
|
key = f"{rule.uuid} + {comp.uuid}"
|
||||||
|
if key not in triggered:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# standardize products from reactions for comparison
|
||||||
|
standardized_products = []
|
||||||
|
for cs in r.products.all():
|
||||||
|
smi = cs.smiles
|
||||||
|
try:
|
||||||
|
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||||
|
standardized_products.append(smi)
|
||||||
|
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||||
|
observed.add(key)
|
||||||
|
feat_columns = []
|
||||||
|
for feat_func in feat_funcs:
|
||||||
|
if isinstance(feat_func, Descriptor):
|
||||||
|
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
|
||||||
|
else:
|
||||||
|
feats = feat_func(compounds[0].smiles)
|
||||||
|
start_i = len(feat_columns)
|
||||||
|
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
|
||||||
|
ds_columns = (["structure_id"] +
|
||||||
|
feat_columns +
|
||||||
|
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||||
|
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for i, comp in enumerate(compounds):
|
||||||
|
# Features
|
||||||
|
feats = []
|
||||||
|
for feat_func in feat_funcs:
|
||||||
|
if isinstance(feat_func, Descriptor):
|
||||||
|
feat = feat_func.get_molecule_descriptors(comp.smiles)
|
||||||
|
else:
|
||||||
|
feat = feat_func(comp.smiles)
|
||||||
|
feats.extend(feat)
|
||||||
|
trig = []
|
||||||
|
obs = []
|
||||||
|
for rule in applicable_rules:
|
||||||
|
key = f"{rule.uuid} + {comp.uuid}"
|
||||||
|
# Check triggered
|
||||||
|
if key in triggered:
|
||||||
|
trig.append(1)
|
||||||
|
else:
|
||||||
|
trig.append(0)
|
||||||
|
# Check obs
|
||||||
|
if key in observed:
|
||||||
|
obs.append(1)
|
||||||
|
elif key not in triggered:
|
||||||
|
obs.append(None)
|
||||||
|
else:
|
||||||
|
obs.append(0)
|
||||||
|
rows.append([str(comp.uuid)] + feats + trig + obs)
|
||||||
|
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
|
||||||
|
return ds
|
||||||
|
|
||||||
def classification_dataset(
|
def classification_dataset(
|
||||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||||
) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||||
classify_data = []
|
classify_data = []
|
||||||
classify_products = []
|
classify_products = []
|
||||||
for struct in structures:
|
for struct in structures:
|
||||||
@ -113,186 +327,18 @@ class Dataset:
|
|||||||
else:
|
else:
|
||||||
trig.append(0)
|
trig.append(0)
|
||||||
prods.append([])
|
prods.append([])
|
||||||
|
new_row = [struct_id] + features + trig + ([-1] * len(trig))
|
||||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
if self.has_probs:
|
||||||
|
new_row += [-1] * len(trig)
|
||||||
|
classify_data.append(new_row)
|
||||||
classify_products.append(prods)
|
classify_products.append(prods)
|
||||||
|
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
|
||||||
|
return ds, classify_products
|
||||||
|
|
||||||
return Dataset(
|
def add_probs(self, probs):
|
||||||
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
|
||||||
), classify_products
|
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
|
||||||
|
self.has_probs = True
|
||||||
@staticmethod
|
|
||||||
def generate_dataset(
|
|
||||||
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
|
||||||
) -> Dataset:
|
|
||||||
_structures = set()
|
|
||||||
|
|
||||||
for r in reactions:
|
|
||||||
for e in r.educts.all():
|
|
||||||
_structures.add(e)
|
|
||||||
|
|
||||||
if not educts_only:
|
|
||||||
for e in r.products:
|
|
||||||
_structures.add(e)
|
|
||||||
|
|
||||||
compounds = sorted(_structures, key=lambda x: x.url)
|
|
||||||
|
|
||||||
triggered: Dict[str, Set[str]] = defaultdict(set)
|
|
||||||
observed: Set[str] = set()
|
|
||||||
|
|
||||||
# Apply rules on collected compounds and store tps
|
|
||||||
for i, comp in enumerate(compounds):
|
|
||||||
logger.debug(f"{i + 1}/{len(compounds)}...")
|
|
||||||
|
|
||||||
for rule in applicable_rules:
|
|
||||||
product_sets = rule.apply(comp.smiles)
|
|
||||||
|
|
||||||
if len(product_sets) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
key = f"{rule.uuid} + {comp.uuid}"
|
|
||||||
|
|
||||||
if key in triggered:
|
|
||||||
logger.info(f"{key} already present. Duplicate reaction?")
|
|
||||||
|
|
||||||
for prod_set in product_sets:
|
|
||||||
for smi in prod_set:
|
|
||||||
try:
|
|
||||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
|
||||||
except Exception:
|
|
||||||
# :shrug:
|
|
||||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
triggered[key].add(smi)
|
|
||||||
|
|
||||||
for i, r in enumerate(reactions):
|
|
||||||
logger.debug(f"{i + 1}/{len(reactions)}...")
|
|
||||||
|
|
||||||
if len(r.educts.all()) != 1:
|
|
||||||
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for comp in r.educts.all():
|
|
||||||
for rule in applicable_rules:
|
|
||||||
key = f"{rule.uuid} + {comp.uuid}"
|
|
||||||
|
|
||||||
if key not in triggered:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# standardize products from reactions for comparison
|
|
||||||
standardized_products = []
|
|
||||||
for cs in r.products.all():
|
|
||||||
smi = cs.smiles
|
|
||||||
|
|
||||||
try:
|
|
||||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
|
||||||
except Exception as e:
|
|
||||||
# :shrug:
|
|
||||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
standardized_products.append(smi)
|
|
||||||
|
|
||||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
|
||||||
observed.add(key)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
ds = None
|
|
||||||
|
|
||||||
for i, comp in enumerate(compounds):
|
|
||||||
# Features
|
|
||||||
feat = FormatConverter.maccs(comp.smiles)
|
|
||||||
trig = []
|
|
||||||
obs = []
|
|
||||||
|
|
||||||
for rule in applicable_rules:
|
|
||||||
key = f"{rule.uuid} + {comp.uuid}"
|
|
||||||
|
|
||||||
# Check triggered
|
|
||||||
if key in triggered:
|
|
||||||
trig.append(1)
|
|
||||||
else:
|
|
||||||
trig.append(0)
|
|
||||||
|
|
||||||
# Check obs
|
|
||||||
if key in observed:
|
|
||||||
obs.append(1)
|
|
||||||
elif key not in triggered:
|
|
||||||
obs.append(None)
|
|
||||||
else:
|
|
||||||
obs.append(0)
|
|
||||||
|
|
||||||
if ds is None:
|
|
||||||
header = (
|
|
||||||
["structure_id"]
|
|
||||||
+ [f"feature_{i}" for i, _ in enumerate(feat)]
|
|
||||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
|
||||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
|
||||||
)
|
|
||||||
ds = Dataset(header, len(applicable_rules))
|
|
||||||
|
|
||||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
|
||||||
|
|
||||||
return ds
|
|
||||||
|
|
||||||
def X(self, exclude_id_col=True, na_replacement=0):
|
|
||||||
res = self.__getitem__(
|
|
||||||
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
|
|
||||||
)
|
|
||||||
if na_replacement is not None:
|
|
||||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
|
||||||
return res
|
|
||||||
|
|
||||||
def trig(self, na_replacement=0):
|
|
||||||
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
|
|
||||||
if na_replacement is not None:
|
|
||||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
|
||||||
return res
|
|
||||||
|
|
||||||
def y(self, na_replacement=0):
|
|
||||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
|
||||||
if na_replacement is not None:
|
|
||||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
|
||||||
return res
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
if not isinstance(key, tuple):
|
|
||||||
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
|
|
||||||
|
|
||||||
row_key, col_key = key
|
|
||||||
|
|
||||||
# Normalize rows
|
|
||||||
if isinstance(row_key, int):
|
|
||||||
rows = [self.data[row_key]]
|
|
||||||
else:
|
|
||||||
rows = self.data[row_key]
|
|
||||||
|
|
||||||
# Normalize columns
|
|
||||||
if isinstance(col_key, int):
|
|
||||||
res = [row[col_key] for row in rows]
|
|
||||||
else:
|
|
||||||
res = [
|
|
||||||
[row[i] for i in range(*col_key.indices(len(row)))]
|
|
||||||
if isinstance(col_key, slice)
|
|
||||||
else [row[i] for i in col_key]
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def save(self, path: "Path"):
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
with open(path, "wb") as fh:
|
|
||||||
pickle.dump(self, fh)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(path: "Path") -> "Dataset":
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
return pickle.load(open(path, "rb"))
|
|
||||||
|
|
||||||
def to_arff(self, path: "Path"):
|
def to_arff(self, path: "Path"):
|
||||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||||
@ -304,7 +350,7 @@ class Dataset:
|
|||||||
arff += f"@attribute {c} {{0,1}}\n"
|
arff += f"@attribute {c} {{0,1}}\n"
|
||||||
|
|
||||||
arff += "\n@data\n"
|
arff += "\n@data\n"
|
||||||
for d in self.data:
|
for d in self:
|
||||||
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
|
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
|
||||||
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
|
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
|
||||||
arff += f"{ys},{xs}\n"
|
arff += f"{ys},{xs}\n"
|
||||||
@ -313,10 +359,40 @@ class Dataset:
|
|||||||
fh.write(arff)
|
fh.write(arff)
|
||||||
fh.flush()
|
fh.flush()
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
class EnviFormerDataset(Dataset):
|
||||||
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
|
def __init__(self, columns=None, data=None):
|
||||||
)
|
super().__init__(columns, data)
|
||||||
|
|
||||||
|
def X(self):
|
||||||
|
"""Return the educts"""
|
||||||
|
return self["educts"]
|
||||||
|
|
||||||
|
def y(self):
|
||||||
|
"""Return the products"""
|
||||||
|
return self["products"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_dataset(reactions, *args, **kwargs):
|
||||||
|
# Standardise reactions for the training data
|
||||||
|
stereo = kwargs.get("stereo", False)
|
||||||
|
rows = []
|
||||||
|
for reaction in reactions:
|
||||||
|
e = ".".join(
|
||||||
|
[
|
||||||
|
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||||
|
for smile in reaction.educts.all()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
p = ".".join(
|
||||||
|
[
|
||||||
|
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
|
||||||
|
for smile in reaction.products.all()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rows.append([e, p])
|
||||||
|
ds = EnviFormerDataset(["educts", "products"], rows)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||||
@ -498,7 +574,7 @@ class EnsembleClassifierChain:
|
|||||||
self.classifiers = []
|
self.classifiers = []
|
||||||
|
|
||||||
if self.num_labels is None:
|
if self.num_labels is None:
|
||||||
self.num_labels = len(Y[0])
|
self.num_labels = Y.shape[1]
|
||||||
|
|
||||||
for p in range(self.num_chains):
|
for p in range(self.num_chains):
|
||||||
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
|
||||||
@ -529,7 +605,7 @@ class RelativeReasoning:
|
|||||||
|
|
||||||
def fit(self, X, Y):
|
def fit(self, X, Y):
|
||||||
n_instances = len(Y)
|
n_instances = len(Y)
|
||||||
n_attributes = len(Y[0])
|
n_attributes = Y.shape[1]
|
||||||
|
|
||||||
for i in range(n_attributes):
|
for i in range(n_attributes):
|
||||||
for j in range(n_attributes):
|
for j in range(n_attributes):
|
||||||
@ -541,8 +617,8 @@ class RelativeReasoning:
|
|||||||
countboth = 0
|
countboth = 0
|
||||||
|
|
||||||
for k in range(n_instances):
|
for k in range(n_instances):
|
||||||
vi = Y[k][i]
|
vi = Y[k, i]
|
||||||
vj = Y[k][j]
|
vj = Y[k, j]
|
||||||
|
|
||||||
if vi is None or vj is None:
|
if vi is None or vj is None:
|
||||||
continue
|
continue
|
||||||
@ -598,7 +674,7 @@ class ApplicabilityDomainPCA(PCA):
|
|||||||
self.min_vals = None
|
self.min_vals = None
|
||||||
self.max_vals = None
|
self.max_vals = None
|
||||||
|
|
||||||
def build(self, train_dataset: "Dataset"):
|
def build(self, train_dataset: "RuleBasedDataset"):
|
||||||
# transform
|
# transform
|
||||||
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
||||||
# fit pca
|
# fit pca
|
||||||
@ -612,7 +688,7 @@ class ApplicabilityDomainPCA(PCA):
|
|||||||
instances_pca = self.transform(instances_scaled)
|
instances_pca = self.transform(instances_scaled)
|
||||||
return instances_pca
|
return instances_pca
|
||||||
|
|
||||||
def is_applicable(self, classify_instances: "Dataset"):
|
def is_applicable(self, classify_instances: "RuleBasedDataset"):
|
||||||
instances_pca = self.__transform(classify_instances.X())
|
instances_pca = self.__transform(classify_instances.X())
|
||||||
|
|
||||||
is_applicable = []
|
is_applicable = []
|
||||||
|
|||||||
Reference in New Issue
Block a user