[FIX] Fixed Search Output, Legacy API Model Endpoint, Handle ObjectsDoesNotExists in views (#297)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#297
This commit is contained in:
2026-01-15 20:39:54 +13:00
parent 6499a0c659
commit 54f8302104
9 changed files with 111 additions and 103 deletions

View File

@ -50,7 +50,7 @@ INSTALLED_APPS = [
# Custom # Custom
"epapi", # API endpoints (v1, etc.) "epapi", # API endpoints (v1, etc.)
"epdb", "epdb",
# "migration", "migration",
] ]
TENANT = os.environ.get("TENANT", "public") TENANT = os.environ.get("TENANT", "public")

View File

@ -5,27 +5,31 @@ from django.conf import settings as s
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.http import HttpResponse from django.http import HttpResponse
from django.shortcuts import redirect from django.shortcuts import redirect
from ninja import Field, Form, Router, Schema, Query from ninja import Field, Form, Query, Router, Schema
from ninja.security import SessionAuth from ninja.security import SessionAuth
from utilities.chem import FormatConverter from utilities.chem import FormatConverter
from utilities.misc import PackageExporter from utilities.misc import PackageExporter
from .logic import GroupManager, PackageManager, SettingManager, UserManager, SearchManager from .logic import GroupManager, PackageManager, SearchManager, SettingManager, UserManager
from .models import ( from .models import (
Compound, Compound,
CompoundStructure, CompoundStructure,
Edge, Edge,
EnviFormer,
EPModel, EPModel,
MLRelativeReasoning,
Node, Node,
PackageBasedModel,
ParallelRule,
Pathway, Pathway,
Reaction, Reaction,
Rule, Rule,
RuleBasedRelativeReasoning,
Scenario, Scenario,
SimpleAmbitRule, SimpleAmbitRule,
User, User,
UserPackagePermission, UserPackagePermission,
ParallelRule,
) )
Package = s.GET_PACKAGE_MODEL() Package = s.GET_PACKAGE_MODEL()
@ -237,11 +241,11 @@ def search(request, search: Query[Search]):
if "Compound Structures" in search_res: if "Compound Structures" in search_res:
res["structure"] = search_res["Compound Structures"] res["structure"] = search_res["Compound Structures"]
if "Reaction" in search_res: if "Reactions" in search_res:
res["reaction"] = search_res["Reaction"] res["reaction"] = search_res["Reactions"]
if "Pathway" in search_res: if "Pathways" in search_res:
res["pathway"] = search_res["Pathway"] res["pathway"] = search_res["Pathways"]
if "Rules" in search_res: if "Rules" in search_res:
res["rule"] = search_res["Rules"] res["rule"] = search_res["Rules"]
@ -1753,26 +1757,46 @@ class ModelWrapper(Schema):
class ModelSchema(Schema): class ModelSchema(Schema):
aliases: List[str] = Field([], alias="aliases") aliases: List[str] = Field([], alias="aliases")
description: str = Field(None, alias="description") description: str = Field(None, alias="description")
evalPackages: List["SimplePackage"] = Field([]) evalPackages: List["SimplePackage"] = Field([], alias="eval_packages")
id: str = Field(None, alias="url") id: str = Field(None, alias="url")
identifier: str = "relative-reasoning" identifier: str = "relative-reasoning"
# "info" : { info: dict = Field({}, alias="info")
# "Accuracy (Single-Gen)" : "0.5932962678936605" ,
# "Area under PR-Curve (Single-Gen)" : "0.5654653182134282" ,
# "Area under ROC-Curve (Single-Gen)" : "0.8178302405034772" ,
# "Precision (Single-Gen)" : "0.6978730822873083" ,
# "Probability Threshold" : "0.5" ,
# "Recall/Sensitivity (Single-Gen)" : "0.4484149210261006"
# } ,
name: str = Field(None, alias="name") name: str = Field(None, alias="name")
pathwayPackages: List["SimplePackage"] = Field([]) pathwayPackages: List["SimplePackage"] = Field([], alias="pathway_packages")
reviewStatus: str = Field(None, alias="review_status") reviewStatus: str = Field(None, alias="review_status")
rulePackages: List["SimplePackage"] = Field([]) rulePackages: List["SimplePackage"] = Field([], alias="rule_packages")
scenarios: List["SimpleScenario"] = Field([], alias="scenarios") scenarios: List["SimpleScenario"] = Field([], alias="scenarios")
status: str status: str = Field(None, alias="model_status")
statusMessage: str statusMessage: str = Field(None, alias="status_message")
threshold: str threshold: str = Field(None, alias="threshold")
type: str type: str = Field(None, alias="model_type")
@staticmethod
def resolve_info(obj: EPModel):
return {}
@staticmethod
def resolve_status_message(obj: EPModel):
for k, v in PackageBasedModel.PROGRESS_STATUS_CHOICES.items():
if k == obj.model_status:
return v
return None
@staticmethod
def resolve_threshold(obj: EPModel):
return f"{obj.threshold:.2f}"
@staticmethod
def resolve_model_type(obj: EPModel):
if isinstance(obj, RuleBasedRelativeReasoning):
return "RULEBASED"
elif isinstance(obj, MLRelativeReasoning):
return "ECC"
elif isinstance(obj, EnviFormer):
return "ENVIFORMER"
else:
return None
@router.get("/model", response={200: ModelWrapper, 403: Error}) @router.get("/model", response={200: ModelWrapper, 403: Error})

