forked from enviPath/enviPy
## Changes - I have finished the backend integration of EnviFormer (#19), this includes, dataset building, model finetuning, model evaluation and model prediction with the finetuned model. - `PackageBasedModel` has been adjusted to be more abstract, this includes making the `_save_model` method and making `compute_averages` a static class function. - I had to bump the python-version in `pyproject.toml` to >=3.12 from >=3.11 otherwise uv failed to install EnviFormer. - The default EnviFormer loading during `settings.py` has been removed. ## Future Fix I noticed you have a little bit of code in `PackageBasedModel` -> `evaluate_model` for using the `eval_packages` during evaluation instead of train/test splits on `data_packages`. It doesn't seem finished, I presume we want this for all models, so I will take care of that in a new branch/pullrequest after this request is merged. Also, I haven't done anything for a POST request to finetune the model, I'm not sure if that is something we want now. Co-authored-by: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Reviewed-on: enviPath/enviPy#141 Reviewed-by: jebus <lorsbach@envipath.com> Co-authored-by: liambrydon <lbry121@aucklanduni.ac.nz> Co-committed-by: liambrydon <lbry121@aucklanduni.ac.nz>
32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
from tempfile import TemporaryDirectory
|
|
from django.test import TestCase
|
|
from epdb.logic import PackageManager
|
|
from epdb.models import User, EnviFormer, Package
|
|
|
|
|
|
class EnviFormerTest(TestCase):
|
|
fixtures = ["test_fixtures.jsonl.gz"]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(EnviFormerTest, cls).setUpClass()
|
|
cls.user = User.objects.get(username='anonymous')
|
|
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
|
|
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
|
|
|
|
def test_model_flow(self):
|
|
"""Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference"""
|
|
with TemporaryDirectory() as tmpdir:
|
|
with self.settings(MODEL_DIR=tmpdir):
|
|
threshold = float(0.5)
|
|
data_package_objs = [self.BBD_SUBSET]
|
|
eval_packages_objs = []
|
|
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
|
|
|
|
mod.build_dataset()
|
|
mod.build_model()
|
|
mod.multigen_eval = True
|
|
mod.save()
|
|
mod.evaluate_model()
|
|
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
|