[Feature] Show Multi Gen Eval + Batch Prediction (#267)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#267
This commit is contained in:
2025-12-15 08:48:28 +13:00
parent 648ec150a9
commit d2d475b990
18 changed files with 1102 additions and 232 deletions

View File

@ -1,11 +1,12 @@
import json
import logging
from typing import Any, Dict, List
from datetime import datetime
import nh3
from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest
from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render
from django.urls import reverse
@ -50,6 +51,7 @@ from .models import (
SimpleAmbitRule,
User,
UserPackagePermission,
ExpansionSchemeChoice,
)
logger = logging.getLogger(__name__)
@ -438,6 +440,18 @@ def predict_pathway(request):
return render(request, "predict_pathway.html", context)
def batch_predict_pathway(request):
"""Top-level predict pathway view using user's default package."""
if request.method != "GET":
return HttpResponseNotAllowed(["GET"])
context = get_base_context(request)
context["title"] = "enviPath - Batch Predict Pathway"
context["meta"]["current_package"] = context["meta"]["user"].default_package
return render(request, "batch_predict_pathway.html", context)
@package_permission_required()
def package_predict_pathway(request, package_uuid):
"""Package-specific predict pathway view."""
@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid):
if pw_mode == "predict" or pw_mode == "incremental":
# unlimited pred (will be handled by setting)
limit = -1
limit = None
# For incremental predict first level and return
if pw_mode == "incremental":
@ -2877,15 +2891,25 @@ def settings(request):
)
if not PackageManager.readable(current_user, params["model"].package):
raise ValueError("")
raise PermissionDenied("You're not allowed to access this model!")
expansion_scheme = request.POST.get(
"model-based-prediction-setting-expansion-scheme", "BFS"
)
if expansion_scheme not in ExpansionSchemeChoice.values:
raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}")
params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme)
elif tp_gen_method == "rule-based-prediction-setting":
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
else:
raise ValueError("")
raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!")
created_setting = SettingManager.create_setting(
current_user,
@ -2963,6 +2987,59 @@ def jobs(request):
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
elif job_name == "batch-predict":
substrates = request.POST.get("substrates")
prediction_setting_url = request.POST.get("prediction-setting")
num_tps = request.POST.get("num-tps")
if substrates is None or substrates.strip() == "":
raise BadRequest("No substrates provided.")
pred_data = []
for pair in substrates.split("\n"):
parts = pair.split(",")
try:
smiles = FormatConverter.standardize(parts[0])
except ValueError:
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
# name is optional
name = parts[1] if len(parts) > 1 else None
pred_data.append([smiles, name])
max_tps = 50
if num_tps is not None and num_tps.strip() != "":
try:
num_tps = int(num_tps)
max_tps = max(min(num_tps, 50), 1)
except ValueError:
raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.")
batch_predict_setting = SettingManager.get_setting_by_url(
current_user, prediction_setting_url
)
target_package = PackageManager.create_package(
current_user,
f"Autogenerated Package for Batch Prediction {datetime.now()}",
"This Package was generated automatically for the batch prediction task.",
)
from .tasks import dispatch, batch_predict
res = dispatch(
current_user,
batch_predict,
pred_data,
batch_predict_setting.pk,
target_package.pk,
num_tps=max_tps,
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
else:
raise BadRequest(f"Job {job_name} is not supported!")
else:
@ -2982,6 +3059,21 @@ def job(request, job_uuid):
# No op if status is already in a terminal state
job.check_for_update()
if request.GET.get("download", False) == "true":
if not job.is_result_downloadable():
raise BadRequest("Result is not downloadable!")
if job.job_name == "batch_predict":
filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv"
else:
raise BadRequest("Result is not downloadable!")
res_str = job.task_result
response = HttpResponse(res_str, content_type="text/csv")
response["Content-Disposition"] = f'attachment; filename="{filename}"'
return response
context["object_type"] = "joblog"
context["breadcrumbs"] = [
{"Home": s.SERVER_URL},