forked from enviPath/enviPy
work towards #120
This commit is contained in:
@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -3088,35 +3088,17 @@ class EnviFormer(PackageBasedModel):
|
||||
self.save()
|
||||
|
||||
start = datetime.now()
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
ds = []
|
||||
for reaction in self._get_reactions():
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
ds.append(f"{educts}>>{products}")
|
||||
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(f, "w") as d_file:
|
||||
json.dump(ds, d_file)
|
||||
ds.save(f)
|
||||
return ds
|
||||
|
||||
def load_dataset(self) -> "RuleBasedDataset":
|
||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||
with open(ds_path) as d_file:
|
||||
ds = json.load(d_file)
|
||||
return ds
|
||||
return EnviFormerDataset.load(ds_path)
|
||||
|
||||
def _fit_model(self, ds):
|
||||
# Call to enviFormer's fine_tune function and return the model
|
||||
@ -3148,13 +3130,12 @@ class EnviFormer(PackageBasedModel):
|
||||
|
||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
||||
# Group the true products of reactions with the same reactant together
|
||||
assert len(test_reactions) == len(predictions)
|
||||
true_dict = {}
|
||||
for r in test_reactions:
|
||||
reactant, true_product_set = r.split(">>")
|
||||
true_product_set = {p for p in true_product_set.split(".")}
|
||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||
assert len(test_reactions) == len(predictions)
|
||||
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
|
||||
|
||||
# Group the predicted products of reactions with the same reactant together
|
||||
pred_dict = {}
|
||||
@ -3274,24 +3255,9 @@ class EnviFormer(PackageBasedModel):
|
||||
|
||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||
if self.eval_packages.count() > 0:
|
||||
ds = []
|
||||
for reaction in Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()
|
||||
).distinct():
|
||||
educts = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
products = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
ds.append(f"{educts}>>{products}")
|
||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
||||
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||
package__in=self.eval_packages.all()).distinct())
|
||||
test_result = self.model.predict_batch(ds)
|
||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||
self.eval_results = self.compute_averages([single_gen_result])
|
||||
else:
|
||||
|
||||
212
utilities/ml.py
212
utilities/ml.py
@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from numpy.random import default_rng
|
||||
import polars as pl
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.dummy import DummyClassifier
|
||||
@ -28,6 +29,37 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Dataset(ABC):
|
||||
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
||||
if isinstance(data, pl.DataFrame):
|
||||
self.df = data
|
||||
else:
|
||||
if data is not None and len(columns) != len(data[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
|
||||
if columns is None:
|
||||
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
||||
self.df = pl.DataFrame(data=data, schema=columns)
|
||||
|
||||
def add_rows(self, rows: List[List[str | int | float]]):
|
||||
if len(self.columns) != len(rows[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
|
||||
new_rows = pl.DataFrame(data=rows, schema=self.columns)
|
||||
self.df.extend(new_rows)
|
||||
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
self.add_rows([row])
|
||||
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
indices: List[int] = []
|
||||
for i, feature in enumerate(self.columns):
|
||||
if feature.startswith(prefix):
|
||||
indices.append(i)
|
||||
|
||||
return min(indices), max(indices)
|
||||
|
||||
@property
|
||||
def columns(self) -> List[str]:
|
||||
return self.df.columns
|
||||
|
||||
@abstractmethod
|
||||
def X(self):
|
||||
pass
|
||||
@ -36,7 +68,143 @@ class Dataset(ABC):
|
||||
def y(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def generate_dataset(reactions, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def at(self, position: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
|
||||
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
|
||||
def save(self, path: "Path"):
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: "str | Path") -> "RuleBasedDataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
|
||||
class NewRuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
self.num_labels: int = num_labels
|
||||
self.num_features: int = len(self.columns) - self.num_labels
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||
|
||||
def struct_features(self) -> Tuple[int, int]:
|
||||
return self._block_indices("feature_")
|
||||
|
||||
def triggered(self) -> Tuple[int, int]:
|
||||
return self._block_indices("trig_")
|
||||
|
||||
def observed(self) -> Tuple[int, int]:
|
||||
return self._block_indices("obs_")
|
||||
|
||||
def X(self):
|
||||
pass
|
||||
|
||||
def y(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
||||
_structures = set()
|
||||
for r in reactions:
|
||||
_structures.update(r.educts.all())
|
||||
if not educts_only:
|
||||
_structures.update(r.products.all())
|
||||
|
||||
compounds = sorted(_structures, key=lambda x: x.url)
|
||||
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||
observed: Set[str] = set()
|
||||
|
||||
# Apply rules on collected compounds and store tps
|
||||
for i, comp in enumerate(compounds):
|
||||
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||
|
||||
for rule in applicable_rules:
|
||||
product_sets = rule.apply(comp.smiles)
|
||||
if len(product_sets) == 0:
|
||||
continue
|
||||
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
if key in triggered:
|
||||
logger.info(f"{key} already present. Duplicate reaction?")
|
||||
|
||||
for prod_set in product_sets:
|
||||
for smi in prod_set:
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
triggered[key].add(smi)
|
||||
|
||||
for i, r in enumerate(reactions):
|
||||
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||
|
||||
if len(r.educts.all()) != 1:
|
||||
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||
continue
|
||||
|
||||
for comp in r.educts.all():
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
if key not in triggered:
|
||||
continue
|
||||
|
||||
# standardize products from reactions for comparison
|
||||
standardized_products = []
|
||||
for cs in r.products.all():
|
||||
smi = cs.smiles
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception as e:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
standardized_products.append(smi)
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
observed.add(key)
|
||||
|
||||
ds_columns = (["structure_id"] +
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
|
||||
rows = []
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feat = FormatConverter.maccs(comp.smiles)
|
||||
trig = []
|
||||
obs = []
|
||||
for rule in applicable_rules:
|
||||
key = f"{rule.uuid} + {comp.uuid}"
|
||||
# Check triggered
|
||||
if key in triggered:
|
||||
trig.append(1)
|
||||
else:
|
||||
trig.append(0)
|
||||
# Check obs
|
||||
if key in observed:
|
||||
obs.append(1)
|
||||
elif key not in triggered:
|
||||
obs.append(None)
|
||||
else:
|
||||
obs.append(0)
|
||||
rows.append([str(comp.uuid)] + feat + trig + obs)
|
||||
ds.add_rows(rows)
|
||||
return ds
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
@ -99,9 +267,6 @@ class RuleBasedDataset(Dataset):
|
||||
def limit(self, limit: int) -> RuleBasedDataset:
|
||||
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
|
||||
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
|
||||
def classification_dataset(
|
||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||
@ -297,18 +462,6 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
return res
|
||||
|
||||
def save(self, path: "Path"):
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: "Path") -> "RuleBasedDataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
def to_arff(self, path: "Path"):
|
||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||
arff += "\n"
|
||||
@ -335,8 +488,30 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
|
||||
class EnviFormerDataset(Dataset):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, educts, products):
|
||||
assert len(educts) == len(products), "Can't have unequal length educts and products"
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, *args, **kwargs):
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
educts = []
|
||||
products = []
|
||||
for reaction in reactions:
|
||||
e = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.educts.all()
|
||||
]
|
||||
)
|
||||
p = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
educts.append(e)
|
||||
products.append(p)
|
||||
return EnviFormerDataset(educts, products)
|
||||
|
||||
def X(self):
|
||||
pass
|
||||
@ -347,6 +522,9 @@ class EnviFormerDataset(Dataset):
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
pass
|
||||
|
||||
|
||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user