Files
enviPy-bayer/epdb/management/commands/create_ml_models.py
jebus afeb56622c [Chore] Linted Files (#150)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#150
2025-10-09 07:25:13 +13:00

106 lines
4.2 KiB
Python

from django.conf import settings as s
from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.models import MLRelativeReasoning, EnviFormer, Package
class Command(BaseCommand):
"""This command can be run with
`python manage.py create_ml_models [model_names] -d [data_packages] OPTIONAL: -e [eval_packages]`
For example, to train both EnviFormer and MLRelativeReasoning on BBD and SOIL and evaluate them on SLUDGE
the below command would be used:
`python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge
"""
def add_arguments(self, parser):
parser.add_argument(
"model_names",
nargs="+",
type=str,
help="The names of models to train. Options are: enviformer, mlrr",
)
parser.add_argument(
"-d", "--data-packages", nargs="+", type=str, help="Packages for training"
)
parser.add_argument(
"-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[]
)
parser.add_argument(
"-r",
"--rule-packages",
nargs="*",
type=str,
help="Rule Packages mandatory for MLRR",
default=[],
)
@transaction.atomic
def handle(self, *args, **options):
# Find Public Prediction Models package to add new models to
try:
pack = Package.objects.filter(name="Public Prediction Models")[0]
bbd = Package.objects.filter(name="EAWAG-BBD")[0]
soil = Package.objects.filter(name="EAWAG-SOIL")[0]
sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0]
sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0]
except IndexError:
raise IndexError(
"Can't find correct packages. They should be created with the bootstrap command"
)
def decode_packages(package_list):
"""Decode package strings into their respective packages"""
packages = []
for p in package_list:
p = p.lower()
if p == "bbd":
packages.append(bbd)
elif p == "soil":
packages.append(soil)
elif p == "sludge":
packages.append(sludge)
elif p == "sediment":
packages.append(sediment)
else:
raise ValueError(f"Unknown package {p}")
return packages
# Iteratively create models in options["model_names"]
print(f"Creating models: {options['model_names']}")
data_packages = decode_packages(options["data_packages"])
eval_packages = decode_packages(options["eval_packages"])
rule_packages = decode_packages(options["rule_packages"])
for model_name in options["model_names"]:
model_name = model_name.lower()
if model_name == "enviformer" and s.ENVIFORMER_PRESENT:
model = EnviFormer.create(
pack,
data_packages=data_packages,
eval_packages=eval_packages,
threshold=0.5,
name="EnviFormer - T0.5",
description="EnviFormer transformer",
)
elif model_name == "mlrr":
model = MLRelativeReasoning.create(
package=pack,
rule_packages=rule_packages,
data_packages=data_packages,
eval_packages=eval_packages,
threshold=0.5,
name="ECC - BBD - T0.5",
description="ML Relative Reasoning",
)
else:
raise ValueError(f"Cannot create model of type {model_name}, unknown model type")
# Build the dataset for the model, train it, evaluate it and save it
print(f"Building dataset for {model_name}")
model.build_dataset()
print(f"Training {model_name}")
model.build_model()
print(f"Evaluating {model_name}")
model.evaluate_model()
print(f"Saving {model_name}")
model.save()