work towards #120

This commit is contained in:
Liam Brydon
2025-10-24 14:40:26 +13:00
parent 2980a75daa
commit 8166df6f39
2 changed files with 203 additions and 59 deletions

View File

@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -3088,35 +3088,17 @@ class EnviFormer(PackageBasedModel):
self.save() self.save()
start = datetime.now() start = datetime.now()
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently ds = EnviFormerDataset.generate_dataset(self._get_reactions())
ds = []
for reaction in self._get_reactions():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
end = datetime.now() end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(f, "w") as d_file: ds.save(f)
json.dump(ds, d_file)
return ds return ds
def load_dataset(self) -> "RuleBasedDataset": def load_dataset(self) -> "RuleBasedDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(ds_path) as d_file: return EnviFormerDataset.load(ds_path)
ds = json.load(d_file)
return ds
def _fit_model(self, ds): def _fit_model(self, ds):
# Call to enviFormer's fine_tune function and return the model # Call to enviFormer's fine_tune function and return the model
@ -3148,13 +3130,12 @@ class EnviFormer(PackageBasedModel):
def evaluate_sg(test_reactions, predictions, model_thresh): def evaluate_sg(test_reactions, predictions, model_thresh):
# Group the true products of reactions with the same reactant together # Group the true products of reactions with the same reactant together
assert len(test_reactions) == len(predictions)
true_dict = {} true_dict = {}
for r in test_reactions: for r in test_reactions:
reactant, true_product_set = r.split(">>") reactant, true_product_set = r.split(">>")
true_product_set = {p for p in true_product_set.split(".")} true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
assert len(test_reactions) == len(predictions)
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
# Group the predicted products of reactions with the same reactant together # Group the predicted products of reactions with the same reactant together
pred_dict = {} pred_dict = {}
@ -3274,24 +3255,9 @@ class EnviFormer(PackageBasedModel):
# If there are eval packages perform single generation evaluation on them instead of random splits # If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0: if self.eval_packages.count() > 0:
ds = [] ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
for reaction in Reaction.objects.filter( package__in=self.eval_packages.all()).distinct())
package__in=self.eval_packages.all() test_result = self.model.predict_batch(ds)
).distinct():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
single_gen_result = evaluate_sg(ds, test_result, self.threshold) single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result]) self.eval_results = self.compute_averages([single_gen_result])
else: else:

View File

@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from numpy.random import default_rng from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier from sklearn.dummy import DummyClassifier
@ -28,6 +29,37 @@ if TYPE_CHECKING:
class Dataset(ABC): class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame):
self.df = data
else:
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns)
def add_rows(self, rows: List[List[str | int | float]]):
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
new_rows = pl.DataFrame(data=rows, schema=self.columns)
self.df.extend(new_rows)
def add_row(self, row: List[str | int | float]):
self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
@property
def columns(self) -> List[str]:
return self.df.columns
@abstractmethod @abstractmethod
def X(self): def X(self):
pass pass
@ -36,7 +68,143 @@ class Dataset(ABC):
def y(self): def y(self):
pass pass
@staticmethod
@abstractmethod @abstractmethod
def generate_dataset(reactions, *args, **kwargs):
pass
def at(self, position: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "RuleBasedDataset":
import pickle
return pickle.load(open(path, "rb"))
class NewRuleBasedDataset(Dataset):
def __init__(self, num_labels, columns=None, data=None):
super().__init__(columns, data)
self.num_labels: int = num_labels
self.num_features: int = len(self.columns) - self.num_labels
def times_triggered(self, rule_uuid) -> int:
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
return self._block_indices("feature_")
def triggered(self) -> Tuple[int, int]:
return self._block_indices("trig_")
def observed(self) -> Tuple[int, int]:
return self._block_indices("obs_")
def X(self):
pass
def y(self):
pass
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
_structures = set()
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
rows = []
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feat + trig + obs)
ds.add_rows(rows)
return ds
def __getitem__(self, item): def __getitem__(self, item):
pass pass
@ -99,9 +267,6 @@ class RuleBasedDataset(Dataset):
def limit(self, limit: int) -> RuleBasedDataset: def limit(self, limit: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit]) return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset( def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]: ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
@ -297,18 +462,6 @@ class RuleBasedDataset(Dataset):
return res return res
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "Path") -> "RuleBasedDataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: "Path"): def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n" arff += "\n"
@ -335,8 +488,30 @@ class RuleBasedDataset(Dataset):
class EnviFormerDataset(Dataset): class EnviFormerDataset(Dataset):
def __init__(self): def __init__(self, educts, products):
pass assert len(educts) == len(products), "Can't have unequal length educts and products"
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
educts = []
products = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
educts.append(e)
products.append(p)
return EnviFormerDataset(educts, products)
def X(self): def X(self):
pass pass
@ -347,6 +522,9 @@ class EnviFormerDataset(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
pass pass
def __len__(self):
pass
class SparseLabelECC(BaseEstimator, ClassifierMixin): class SparseLabelECC(BaseEstimator, ClassifierMixin):
""" """