work towards #120

This commit is contained in:
Liam Brydon
2025-10-24 14:40:26 +13:00
parent 2980a75daa
commit 8166df6f39
2 changed files with 203 additions and 59 deletions

View File

@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
logger = logging.getLogger(__name__)
@ -3088,35 +3088,17 @@ class EnviFormer(PackageBasedModel):
self.save()
start = datetime.now()
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
ds = []
for reaction in self._get_reactions():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(f, "w") as d_file:
json.dump(ds, d_file)
ds.save(f)
return ds
def load_dataset(self) -> "RuleBasedDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(ds_path) as d_file:
ds = json.load(d_file)
return ds
return EnviFormerDataset.load(ds_path)
def _fit_model(self, ds):
# Call to enviFormer's fine_tune function and return the model
@ -3148,13 +3130,12 @@ class EnviFormer(PackageBasedModel):
def evaluate_sg(test_reactions, predictions, model_thresh):
# Group the true products of reactions with the same reactant together
assert len(test_reactions) == len(predictions)
true_dict = {}
for r in test_reactions:
reactant, true_product_set = r.split(">>")
true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
assert len(test_reactions) == len(predictions)
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
# Group the predicted products of reactions with the same reactant together
pred_dict = {}
@ -3274,24 +3255,9 @@ class EnviFormer(PackageBasedModel):
# If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0:
ds = []
for reaction in Reaction.objects.filter(
package__in=self.eval_packages.all()
).distinct():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
package__in=self.eval_packages.all()).distinct())
test_result = self.model.predict_batch(ds)
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result])
else:

View File

@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
import networkx as nx
import numpy as np
from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
@ -28,6 +29,37 @@ if TYPE_CHECKING:
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame):
self.df = data
else:
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns)
def add_rows(self, rows: List[List[str | int | float]]):
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
new_rows = pl.DataFrame(data=rows, schema=self.columns)
self.df.extend(new_rows)
def add_row(self, row: List[str | int | float]):
self.add_rows([row])
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
for i, feature in enumerate(self.columns):
if feature.startswith(prefix):
indices.append(i)
return min(indices), max(indices)
@property
def columns(self) -> List[str]:
return self.df.columns
@abstractmethod
def X(self):
pass
@ -36,7 +68,143 @@ class Dataset(ABC):
def y(self):
pass
@staticmethod
@abstractmethod
def generate_dataset(reactions, *args, **kwargs):
pass
def at(self, position: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "RuleBasedDataset":
import pickle
return pickle.load(open(path, "rb"))
class NewRuleBasedDataset(Dataset):
def __init__(self, num_labels, columns=None, data=None):
super().__init__(columns, data)
self.num_labels: int = num_labels
self.num_features: int = len(self.columns) - self.num_labels
def times_triggered(self, rule_uuid) -> int:
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> Tuple[int, int]:
return self._block_indices("feature_")
def triggered(self) -> Tuple[int, int]:
return self._block_indices("trig_")
def observed(self) -> Tuple[int, int]:
return self._block_indices("obs_")
def X(self):
pass
def y(self):
pass
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
_structures = set()
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
rows = []
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feat + trig + obs)
ds.add_rows(rows)
return ds
def __getitem__(self, item):
pass
@ -99,9 +267,6 @@ class RuleBasedDataset(Dataset):
def limit(self, limit: int) -> RuleBasedDataset:
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
@ -297,18 +462,6 @@ class RuleBasedDataset(Dataset):
return res
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "Path") -> "RuleBasedDataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
@ -335,8 +488,30 @@ class RuleBasedDataset(Dataset):
class EnviFormerDataset(Dataset):
def __init__(self):
pass
def __init__(self, educts, products):
assert len(educts) == len(products), "Can't have unequal length educts and products"
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
educts = []
products = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
educts.append(e)
products.append(p)
return EnviFormerDataset(educts, products)
def X(self):
pass
@ -347,6 +522,9 @@ class EnviFormerDataset(Dataset):
def __getitem__(self, item):
pass
def __len__(self):
pass
class SparseLabelECC(BaseEstimator, ClassifierMixin):
"""