import json import os import shutil import tarfile from tempfile import TemporaryDirectory from django.conf import settings as s from django.core.management.base import BaseCommand from django.db import transaction from epdb.models import EnviFormer, Package class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( "input", type=str, help=".tar.gz file containing the Model dump.", ) parser.add_argument( "package", type=str, help="Package UUID where the Model should be loaded to.", ) def read_dict_and_folder_from_archive(self, archive_path, extract_to="extracted_folder"): with tarfile.open(archive_path, "r:gz") as tar: tar.extractall(extract_to) dict_path = os.path.join(extract_to, "data.json") if not os.path.exists(dict_path): raise FileNotFoundError("data.json not found in the archive.") with open(dict_path, "r", encoding="utf-8") as f: data_dict = json.load(f) extracted_items = os.listdir(extract_to) folders = [item for item in extracted_items if item != "data.json"] folder_path = os.path.join(extract_to, folders[0]) if folders else None return data_dict, folder_path @transaction.atomic def handle(self, *args, **options): if not os.path.exists(options["input"]): raise ValueError(f"Input file {options['input']} does not exist.") target_package = Package.objects.get(uuid=options["package"]) with TemporaryDirectory() as tmpdir: data, folder = self.read_dict_and_folder_from_archive(options["input"], tmpdir) model = EnviFormer() model.package = target_package # model.uuid = data["uuid"] model.name = data["name"] model.description = data["description"] model.kv = data["kv"] model.threshold = float(data["threshold"]) model.eval_results = data["eval_results"] model.multigen_eval = data["multigen_eval"] model.model_status = data["model_status"] model.save() for p_uuid in data["data_packages_uuids"]: p = Package.objects.get(uuid=p_uuid) model.data_packages.add(p) for p_uuid in data["eval_packages_uuids"]: p = Package.objects.get(uuid=p_uuid) model.eval_packages.add(p) target_folder = os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid)) shutil.copytree(folder, target_folder) os.rename( os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{data['uuid']}.ckpt"), os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{model.uuid}.ckpt"), )