forked from enviPath/enviPy
work towards #120
This commit is contained in:
@ -28,7 +28,7 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|||||||
from sklearn.model_selection import ShuffleSplit
|
from sklearn.model_selection import ShuffleSplit
|
||||||
|
|
||||||
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
||||||
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning
|
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -3088,35 +3088,17 @@ class EnviFormer(PackageBasedModel):
|
|||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
||||||
ds = []
|
|
||||||
for reaction in self._get_reactions():
|
|
||||||
educts = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.educts.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
products = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.products.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
ds.append(f"{educts}>>{products}")
|
|
||||||
|
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||||
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||||
with open(f, "w") as d_file:
|
ds.save(f)
|
||||||
json.dump(ds, d_file)
|
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def load_dataset(self) -> "RuleBasedDataset":
|
def load_dataset(self) -> "RuleBasedDataset":
|
||||||
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
||||||
with open(ds_path) as d_file:
|
return EnviFormerDataset.load(ds_path)
|
||||||
ds = json.load(d_file)
|
|
||||||
return ds
|
|
||||||
|
|
||||||
def _fit_model(self, ds):
|
def _fit_model(self, ds):
|
||||||
# Call to enviFormer's fine_tune function and return the model
|
# Call to enviFormer's fine_tune function and return the model
|
||||||
@ -3148,13 +3130,12 @@ class EnviFormer(PackageBasedModel):
|
|||||||
|
|
||||||
def evaluate_sg(test_reactions, predictions, model_thresh):
|
def evaluate_sg(test_reactions, predictions, model_thresh):
|
||||||
# Group the true products of reactions with the same reactant together
|
# Group the true products of reactions with the same reactant together
|
||||||
|
assert len(test_reactions) == len(predictions)
|
||||||
true_dict = {}
|
true_dict = {}
|
||||||
for r in test_reactions:
|
for r in test_reactions:
|
||||||
reactant, true_product_set = r.split(">>")
|
reactant, true_product_set = r.split(">>")
|
||||||
true_product_set = {p for p in true_product_set.split(".")}
|
true_product_set = {p for p in true_product_set.split(".")}
|
||||||
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
||||||
assert len(test_reactions) == len(predictions)
|
|
||||||
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
|
|
||||||
|
|
||||||
# Group the predicted products of reactions with the same reactant together
|
# Group the predicted products of reactions with the same reactant together
|
||||||
pred_dict = {}
|
pred_dict = {}
|
||||||
@ -3274,24 +3255,9 @@ class EnviFormer(PackageBasedModel):
|
|||||||
|
|
||||||
# If there are eval packages perform single generation evaluation on them instead of random splits
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
||||||
if self.eval_packages.count() > 0:
|
if self.eval_packages.count() > 0:
|
||||||
ds = []
|
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
||||||
for reaction in Reaction.objects.filter(
|
package__in=self.eval_packages.all()).distinct())
|
||||||
package__in=self.eval_packages.all()
|
test_result = self.model.predict_batch(ds)
|
||||||
).distinct():
|
|
||||||
educts = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.educts.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
products = ".".join(
|
|
||||||
[
|
|
||||||
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
||||||
for smile in reaction.products.all()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
ds.append(f"{educts}>>{products}")
|
|
||||||
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
|
|
||||||
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
||||||
self.eval_results = self.compute_averages([single_gen_result])
|
self.eval_results = self.compute_averages([single_gen_result])
|
||||||
else:
|
else:
|
||||||
|
|||||||
212
utilities/ml.py
212
utilities/ml.py
@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.random import default_rng
|
from numpy.random import default_rng
|
||||||
|
import polars as pl
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
from sklearn.dummy import DummyClassifier
|
from sklearn.dummy import DummyClassifier
|
||||||
@ -28,6 +29,37 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class Dataset(ABC):
|
class Dataset(ABC):
|
||||||
|
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
||||||
|
if isinstance(data, pl.DataFrame):
|
||||||
|
self.df = data
|
||||||
|
else:
|
||||||
|
if data is not None and len(columns) != len(data[0]):
|
||||||
|
raise ValueError(f"Header and Data are not aligned {len(columns)} vs. {len(data[0])}")
|
||||||
|
if columns is None:
|
||||||
|
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
||||||
|
self.df = pl.DataFrame(data=data, schema=columns)
|
||||||
|
|
||||||
|
def add_rows(self, rows: List[List[str | int | float]]):
|
||||||
|
if len(self.columns) != len(rows[0]):
|
||||||
|
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(rows[0])}")
|
||||||
|
new_rows = pl.DataFrame(data=rows, schema=self.columns)
|
||||||
|
self.df.extend(new_rows)
|
||||||
|
|
||||||
|
def add_row(self, row: List[str | int | float]):
|
||||||
|
self.add_rows([row])
|
||||||
|
|
||||||
|
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||||
|
indices: List[int] = []
|
||||||
|
for i, feature in enumerate(self.columns):
|
||||||
|
if feature.startswith(prefix):
|
||||||
|
indices.append(i)
|
||||||
|
|
||||||
|
return min(indices), max(indices)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def columns(self) -> List[str]:
|
||||||
|
return self.df.columns
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def X(self):
|
def X(self):
|
||||||
pass
|
pass
|
||||||
@ -36,7 +68,143 @@ class Dataset(ABC):
|
|||||||
def y(self):
|
def y(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
def generate_dataset(reactions, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def at(self, position: int) -> RuleBasedDataset:
|
||||||
|
return RuleBasedDataset(self.columns, self.num_labels, self.df[position])
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return (self.at(i) for i, _ in enumerate(self.data))
|
||||||
|
|
||||||
|
def save(self, path: "Path"):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(path, "wb") as fh:
|
||||||
|
pickle.dump(self, fh)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: "str | Path") -> "RuleBasedDataset":
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
return pickle.load(open(path, "rb"))
|
||||||
|
|
||||||
|
|
||||||
|
class NewRuleBasedDataset(Dataset):
|
||||||
|
def __init__(self, num_labels, columns=None, data=None):
|
||||||
|
super().__init__(columns, data)
|
||||||
|
self.num_labels: int = num_labels
|
||||||
|
self.num_features: int = len(self.columns) - self.num_labels
|
||||||
|
|
||||||
|
def times_triggered(self, rule_uuid) -> int:
|
||||||
|
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
|
||||||
|
|
||||||
|
def struct_features(self) -> Tuple[int, int]:
|
||||||
|
return self._block_indices("feature_")
|
||||||
|
|
||||||
|
def triggered(self) -> Tuple[int, int]:
|
||||||
|
return self._block_indices("trig_")
|
||||||
|
|
||||||
|
def observed(self) -> Tuple[int, int]:
|
||||||
|
return self._block_indices("obs_")
|
||||||
|
|
||||||
|
def X(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def y(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
||||||
|
_structures = set()
|
||||||
|
for r in reactions:
|
||||||
|
_structures.update(r.educts.all())
|
||||||
|
if not educts_only:
|
||||||
|
_structures.update(r.products.all())
|
||||||
|
|
||||||
|
compounds = sorted(_structures, key=lambda x: x.url)
|
||||||
|
triggered: Dict[str, Set[str]] = defaultdict(set)
|
||||||
|
observed: Set[str] = set()
|
||||||
|
|
||||||
|
# Apply rules on collected compounds and store tps
|
||||||
|
for i, comp in enumerate(compounds):
|
||||||
|
logger.debug(f"{i + 1}/{len(compounds)}...")
|
||||||
|
|
||||||
|
for rule in applicable_rules:
|
||||||
|
product_sets = rule.apply(comp.smiles)
|
||||||
|
if len(product_sets) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = f"{rule.uuid} + {comp.uuid}"
|
||||||
|
if key in triggered:
|
||||||
|
logger.info(f"{key} already present. Duplicate reaction?")
|
||||||
|
|
||||||
|
for prod_set in product_sets:
|
||||||
|
for smi in prod_set:
|
||||||
|
try:
|
||||||
|
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||||
|
triggered[key].add(smi)
|
||||||
|
|
||||||
|
for i, r in enumerate(reactions):
|
||||||
|
logger.debug(f"{i + 1}/{len(reactions)}...")
|
||||||
|
|
||||||
|
if len(r.educts.all()) != 1:
|
||||||
|
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for comp in r.educts.all():
|
||||||
|
for rule in applicable_rules:
|
||||||
|
key = f"{rule.uuid} + {comp.uuid}"
|
||||||
|
if key not in triggered:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# standardize products from reactions for comparison
|
||||||
|
standardized_products = []
|
||||||
|
for cs in r.products.all():
|
||||||
|
smi = cs.smiles
|
||||||
|
try:
|
||||||
|
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||||
|
standardized_products.append(smi)
|
||||||
|
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||||
|
observed.add(key)
|
||||||
|
|
||||||
|
ds_columns = (["structure_id"] +
|
||||||
|
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
|
||||||
|
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||||
|
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||||
|
ds = NewRuleBasedDataset(len(applicable_rules), ds_columns)
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for i, comp in enumerate(compounds):
|
||||||
|
# Features
|
||||||
|
feat = FormatConverter.maccs(comp.smiles)
|
||||||
|
trig = []
|
||||||
|
obs = []
|
||||||
|
for rule in applicable_rules:
|
||||||
|
key = f"{rule.uuid} + {comp.uuid}"
|
||||||
|
# Check triggered
|
||||||
|
if key in triggered:
|
||||||
|
trig.append(1)
|
||||||
|
else:
|
||||||
|
trig.append(0)
|
||||||
|
# Check obs
|
||||||
|
if key in observed:
|
||||||
|
obs.append(1)
|
||||||
|
elif key not in triggered:
|
||||||
|
obs.append(None)
|
||||||
|
else:
|
||||||
|
obs.append(0)
|
||||||
|
rows.append([str(comp.uuid)] + feat + trig + obs)
|
||||||
|
ds.add_rows(rows)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -99,9 +267,6 @@ class RuleBasedDataset(Dataset):
|
|||||||
def limit(self, limit: int) -> RuleBasedDataset:
|
def limit(self, limit: int) -> RuleBasedDataset:
|
||||||
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
|
return RuleBasedDataset(self.columns, self.num_labels, self.data[:limit])
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return (self.at(i) for i, _ in enumerate(self.data))
|
|
||||||
|
|
||||||
def classification_dataset(
|
def classification_dataset(
|
||||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||||
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
|
||||||
@ -297,18 +462,6 @@ class RuleBasedDataset(Dataset):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def save(self, path: "Path"):
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
with open(path, "wb") as fh:
|
|
||||||
pickle.dump(self, fh)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(path: "Path") -> "RuleBasedDataset":
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
return pickle.load(open(path, "rb"))
|
|
||||||
|
|
||||||
def to_arff(self, path: "Path"):
|
def to_arff(self, path: "Path"):
|
||||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||||
arff += "\n"
|
arff += "\n"
|
||||||
@ -335,8 +488,30 @@ class RuleBasedDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class EnviFormerDataset(Dataset):
|
class EnviFormerDataset(Dataset):
|
||||||
def __init__(self):
|
def __init__(self, educts, products):
|
||||||
pass
|
assert len(educts) == len(products), "Can't have unequal length educts and products"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_dataset(reactions, *args, **kwargs):
|
||||||
|
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||||
|
educts = []
|
||||||
|
products = []
|
||||||
|
for reaction in reactions:
|
||||||
|
e = ".".join(
|
||||||
|
[
|
||||||
|
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||||
|
for smile in reaction.educts.all()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
p = ".".join(
|
||||||
|
[
|
||||||
|
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
||||||
|
for smile in reaction.products.all()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
educts.append(e)
|
||||||
|
products.append(p)
|
||||||
|
return EnviFormerDataset(educts, products)
|
||||||
|
|
||||||
def X(self):
|
def X(self):
|
||||||
pass
|
pass
|
||||||
@ -347,6 +522,9 @@ class EnviFormerDataset(Dataset):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user