[FIX] Fixed Search Output, Legacy API Model Endpoint, Handle ObjectsDoesNotExists in views (#297)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#297
This commit is contained in:
2026-01-15 20:39:54 +13:00
parent 6499a0c659
commit 54f8302104
9 changed files with 111 additions and 103 deletions

View File

@ -50,7 +50,7 @@ INSTALLED_APPS = [
# Custom
"epapi", # API endpoints (v1, etc.)
"epdb",
# "migration",
"migration",
]
TENANT = os.environ.get("TENANT", "public")

View File

@ -5,27 +5,31 @@ from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.http import HttpResponse
from django.shortcuts import redirect
from ninja import Field, Form, Router, Schema, Query
from ninja import Field, Form, Query, Router, Schema
from ninja.security import SessionAuth
from utilities.chem import FormatConverter
from utilities.misc import PackageExporter
from .logic import GroupManager, PackageManager, SettingManager, UserManager, SearchManager
from .logic import GroupManager, PackageManager, SearchManager, SettingManager, UserManager
from .models import (
Compound,
CompoundStructure,
Edge,
EnviFormer,
EPModel,
MLRelativeReasoning,
Node,
PackageBasedModel,
ParallelRule,
Pathway,
Reaction,
Rule,
RuleBasedRelativeReasoning,
Scenario,
SimpleAmbitRule,
User,
UserPackagePermission,
ParallelRule,
)
Package = s.GET_PACKAGE_MODEL()
@ -237,11 +241,11 @@ def search(request, search: Query[Search]):
if "Compound Structures" in search_res:
res["structure"] = search_res["Compound Structures"]
if "Reaction" in search_res:
res["reaction"] = search_res["Reaction"]
if "Reactions" in search_res:
res["reaction"] = search_res["Reactions"]
if "Pathway" in search_res:
res["pathway"] = search_res["Pathway"]
if "Pathways" in search_res:
res["pathway"] = search_res["Pathways"]
if "Rules" in search_res:
res["rule"] = search_res["Rules"]
@ -1753,26 +1757,46 @@ class ModelWrapper(Schema):
class ModelSchema(Schema):
aliases: List[str] = Field([], alias="aliases")
description: str = Field(None, alias="description")
evalPackages: List["SimplePackage"] = Field([])
evalPackages: List["SimplePackage"] = Field([], alias="eval_packages")
id: str = Field(None, alias="url")
identifier: str = "relative-reasoning"
# "info" : {
# "Accuracy (Single-Gen)" : "0.5932962678936605" ,
# "Area under PR-Curve (Single-Gen)" : "0.5654653182134282" ,
# "Area under ROC-Curve (Single-Gen)" : "0.8178302405034772" ,
# "Precision (Single-Gen)" : "0.6978730822873083" ,
# "Probability Threshold" : "0.5" ,
# "Recall/Sensitivity (Single-Gen)" : "0.4484149210261006"
# } ,
info: dict = Field({}, alias="info")
name: str = Field(None, alias="name")
pathwayPackages: List["SimplePackage"] = Field([])
pathwayPackages: List["SimplePackage"] = Field([], alias="pathway_packages")
reviewStatus: str = Field(None, alias="review_status")
rulePackages: List["SimplePackage"] = Field([])
rulePackages: List["SimplePackage"] = Field([], alias="rule_packages")
scenarios: List["SimpleScenario"] = Field([], alias="scenarios")
status: str
statusMessage: str
threshold: str
type: str
status: str = Field(None, alias="model_status")
statusMessage: str = Field(None, alias="status_message")
threshold: str = Field(None, alias="threshold")
type: str = Field(None, alias="model_type")
@staticmethod
def resolve_info(obj: EPModel):
return {}
@staticmethod
def resolve_status_message(obj: EPModel):
for k, v in PackageBasedModel.PROGRESS_STATUS_CHOICES.items():
if k == obj.model_status:
return v
return None
@staticmethod
def resolve_threshold(obj: EPModel):
return f"{obj.threshold:.2f}"
@staticmethod
def resolve_model_type(obj: EPModel):
if isinstance(obj, RuleBasedRelativeReasoning):
return "RULEBASED"
elif isinstance(obj, MLRelativeReasoning):
return "ECC"
elif isinstance(obj, EnviFormer):
return "ENVIFORMER"
else:
return None
@router.get("/model", response={200: ModelWrapper, 403: Error})

View File

@ -1353,12 +1353,13 @@ class SimpleAmbitRule(SimpleRule):
def get_rule_identifier(self) -> str:
return "simple-rule"
def apply(self, smiles):
def apply(self, smiles, *args, **kwargs):
return FormatConverter.apply(
smiles,
self.smirks,
reactant_filter_smarts=self.reactant_filter_smarts,
product_filter_smarts=self.product_filter_smarts,
**kwargs,
)
@property
@ -1388,8 +1389,8 @@ class SimpleAmbitRule(SimpleRule):
class SimpleRDKitRule(SimpleRule):
reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
def apply(self, smiles):
return FormatConverter.apply(smiles, self.reaction_smarts)
def apply(self, smiles, *args, **kwargs):
return FormatConverter.apply(smiles, self.reaction_smarts, **kwargs)
def _url(self):
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
@ -1410,10 +1411,10 @@ class ParallelRule(Rule):
def srs(self) -> QuerySet:
return self.simple_rules.all()
def apply(self, structure):
def apply(self, structure, *args, **kwargs):
res = list()
for simple_rule in self.srs:
res.extend(simple_rule.apply(structure))
res.extend(simple_rule.apply(structure, **kwargs))
return list(set(res))
@ -1518,11 +1519,11 @@ class SequentialRule(Rule):
def srs(self):
return self.simple_rules.all()
def apply(self, structure):
def apply(self, structure, *args, **kwargs):
# TODO determine levels or see java implementation
res = set()
for simple_rule in self.srs:
res.union(set(simple_rule.apply(structure)))
res.union(set(simple_rule.apply(structure, **kwargs)))
return res

View File

@ -8,7 +8,7 @@ from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.core.exceptions import BadRequest, PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
from django.shortcuts import redirect, render
from django.shortcuts import get_object_or_404, redirect, render
from django.urls import reverse
from django.views.decorators.csrf import csrf_exempt
from envipy_additional_information import NAME_MAPPING
@ -880,7 +880,7 @@ def package_models(request, package_uuid):
def package_model(request, package_uuid, model_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_model = EPModel.objects.get(package=current_package, uuid=model_uuid)
current_model = get_object_or_404(EPModel, package=current_package, uuid=model_uuid)
if request.method == "GET":
classify = request.GET.get("classify", False)
@ -1212,7 +1212,7 @@ def package_compounds(request, package_uuid):
def package_compound(request, package_uuid, compound_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid)
current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -1346,9 +1346,9 @@ def package_compound_structures(request, package_uuid, compound_uuid):
def package_compound_structure(request, package_uuid, compound_uuid, structure_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid)
current_structure = CompoundStructure.objects.get(
compound=current_compound, uuid=structure_uuid
current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
current_structure = get_object_or_404(
CompoundStructure, compound=current_compound, uuid=structure_uuid
)
if request.method == "GET":
@ -1534,7 +1534,7 @@ def package_rules(request, package_uuid):
def package_rule(request, package_uuid, rule_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_rule = Rule.objects.get(package=current_package, uuid=rule_uuid)
current_rule = get_object_or_404(Rule, package=current_package, uuid=rule_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -1729,7 +1729,7 @@ def package_reactions(request, package_uuid):
def package_reaction(request, package_uuid, reaction_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_reaction = Reaction.objects.get(package=current_package, uuid=reaction_uuid)
current_reaction = get_object_or_404(Reaction, package=current_package, uuid=reaction_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -1924,7 +1924,9 @@ def package_pathways(request, package_uuid):
def package_pathway(request, package_uuid, pathway_uuid):
current_user: User = _anonymous_or_real(request)
current_package: Package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway: Pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_pathway: Pathway = get_object_or_404(
Pathway, package=current_package, uuid=pathway_uuid
)
if request.method == "GET":
if request.GET.get("last_modified", False):
@ -2079,7 +2081,7 @@ def package_pathway(request, package_uuid, pathway_uuid):
def package_pathway_nodes(request, package_uuid, pathway_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -2138,8 +2140,8 @@ def package_pathway_nodes(request, package_uuid, pathway_uuid):
def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_node = Node.objects.get(pathway=current_pathway, uuid=node_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
current_node = get_object_or_404(Node, pathway=current_pathway, uuid=node_uuid)
if request.method == "GET":
is_image_request = request.GET.get("image")
@ -2243,7 +2245,7 @@ def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
def package_pathway_edges(request, package_uuid, pathway_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
if request.method == "GET":
context = get_base_context(request)
@ -2312,8 +2314,8 @@ def package_pathway_edges(request, package_uuid, pathway_uuid):
def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
current_edge = Edge.objects.get(pathway=current_pathway, uuid=edge_uuid)
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
current_edge = get_object_or_404(Edge, pathway=current_pathway, uuid=edge_uuid)
if request.method == "GET":
is_image_request = request.GET.get("image")
@ -2493,7 +2495,7 @@ def package_scenarios(request, package_uuid):
def package_scenario(request, package_uuid, scenario_uuid):
current_user = _anonymous_or_real(request)
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
current_scenario = Scenario.objects.get(package=current_package, uuid=scenario_uuid)
current_scenario = get_object_or_404(Scenario, package=current_package, uuid=scenario_uuid)
if request.method == "GET":
context = get_base_context(request)

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -76,9 +76,7 @@ def migration(request):
open(s.BASE_DIR / "fixtures" / "migration_status_per_rule.json")
)
else:
BBD = Package.objects.get(
url="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1"
)
BBD = Package.objects.get(uuid="32de3cf4-e3e6-4168-956e-32fa5ddb0ce1")
ALL_SMILES = [
cs.smiles for cs in CompoundStructure.objects.filter(compound__package=BBD)
]
@ -147,7 +145,7 @@ def migration_detail(request, package_uuid, rule_uuid):
if request.method == "GET":
context = get_base_context(request)
BBD = Package.objects.get(name="EAWAG-BBD")
BBD = Package.objects.get(uuid="32de3cf4-e3e6-4168-956e-32fa5ddb0ce1")
STRUCTURES = CompoundStructure.objects.filter(compound__package=BBD)
rule = Rule.objects.get(package=BBD, uuid=rule_uuid)

View File

@ -37,7 +37,7 @@ class RuleApplicationTest(TestCase):
def tearDown(self):
super().tearDown()
print(f"\nTotal errors {self.total_errors}")
# print(f"\nTotal errors {self.total_errors}")
@staticmethod
def normalize_smiles(smiles):

View File

@ -2,14 +2,12 @@ import logging
import re
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Dict, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer
from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
from rdkit.Chem import rdchem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import Descriptors, MACCSkeys, rdchem, rdChemReactions, rdFingerprintGenerator
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.rdmolops import GetMolFrags
@ -335,9 +333,14 @@ class FormatConverter(object):
# Inplace
if preprocess_smiles:
# from rdkit.Chem.rdmolops import AROMATICITY_RDKIT
# Chem.SetAromaticity(mol, AROMATICITY_RDKIT)
Chem.SanitizeMol(mol)
mol = Chem.AddHs(mol)
# for std in BASIC:
# mol = std.standardize(mol)
# Check if reactant_filter_smarts matches and we shouldn't apply the rule
if reactant_filter_smarts and FormatConverter.smarts_matches(
mol, reactant_filter_smarts
@ -376,29 +379,6 @@ class FormatConverter(object):
prods.append(p)
# if kekulize:
# # from rdkit.Chem import MolStandardize
# #
# # # Attempt re-sanitization via standardizer
# # cleaner = MolStandardize.rdMolStandardize.Cleanup()
# # mol = cleaner.cleanup(product)
# # # Fixes
# # # [2025-01-30 23:00:50] ERROR chem - Sanitizing and converting failed:
# # # non-ring atom 3 marked aromatic
# # # But does not improve overall performance
# # # for a in product.GetAtoms():
# # # if (not a.IsInRing()) and a.GetIsAromatic():
# # # a.SetIsAromatic(False)
# # #
# # # for b in product.GetBonds():
# # # if (not b.IsInRing()) and b.GetIsAromatic():
# # # b.SetIsAromatic(False)
# # for atom in product.GetAtoms():
# # atom.SetIsAromatic(False)
# # for bond in product.GetBonds():
# # bond.SetIsAromatic(False)
# Chem.Kekulize(product)
except ValueError as e:
logger.error(f"Sanitizing and converting failed:\n{e}")
continue
@ -524,8 +504,8 @@ class Standardizer(ABC):
def __init__(self, name):
self.name = name
def standardize(self, smiles: str) -> str:
return FormatConverter.normalize(smiles)
def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
return mol
class RuleStandardizer(Standardizer):
@ -533,18 +513,20 @@ class RuleStandardizer(Standardizer):
super().__init__(name)
self.smirks = smirks
def standardize(self, smiles: str) -> str:
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks)))
def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
rxn = rdChemReactions.ReactionFromSmarts(self.smirks)
sites = rxn.RunReactants((mol,))
if len(standardized_smiles) > 1:
logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
standardized_smiles = standardized_smiles[:1]
if len(sites) == 1:
sites = sites[0]
if standardized_smiles:
smiles = standardized_smiles[0]
if len(sites) > 1:
logger.warning(f"{self.smirks} generated more than 1 compound {sites}")
print(f"{self.smirks} generated more than 1 compound {sites}")
return super().standardize(smiles)
mol = sites[0]
return mol
class RegExStandardizer(Standardizer):
@ -552,19 +534,20 @@ class RegExStandardizer(Standardizer):
super().__init__(name)
self.replacements = replacements
def standardize(self, smiles: str) -> str:
smi = smiles
mod_smi = smiles
for k, v in self.replacements.items():
mod_smi = smi.replace(k, v)
while mod_smi != smi:
mod_smi = smi
for k, v in self.replacements.items():
smi = smi.replace(k, v)
return super().standardize(smi)
def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
# smi = smiles
# mod_smi = smiles
#
# for k, v in self.replacements.items():
# mod_smi = smi.replace(k, v)
#
# while mod_smi != smi:
# mod_smi = smi
# for k, v in self.replacements.items():
# smi = smi.replace(k, v)
#
# return super().standardize(smi)
raise ValueError("Not implemented yet!")
FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]