forked from enviPath/enviPy
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#216 Reviewed-by: liambrydon <lbry121@aucklanduni.ac.nz> Reviewed-by: Tobias O <tobias.olenyi@envipath.com>
124 lines
5.2 KiB
Python
124 lines
5.2 KiB
Python
from django.conf import settings as s
|
|
from django.core.management.base import BaseCommand
|
|
from django.db import transaction
|
|
|
|
from epdb.models import EnviFormer, MLRelativeReasoning
|
|
|
|
Package = s.GET_PACKAGE_MODEL()
|
|
|
|
|
|
class Command(BaseCommand):
|
|
"""This command can be run with
|
|
`python manage.py create_ml_models [model_names] -d [data_packages] FOR MLRR ONLY: -r [rule_packages]
|
|
OPTIONAL: -e [eval_packages] -t threshold`
|
|
For example, to train both EnviFormer and MLRelativeReasoning on BBD and SOIL and evaluate them on SLUDGE with a
|
|
threshold of 0.6, the below command would be used:
|
|
`python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge -t 0.6
|
|
"""
|
|
|
|
def add_arguments(self, parser):
|
|
parser.add_argument(
|
|
"model_names",
|
|
nargs="+",
|
|
type=str,
|
|
help="The names of models to train. Options are: enviformer, mlrr",
|
|
)
|
|
parser.add_argument(
|
|
"-d", "--data-packages", nargs="+", type=str, help="Packages for training"
|
|
)
|
|
parser.add_argument(
|
|
"-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[]
|
|
)
|
|
parser.add_argument(
|
|
"-r",
|
|
"--rule-packages",
|
|
nargs="*",
|
|
type=str,
|
|
help="Rule Packages mandatory for MLRR",
|
|
default=[],
|
|
)
|
|
parser.add_argument(
|
|
"-t",
|
|
"--threshold",
|
|
type=float,
|
|
help="Model prediction threshold",
|
|
default=0.5,
|
|
)
|
|
|
|
@transaction.atomic
|
|
def handle(self, *args, **options):
|
|
# Find Public Prediction Models package to add new models to
|
|
try:
|
|
pack = Package.objects.filter(name="Public Prediction Models")[0]
|
|
bbd = Package.objects.filter(name="EAWAG-BBD")[0]
|
|
soil = Package.objects.filter(name="EAWAG-SOIL")[0]
|
|
sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0]
|
|
sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0]
|
|
except IndexError:
|
|
raise IndexError(
|
|
"Can't find correct packages. They should be created with the bootstrap command"
|
|
)
|
|
|
|
def decode_packages(package_list):
|
|
"""Decode package strings into their respective packages"""
|
|
packages = []
|
|
for p in package_list:
|
|
p = p.lower()
|
|
if p == "bbd":
|
|
packages.append(bbd)
|
|
elif p == "soil":
|
|
packages.append(soil)
|
|
elif p == "sludge":
|
|
packages.append(sludge)
|
|
elif p == "sediment":
|
|
packages.append(sediment)
|
|
else:
|
|
raise ValueError(f"Unknown package {p}")
|
|
return packages
|
|
|
|
# Iteratively create models in options["model_names"]
|
|
print(
|
|
f"Creating models: {options['model_names']}\n"
|
|
f"Data packages: {options['data_packages']}\n"
|
|
f"Rule Packages (only for MLRR): {options['rule_packages']}\n"
|
|
f"Eval Packages: {options['eval_packages']}\n"
|
|
f"Threshold: {options['threshold']:.2f}"
|
|
)
|
|
data_packages = decode_packages(options["data_packages"])
|
|
eval_packages = decode_packages(options["eval_packages"])
|
|
rule_packages = decode_packages(options["rule_packages"])
|
|
for model_name in options["model_names"]:
|
|
model_name = model_name.lower()
|
|
if model_name == "enviformer" and s.ENVIFORMER_PRESENT:
|
|
model = EnviFormer.create(
|
|
pack,
|
|
data_packages=data_packages,
|
|
eval_packages=eval_packages,
|
|
threshold=options["threshold"],
|
|
name=f"EnviFormer - {', '.join(options['data_packages'])} - T{options['threshold']:.2f}",
|
|
description=f"EnviFormer transformer trained on {options['data_packages']} "
|
|
f"evaluated on {options['eval_packages']}.",
|
|
)
|
|
elif model_name == "mlrr":
|
|
model = MLRelativeReasoning.create(
|
|
package=pack,
|
|
rule_packages=rule_packages,
|
|
data_packages=data_packages,
|
|
eval_packages=eval_packages,
|
|
threshold=options["threshold"],
|
|
name=f"ECC - {', '.join(options['data_packages'])} - T{options['threshold']:.2f}",
|
|
description=f"ML Relative Reasoning trained on {options['data_packages']} with rules from "
|
|
f"{options['rule_packages']} and evaluated on {options['eval_packages']}.",
|
|
)
|
|
else:
|
|
raise ValueError(f"Cannot create model of type {model_name}, unknown model type")
|
|
# Build the dataset for the model, train it, evaluate it and save it
|
|
print(f"Building dataset for {model_name}")
|
|
model.build_dataset()
|
|
print(f"Training {model_name}")
|
|
model.build_model()
|
|
print(f"Evaluating {model_name}")
|
|
model.evaluate_model(False, eval_packages=eval_packages)
|
|
print(f"Saving {model_name}")
|
|
model.save()
|