forked from enviPath/enviPy
[Feature] Show Multi Gen Eval + Batch Prediction (#267)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#267
This commit is contained in:
100
epdb/views.py
100
epdb/views.py
@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from datetime import datetime
|
||||
|
||||
import nh3
|
||||
from django.conf import settings as s
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import BadRequest
|
||||
from django.core.exceptions import BadRequest, PermissionDenied
|
||||
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
|
||||
from django.shortcuts import redirect, render
|
||||
from django.urls import reverse
|
||||
@ -50,6 +51,7 @@ from .models import (
|
||||
SimpleAmbitRule,
|
||||
User,
|
||||
UserPackagePermission,
|
||||
ExpansionSchemeChoice,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -438,6 +440,18 @@ def predict_pathway(request):
|
||||
return render(request, "predict_pathway.html", context)
|
||||
|
||||
|
||||
def batch_predict_pathway(request):
|
||||
"""Top-level predict pathway view using user's default package."""
|
||||
if request.method != "GET":
|
||||
return HttpResponseNotAllowed(["GET"])
|
||||
|
||||
context = get_base_context(request)
|
||||
context["title"] = "enviPath - Batch Predict Pathway"
|
||||
context["meta"]["current_package"] = context["meta"]["user"].default_package
|
||||
|
||||
return render(request, "batch_predict_pathway.html", context)
|
||||
|
||||
|
||||
@package_permission_required()
|
||||
def package_predict_pathway(request, package_uuid):
|
||||
"""Package-specific predict pathway view."""
|
||||
@ -1967,7 +1981,7 @@ def package_pathways(request, package_uuid):
|
||||
|
||||
if pw_mode == "predict" or pw_mode == "incremental":
|
||||
# unlimited pred (will be handled by setting)
|
||||
limit = -1
|
||||
limit = None
|
||||
|
||||
# For incremental predict first level and return
|
||||
if pw_mode == "incremental":
|
||||
@ -2877,15 +2891,25 @@ def settings(request):
|
||||
)
|
||||
|
||||
if not PackageManager.readable(current_user, params["model"].package):
|
||||
raise ValueError("")
|
||||
raise PermissionDenied("You're not allowed to access this model!")
|
||||
|
||||
expansion_scheme = request.POST.get(
|
||||
"model-based-prediction-setting-expansion-scheme", "BFS"
|
||||
)
|
||||
|
||||
if expansion_scheme not in ExpansionSchemeChoice.values:
|
||||
raise BadRequest(f"Unknown expansion scheme: {expansion_scheme}")
|
||||
|
||||
params["expansion_scheme"] = ExpansionSchemeChoice(expansion_scheme)
|
||||
|
||||
elif tp_gen_method == "rule-based-prediction-setting":
|
||||
rule_packages = request.POST.getlist("rule-based-prediction-setting-packages")
|
||||
params["rule_packages"] = [
|
||||
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
||||
]
|
||||
|
||||
else:
|
||||
raise ValueError("")
|
||||
raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!")
|
||||
|
||||
created_setting = SettingManager.create_setting(
|
||||
current_user,
|
||||
@ -2963,6 +2987,59 @@ def jobs(request):
|
||||
)
|
||||
|
||||
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
|
||||
|
||||
elif job_name == "batch-predict":
|
||||
substrates = request.POST.get("substrates")
|
||||
prediction_setting_url = request.POST.get("prediction-setting")
|
||||
num_tps = request.POST.get("num-tps")
|
||||
|
||||
if substrates is None or substrates.strip() == "":
|
||||
raise BadRequest("No substrates provided.")
|
||||
|
||||
pred_data = []
|
||||
for pair in substrates.split("\n"):
|
||||
parts = pair.split(",")
|
||||
|
||||
try:
|
||||
smiles = FormatConverter.standardize(parts[0])
|
||||
except ValueError:
|
||||
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
|
||||
|
||||
# name is optional
|
||||
name = parts[1] if len(parts) > 1 else None
|
||||
pred_data.append([smiles, name])
|
||||
|
||||
max_tps = 50
|
||||
if num_tps is not None and num_tps.strip() != "":
|
||||
try:
|
||||
num_tps = int(num_tps)
|
||||
max_tps = max(min(num_tps, 50), 1)
|
||||
except ValueError:
|
||||
raise BadRequest(f"Parameter for num-tps {num_tps} is not a valid integer.")
|
||||
|
||||
batch_predict_setting = SettingManager.get_setting_by_url(
|
||||
current_user, prediction_setting_url
|
||||
)
|
||||
|
||||
target_package = PackageManager.create_package(
|
||||
current_user,
|
||||
f"Autogenerated Package for Batch Prediction {datetime.now()}",
|
||||
"This Package was generated automatically for the batch prediction task.",
|
||||
)
|
||||
|
||||
from .tasks import dispatch, batch_predict
|
||||
|
||||
res = dispatch(
|
||||
current_user,
|
||||
batch_predict,
|
||||
pred_data,
|
||||
batch_predict_setting.pk,
|
||||
target_package.pk,
|
||||
num_tps=max_tps,
|
||||
)
|
||||
|
||||
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
|
||||
|
||||
else:
|
||||
raise BadRequest(f"Job {job_name} is not supported!")
|
||||
else:
|
||||
@ -2982,6 +3059,21 @@ def job(request, job_uuid):
|
||||
# No op if status is already in a terminal state
|
||||
job.check_for_update()
|
||||
|
||||
if request.GET.get("download", False) == "true":
|
||||
if not job.is_result_downloadable():
|
||||
raise BadRequest("Result is not downloadable!")
|
||||
|
||||
if job.job_name == "batch_predict":
|
||||
filename = f"{job.job_name.replace(' ', '_')}_{job.task_id}.csv"
|
||||
else:
|
||||
raise BadRequest("Result is not downloadable!")
|
||||
|
||||
res_str = job.task_result
|
||||
response = HttpResponse(res_str, content_type="text/csv")
|
||||
response["Content-Disposition"] = f'attachment; filename="{filename}"'
|
||||
|
||||
return response
|
||||
|
||||
context["object_type"] = "joblog"
|
||||
context["breadcrumbs"] = [
|
||||
{"Home": s.SERVER_URL},
|
||||
|
||||
Reference in New Issue
Block a user