[Feature] Enviformer fine tuning and evaluation

## Changes
- I have finished the backend integration of EnviFormer (#19), this includes, dataset building, model finetuning, model evaluation and model prediction with the finetuned model.
- `PackageBasedModel` has been adjusted to be more abstract, this includes making the `_save_model` method and making `compute_averages` a static class function.
- I had to bump the python-version in `pyproject.toml` to >=3.12 from >=3.11 otherwise uv failed to install EnviFormer.
- The default EnviFormer loading during `settings.py` has been removed.

## Future Fix
I noticed you have a little bit of code in `PackageBasedModel` -> `evaluate_model` for using the `eval_packages` during evaluation instead of train/test splits on `data_packages`. It doesn't seem finished, I presume we want this for all models, so I will take care of that in a new branch/pullrequest after this request is merged.

Also, I haven't done anything for a POST request to finetune the model, I'm not sure if that is something we want now.

Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com>
Reviewed-on: enviPath/enviPy#141
Reviewed-by: jebus <lorsbach@envipath.com>
Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz>
Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
This commit is contained in:
2025-10-07 21:14:10 +13:00
committed by jebus
parent 3f2b046bd6
commit d2f4fdc58a
6 changed files with 1220 additions and 1079 deletions

View File

@ -247,13 +247,7 @@ LOGGING = {
# Flags
ENVIFORMER_PRESENT = os.environ.get('ENVIFORMER_PRESENT', 'False') == 'True'
if ENVIFORMER_PRESENT:
print("Loading enviFormer")
device = os.environ.get('ENVIFORMER_DEVICE', 'cpu')
from enviformer import load
ENVIFORMER_INSTANCE = load(device=device)
print("loaded")
ENVIFORMER_DEVICE = os.environ.get('ENVIFORMER_DEVICE', 'cpu')
# If celery is not present set always eager to true which will cause delayed tasks to block until finished
FLAG_CELERY_PRESENT = os.environ.get('FLAG_CELERY_PRESENT', 'False') == 'True'

View File

@ -0,0 +1,74 @@
from django.conf import settings as s
from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.models import MLRelativeReasoning, EnviFormer, Package
class Command(BaseCommand):
"""This command can be run with
`python manage.py create_ml_models [model_names] -d [data_packages] OPTIONAL: -e [eval_packages]`
For example, to train both EnviFormer and MLRelativeReasoning on BBD and SOIL and evaluate them on SLUDGE
the below command would be used:
`python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge
"""
def add_arguments(self, parser):
parser.add_argument("model_names", nargs="+", type=str, help="The names of models to train. Options are: enviformer, mlrr")
parser.add_argument("-d", "--data-packages", nargs="+", type=str, help="Packages for training")
parser.add_argument("-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[])
parser.add_argument("-r", "--rule-packages", nargs="*", type=str, help="Rule Packages mandatory for MLRR", default=[])
@transaction.atomic
def handle(self, *args, **options):
# Find Public Prediction Models package to add new models to
try:
pack = Package.objects.filter(name="Public Prediction Models")[0]
bbd = Package.objects.filter(name="EAWAG-BBD")[0]
soil = Package.objects.filter(name="EAWAG-SOIL")[0]
sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0]
sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0]
except IndexError:
raise IndexError("Can't find correct packages. They should be created with the bootstrap command")
def decode_packages(package_list):
"""Decode package strings into their respective packages"""
packages = []
for p in package_list:
p = p.lower()
if p == "bbd":
packages.append(bbd)
elif p == "soil":
packages.append(soil)
elif p == "sludge":
packages.append(sludge)
elif p == "sediment":
packages.append(sediment)
else:
raise ValueError(f"Unknown package {p}")
return packages
# Iteratively create models in options["model_names"]
print(f"Creating models: {options['model_names']}")
data_packages = decode_packages(options["data_packages"])
eval_packages = decode_packages(options["eval_packages"])
rule_packages = decode_packages(options["rule_packages"])
for model_name in options['model_names']:
model_name = model_name.lower()
if model_name == "enviformer" and s.ENVIFORMER_PRESENT:
model = EnviFormer.create(pack, data_packages=data_packages, eval_packages=eval_packages, threshold=0.5,
name="EnviFormer - T0.5", description="EnviFormer transformer")
elif model_name == "mlrr":
model = MLRelativeReasoning.create(package=pack, rule_packages=rule_packages,
data_packages=data_packages, eval_packages=eval_packages, threshold=0.5,
name='ECC - BBD - T0.5', description='ML Relative Reasoning')
else:
raise ValueError(f"Cannot create model of type {model_name}, unknown model type")
# Build the dataset for the model, train it, evaluate it and save it
print(f"Building dataset for {model_name}")
model.build_dataset()
print(f"Training {model_name}")
model.build_model()
print(f"Evaluating {model_name}")
model.evaluate_model()
print(f"Saving {model_name}")
model.save()

View File

@ -9,7 +9,7 @@ from collections import defaultdict
from datetime import datetime
from typing import Union, List, Optional, Dict, Tuple, Set, Any
from uuid import uuid4
import math
import joblib
import numpy as np
from django.conf import settings as s
@ -2002,6 +2002,10 @@ class PackageBasedModel(EPModel):
def _model_args(self) -> Dict[str, Any]:
pass
@abstractmethod
def _save_model(self, model):
pass
def build_model(self):
self.model_status = self.BUILDING
self.save()
@ -2010,8 +2014,7 @@ class PackageBasedModel(EPModel):
mod = self._fit_model(ds)
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(mod, f)
self._save_model(mod)
if self.app_domain is not None:
logger.debug("Building applicability domain...")
@ -2116,7 +2119,7 @@ class PackageBasedModel(EPModel):
mg_acc = 0.0
for t in thresholds:
for true, pred in zip(test_pathways, pred_pathways):
for true, pred in zip(pathways, pred_pathways):
acc, pre, rec = multigen_eval(true, pred, t)
if abs(t - threshold) < 0.01:
mg_acc = acc
@ -2146,29 +2149,7 @@ class PackageBasedModel(EPModel):
evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits))
def compute_averages(data):
num_items = len(data)
avg_first_item = sum(item[0] for item in data) / num_items
sum_dict2 = defaultdict(float)
sum_dict3 = defaultdict(float)
for _, dict2, dict3 in data:
for key in dict2:
sum_dict2[key] += dict2[key]
for key in dict3:
sum_dict3[key] += dict3[key]
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
return {
"average_accuracy": float(avg_first_item),
"average_precision_per_threshold": avg_dict2,
"average_recall_per_threshold": avg_dict3
}
self.eval_results = compute_averages(evaluations)
self.eval_results = self.compute_averages(evaluations)
if self.multigen_eval:
@ -2209,7 +2190,7 @@ class PackageBasedModel(EPModel):
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
# build lookup to avoid recalculation of features, labels
id_to_index = {uuid: i for i, uuid in enumerate(ds[:, 0])}
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
# Compute splits of the collected pathway
splits = []
@ -2233,10 +2214,8 @@ class PackageBasedModel(EPModel):
# Ensure compounds in the training set do not appear in the test set
if educt not in test_educts:
if educt in id_to_index:
split_ids.append(id_to_index[str(educt)])
try:
split_ids.append(id_to_index[str(educt)])
except KeyError:
split_ids.append(id_to_index[educt])
else:
logger.debug(f"Couldn't find features in X for compound {educt}")
else:
overlap += 1
@ -2260,12 +2239,34 @@ class PackageBasedModel(EPModel):
zip(trained_models, splits)
)
self.eval_results.update({f"multigen_{k}": v for k, v in compute_averages(multi_ret_vals).items()})
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()})
self.model_status = self.FINISHED
self.save()
@staticmethod
def compute_averages(data):
num_items = len(data)
avg_first_item = sum(item[0] for item in data) / num_items
sum_dict2 = defaultdict(float)
sum_dict3 = defaultdict(float)
for _, dict2, dict3 in data:
for key in dict2:
sum_dict2[key] += dict2[key]
for key in dict3:
sum_dict3[key] += dict3[key]
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
return {
"average_accuracy": float(avg_first_item),
"average_precision_per_threshold": avg_dict2,
"average_recall_per_threshold": avg_dict3
}
@staticmethod
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
res = []
@ -2354,6 +2355,10 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
'end_index': ds.triggered()[1],
}
def _save_model(self, model):
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(model, f)
@cached_property
def model(self) -> 'RelativeReasoning':
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
@ -2443,6 +2448,10 @@ class MLRelativeReasoning(PackageBasedModel):
**s.DEFAULT_MODEL_PARAMS,
}
def _save_model(self, model):
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
joblib.dump(model, f)
@cached_property
def model(self) -> 'EnsembleClassifierChain':
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
@ -2696,57 +2705,335 @@ class ApplicabilityDomain(EnviPathModel):
return accuracy
class EnviFormer(EPModel):
threshold = models.FloatField(null=False, blank=False, default=0.5)
class EnviFormer(PackageBasedModel):
@staticmethod
@transaction.atomic
def create(package, name, description, threshold):
def create(package: 'Package', data_packages: List['Package'], eval_packages: List['Package'],
threshold: float = 0.5, name: 'str' = None, description: str = None, build_app_domain: bool = False,
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
app_domain_local_compatibility_threshold: float = None):
mod = EnviFormer()
mod.package = package
if name is None or name.strip() == '':
name = f"EnviFormer {EnviFormer.objects.filter(package=package).count() + 1}"
mod.name = name
mod.description = description
if description is not None and description.strip() != '':
mod.description = description
if threshold is None or (threshold <= 0 or 1 <= threshold):
raise ValueError("Threshold must be a float between 0 and 1.")
mod.threshold = threshold
if len(data_packages) == 0:
raise ValueError("At least one data package must be provided.")
mod.save()
for p in data_packages:
mod.data_packages.add(p)
if eval_packages:
for p in eval_packages:
mod.eval_packages.add(p)
# if build_app_domain:
# ad = ApplicabilityDomain.create(mod, app_domain_num_neighbours, app_domain_reliability_threshold,
# app_domain_local_compatibility_threshold)
# mod.app_domain = ad
mod.save()
return mod
@cached_property
def model(self):
mod = getattr(s, 'ENVIFORMER_INSTANCE', None)
logger.info(f"Model from settings {hash(mod)}")
return mod
from enviformer import load
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
return load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
def predict(self, smiles) -> List['PredictionResult']:
# example = {
# 'C#N': 0.46326889595136767,
# 'C#C': 0.04981685951409509,
# }
from rdkit import Chem
m = Chem.MolFromSmiles(smiles)
Chem.Kekulize(m)
kek = Chem.MolToSmiles(m, kekuleSmiles=True)
logger.info(f"Submitting {kek} to {hash(self.model)}")
products = self.model.predict(kek)
logger.info(f"Got results {products}")
return self.predict_batch([smiles])[0]
res = []
for smi, prob in products.items():
res.append(PredictionResult([ProductSet([smi])], prob, None))
def predict_batch(self, smiles_list):
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list]
logger.info(f"Submitting {canon_smiles} to {hash(self.model)}")
products_list = self.model.predict_batch(canon_smiles)
logger.info(f"Got results {products_list}")
return res
results = []
for products in products_list:
res = []
for smi, prob in products.items():
try:
smi = ".".join([FormatConverter.standardize(smile, remove_stereo=True) for smile in smi.split(".")])
except ValueError: # This occurs when the predicted string is an invalid SMILES
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
continue
res.append(PredictionResult([ProductSet([smi])], prob, None))
results.append(res)
return results
def build_dataset(self):
self.model_status = self.INITIALIZING
self.save()
start = datetime.now()
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently
ds = []
for reaction in self._get_reactions():
educts = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
ds.append(f"{educts}>>{products}")
end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(f, "w") as d_file:
json.dump(ds, d_file)
return ds
def load_dataset(self) -> 'Dataset':
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(ds_path) as d_file:
ds = json.load(d_file)
return ds
def _fit_model(self, ds):
# Call to enviFormer's fine_tune function and return the model
from enviformer.finetune import fine_tune
start = datetime.now()
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model
def _save_model(self, model):
from enviformer.utils import save_model
save_model(model, os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt"))
def _model_args(self) -> Dict[str, Any]:
args = {"clz": "EnviFormer"}
return args
def evaluate_model(self):
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
self.model_status = self.EVALUATING
self.save()
def evaluate_sg(test_reactions, predictions, model_thresh):
# Group the true products of reactions with the same reactant together
true_dict = {}
for r in test_reactions:
reactant, true_product_set = r.split(">>")
true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
assert len(test_reactions) == len(predictions)
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
# Group the predicted products of reactions with the same reactant together
pred_dict = {}
for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_reactions[k].split(">>")
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split("."))
if smiles not in pred_dict[reactant]["predict"]:
pred_dict[reactant]["predict"].append(smiles)
pred_dict[reactant]["scores"].append(proba)
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
thresholds = set()
thresholds.update({i / 5 for i in range(-75, -10, 15)})
thresholds.update({i / 50 for i in range(-100, -10, 10)})
thresholds = {math.exp(t) for t in thresholds}
thresholds.add(model_thresh)
thresholds = sorted(thresholds)
# Calculate the number correct and predicted for each threshold and at each top-k
correct = {t: 0 for t in thresholds}
predicted = {t: 0 for t in thresholds}
for reactant, product_sets in true_dict.items():
pred_smiles = pred_dict[reactant]["predict"]
pred_scores = pred_dict[reactant]["scores"]
for true_set in product_sets:
for threshold in correct:
pred_s = [s for i, s in enumerate(pred_smiles) if pred_scores[i] > threshold]
predicted[threshold] += len(pred_s)
for pred_set in pred_s:
if len(true_set - pred_set) == 0:
correct[threshold] += 1
break
# Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted)
prec = {f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()}
# Accuracy for EnviFormer is just recall
return rec[f"{model_thresh:.2f}"], prec, rec
def evaluate_mg(model, pathways, threshold):
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
thresholds = set()
thresholds.update({i / 5 for i in range(-75, -10, 15)})
thresholds.update({i / 50 for i in range(-100, -10, 10)})
thresholds = {math.exp(t) for t in thresholds}
thresholds.add(threshold)
thresholds = sorted(thresholds)
precision = {f"{t:.2f}": [] for t in thresholds}
recall = {f"{t:.2f}": [] for t in thresholds}
# Note: only one root compound supported at this time
root_compounds = []
for p in pathways:
root_node = p.root_nodes
if len(root_node) > 1:
logging.warning(f"Pathway {p.name} has more than one root compound, only {root_node[0]} will be used")
root_node = ".".join([FormatConverter.standardize(smile) for smile in root_node[0].default_node_label.smiles.split(".")])
root_compounds.append(root_node)
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
# pass it to the setting used in prediction
mod = EnviFormer.objects.get(pk=self.pk)
mod.model = model
s = Setting()
s.model = mod
s.model_threshold = min(thresholds)
s.max_depth = 10
s.max_nodes = 50
from epdb.logic import SPathway
from utilities.ml import multigen_eval
# Predict pathways from each root compound
pred_pathways = []
for i, root in enumerate(root_compounds):
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
spw = SPathway(root_nodes=root, prediction_setting=s)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
pred_pathways.append(spw)
mg_acc = 0.0
for t in thresholds:
for true, pred in zip(pathways, pred_pathways):
# Calculate multigen statistics
acc, pre, rec = multigen_eval(true, pred, t)
if t == threshold:
mg_acc = acc
precision[f"{t:.2f}"].append(pre)
recall[f"{t:.2f}"].append(rec)
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
return mg_acc, precision, recall
from enviformer.finetune import fine_tune
ds = self.load_dataset()
n_splits = 20
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint.
single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index]
test = [ds[i] for i in test_index]
start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results)
if self.multigen_eval:
pathway_qs = Pathway.objects.prefetch_related(
'node_set',
'node_set__out_edges',
'node_set__default_node_label',
'node_set__scenarios',
'edge_set',
'edge_set__start_nodes',
'edge_set__end_nodes',
'edge_set__edge_label',
'edge_set__scenarios'
).filter(package__in=self.data_packages.all()).distinct()
pathways = []
for pathway in pathway_qs:
# There is one pathway with no root compounds, so this check is required
if len(pathway.root_nodes) > 0:
pathways.append(pathway)
else:
logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation")
# build lookup reaction -> {uuid1, uuid2} for overlap check
reaction_to_educts = defaultdict(set)
for pathway in pathways:
for reaction in pathway.edges:
for e in reaction.edge_label.educts.all():
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
multi_gen_results = []
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
# iteration instead of storing all trained models.
for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)):
train_pathways = [pathways[i] for i in train]
test_pathways = [pathways[i] for i in test]
# Collect structures from test pathways
test_educts = set()
for pathway in test_pathways:
for reaction in pathway.edges:
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
train_reactions = []
overlap = 0
# Collect indices of the structures contained in train pathways iff they're not present in any of
# the test pathways
for pathway in train_pathways:
for reaction in pathway.edges:
reaction = reaction.edge_label
if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
overlap += 1
continue
educts = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()])
products = ".".join(
[FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()])
train_reactions.append(f"{educts}>>{products}")
logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways")
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()})
self.model_status = self.FINISHED
self.save()
@cached_property
def applicable_rules(self):
return []
def status(self):
return "Model is built and can be used for predictions, Model is not evaluated yet."
def ready_for_prediction(self) -> bool:
return True
class PluginModel(EPModel):
pass

View File

@ -3,7 +3,7 @@ name = "envipy"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
requires-python = ">=3.12"
dependencies = [
"celery>=5.5.2",
"django>=5.2.1",
@ -31,7 +31,7 @@ dependencies = [
]
[tool.uv.sources]
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.0" }
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" }
envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" }
envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.4"}
envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }

31
tests/test_enviformer.py Normal file
View File

@ -0,0 +1,31 @@
from tempfile import TemporaryDirectory
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, EnviFormer, Package
class EnviFormerTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
super(EnviFormerTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
def test_model_flow(self):
"""Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference"""
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = []
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')

1769
uv.lock generated

File diff suppressed because it is too large Load Diff