forked from enviPath/enviPy
[FIX] Fixed Search Output, Legacy API Model Endpoint, Handle ObjectsDoesNotExists in views (#297)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#297
This commit is contained in:
@ -50,7 +50,7 @@ INSTALLED_APPS = [
|
||||
# Custom
|
||||
"epapi", # API endpoints (v1, etc.)
|
||||
"epdb",
|
||||
# "migration",
|
||||
"migration",
|
||||
]
|
||||
|
||||
TENANT = os.environ.get("TENANT", "public")
|
||||
|
||||
@ -5,27 +5,31 @@ from django.conf import settings as s
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import redirect
|
||||
from ninja import Field, Form, Router, Schema, Query
|
||||
from ninja import Field, Form, Query, Router, Schema
|
||||
from ninja.security import SessionAuth
|
||||
|
||||
from utilities.chem import FormatConverter
|
||||
from utilities.misc import PackageExporter
|
||||
|
||||
from .logic import GroupManager, PackageManager, SettingManager, UserManager, SearchManager
|
||||
from .logic import GroupManager, PackageManager, SearchManager, SettingManager, UserManager
|
||||
from .models import (
|
||||
Compound,
|
||||
CompoundStructure,
|
||||
Edge,
|
||||
EnviFormer,
|
||||
EPModel,
|
||||
MLRelativeReasoning,
|
||||
Node,
|
||||
PackageBasedModel,
|
||||
ParallelRule,
|
||||
Pathway,
|
||||
Reaction,
|
||||
Rule,
|
||||
RuleBasedRelativeReasoning,
|
||||
Scenario,
|
||||
SimpleAmbitRule,
|
||||
User,
|
||||
UserPackagePermission,
|
||||
ParallelRule,
|
||||
)
|
||||
|
||||
Package = s.GET_PACKAGE_MODEL()
|
||||
@ -237,11 +241,11 @@ def search(request, search: Query[Search]):
|
||||
if "Compound Structures" in search_res:
|
||||
res["structure"] = search_res["Compound Structures"]
|
||||
|
||||
if "Reaction" in search_res:
|
||||
res["reaction"] = search_res["Reaction"]
|
||||
if "Reactions" in search_res:
|
||||
res["reaction"] = search_res["Reactions"]
|
||||
|
||||
if "Pathway" in search_res:
|
||||
res["pathway"] = search_res["Pathway"]
|
||||
if "Pathways" in search_res:
|
||||
res["pathway"] = search_res["Pathways"]
|
||||
|
||||
if "Rules" in search_res:
|
||||
res["rule"] = search_res["Rules"]
|
||||
@ -1753,26 +1757,46 @@ class ModelWrapper(Schema):
|
||||
class ModelSchema(Schema):
|
||||
aliases: List[str] = Field([], alias="aliases")
|
||||
description: str = Field(None, alias="description")
|
||||
evalPackages: List["SimplePackage"] = Field([])
|
||||
evalPackages: List["SimplePackage"] = Field([], alias="eval_packages")
|
||||
id: str = Field(None, alias="url")
|
||||
identifier: str = "relative-reasoning"
|
||||
# "info" : {
|
||||
# "Accuracy (Single-Gen)" : "0.5932962678936605" ,
|
||||
# "Area under PR-Curve (Single-Gen)" : "0.5654653182134282" ,
|
||||
# "Area under ROC-Curve (Single-Gen)" : "0.8178302405034772" ,
|
||||
# "Precision (Single-Gen)" : "0.6978730822873083" ,
|
||||
# "Probability Threshold" : "0.5" ,
|
||||
# "Recall/Sensitivity (Single-Gen)" : "0.4484149210261006"
|
||||
# } ,
|
||||
info: dict = Field({}, alias="info")
|
||||
name: str = Field(None, alias="name")
|
||||
pathwayPackages: List["SimplePackage"] = Field([])
|
||||
pathwayPackages: List["SimplePackage"] = Field([], alias="pathway_packages")
|
||||
reviewStatus: str = Field(None, alias="review_status")
|
||||
rulePackages: List["SimplePackage"] = Field([])
|
||||
rulePackages: List["SimplePackage"] = Field([], alias="rule_packages")
|
||||
scenarios: List["SimpleScenario"] = Field([], alias="scenarios")
|
||||
status: str
|
||||
statusMessage: str
|
||||
threshold: str
|
||||
type: str
|
||||
status: str = Field(None, alias="model_status")
|
||||
statusMessage: str = Field(None, alias="status_message")
|
||||
threshold: str = Field(None, alias="threshold")
|
||||
type: str = Field(None, alias="model_type")
|
||||
|
||||
@staticmethod
|
||||
def resolve_info(obj: EPModel):
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def resolve_status_message(obj: EPModel):
|
||||
for k, v in PackageBasedModel.PROGRESS_STATUS_CHOICES.items():
|
||||
if k == obj.model_status:
|
||||
return v
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def resolve_threshold(obj: EPModel):
|
||||
return f"{obj.threshold:.2f}"
|
||||
|
||||
@staticmethod
|
||||
def resolve_model_type(obj: EPModel):
|
||||
if isinstance(obj, RuleBasedRelativeReasoning):
|
||||
return "RULEBASED"
|
||||
elif isinstance(obj, MLRelativeReasoning):
|
||||
return "ECC"
|
||||
elif isinstance(obj, EnviFormer):
|
||||
return "ENVIFORMER"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/model", response={200: ModelWrapper, 403: Error})
|
||||
|
||||
@ -1353,12 +1353,13 @@ class SimpleAmbitRule(SimpleRule):
|
||||
def get_rule_identifier(self) -> str:
|
||||
return "simple-rule"
|
||||
|
||||
def apply(self, smiles):
|
||||
def apply(self, smiles, *args, **kwargs):
|
||||
return FormatConverter.apply(
|
||||
smiles,
|
||||
self.smirks,
|
||||
reactant_filter_smarts=self.reactant_filter_smarts,
|
||||
product_filter_smarts=self.product_filter_smarts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -1388,8 +1389,8 @@ class SimpleAmbitRule(SimpleRule):
|
||||
class SimpleRDKitRule(SimpleRule):
|
||||
reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
|
||||
|
||||
def apply(self, smiles):
|
||||
return FormatConverter.apply(smiles, self.reaction_smarts)
|
||||
def apply(self, smiles, *args, **kwargs):
|
||||
return FormatConverter.apply(smiles, self.reaction_smarts, **kwargs)
|
||||
|
||||
def _url(self):
|
||||
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
|
||||
@ -1410,10 +1411,10 @@ class ParallelRule(Rule):
|
||||
def srs(self) -> QuerySet:
|
||||
return self.simple_rules.all()
|
||||
|
||||
def apply(self, structure):
|
||||
def apply(self, structure, *args, **kwargs):
|
||||
res = list()
|
||||
for simple_rule in self.srs:
|
||||
res.extend(simple_rule.apply(structure))
|
||||
res.extend(simple_rule.apply(structure, **kwargs))
|
||||
|
||||
return list(set(res))
|
||||
|
||||
@ -1518,11 +1519,11 @@ class SequentialRule(Rule):
|
||||
def srs(self):
|
||||
return self.simple_rules.all()
|
||||
|
||||
def apply(self, structure):
|
||||
def apply(self, structure, *args, **kwargs):
|
||||
# TODO determine levels or see java implementation
|
||||
res = set()
|
||||
for simple_rule in self.srs:
|
||||
res.union(set(simple_rule.apply(structure)))
|
||||
res.union(set(simple_rule.apply(structure, **kwargs)))
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from django.conf import settings as s
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import BadRequest, PermissionDenied
|
||||
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
|
||||
from django.shortcuts import redirect, render
|
||||
from django.shortcuts import get_object_or_404, redirect, render
|
||||
from django.urls import reverse
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from envipy_additional_information import NAME_MAPPING
|
||||
@ -880,7 +880,7 @@ def package_models(request, package_uuid):
|
||||
def package_model(request, package_uuid, model_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_model = EPModel.objects.get(package=current_package, uuid=model_uuid)
|
||||
current_model = get_object_or_404(EPModel, package=current_package, uuid=model_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
classify = request.GET.get("classify", False)
|
||||
@ -1212,7 +1212,7 @@ def package_compounds(request, package_uuid):
|
||||
def package_compound(request, package_uuid, compound_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid)
|
||||
current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
@ -1346,9 +1346,9 @@ def package_compound_structures(request, package_uuid, compound_uuid):
|
||||
def package_compound_structure(request, package_uuid, compound_uuid, structure_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid)
|
||||
current_structure = CompoundStructure.objects.get(
|
||||
compound=current_compound, uuid=structure_uuid
|
||||
current_compound = get_object_or_404(Compound, package=current_package, uuid=compound_uuid)
|
||||
current_structure = get_object_or_404(
|
||||
CompoundStructure, compound=current_compound, uuid=structure_uuid
|
||||
)
|
||||
|
||||
if request.method == "GET":
|
||||
@ -1534,7 +1534,7 @@ def package_rules(request, package_uuid):
|
||||
def package_rule(request, package_uuid, rule_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_rule = Rule.objects.get(package=current_package, uuid=rule_uuid)
|
||||
current_rule = get_object_or_404(Rule, package=current_package, uuid=rule_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
@ -1729,7 +1729,7 @@ def package_reactions(request, package_uuid):
|
||||
def package_reaction(request, package_uuid, reaction_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_reaction = Reaction.objects.get(package=current_package, uuid=reaction_uuid)
|
||||
current_reaction = get_object_or_404(Reaction, package=current_package, uuid=reaction_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
@ -1924,7 +1924,9 @@ def package_pathways(request, package_uuid):
|
||||
def package_pathway(request, package_uuid, pathway_uuid):
|
||||
current_user: User = _anonymous_or_real(request)
|
||||
current_package: Package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_pathway: Pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
|
||||
current_pathway: Pathway = get_object_or_404(
|
||||
Pathway, package=current_package, uuid=pathway_uuid
|
||||
)
|
||||
|
||||
if request.method == "GET":
|
||||
if request.GET.get("last_modified", False):
|
||||
@ -2079,7 +2081,7 @@ def package_pathway(request, package_uuid, pathway_uuid):
|
||||
def package_pathway_nodes(request, package_uuid, pathway_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
|
||||
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
@ -2138,8 +2140,8 @@ def package_pathway_nodes(request, package_uuid, pathway_uuid):
|
||||
def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
|
||||
current_node = Node.objects.get(pathway=current_pathway, uuid=node_uuid)
|
||||
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
|
||||
current_node = get_object_or_404(Node, pathway=current_pathway, uuid=node_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
is_image_request = request.GET.get("image")
|
||||
@ -2243,7 +2245,7 @@ def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid):
|
||||
def package_pathway_edges(request, package_uuid, pathway_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
|
||||
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
@ -2312,8 +2314,8 @@ def package_pathway_edges(request, package_uuid, pathway_uuid):
|
||||
def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid)
|
||||
current_edge = Edge.objects.get(pathway=current_pathway, uuid=edge_uuid)
|
||||
current_pathway = get_object_or_404(Pathway, package=current_package, uuid=pathway_uuid)
|
||||
current_edge = get_object_or_404(Edge, pathway=current_pathway, uuid=edge_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
is_image_request = request.GET.get("image")
|
||||
@ -2493,7 +2495,7 @@ def package_scenarios(request, package_uuid):
|
||||
def package_scenario(request, package_uuid, scenario_uuid):
|
||||
current_user = _anonymous_or_real(request)
|
||||
current_package = PackageManager.get_package_by_id(current_user, package_uuid)
|
||||
current_scenario = Scenario.objects.get(package=current_package, uuid=scenario_uuid)
|
||||
current_scenario = get_object_or_404(Scenario, package=current_package, uuid=scenario_uuid)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@ -76,9 +76,7 @@ def migration(request):
|
||||
open(s.BASE_DIR / "fixtures" / "migration_status_per_rule.json")
|
||||
)
|
||||
else:
|
||||
BBD = Package.objects.get(
|
||||
url="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1"
|
||||
)
|
||||
BBD = Package.objects.get(uuid="32de3cf4-e3e6-4168-956e-32fa5ddb0ce1")
|
||||
ALL_SMILES = [
|
||||
cs.smiles for cs in CompoundStructure.objects.filter(compound__package=BBD)
|
||||
]
|
||||
@ -147,7 +145,7 @@ def migration_detail(request, package_uuid, rule_uuid):
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
|
||||
BBD = Package.objects.get(name="EAWAG-BBD")
|
||||
BBD = Package.objects.get(uuid="32de3cf4-e3e6-4168-956e-32fa5ddb0ce1")
|
||||
STRUCTURES = CompoundStructure.objects.filter(compound__package=BBD)
|
||||
rule = Rule.objects.get(package=BBD, uuid=rule_uuid)
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ class RuleApplicationTest(TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
print(f"\nTotal errors {self.total_errors}")
|
||||
# print(f"\nTotal errors {self.total_errors}")
|
||||
|
||||
@staticmethod
|
||||
def normalize_smiles(smiles):
|
||||
|
||||
@ -2,14 +2,12 @@ import logging
|
||||
import re
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from indigo import Indigo, IndigoException, IndigoObject
|
||||
from indigo.renderer import IndigoRenderer
|
||||
from rdkit import Chem, rdBase
|
||||
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
|
||||
from rdkit.Chem import rdchem
|
||||
from rdkit.Chem import rdChemReactions
|
||||
from rdkit.Chem import Descriptors, MACCSkeys, rdchem, rdChemReactions, rdFingerprintGenerator
|
||||
from rdkit.Chem.Draw import rdMolDraw2D
|
||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
from rdkit.Chem.rdmolops import GetMolFrags
|
||||
@ -335,9 +333,14 @@ class FormatConverter(object):
|
||||
|
||||
# Inplace
|
||||
if preprocess_smiles:
|
||||
# from rdkit.Chem.rdmolops import AROMATICITY_RDKIT
|
||||
# Chem.SetAromaticity(mol, AROMATICITY_RDKIT)
|
||||
Chem.SanitizeMol(mol)
|
||||
mol = Chem.AddHs(mol)
|
||||
|
||||
# for std in BASIC:
|
||||
# mol = std.standardize(mol)
|
||||
|
||||
# Check if reactant_filter_smarts matches and we shouldn't apply the rule
|
||||
if reactant_filter_smarts and FormatConverter.smarts_matches(
|
||||
mol, reactant_filter_smarts
|
||||
@ -376,29 +379,6 @@ class FormatConverter(object):
|
||||
|
||||
prods.append(p)
|
||||
|
||||
# if kekulize:
|
||||
# # from rdkit.Chem import MolStandardize
|
||||
# #
|
||||
# # # Attempt re-sanitization via standardizer
|
||||
# # cleaner = MolStandardize.rdMolStandardize.Cleanup()
|
||||
# # mol = cleaner.cleanup(product)
|
||||
# # # Fixes
|
||||
# # # [2025-01-30 23:00:50] ERROR chem - Sanitizing and converting failed:
|
||||
# # # non-ring atom 3 marked aromatic
|
||||
# # # But does not improve overall performance
|
||||
# # # for a in product.GetAtoms():
|
||||
# # # if (not a.IsInRing()) and a.GetIsAromatic():
|
||||
# # # a.SetIsAromatic(False)
|
||||
# # #
|
||||
# # # for b in product.GetBonds():
|
||||
# # # if (not b.IsInRing()) and b.GetIsAromatic():
|
||||
# # # b.SetIsAromatic(False)
|
||||
# # for atom in product.GetAtoms():
|
||||
# # atom.SetIsAromatic(False)
|
||||
# # for bond in product.GetBonds():
|
||||
# # bond.SetIsAromatic(False)
|
||||
# Chem.Kekulize(product)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Sanitizing and converting failed:\n{e}")
|
||||
continue
|
||||
@ -524,8 +504,8 @@ class Standardizer(ABC):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def standardize(self, smiles: str) -> str:
|
||||
return FormatConverter.normalize(smiles)
|
||||
def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
|
||||
return mol
|
||||
|
||||
|
||||
class RuleStandardizer(Standardizer):
|
||||
@ -533,18 +513,20 @@ class RuleStandardizer(Standardizer):
|
||||
super().__init__(name)
|
||||
self.smirks = smirks
|
||||
|
||||
def standardize(self, smiles: str) -> str:
|
||||
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks)))
|
||||
def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
|
||||
rxn = rdChemReactions.ReactionFromSmarts(self.smirks)
|
||||
sites = rxn.RunReactants((mol,))
|
||||
|
||||
if len(standardized_smiles) > 1:
|
||||
logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
|
||||
print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
|
||||
standardized_smiles = standardized_smiles[:1]
|
||||
if len(sites) == 1:
|
||||
sites = sites[0]
|
||||
|
||||
if standardized_smiles:
|
||||
smiles = standardized_smiles[0]
|
||||
if len(sites) > 1:
|
||||
logger.warning(f"{self.smirks} generated more than 1 compound {sites}")
|
||||
print(f"{self.smirks} generated more than 1 compound {sites}")
|
||||
|
||||
return super().standardize(smiles)
|
||||
mol = sites[0]
|
||||
|
||||
return mol
|
||||
|
||||
|
||||
class RegExStandardizer(Standardizer):
|
||||
@ -552,19 +534,20 @@ class RegExStandardizer(Standardizer):
|
||||
super().__init__(name)
|
||||
self.replacements = replacements
|
||||
|
||||
def standardize(self, smiles: str) -> str:
|
||||
smi = smiles
|
||||
mod_smi = smiles
|
||||
|
||||
for k, v in self.replacements.items():
|
||||
mod_smi = smi.replace(k, v)
|
||||
|
||||
while mod_smi != smi:
|
||||
mod_smi = smi
|
||||
for k, v in self.replacements.items():
|
||||
smi = smi.replace(k, v)
|
||||
|
||||
return super().standardize(smi)
|
||||
def standardize(self, mol: rdchem.Mol) -> rdchem.Mol:
|
||||
# smi = smiles
|
||||
# mod_smi = smiles
|
||||
#
|
||||
# for k, v in self.replacements.items():
|
||||
# mod_smi = smi.replace(k, v)
|
||||
#
|
||||
# while mod_smi != smi:
|
||||
# mod_smi = smi
|
||||
# for k, v in self.replacements.items():
|
||||
# smi = smi.replace(k, v)
|
||||
#
|
||||
# return super().standardize(smi)
|
||||
raise ValueError("Not implemented yet!")
|
||||
|
||||
|
||||
FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]
|
||||
|
||||
Reference in New Issue
Block a user