Files
enviPy-bayer/tests/test_enviformer.py
jebus afeb56622c [Chore] Linted Files (#150)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#150
2025-10-09 07:25:13 +13:00

36 lines
1.3 KiB
Python

from tempfile import TemporaryDirectory
from django.test import TestCase, tag
from epdb.logic import PackageManager
from epdb.models import User, EnviFormer, Package
@tag("slow")
class EnviFormerTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
super(EnviFormerTest, cls).setUpClass()
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_model_flow(self):
"""Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference"""
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create(
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model()
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")