forked from enviPath/enviPy
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#216 Reviewed-by: liambrydon <lbry121@aucklanduni.ac.nz> Reviewed-by: Tobias O <tobias.olenyi@envipath.com>
84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
import tarfile
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from django.conf import settings as s
|
|
from django.core.management.base import BaseCommand
|
|
from django.db import transaction
|
|
|
|
from epdb.models import EnviFormer
|
|
|
|
Package = s.GET_PACKAGE_MODEL()
|
|
|
|
|
|
class Command(BaseCommand):
|
|
def add_arguments(self, parser):
|
|
parser.add_argument(
|
|
"input",
|
|
type=str,
|
|
help=".tar.gz file containing the Model dump.",
|
|
)
|
|
parser.add_argument(
|
|
"package",
|
|
type=str,
|
|
help="Package UUID where the Model should be loaded to.",
|
|
)
|
|
|
|
def read_dict_and_folder_from_archive(self, archive_path, extract_to="extracted_folder"):
|
|
with tarfile.open(archive_path, "r:gz") as tar:
|
|
tar.extractall(extract_to)
|
|
|
|
dict_path = os.path.join(extract_to, "data.json")
|
|
|
|
if not os.path.exists(dict_path):
|
|
raise FileNotFoundError("data.json not found in the archive.")
|
|
|
|
with open(dict_path, "r", encoding="utf-8") as f:
|
|
data_dict = json.load(f)
|
|
|
|
extracted_items = os.listdir(extract_to)
|
|
folders = [item for item in extracted_items if item != "data.json"]
|
|
folder_path = os.path.join(extract_to, folders[0]) if folders else None
|
|
|
|
return data_dict, folder_path
|
|
|
|
@transaction.atomic
|
|
def handle(self, *args, **options):
|
|
if not os.path.exists(options["input"]):
|
|
raise ValueError(f"Input file {options['input']} does not exist.")
|
|
|
|
target_package = Package.objects.get(uuid=options["package"])
|
|
|
|
with TemporaryDirectory() as tmpdir:
|
|
data, folder = self.read_dict_and_folder_from_archive(options["input"], tmpdir)
|
|
|
|
model = EnviFormer()
|
|
model.package = target_package
|
|
# model.uuid = data["uuid"]
|
|
model.name = data["name"]
|
|
model.description = data["description"]
|
|
model.kv = data["kv"]
|
|
model.threshold = float(data["threshold"])
|
|
model.eval_results = data["eval_results"]
|
|
model.multigen_eval = data["multigen_eval"]
|
|
model.model_status = data["model_status"]
|
|
model.save()
|
|
|
|
for p_uuid in data["data_packages_uuids"]:
|
|
p = Package.objects.get(uuid=p_uuid)
|
|
model.data_packages.add(p)
|
|
|
|
for p_uuid in data["eval_packages_uuids"]:
|
|
p = Package.objects.get(uuid=p_uuid)
|
|
model.eval_packages.add(p)
|
|
|
|
target_folder = os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid))
|
|
|
|
shutil.copytree(folder, target_folder)
|
|
os.rename(
|
|
os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{data['uuid']}.ckpt"),
|
|
os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{model.uuid}.ckpt"),
|
|
)
|