start towards #120

This commit is contained in:
Liam Brydon
2025-10-22 08:22:29 +13:00
parent 376fd65785
commit 2980a75daa
3 changed files with 53 additions and 24 deletions

View File

@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -2175,7 +2175,7 @@ class PackageBasedModel(EPModel):
applicable_rules = self.applicable_rules applicable_rules = self.applicable_rules
reactions = list(self._get_reactions()) reactions = list(self._get_reactions())
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True) ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
end = datetime.now() end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
@ -2184,9 +2184,9 @@ class PackageBasedModel(EPModel):
ds.save(f) ds.save(f)
return ds return ds
def load_dataset(self) -> "Dataset": def load_dataset(self) -> "RuleBasedDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return Dataset.load(ds_path) return RuleBasedDataset.load(ds_path)
def retrain(self): def retrain(self):
self.build_dataset() self.build_dataset()
@ -2196,7 +2196,7 @@ class PackageBasedModel(EPModel):
self.build_model() self.build_model()
@abstractmethod @abstractmethod
def _fit_model(self, ds: Dataset): def _fit_model(self, ds: RuleBasedDataset):
pass pass
@abstractmethod @abstractmethod
@ -2335,7 +2335,7 @@ class PackageBasedModel(EPModel):
eval_reactions = list( eval_reactions = list(
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
) )
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning): if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
y = np.array(ds.y(na_replacement=np.nan)) y = np.array(ds.y(na_replacement=np.nan))
@ -2582,7 +2582,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return rbrr return rbrr
def _fit_model(self, ds: Dataset): def _fit_model(self, ds: RuleBasedDataset):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None) X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning( model = RelativeReasoning(
start_index=ds.triggered()[0], start_index=ds.triggered()[0],
@ -2689,7 +2689,7 @@ class MLRelativeReasoning(PackageBasedModel):
return mlrr return mlrr
def _fit_model(self, ds: Dataset): def _fit_model(self, ds: RuleBasedDataset):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
@ -2967,7 +2967,7 @@ class ApplicabilityDomain(EnviPathModel):
return distances return distances
@staticmethod @staticmethod
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]): def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "RuleBasedDataset"]]):
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0 tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
accuracy = 0.0 accuracy = 0.0
@ -3112,7 +3112,7 @@ class EnviFormer(PackageBasedModel):
json.dump(ds, d_file) json.dump(ds, d_file)
return ds return ds
def load_dataset(self) -> "Dataset": def load_dataset(self) -> "RuleBasedDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(ds_path) as d_file: with open(ds_path) as d_file:
ds = json.load(d_file) ds = json.load(d_file)

View File

@ -2,7 +2,7 @@ from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule from epdb.models import Reaction, Compound, User, Rule
from utilities.ml import Dataset from utilities.ml import RuleBasedDataset
class DatasetTest(TestCase): class DatasetTest(TestCase):
@ -46,7 +46,7 @@ class DatasetTest(TestCase):
reactions = [r for r in Reaction.objects.filter(package=self.package)] reactions = [r for r in Reaction.objects.filter(package=self.package)]
applicable_rules = [self.rule1] applicable_rules = [self.rule1]
ds = Dataset.generate_dataset(reactions, applicable_rules) ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
self.assertEqual(len(ds.y()), 1) self.assertEqual(len(ds.y()), 1)
self.assertEqual(sum(ds.y()[0]), 1) self.assertEqual(sum(ds.y()[0]), 1)

View File

@ -6,6 +6,7 @@ from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING from typing import List, Dict, Set, Tuple, TYPE_CHECKING
from abc import ABC, abstractmethod
import networkx as nx import networkx as nx
import numpy as np import numpy as np
@ -26,7 +27,21 @@ if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction from epdb.models import Rule, CompoundStructure, Reaction
class Dataset: class Dataset(ABC):
@abstractmethod
def X(self):
pass
@abstractmethod
def y(self):
pass
@abstractmethod
def __getitem__(self, item):
pass
class RuleBasedDataset(Dataset):
def __init__( def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
): ):
@ -78,18 +93,18 @@ class Dataset:
def observed(self) -> Tuple[int, int]: def observed(self) -> Tuple[int, int]:
return self._observed return self._observed
def at(self, position: int) -> Dataset: def at(self, position: int) -> RuleBasedDataset:
return Dataset(self.columns, self.num_labels, [self.data[position]]) return RuleBasedDataset(self.columns, self.num_labels, [self.data[position]])
def limit(self, limit: int) -> Dataset: def limit(self, limit: int) -> RuleBasedDataset:
return Dataset(self.columns, self.num_labels, self.data[:limit]) return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self): def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data)) return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset( def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]: ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = [] classify_data = []
classify_products = [] classify_products = []
for struct in structures: for struct in structures:
@ -117,14 +132,14 @@ class Dataset:
classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods) classify_products.append(prods)
return Dataset( return RuleBasedDataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products ), classify_products
@staticmethod @staticmethod
def generate_dataset( def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset: ) -> RuleBasedDataset:
_structures = set() _structures = set()
for r in reactions: for r in reactions:
@ -231,7 +246,7 @@ class Dataset:
+ [f"trig_{r.uuid}" for r in applicable_rules] + [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules] + [f"obs_{r.uuid}" for r in applicable_rules]
) )
ds = Dataset(header, len(applicable_rules)) ds = RuleBasedDataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs) ds.add_row([str(comp.uuid)] + feat + trig + obs)
@ -289,7 +304,7 @@ class Dataset:
pickle.dump(self, fh) pickle.dump(self, fh)
@staticmethod @staticmethod
def load(path: "Path") -> "Dataset": def load(path: "Path") -> "RuleBasedDataset":
import pickle import pickle
return pickle.load(open(path, "rb")) return pickle.load(open(path, "rb"))
@ -319,6 +334,20 @@ class Dataset:
) )
class EnviFormerDataset(Dataset):
def __init__(self):
pass
def X(self):
pass
def y(self):
pass
def __getitem__(self, item):
pass
class SparseLabelECC(BaseEstimator, ClassifierMixin): class SparseLabelECC(BaseEstimator, ClassifierMixin):
""" """
Ensemble of Classifier Chains with sparse label removal. Ensemble of Classifier Chains with sparse label removal.
@ -598,7 +627,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None self.min_vals = None
self.max_vals = None self.max_vals = None
def build(self, train_dataset: "Dataset"): def build(self, train_dataset: "RuleBasedDataset"):
# transform # transform
X_scaled = self.scaler.fit_transform(train_dataset.X()) X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca # fit pca
@ -612,7 +641,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled) instances_pca = self.transform(instances_scaled)
return instances_pca return instances_pca
def is_applicable(self, classify_instances: "Dataset"): def is_applicable(self, classify_instances: "RuleBasedDataset"):
instances_pca = self.__transform(classify_instances.X()) instances_pca = self.__transform(classify_instances.X())
is_applicable = [] is_applicable = []