forked from enviPath/enviPy
start towards #120
This commit is contained in:
@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2175,7 +2175,7 @@ class PackageBasedModel(EPModel):
|
||||
|
||||
applicable_rules = self.applicable_rules
|
||||
reactions = list(self._get_reactions())
|
||||
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
@ -2184,9 +2184,9 @@ class PackageBasedModel(EPModel):
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "Dataset":
|
||||
def load_dataset(self) -> "RuleBasedDataset":
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
||||
return Dataset.load(ds_path)
|
||||
return RuleBasedDataset.load(ds_path)
|
||||
|
||||
def retrain(self):
|
||||
self.build_dataset()
|
||||
@ -2196,7 +2196,7 @@ class PackageBasedModel(EPModel):
|
||||
self.build_model()
|
||||
|
||||
@abstractmethod
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -2335,7 +2335,7 @@ class PackageBasedModel(EPModel):
|
||||
eval_reactions = list(
|
||||
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
||||
)
|
||||
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
||||
if isinstance(self, RuleBasedRelativeReasoning):
|
||||
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
||||
y = np.array(ds.y(na_replacement=np.nan))
|
||||
@ -2582,7 +2582,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return rbrr
|
||||
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
||||
model = RelativeReasoning(
|
||||
start_index=ds.triggered()[0],
|
||||
@ -2689,7 +2689,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
|
||||
return mlrr
|
||||
|
||||
def _fit_model(self, ds: Dataset):
|
||||
def _fit_model(self, ds: RuleBasedDataset):
|
||||
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
||||
|
||||
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
||||
@ -2967,7 +2967,7 @@ class ApplicabilityDomain(EnviPathModel):
|
||||
return distances
|
||||
|
||||
@staticmethod
|
||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
|
||||
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "RuleBasedDataset"]]):
|
||||
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
||||
accuracy = 0.0
|
||||
|
||||
@ -3112,7 +3112,7 @@ class EnviFormer(PackageBasedModel):
|
||||
json.dump(ds, d_file)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "Dataset":
|
||||
def load_dataset(self) -> "RuleBasedDataset":
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(ds_path) as d_file:
|
||||
ds = json.load(d_file)
|
||||
|
||||
Reference in New Issue
Block a user