View File

@ -1353,12 +1353,13 @@ class SimpleAmbitRule(SimpleRule):
def get_rule_identifier(self) -> str: def get_rule_identifier(self) -> str:
return "simple-rule" return "simple-rule"
def apply(self, smiles): def apply(self, smiles, *args, **kwargs):
return FormatConverter.apply( return FormatConverter.apply(
smiles, smiles,
self.smirks, self.smirks,
reactant_filter_smarts=self.reactant_filter_smarts, reactant_filter_smarts=self.reactant_filter_smarts,
product_filter_smarts=self.product_filter_smarts, product_filter_smarts=self.product_filter_smarts,
**kwargs,
) )
@property @property
@ -1388,8 +1389,8 @@ class SimpleAmbitRule(SimpleRule):
class SimpleRDKitRule(SimpleRule): class SimpleRDKitRule(SimpleRule):
reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS") reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
def apply(self, smiles): def apply(self, smiles, *args, **kwargs):
return FormatConverter.apply(smiles, self.reaction_smarts) return FormatConverter.apply(smiles, self.reaction_smarts, **kwargs)
def _url(self): def _url(self):
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid) return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
@ -1410,10 +1411,10 @@ class ParallelRule(Rule):
def srs(self) -> QuerySet: def srs(self) -> QuerySet:
return self.simple_rules.all() return self.simple_rules.all()
def apply(self, structure): def apply(self, structure, *args, **kwargs):
res = list() res = list()
for simple_rule in self.srs: for simple_rule in self.srs:
res.extend(simple_rule.apply(structure)) res.extend(simple_rule.apply(structure, **kwargs))
return list(set(res)) return list(set(res))
@ -1518,11 +1519,11 @@ class SequentialRule(Rule):
def srs(self): def srs(self):
return self.simple_rules.all() return self.simple_rules.all()
def apply(self, structure): def apply(self, structure, *args, **kwargs):
# TODO determine levels or see java implementation # TODO determine levels or see java implementation
res = set() res = set()
for simple_rule in self.srs: for simple_rule in self.srs:
res.union(set(simple_rule.apply(structure))) res.union(set(simple_rule.apply(structure, **kwargs)))
return res return res

View File

