forked from enviPath/enviPy
start towards #120
This commit is contained in:
@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|||||||
from sklearn.model_selection import ShuffleSplit
|
from sklearn.model_selection import ShuffleSplit
|
||||||
|
|
||||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -2175,7 +2175,7 @@ class PackageBasedModel(EPModel):
|
|||||||
|
|
||||||
applicable_rules = self.applicable_rules
|
applicable_rules = self.applicable_rules
|
||||||
reactions = list(self._get_reactions())
|
reactions = list(self._get_reactions())
|
||||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||||
|
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||||
@ -2184,9 +2184,9 @@ class PackageBasedModel(EPModel):
|
|||||||
ds.save(f)
|
ds.save(f)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "Dataset":
|
def load_dataset(self) -> "RuleBasedDataset":
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||||
return Dataset.load(ds_path)
|
return RuleBasedDataset.load(ds_path)
|
||||||
|
|
||||||
def retrain(self):
|
def retrain(self):
|
||||||
self.build_dataset()
|
self.build_dataset()
|
||||||
@ -2196,7 +2196,7 @@ class PackageBasedModel(EPModel):
|
|||||||
self.build_model()
|
self.build_model()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _fit_model(self, ds: Dataset):
|
def _fit_model(self, ds: RuleBasedDataset):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -2335,7 +2335,7 @@ class PackageBasedModel(EPModel):
|
|||||||
eval_reactions = list(
|
eval_reactions = list(
|
||||||
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||||
)
|
)
|
||||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||||
if isinstance(self, RuleBasedRelativeReasoning):
|
if isinstance(self, RuleBasedRelativeReasoning):
|
||||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||||
y = np.array(ds.y(na_replacement=np.nan))
|
y = np.array(ds.y(na_replacement=np.nan))
|
||||||
@ -2582,7 +2582,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
|||||||
|
|
||||||
return rbrr
|
return rbrr
|
||||||
|
|
||||||
def _fit_model(self, ds: Dataset):
|
def _fit_model(self, ds: RuleBasedDataset):
|
||||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||||
model = RelativeReasoning(
|
model = RelativeReasoning(
|
||||||
start_index=ds.triggered()[0],
|
start_index=ds.triggered()[0],
|
||||||
@ -2689,7 +2689,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
|||||||
|
|
||||||
return mlrr
|
return mlrr
|
||||||
|
|
||||||
def _fit_model(self, ds: Dataset):
|
def _fit_model(self, ds: RuleBasedDataset):
|
||||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||||
|
|
||||||
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
||||||
@ -2967,7 +2967,7 @@ class ApplicabilityDomain(EnviPathModel):
|
|||||||
return distances
|
return distances
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
|
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "RuleBasedDataset"]]):
|
||||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
||||||
accuracy = 0.0
|
accuracy = 0.0
|
||||||
|
|
||||||
@ -3112,7 +3112,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
json.dump(ds, d_file)
|
json.dump(ds, d_file)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "Dataset":
|
def load_dataset(self) -> "RuleBasedDataset":
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||||
with open(ds_path) as d_file:
|
with open(ds_path) as d_file:
|
||||||
ds = json.load(d_file)
|
ds = json.load(d_file)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from django.test import TestCase
|
|||||||
|
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import Reaction, Compound, User, Rule
|
from epdb.models import Reaction, Compound, User, Rule
|
||||||
from utilities.ml import Dataset
|
from utilities.ml import RuleBasedDataset
|
||||||
|
|
||||||
|
|
||||||
class DatasetTest(TestCase):
|
class DatasetTest(TestCase):
|
||||||
@ -46,7 +46,7 @@ class DatasetTest(TestCase):
|
|||||||
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
reactions = [r for r in Reaction.objects.filter(package=self.package)]
|
||||||
applicable_rules = [self.rule1]
|
applicable_rules = [self.rule1]
|
||||||
|
|
||||||
ds = Dataset.generate_dataset(reactions, applicable_rules)
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
|
||||||
|
|
||||||
self.assertEqual(len(ds.y()), 1)
|
self.assertEqual(len(ds.y()), 1)
|
||||||
self.assertEqual(sum(ds.y()[0]), 1)
|
self.assertEqual(sum(ds.y()[0]), 1)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from collections import defaultdict
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,7 +27,21 @@ if TYPE_CHECKING:
|
|||||||
from epdb.models import Rule, CompoundStructure, Reaction
|
from epdb.models import Rule, CompoundStructure, Reaction
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def X(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def y(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, item):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RuleBasedDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
||||||
):
|
):
|
||||||
@ -78,18 +93,18 @@ class Dataset:
|
|||||||
def observed(self) -> Tuple[int, int]:
|
def observed(self) -> Tuple[int, int]:
|
||||||
return self._observed
|
return self._observed
|
||||||
|
|
||||||
def at(self, position: int) -> Dataset:
|
def at(self, position: int) -> RuleBasedDataset:
|
||||||
return Dataset(self.columns, self.num_labels, [self.data[position]])
|
return RuleBasedDataset(self.columns, self.num_labels, [self.data[position]])
|
||||||
|
|
||||||
def limit(self, limit: int) -> Dataset:
|
def limit(self, limit: int) -> RuleBasedDataset:
|
||||||
return Dataset(self.columns, self.num_labels, self.data[:limit])
|
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return (self.at(i) for i, _ in enumerate(self.data))
|
return (self.at(i) for i, _ in enumerate(self.data))
|
||||||
|
|
||||||
def classification_dataset(
|
def classification_dataset(
|
||||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||||
) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||||
classify_data = []
|
classify_data = []
|
||||||
classify_products = []
|
classify_products = []
|
||||||
for struct in structures:
|
for struct in structures:
|
||||||
@ -117,14 +132,14 @@ class Dataset:
|
|||||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||||
classify_products.append(prods)
|
classify_products.append(prods)
|
||||||
|
|
||||||
return Dataset(
|
return RuleBasedDataset(
|
||||||
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
||||||
), classify_products
|
), classify_products
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_dataset(
|
def generate_dataset(
|
||||||
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
||||||
) -> Dataset:
|
) -> RuleBasedDataset:
|
||||||
_structures = set()
|
_structures = set()
|
||||||
|
|
||||||
for r in reactions:
|
for r in reactions:
|
||||||
@ -231,7 +246,7 @@ class Dataset:
|
|||||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
||||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
||||||
)
|
)
|
||||||
ds = Dataset(header, len(applicable_rules))
|
ds = RuleBasedDataset(header, len(applicable_rules))
|
||||||
|
|
||||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
||||||
|
|
||||||
@ -289,7 +304,7 @@ class Dataset:
|
|||||||
pickle.dump(self, fh)
|
pickle.dump(self, fh)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(path: "Path") -> "Dataset":
|
def load(path: "Path") -> "RuleBasedDataset":
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
return pickle.load(open(path, "rb"))
|
return pickle.load(open(path, "rb"))
|
||||||
@ -319,6 +334,20 @@ class Dataset:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnviFormerDataset(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def X(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def y(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||||
"""
|
"""
|
||||||
Ensemble of Classifier Chains with sparse label removal.
|
Ensemble of Classifier Chains with sparse label removal.
|
||||||
@ -598,7 +627,7 @@ class ApplicabilityDomainPCA(PCA):
|
|||||||
self.min_vals = None
|
self.min_vals = None
|
||||||
self.max_vals = None
|
self.max_vals = None
|
||||||
|
|
||||||
def build(self, train_dataset: "Dataset"):
|
def build(self, train_dataset: "RuleBasedDataset"):
|
||||||
# transform
|
# transform
|
||||||
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
||||||
# fit pca
|
# fit pca
|
||||||
@ -612,7 +641,7 @@ class ApplicabilityDomainPCA(PCA):
|
|||||||
instances_pca = self.transform(instances_scaled)
|
instances_pca = self.transform(instances_scaled)
|
||||||
return instances_pca
|
return instances_pca
|
||||||
|
|
||||||
def is_applicable(self, classify_instances: "Dataset"):
|
def is_applicable(self, classify_instances: "RuleBasedDataset"):
|
||||||
instances_pca = self.__transform(classify_instances.X())
|
instances_pca = self.__transform(classify_instances.X())
|
||||||
|
|
||||||
is_applicable = []
|
is_applicable = []
|
||||||
|
|||||||
Reference in New Issue
Block a user