Merge remote-tracking branch 'origin/develop' into enhancement/dataset

# Conflicts:
#	epdb/models.py
#	tests/test_enviformer.py
#	tests/test_model.py
This commit is contained in:
Liam Brydon
2025-11-07 08:28:03 +13:00
25 changed files with 1024 additions and 280 deletions

View File

@ -47,6 +47,7 @@ from .models import (
ExternalDatabase,
ExternalIdentifier,
EnzymeLink,
JobLog,
)
logger = logging.getLogger(__name__)
@ -236,6 +237,7 @@ def get_base_context(request, for_user=None) -> Dict[str, Any]:
"enabled_features": s.FLAGS,
"debug": s.DEBUG,
"external_databases": ExternalDatabase.get_databases(),
"site_id": s.MATOMO_SITE_ID,
},
}
@ -754,8 +756,8 @@ def package_models(request, package_uuid):
context["unreviewed_objects"] = unreviewed_model_qs
context["model_types"] = {
"ML Relative Reasoning": "ml-relative-reasoning",
"Rule Based Relative Reasoning": "rule-based-relative-reasoning",
"ML Relative Reasoning": "mlrr",
"Rule Based Relative Reasoning": "rbrr",
}
if s.FLAGS.get("ENVIFORMER", False):
@ -775,69 +777,67 @@ def package_models(request, package_uuid):
model_type = request.POST.get("model-type")
# Generic fields for ML and Rule Based
rule_packages = request.POST.getlist("model-rule-packages")
data_packages = request.POST.getlist("model-data-packages")
# Generic params
params = {
"package": current_package,
"name": name,
"description": description,
"data_packages": [
PackageManager.get_package_by_url(current_user, p) for p in data_packages
],
}
if model_type == "enviformer":
threshold = float(request.POST.get(f"{model_type}-threshold", 0.5))
threshold = float(request.POST.get("model-threshold", 0.5))
params["threshold"] = threshold
mod = EnviFormer.create(current_package, name, description, threshold)
mod = EnviFormer.create(**params)
elif model_type == "mlrr":
# ML Specific
threshold = float(request.POST.get("model-threshold", 0.5))
# TODO handle additional fingerprinter
# fingerprinter = request.POST.get("model-fingerprinter")
elif model_type == "ml-relative-reasoning" or model_type == "rule-based-relative-reasoning":
# Generic fields for ML and Rule Based
rule_packages = request.POST.getlist("package-based-relative-reasoning-rule-packages")
data_packages = request.POST.getlist("package-based-relative-reasoning-data-packages")
eval_packages = request.POST.getlist(
"package-based-relative-reasoning-evaluation-packages", []
)
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
# Generic params
params = {
"package": current_package,
"name": name,
"description": description,
"rule_packages": [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
],
"data_packages": [
PackageManager.get_package_by_url(current_user, p) for p in data_packages
],
"eval_packages": [
PackageManager.get_package_by_url(current_user, p) for p in eval_packages
],
}
# App Domain related parameters
build_ad = request.POST.get("build-app-domain", False) == "on"
num_neighbors = request.POST.get("num-neighbors", 5)
reliability_threshold = request.POST.get("reliability-threshold", 0.5)
local_compatibility_threshold = request.POST.get("local-compatibility-threshold", 0.5)
if model_type == "ml-relative-reasoning":
# ML Specific
threshold = float(request.POST.get(f"{model_type}-threshold", 0.5))
# TODO handle additional fingerprinter
# fingerprinter = request.POST.get(f"{model_type}-fingerprinter")
params["threshold"] = threshold
# params['fingerprinter'] = fingerprinter
params["build_app_domain"] = build_ad
params["app_domain_num_neighbours"] = num_neighbors
params["app_domain_reliability_threshold"] = reliability_threshold
params["app_domain_local_compatibility_threshold"] = local_compatibility_threshold
# App Domain related parameters
build_ad = request.POST.get("build-app-domain", False) == "on"
num_neighbors = request.POST.get("num-neighbors", 5)
reliability_threshold = request.POST.get("reliability-threshold", 0.5)
local_compatibility_threshold = request.POST.get(
"local-compatibility-threshold", 0.5
)
mod = MLRelativeReasoning.create(**params)
elif model_type == "rbrr":
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
params["threshold"] = threshold
# params['fingerprinter'] = fingerprinter
params["build_app_domain"] = build_ad
params["app_domain_num_neighbours"] = num_neighbors
params["app_domain_reliability_threshold"] = reliability_threshold
params["app_domain_local_compatibility_threshold"] = local_compatibility_threshold
mod = MLRelativeReasoning.create(**params)
else:
mod = RuleBasedRelativeReasoning.create(**params)
from .tasks import build_model
build_model.delay(mod.pk)
mod = RuleBasedRelativeReasoning.create(**params)
elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS.values():
pass
else:
return error(
request, "Invalid model type.", f'Model type "{model_type}" is not supported."'
)
return redirect(mod.url)
from .tasks import dispatch, build_model
dispatch(current_user, build_model, mod.pk)
return redirect(mod.url)
else:
return HttpResponseNotAllowed(["GET", "POST"])
@ -865,6 +865,10 @@ def package_model(request, package_uuid, model_uuid):
return JsonResponse({"error": f'"{smiles}" is not a valid SMILES'}, status=400)
if classify:
from epdb.tasks import dispatch_eager, predict_simple
res = dispatch_eager(current_user, predict_simple, current_model.pk, stand_smiles)
pred_res = current_model.predict(stand_smiles)
res = []
@ -909,9 +913,25 @@ def package_model(request, package_uuid, model_uuid):
current_model.delete()
return redirect(current_package.url + "/model")
elif hidden == "evaluate":
from .tasks import evaluate_model
from .tasks import dispatch, evaluate_model
eval_type = request.POST.get("model-evaluation-type")
if eval_type not in ["sg", "mg"]:
return error(
request,
"Invalid evaluation type",
f'Evaluation type "{eval_type}" is not supported. Only "sg" and "mg" are supported.',
)
multigen = eval_type == "mg"
eval_packages = request.POST.getlist("model-evaluation-packages")
eval_package_ids = [
PackageManager.get_package_by_url(current_user, p).id for p in eval_packages
]
dispatch(current_user, evaluate_model, current_model.pk, multigen, eval_package_ids)
evaluate_model.delay(current_model.pk)
return redirect(current_model.url)
else:
return HttpResponseBadRequest()
@ -1809,9 +1829,9 @@ def package_pathways(request, package_uuid):
pw.setting = prediction_setting
pw.save()
from .tasks import predict
from .tasks import dispatch, predict
predict.delay(pw.pk, prediction_setting.pk, limit=limit)
dispatch(current_user, predict, pw.pk, prediction_setting.pk, limit=limit)
return redirect(pw.url)
@ -1847,6 +1867,25 @@ def package_pathway(request, package_uuid, pathway_uuid):
return response
if (
request.GET.get("identify-missing-rules", False) == "true"
and request.GET.get("rule-package") is not None
):
from .tasks import dispatch_eager, identify_missing_rules
rule_package = PackageManager.get_package_by_url(
current_user, request.GET.get("rule-package")
)
res = dispatch_eager(
current_user, identify_missing_rules, [current_pathway.pk], rule_package.pk
)
filename = f"{current_pathway.name.replace(' ', '_')}_{current_pathway.uuid}.csv"
response = HttpResponse(res, content_type="text/csv")
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
# Pathway d3_json() relies on a lot of related objects (Nodes, Structures, Edges, Reaction, Rules, ...)
# we will again fetch the current pathway identified by this url, but this time together with nearly all
# related objects
@ -1930,10 +1969,16 @@ def package_pathway(request, package_uuid, pathway_uuid):
if node_url:
n = current_pathway.get_node(node_url)
from .tasks import predict
from .tasks import dispatch, predict
dispatch(
current_user,
predict,
current_pathway.pk,
current_pathway.setting.pk,
node_pk=n.pk,
)
# Dont delay?
predict(current_pathway.pk, current_pathway.setting.pk, node_pk=n.pk)
return JsonResponse({"success": current_pathway.url})
return HttpResponseBadRequest()
@ -2705,6 +2750,24 @@ def setting(request, setting_uuid):
pass
def jobs(request):
current_user = _anonymous_or_real(request)
context = get_base_context(request)
if request.method == "GET":
context["object_type"] = "joblog"
context["breadcrumbs"] = [
{"Home": s.SERVER_URL},
{"Jobs": s.SERVER_URL + "/jobs"},
]
if current_user.is_superuser:
context["jobs"] = JobLog.objects.all().order_by("-created")
else:
context["jobs"] = JobLog.objects.filter(user=current_user).order_by("-created")
return render(request, "collections/joblog.html", context)
###########
# KETCHER #
###########