[FIX] Fixed Search Output, Legacy API Model Endpoint, Handle ObjectsDoesNotExists in views (#297)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#297
This commit is contained in:
2026-01-15 20:39:54 +13:00
parent 6499a0c659
commit 54f8302104
9 changed files with 111 additions and 103 deletions

View File

@ -5,27 +5,31 @@ from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.http import HttpResponse
from django.shortcuts import redirect
from ninja import Field, Form, Router, Schema, Query
from ninja import Field, Form, Query, Router, Schema
from ninja.security import SessionAuth
from utilities.chem import FormatConverter
from utilities.misc import PackageExporter
from .logic import GroupManager, PackageManager, SettingManager, UserManager, SearchManager
from .logic import GroupManager, PackageManager, SearchManager, SettingManager, UserManager
from .models import (
Compound,
CompoundStructure,
Edge,
EnviFormer,
EPModel,
MLRelativeReasoning,
Node,
PackageBasedModel,
ParallelRule,
Pathway,
Reaction,
Rule,
RuleBasedRelativeReasoning,
Scenario,
SimpleAmbitRule,
User,
UserPackagePermission,
ParallelRule,
)
Package = s.GET_PACKAGE_MODEL()
@ -237,11 +241,11 @@ def search(request, search: Query[Search]):
if "Compound Structures" in search_res:
res["structure"] = search_res["Compound Structures"]
if "Reaction" in search_res:
res["reaction"] = search_res["Reaction"]
if "Reactions" in search_res:
res["reaction"] = search_res["Reactions"]
if "Pathway" in search_res:
res["pathway"] = search_res["Pathway"]
if "Pathways" in search_res:
res["pathway"] = search_res["Pathways"]
if "Rules" in search_res:
res["rule"] = search_res["Rules"]
@ -1753,26 +1757,46 @@ class ModelWrapper(Schema):
class ModelSchema(Schema):
aliases: List[str] = Field([], alias="aliases")
description: str = Field(None, alias="description")
evalPackages: List["SimplePackage"] = Field([])
evalPackages: List["SimplePackage"] = Field([], alias="eval_packages")
id: str = Field(None, alias="url")
identifier: str = "relative-reasoning"
# "info" : {
# "Accuracy (Single-Gen)" : "0.5932962678936605" ,
# "Area under PR-Curve (Single-Gen)" : "0.5654653182134282" ,
# "Area under ROC-Curve (Single-Gen)" : "0.8178302405034772" ,
# "Precision (Single-Gen)" : "0.6978730822873083" ,
# "Probability Threshold" : "0.5" ,
# "Recall/Sensitivity (Single-Gen)" : "0.4484149210261006"
# } ,
info: dict = Field({}, alias="info")
name: str = Field(None, alias="name")
pathwayPackages: List["SimplePackage"] = Field([])
pathwayPackages: List["SimplePackage"] = Field([], alias="pathway_packages")
reviewStatus: str = Field(None, alias="review_status")
rulePackages: List["SimplePackage"] = Field([])
rulePackages: List["SimplePackage"] = Field([], alias="rule_packages")
scenarios: List["SimpleScenario"] = Field([], alias="scenarios")
status: str
statusMessage: str
threshold: str
type: str
status: str = Field(None, alias="model_status")
statusMessage: str = Field(None, alias="status_message")
threshold: str = Field(None, alias="threshold")
type: str = Field(None, alias="model_type")
@staticmethod
def resolve_info(obj: EPModel):
return {}
@staticmethod
def resolve_status_message(obj: EPModel):
for k, v in PackageBasedModel.PROGRESS_STATUS_CHOICES.items():
if k == obj.model_status:
return v
return None
@staticmethod
def resolve_threshold(obj: EPModel):
return f"{obj.threshold:.2f}"
@staticmethod
def resolve_model_type(obj: EPModel):
if isinstance(obj, RuleBasedRelativeReasoning):
return "RULEBASED"
elif isinstance(obj, MLRelativeReasoning):
return "ECC"
elif isinstance(obj, EnviFormer):
return "ENVIFORMER"
else:
return None
@router.get("/model", response={200: ModelWrapper, 403: Error})

View File

@ -1353,12 +1353,13 @@ class SimpleAmbitRule(SimpleRule):
def get_rule_identifier(self) -> str:
return "simple-rule"
def apply(self, smiles):
def apply(self, smiles, *args, **kwargs):
return FormatConverter.apply(
smiles,
self.smirks,
reactant_filter_smarts=self.reactant_filter_smarts,
product_filter_smarts=self.product_filter_smarts,
**kwargs,
)
@property
@ -1388,8 +1389,8 @@ class SimpleAmbitRule(SimpleRule):
class SimpleRDKitRule(SimpleRule):
reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
def apply(self, smiles):
return FormatConverter.apply(smiles, self.reaction_smarts)
def apply(self, smiles, *args, **kwargs):
return FormatConverter.apply(smiles, self.reaction_smarts, **kwargs)
def _url(self):
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
@ -1410,10 +1411,10 @@ class ParallelRule(Rule):
def srs(self) -> QuerySet:
return self.simple_rules.all()
def apply(self, structure):
def apply(self, structure, *args, **kwargs):
res = list()
for simple_rule in self.srs:
res.extend(simple_rule.apply(structure))
res.extend(simple_rule.apply(structure, **kwargs))
return list(set(res))
@ -1518,11 +1519,11 @@ class SequentialRule(Rule):
def srs(self):
return self.simple_rules.all()
def apply(self, structure):
def apply(self, structure, *args, **kwargs):
# TODO determine levels or see java implementation
res = set()
for simple_rule in self.srs:
res.union(set(simple_rule.apply(structure)))
res.union(set(simple_rule.apply(structure, **kwargs)))
return res

View File

@ -8,7 +8,7 @@ from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render
from django.shortcuts import get_object_or_404, redirect, render
from django.urls import reverse
from django.views.decorators.csrf import csrf_exempt
from envipy_additional_information import NAME_MAPPING
@ -880,7 +880,7 @@ def package_models(request, package_uuid):
def package_model(request, package_uuid, model_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_model = EPModel.objects.get(package=current_package, uuid=model_uuid)
current_model = get_object_or_404(EPModel, package=current_package, uuid=model_uuid)
if request.method == "GET":
classify = request.GET.get("classify", False)
@ -1212,7 +1212,7 @@ def package_compounds(request, package_uuid):
def package_compound(request, package_uuid, compound_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid)
current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -1346,9 +1346,9 @@ def package_compound_structures(request, package_uuid, compound_uuid):
def package_compound_structure(request, package_uuid, compound_uuid, structure_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid)
current_structure = CompoundStructure.objects.get(
compound=current_compound, uuid=structure_uuid
current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
current_structure = get_object_or_404(
CompoundStructure, compound=current_compound, uuid=structure_uuid
)
if request.method == "GET":
@ -1534,7 +1534,7 @@ def package_rules(request, package_uuid):
def package_rule(request, package_uuid, rule_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_rule = Rule.objects.get(package=current_package, uuid=rule_uuid)
current_rule = get_object_or_404(Rule, package=current_package, uuid=rule_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -1729,7 +1729,7 @@ def package_reactions(request, package_uuid):
def package_reaction(request, package_uuid, reaction_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_reaction = Reaction.objects.get(package=current_package, uuid=reaction_uuid)
current_reaction = get_object_or_404(Reaction, package=current_package, uuid=reaction_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -1924,7 +1924,9 @@ def package_pathways(request, package_uuid):
def package_pathway(request, package_uuid, pathway_uuid):
current_user: User = _anonymous_or_real(request)
current_package: Package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway: Pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_pathway: Pathway = get_object_or_404(
Pathway, package=current_package, uuid=pathway_uuid
)
if request.method == "GET":
if request.GET.get("last_modified", False):
@ -2079,7 +2081,7 @@ def package_pathway(request, package_uuid, pathway_uuid):
def package_pathway_nodes(request, package_uuid, pathway_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -2138,8 +2140,8 @@ def package_pathway_nodes(request, package_uuid, pathway_uuid):
def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_node = Node.objects.get(pathway=current_pathway, uuid=node_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
current_node = get_object_or_404(Node, pathway=current_pathway, uuid=node_uuid)
if request.method == "GET":
is_image_request = request.GET.get("image")
@ -2243,7 +2245,7 @@ def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
def package_pathway_edges(request, package_uuid, pathway_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -2312,8 +2314,8 @@ def package_pathway_edges(request, package_uuid, pathway_uuid):
def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_edge = Edge.objects.get(pathway=current_pathway, uuid=edge_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
current_edge = get_object_or_404(Edge, pathway=current_pathway, uuid=edge_uuid)
if request.method == "GET":
is_image_request = request.GET.get("image")
@ -2493,7 +2495,7 @@ def package_scenarios(request, package_uuid):
def package_scenario(request, package_uuid, scenario_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_scenario = Scenario.objects.get(package=current_package, uuid=scenario_uuid)
current_scenario = get_object_or_404(Scenario, package=current_package, uuid=scenario_uuid)
if request.method == "GET":
context = get_base_context(request)