forked from enviPath/enviPy
[Enhancement] Create ML Models (#173)
## Changes - Ability to change the threshold from a command line argument. - Names of data packages included in model name - Names of data, rule and eval packages included in the model description - EnviFormer models are now viewable on the admin site - Ignore CO2 for training and evaluating EnviFormer Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#173 Reviewed-by: jebus <lorsbach@envipath.com> Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
@ -3092,6 +3092,7 @@ class EnviFormer(PackageBasedModel):
|
||||
|
||||
start = datetime.now()
|
||||
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
|
||||
co2 = {"C(=O)=O", "O=C=O"}
|
||||
ds = []
|
||||
for reaction in self._get_reactions():
|
||||
educts = ".".join(
|
||||
@ -3106,7 +3107,8 @@ class EnviFormer(PackageBasedModel):
|
||||
for smile in reaction.products.all()
|
||||
]
|
||||
)
|
||||
ds.append(f"{educts}>>{products}")
|
||||
if products not in co2:
|
||||
ds.append(f"{educts}>>{products}")
|
||||
|
||||
end = datetime.now()
|
||||
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
||||
@ -3302,7 +3304,7 @@ class EnviFormer(PackageBasedModel):
|
||||
|
||||
ds = self.load_dataset()
|
||||
n_splits = 20
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
||||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
||||
|
||||
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
||||
# this helps reduce the memory footprint.
|
||||
@ -3370,7 +3372,7 @@ class EnviFormer(PackageBasedModel):
|
||||
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
||||
# iteration instead of storing all trained models.
|
||||
for split_id, (train, test) in enumerate(
|
||||
ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)
|
||||
ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42).split(pathways)
|
||||
):
|
||||
train_pathways = [pathways[i] for i in train]
|
||||
test_pathways = [pathways[i] for i in test]
|
||||
|
||||
Reference in New Issue
Block a user