@ -8,7 +8,7 @@ from django.conf import settings as s
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest, PermissionDenied from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render from django.shortcuts import get_object_or_404, redirect, render
from django.urls import reverse from django.urls import reverse
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from envipy_additional_information import NAME_MAPPING from envipy_additional_information import NAME_MAPPING
@ -880,7 +880,7 @@ def package_models(request, package_uuid):
def package_model(request, package_uuid, model_uuid): def package_model(request, package_uuid, model_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_model = EPModel.objects.get(package=current_package, uuid=model_uuid) current_model = get_object_or_404(EPModel, package=current_package, uuid=model_uuid)
if request.method == "GET": if request.method == "GET":
classify = request.GET.get("classify", False) classify = request.GET.get("classify", False)
@ -1212,7 +1212,7 @@ def package_compounds(request, package_uuid):
def package_compound(request, package_uuid, compound_uuid): def package_compound(request, package_uuid, compound_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid) current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
@ -1346,9 +1346,9 @@ def package_compound_structures(request, package_uuid, compound_uuid):
def package_compound_structure(request, package_uuid, compound_uuid, structure_uuid): def package_compound_structure(request, package_uuid, compound_uuid, structure_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid) current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
current_structure = CompoundStructure.objects.get( current_structure = get_object_or_404(
compound=current_compound, uuid=structure_uuid CompoundStructure, compound=current_compound, uuid=structure_uuid
) )
if request.method == "GET": if request.method == "GET":
@ -1534,7 +1534,7 @@ def package_rules(request, package_uuid):
def package_rule(request, package_uuid, rule_uuid): def package_rule(request, package_uuid, rule_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_rule = Rule.objects.get(package=current_package, uuid=rule_uuid) current_rule = get_object_or_404(Rule, package=current_package, uuid=rule_uuid)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
@ -1729,7 +1729,7 @@ def package_reactions(request, package_uuid):
def package_reaction(request, package_uuid, reaction_uuid): def package_reaction(request, package_uuid, reaction_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_reaction = Reaction.objects.get(package=current_package, uuid=reaction_uuid) current_reaction = get_object_or_404(Reaction, package=current_package, uuid=reaction_uuid)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
@ -1924,7 +1924,9 @@ def package_pathways(request, package_uuid):
def package_pathway(request, package_uuid, pathway_uuid): def package_pathway(request, package_uuid, pathway_uuid):
current_user: User = _anonymous_or_real(request) current_user: User = _anonymous_or_real(request)
current_package: Package = PackageManager.get_package_by_id(current_user, package_uuid) current_package: Package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway: Pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_pathway: Pathway = get_object_or_404(
Pathway, package=current_package, uuid=pathway_uuid
)
if request.method == "GET": if request.method == "GET":
if request.GET.get("last_modified", False): if request.GET.get("last_modified", False):
@ -2079,7 +2081,7 @@ def package_pathway(request, package_uuid, pathway_uuid):
def package_pathway_nodes(request, package_uuid, pathway_uuid): def package_pathway_nodes(request, package_uuid, pathway_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
@ -2138,8 +2140,8 @@ def package_pathway_nodes(request, package_uuid, pathway_uuid):
def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid): def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
current_node = Node.objects.get(pathway=current_pathway, uuid=node_uuid) current_node = get_object_or_404(Node, pathway=current_pathway, uuid=node_uuid)
if request.method == "GET": if request.method == "GET":
is_image_request = request.GET.get("image") is_image_request = request.GET.get("image")
@ -2243,7 +2245,7 @@ def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
def package_pathway_edges(request, package_uuid, pathway_uuid): def package_pathway_edges(request, package_uuid, pathway_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
@ -2312,8 +2314,8 @@ def package_pathway_edges(request, package_uuid, pathway_uuid):
def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid): def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
current_edge = Edge.objects.get(pathway=current_pathway, uuid=edge_uuid) current_edge = get_object_or_404(Edge, pathway=current_pathway, uuid=edge_uuid)
if request.method == "GET": if request.method == "GET":
is_image_request = request.GET.get("image") is_image_request = request.GET.get("image")
@ -2493,7 +2495,7 @@ def package_scenarios(request, package_uuid):
def package_scenario(request, package_uuid, scenario_uuid): def package_scenario(request, package_uuid, scenario_uuid):
current_user = _anonymous_or_real(request) current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_scenario = Scenario.objects.get(package=current_package, uuid=scenario_uuid) current_scenario = get_object_or_404(Scenario, package=current_package, uuid=scenario_uuid)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -76,9 +76,7 @@ def migration(request):
open(s.BASE_DIR / "fixtures" / "migration_status_per_rule.json") open(s.BASE_DIR / "fixtures" / "migration_status_per_rule.json")
) )
else: else:
BBD = Package.objects.get( BBD = Package.objects.get(uuid="32de3cf4-e3e6-4168-956e-32fa5ddb0ce1")
url="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1"
)
ALL_SMILES = [ ALL_SMILES = [
cs.smiles for cs in CompoundStructure.objects.filter(compound__package=BBD) cs.smiles for cs in CompoundStructure.objects.filter(compound__package=BBD)
] ]
@ -147,7 +145,7 @@ def migration_detail(request, package_uuid, rule_uuid):
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
BBD = Package.objects.get(name="EAWAG-BBD") BBD = Package.objects.get(uuid="32de3cf4-e3e6-4168-956e-32fa5ddb0ce1")
STRUCTURES = CompoundStructure.objects.filter(compound__package=BBD) STRUCTURES = CompoundStructure.objects.filter(compound__package=BBD)
rule = Rule.objects.get(package=BBD, uuid=rule_uuid) rule = Rule.objects.get(package=BBD, uuid=rule_uuid)

View File

@ -37,7 +37,7 @@ class RuleApplicationTest(TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
print(f"\nTotal errors {self.total_errors}") # print(f"\nTotal errors {self.total_errors}")
@staticmethod @staticmethod
def normalize_smiles(smiles): def normalize_smiles(smiles):

View File

@ -2,14 +2,12 @@ import logging
import re import re
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Dict, TYPE_CHECKING, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
from indigo import Indigo, IndigoException, IndigoObject from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer from indigo.renderer import IndigoRenderer
from rdkit import Chem, rdBase from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator from rdkit.Chem import Descriptors, MACCSkeys, rdchem, rdChemReactions, rdFingerprintGenerator
from rdkit.Chem import rdchem
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.rdmolops import GetMolFrags from rdkit.Chem.rdmolops import GetMolFrags
@ -335,9 +333,14 @@ class FormatConverter(object):
# Inplace # Inplace
if preprocess_smiles: if preprocess_smiles:
# from rdkit.Chem.rdmolops import AROMATICITY_RDKIT
# Chem.SetAromaticity(mol, AROMATICITY_RDKIT)
Chem.SanitizeMol(mol) Chem.SanitizeMol(mol)
mol = Chem.AddHs(mol) mol = Chem.AddHs(mol)
# for std in BASIC:
# mol = std.standardize(mol)
# Check if reactant_filter_smarts matches and we shouldn't apply the rule # Check if reactant_filter_smarts matches and we shouldn't apply the rule
if reactant_filter_smarts and FormatConverter.smarts_matches( if reactant_filter_smarts and FormatConverter.smarts_matches(
mol, reactant_filter_smarts mol, reactant_filter_smarts
@ -376,29 +379,6 @@ class FormatConverter(object):
prods.append(p) prods.append(p)
# if kekulize:
# # from rdkit.Chem import MolStandardize
# #
# # # Attempt re-sanitization via standardizer
# # cleaner = MolStandardize.rdMolStandardize.Cleanup()
# # mol = cleaner.cleanup(product)
# # # Fixes
# # # [2025-01-30 23:00:50] ERROR chem - Sanitizing and converting failed:
# # # non-ring atom 3 marked aromatic
# # # But does not improve overall performance
# # # for a in product.GetAtoms():
# # # if (not a.IsInRing()) and a.GetIsAromatic():
# # # a.SetIsAromatic(False)
# # #
# # # for b in product.GetBonds():
# # # if (not b.IsInRing()) and b.GetIsAromatic():
# # # b.SetIsAromatic(False)
# # for atom in product.GetAtoms():
# # atom.SetIsAromatic(False)
# # for bond in product.GetBonds():
# # bond.SetIsAromatic(False)
# Chem.Kekulize(product)
except ValueError as e: except ValueError as e:
logger.error(f"Sanitizing and converting failed:\n{e}") logger.error(f"Sanitizing and converting failed:\n{e}")
continue continue
@ -524,8 +504,8 @@ class Standardizer(ABC):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def standardize(self, smiles: str) -> str: def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
return FormatConverter.normalize(smiles) return mol
class RuleStandardizer(Standardizer): class RuleStandardizer(Standardizer):
@ -533,18 +513,20 @@ class RuleStandardizer(Standardizer):
super().__init__(name) super().__init__(name)
self.smirks = smirks self.smirks = smirks
def standardize(self, smiles: str) -> str: def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks))) rxn = rdChemReactions.ReactionFromSmarts(self.smirks)
sites = rxn.RunReactants((mol,))
if len(standardized_smiles) > 1: if len(sites) == 1:
logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}") sites = sites[0]
print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
standardized_smiles = standardized_smiles[:1]
if standardized_smiles: if len(sites) > 1:
smiles = standardized_smiles[0] logger.warning(f"{self.smirks} generated more than 1 compound {sites}")
print(f"{self.smirks} generated more than 1 compound {sites}")
return super().standardize(smiles) mol = sites[0]
return mol
class RegExStandardizer(Standardizer): class RegExStandardizer(Standardizer):
@ -552,19 +534,20 @@ class RegExStandardizer(Standardizer):
super().__init__(name) super().__init__(name)
self.replacements = replacements self.replacements = replacements
def standardize(self, smiles: str) -> str: def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
smi = smiles # smi = smiles
mod_smi = smiles # mod_smi = smiles
#
for k, v in self.replacements.items(): # for k, v in self.replacements.items():
mod_smi = smi.replace(k, v) # mod_smi = smi.replace(k, v)
#
while mod_smi != smi: # while mod_smi != smi:
mod_smi = smi # mod_smi = smi
for k, v in self.replacements.items(): # for k, v in self.replacements.items():
smi = smi.replace(k, v) # smi = smi.replace(k, v)
#
return super().standardize(smi) # return super().standardize(smi)
raise ValueError("Not implemented yet!")
FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})] FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]