diff --git a/epdb/management/commands/dump_enviformer.py b/epdb/management/commands/dump_enviformer.py new file mode 100644 index 00000000..e333248a --- /dev/null +++ b/epdb/management/commands/dump_enviformer.py @@ -0,0 +1,59 @@ +import json +import os +import tarfile +from tempfile import TemporaryDirectory + +from django.conf import settings as s +from django.core.management.base import BaseCommand +from django.db import transaction + +from epdb.models import EnviFormer + + +class Command(BaseCommand): + def add_arguments(self, parser): + parser.add_argument( + "model", + type=str, + help="Model UUID of the Model to Dump", + ) + parser.add_argument("--output", type=str) + + def package_dict_and_folder(self, dict_data, folder_path, output_path): + with TemporaryDirectory() as tmpdir: + dict_filename = os.path.join(tmpdir, "data.json") + + with open(dict_filename, "w", encoding="utf-8") as f: + json.dump(dict_data, f, indent=2) + + with tarfile.open(output_path, "w:gz") as tar: + tar.add(dict_filename, arcname="data.json") + tar.add(folder_path, arcname=os.path.basename(folder_path)) + + os.remove(dict_filename) + + @transaction.atomic + def handle(self, *args, **options): + output = options["output"] + + if os.path.exists(output): + raise ValueError(f"Output file {output} already exists") + + model = EnviFormer.objects.get(uuid=options["model"]) + + data = { + "uuid": str(model.uuid), + "name": model.name, + "description": model.description, + "kv": model.kv, + "data_packages_uuids": [str(p.uuid) for p in model.data_packages.all()], + "eval_packages_uuids": [str(p.uuid) for p in model.data_packages.all()], + "threshold": model.threshold, + "eval_results": model.eval_results, + "multigen_eval": model.multigen_eval, + "model_status": model.model_status, + } + + model_folder = os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid)) + + self.package_dict_and_folder(data, model_folder, output) diff --git a/epdb/management/commands/load_enviformer.py b/epdb/management/commands/load_enviformer.py new file mode 100644 index 00000000..b2f9c3e3 --- /dev/null +++ b/epdb/management/commands/load_enviformer.py @@ -0,0 +1,81 @@ +import json +import os +import shutil +import tarfile +from tempfile import TemporaryDirectory + +from django.conf import settings as s +from django.core.management.base import BaseCommand +from django.db import transaction + +from epdb.models import EnviFormer, Package + + +class Command(BaseCommand): + def add_arguments(self, parser): + parser.add_argument( + "input", + type=str, + help=".tar.gz file containing the Model dump.", + ) + parser.add_argument( + "package", + type=str, + help="Package UUID where the Model should be loaded to.", + ) + + def read_dict_and_folder_from_archive(self, archive_path, extract_to="extracted_folder"): + with tarfile.open(archive_path, "r:gz") as tar: + tar.extractall(extract_to) + + dict_path = os.path.join(extract_to, "data.json") + + if not os.path.exists(dict_path): + raise FileNotFoundError("data.json not found in the archive.") + + with open(dict_path, "r", encoding="utf-8") as f: + data_dict = json.load(f) + + extracted_items = os.listdir(extract_to) + folders = [item for item in extracted_items if item != "data.json"] + folder_path = os.path.join(extract_to, folders[0]) if folders else None + + return data_dict, folder_path + + @transaction.atomic + def handle(self, *args, **options): + if not os.path.exists(options["input"]): + raise ValueError(f"Input file {options['input']} does not exist.") + + target_package = Package.objects.get(uuid=options["package"]) + + with TemporaryDirectory() as tmpdir: + data, folder = self.read_dict_and_folder_from_archive(options["input"], tmpdir) + + model = EnviFormer() + model.package = target_package + # model.uuid = data["uuid"] + model.name = data["name"] + model.description = data["description"] + model.kv = data["kv"] + model.threshold = float(data["threshold"]) + model.eval_results = data["eval_results"] + model.multigen_eval = data["multigen_eval"] + model.model_status = data["model_status"] + model.save() + + for p_uuid in data["data_packages_uuids"]: + p = Package.objects.get(uuid=p_uuid) + model.data_packages.add(p) + + for p_uuid in data["eval_packages_uuids"]: + p = Package.objects.get(uuid=p_uuid) + model.eval_packages.add(p) + + target_folder = os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid)) + + shutil.copytree(folder, target_folder) + os.rename( + os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{data['uuid']}.ckpt"), + os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{model.uuid}.ckpt"), + )