diff --git a/envipath/__init__.py b/envipath/__init__.py index 15d7c508..5568b6d7 100644 --- a/envipath/__init__.py +++ b/envipath/__init__.py @@ -2,4 +2,4 @@ # Django starts so that shared_task will use this app. from .celery import app as celery_app -__all__ = ('celery_app',) +__all__ = ("celery_app",) diff --git a/envipath/api.py b/envipath/api.py index 26983512..0e9bd46d 100644 --- a/envipath/api.py +++ b/envipath/api.py @@ -4,8 +4,6 @@ from ninja import NinjaAPI api = NinjaAPI() -from ninja import NinjaAPI - api_v1 = NinjaAPI(title="API V1 Docs", urls_namespace="api-v1") api_legacy = NinjaAPI(title="Legacy API Docs", urls_namespace="api-legacy") diff --git a/envipath/asgi.py b/envipath/asgi.py index 18443e6e..75773752 100644 --- a/envipath/asgi.py +++ b/envipath/asgi.py @@ -11,6 +11,6 @@ import os from django.core.asgi import get_asgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings") application = get_asgi_application() diff --git a/envipath/celery.py b/envipath/celery.py index 11f4cae2..02081b62 100644 --- a/envipath/celery.py +++ b/envipath/celery.py @@ -4,15 +4,15 @@ from celery import Celery from celery.signals import setup_logging # Set the default Django settings module for the 'celery' program. -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings") -app = Celery('envipath') +app = Celery("envipath") # Using a string here means the worker doesn't have to serialize # the configuration object to child processes. # - namespace='CELERY' means all celery-related configuration keys # should have a `CELERY_` prefix. -app.config_from_object('django.conf:settings', namespace='CELERY') +app.config_from_object("django.conf:settings", namespace="CELERY") @setup_logging.connect diff --git a/envipath/urls.py b/envipath/urls.py index 92a29799..487c428e 100644 --- a/envipath/urls.py +++ b/envipath/urls.py @@ -14,6 +14,7 @@ Including another URLconf 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.conf import settings as s from django.contrib import admin from django.urls import include, path @@ -21,7 +22,6 @@ from django.urls import include, path from .api import api_v1, api_legacy urlpatterns = [ - path("", include("epdb.urls")), path("", include("migration.urls")), path("admin/", admin.site.urls), diff --git a/envipath/wsgi.py b/envipath/wsgi.py index 2e45407e..9fac6525 100644 --- a/envipath/wsgi.py +++ b/envipath/wsgi.py @@ -11,6 +11,6 @@ import os from django.core.wsgi import get_wsgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings") application = get_wsgi_application() diff --git a/epdb/admin.py b/epdb/admin.py index dfb7cb41..fefcdc32 100644 --- a/epdb/admin.py +++ b/epdb/admin.py @@ -18,7 +18,7 @@ from .models import ( Scenario, Setting, ExternalDatabase, - ExternalIdentifier + ExternalIdentifier, ) @@ -39,7 +39,7 @@ class GroupPackagePermissionAdmin(admin.ModelAdmin): class EPAdmin(admin.ModelAdmin): - search_fields = ['name', 'description'] + search_fields = ["name", "description"] class PackageAdmin(EPAdmin): diff --git a/epdb/api.py b/epdb/api.py index 0920b9bd..646d873f 100644 --- a/epdb/api.py +++ b/epdb/api.py @@ -21,7 +21,7 @@ class BearerTokenAuth(HttpBearer): def _anonymous_or_real(request): if request.user.is_authenticated and not request.user.is_anonymous: return request.user - return get_user_model().objects.get(username='anonymous') + return get_user_model().objects.get(username="anonymous") router = Router(auth=BearerTokenAuth()) @@ -85,7 +85,9 @@ def get_package(request, package_uuid): try: return PackageManager.get_package_by_id(request.auth, package_id=package_uuid) except ValueError: - return 403, {'message': f'Getting Package with id {package_uuid} failed due to insufficient rights!'} + return 403, { + "message": f"Getting Package with id {package_uuid} failed due to insufficient rights!" + } @router.get("/compound", response={200: List[CompoundSchema], 403: Error}) @@ -97,7 +99,9 @@ def get_compounds(request): return qs -@router.get("/package/{uuid:package_uuid}/compound", response={200: List[CompoundSchema], 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/compound", response={200: List[CompoundSchema], 403: Error} +) @paginate def get_package_compounds(request, package_uuid): try: @@ -105,4 +109,5 @@ def get_package_compounds(request, package_uuid): return Compound.objects.filter(package=p) except ValueError: return 403, { - 'message': f'Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!" + } diff --git a/epdb/apps.py b/epdb/apps.py index 8d43941e..179f5ad7 100644 --- a/epdb/apps.py +++ b/epdb/apps.py @@ -2,8 +2,8 @@ from django.apps import AppConfig class EPDBConfig(AppConfig): - default_auto_field = 'django.db.models.BigAutoField' - name = 'epdb' + default_auto_field = "django.db.models.BigAutoField" + name = "epdb" def ready(self): import epdb.signals # noqa: F401 diff --git a/epdb/forms.py b/epdb/forms.py deleted file mode 100644 index dcb9685c..00000000 --- a/epdb/forms.py +++ /dev/null @@ -1,5 +0,0 @@ -from django import forms - - -class EmailLoginForm(forms.Form): - email = forms.EmailField() diff --git a/epdb/legacy_api.py b/epdb/legacy_api.py index a048a441..b71e7148 100644 --- a/epdb/legacy_api.py +++ b/epdb/legacy_api.py @@ -18,14 +18,15 @@ from .models import ( Scenario, Pathway, Node, - Edge, SimpleAmbitRule + Edge, + SimpleAmbitRule, ) def _anonymous_or_real(request): if request.user.is_authenticated and not request.user.is_anonymous: return request.user - return get_user_model().objects.get(username='anonymous') + return get_user_model().objects.get(username="anonymous") # router = Router(auth=SessionAuth()) @@ -35,26 +36,27 @@ router = Router() class Error(Schema): message: str + ################# # SimpleObjects # ################# class SimpleUser(Schema): id: str = Field(None, alias="url") - identifier: str = 'user' - name: str = Field(None, alias='username') - email: str = Field(None, alias='email') + identifier: str = "user" + name: str = Field(None, alias="username") + email: str = Field(None, alias="email") class SimpleGroup(Schema): id: str = Field(None, alias="url") - identifier: str = 'group' - name: str = Field(None, alias='name') + identifier: str = "group" + name: str = Field(None, alias="name") class SimpleSetting(Schema): id: str = Field(None, alias="url") - identifier: str = 'setting' - name: str = Field(None, alias='name') + identifier: str = "setting" + name: str = Field(None, alias="name") class SimpleObject(Schema): @@ -65,55 +67,55 @@ class SimpleObject(Schema): @staticmethod def resolve_review_status(obj): if isinstance(obj, Package): - return 'reviewed' if obj.reviewed else 'unreviewed' - elif hasattr(obj, 'package'): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.reviewed else "unreviewed" + elif hasattr(obj, "package"): + return "reviewed" if obj.package.reviewed else "unreviewed" elif isinstance(obj, CompoundStructure): - return 'reviewed' if obj.compound.package.reviewed else 'unreviewed' + return "reviewed" if obj.compound.package.reviewed else "unreviewed" elif isinstance(obj, Node) or isinstance(obj, Edge): - return 'reviewed' if obj.pathway.package.reviewed else 'unreviewed' + return "reviewed" if obj.pathway.package.reviewed else "unreviewed" else: - raise ValueError('Object has no package') + raise ValueError("Object has no package") class SimplePackage(SimpleObject): - identifier: str = 'package' + identifier: str = "package" class SimpleCompound(SimpleObject): - identifier: str = 'compound' + identifier: str = "compound" class SimpleCompoundStructure(SimpleObject): - identifier: str = 'structure' + identifier: str = "structure" class SimpleRule(SimpleObject): - identifier: str = 'rule' + identifier: str = "rule" @staticmethod def resolve_url(obj: Rule): - return obj.url.replace('-ambit-', '-').replace('-rdkit-', '-') + return obj.url.replace("-ambit-", "-").replace("-rdkit-", "-") class SimpleReaction(SimpleObject): - identifier: str = 'reaction' + identifier: str = "reaction" class SimpleScenario(SimpleObject): - identifier: str = 'scenario' + identifier: str = "scenario" class SimplePathway(SimpleObject): - identifier: str = 'pathway' + identifier: str = "pathway" class SimpleNode(SimpleObject): - identifier: str = 'node' + identifier: str = "node" class SimpleEdge(SimpleObject): - identifier: str = 'edge' + identifier: str = "edge" ################ @@ -123,13 +125,14 @@ class SimpleEdge(SimpleObject): def login(request, loginusername: Form[str], loginpassword: Form[str]): from django.contrib.auth import authenticate from django.contrib.auth import login + email = User.objects.get(username=loginusername).email user = authenticate(username=email, password=loginpassword) if user: login(request, user) return user else: - return 403, {'message': 'Invalid username and/or password'} + return 403, {"message": "Invalid username and/or password"} ######## @@ -141,27 +144,23 @@ class UserWrapper(Schema): class UserSchema(Schema): affiliation: Dict[str, str] = Field(None, alias="affiliation") - defaultGroup: 'SimpleGroup' = Field(None, alias="default_group") - defaultPackage: 'SimplePackage' = Field(None, alias="default_package") - defaultSetting: 'SimpleSetting' = Field(None, alias="default_setting") + defaultGroup: "SimpleGroup" = Field(None, alias="default_group") + defaultPackage: "SimplePackage" = Field(None, alias="default_package") + defaultSetting: "SimpleSetting" = Field(None, alias="default_setting") email: str = Field(None, alias="email") - forename: str = 'not specified' - groups: List['SimpleGroup'] = Field([], alias="groups") + forename: str = "not specified" + groups: List["SimpleGroup"] = Field([], alias="groups") id: str = Field(None, alias="url") - identifier: str = 'user' - link: str = 'empty' + identifier: str = "user" + link: str = "empty" name: str = Field(None, alias="username") - surname: str = 'not specified' - settings: List['SimpleSetting'] = Field([], alias="settings") + surname: str = "not specified" + settings: List["SimpleSetting"] = Field([], alias="settings") @staticmethod def resolve_affiliation(obj: User): # TODO - return { - "city": "not specified", - "country": "not specified", - "workplace": "not specified" - } + return {"city": "not specified", "country": "not specified", "workplace": "not specified"} @staticmethod def resolve_settings(obj: User): @@ -171,9 +170,9 @@ class UserSchema(Schema): @router.get("/user", response={200: UserWrapper, 403: Error}) def get_users(request, whoami: str = None): if whoami: - return {'user': [request.user]} + return {"user": [request.user]} else: - return {'user': User.objects.all()} + return {"user": User.objects.all()} @router.get("/user/{uuid:user_uuid}", response={200: UserSchema, 403: Error}) @@ -183,14 +182,15 @@ def get_user(request, user_uuid): return u except ValueError: return 403, { - 'message': f'Getting User with id {user_uuid} failed due to insufficient rights!'} + "message": f"Getting User with id {user_uuid} failed due to insufficient rights!" + } ########### # Package # ########### class PackageWrapper(Schema): - package: List['PackageSchema'] + package: List["PackageSchema"] class PackageSchema(Schema): @@ -207,40 +207,20 @@ class PackageSchema(Schema): @staticmethod def resolve_links(obj: Package): return [ - { - 'Pathways': [ - f'{obj.url}/pathway', obj.pathways.count() - ] - }, { - 'Rules': [ - f'{obj.url}/rule', obj.rules.count() - ] - }, { - 'Compounds': [ - f'{obj.url}/compound', obj.compounds.count() - ] - }, { - 'Reactions': [ - f'{obj.url}/reaction', obj.reactions.count() - ] - }, { - 'Relative Reasoning': [ - f'{obj.url}/relative-reasoning', obj.models.count() - ] - }, { - 'Scenarios': [ - f'{obj.url}/scenario', obj.scenarios.count() - ] - } + {"Pathways": [f"{obj.url}/pathway", obj.pathways.count()]}, + {"Rules": [f"{obj.url}/rule", obj.rules.count()]}, + {"Compounds": [f"{obj.url}/compound", obj.compounds.count()]}, + {"Reactions": [f"{obj.url}/reaction", obj.reactions.count()]}, + {"Relative Reasoning": [f"{obj.url}/relative-reasoning", obj.models.count()]}, + {"Scenarios": [f"{obj.url}/scenario", obj.scenarios.count()]}, ] @staticmethod def resolve_readers(obj: Package): users = User.objects.filter( id__in=UserPackagePermission.objects.filter( - package=obj, - permission=UserPackagePermission.READ[0] - ).values_list('user', flat=True) + package=obj, permission=UserPackagePermission.READ[0] + ).values_list("user", flat=True) ).distinct() return [{u.id: u.name} for u in users] @@ -249,9 +229,8 @@ class PackageSchema(Schema): def resolve_writers(obj: Package): users = User.objects.filter( id__in=UserPackagePermission.objects.filter( - package=obj, - permission=UserPackagePermission.WRITE[0] - ).values_list('user', flat=True) + package=obj, permission=UserPackagePermission.WRITE[0] + ).values_list("user", flat=True) ).distinct() return [{u.id: u.name} for u in users] @@ -262,12 +241,14 @@ class PackageSchema(Schema): @staticmethod def resolve_review_status(obj): - return 'reviewed' if obj.reviewed else 'unreviewed' + return "reviewed" if obj.reviewed else "unreviewed" @router.get("/package", response={200: PackageWrapper, 403: Error}) def get_packages(request): - return {'package': PackageManager.get_all_readable_packages(request.user, include_reviewed=True)} + return { + "package": PackageManager.get_all_readable_packages(request.user, include_reviewed=True) + } @router.get("/package/{uuid:package_uuid}", response={200: PackageSchema, 403: Error}) @@ -276,61 +257,66 @@ def get_package(request, package_uuid): return PackageManager.get_package_by_id(request.user, package_uuid) except ValueError: return 403, { - 'message': f'Getting Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Package with id {package_uuid} failed due to insufficient rights!" + } @router.post("/package") -def create_packages(request, packageName: Form[str], packageDescription: Optional[str] = Form(None)): +def create_packages( + request, packageName: Form[str], packageDescription: Optional[str] = Form(None) +): try: - if packageName.strip() == '': - raise ValueError('Package name cannot be empty!') + if packageName.strip() == "": + raise ValueError("Package name cannot be empty!") new_pacakge = PackageManager.create_package(request.user, packageName, packageDescription) return redirect(new_pacakge.url) except ValueError as e: - return 400, {'message': str(e)} + return 400, {"message": str(e)} -@router.post("/package/{uuid:package_uuid}", response={200: PackageSchema | Any , 400: Error}) +@router.post("/package/{uuid:package_uuid}", response={200: PackageSchema | Any, 400: Error}) def update_package( - request, - package_uuid, - packageDescription: Optional[str] = Form(None), - hiddenMethod: Optional[str] = Form(None), - exportAsJson: Optional[str] = Form(None), - permissions: Optional[str] = Form(None), - ppsURI: Optional[str] = Form(None), - read: Optional[str] = Form(None), - write: Optional[str] = Form(None), + request, + package_uuid, + packageDescription: Optional[str] = Form(None), + hiddenMethod: Optional[str] = Form(None), + exportAsJson: Optional[str] = Form(None), + permissions: Optional[str] = Form(None), + ppsURI: Optional[str] = Form(None), + read: Optional[str] = Form(None), + write: Optional[str] = Form(None), ): try: p = PackageManager.get_package_by_id(request.user, package_uuid) if hiddenMethod: - if hiddenMethod == 'DELETE': + if hiddenMethod == "DELETE": p.delete() - elif packageDescription and packageDescription.strip() != '': + elif packageDescription and packageDescription.strip() != "": p.description = packageDescription p.save() return - elif exportAsJson == 'true': - pack_json = PackageManager.export_package(p, include_models=False, include_external_identifiers=False) + elif exportAsJson == "true": + pack_json = PackageManager.export_package( + p, include_models=False, include_external_identifiers=False + ) return pack_json elif all([permissions, ppsURI, read]): PackageManager.update_permissions elif all([permissions, ppsURI, write]): pass - except ValueError as e: - return 400, {'message': str(e)} + return 400, {"message": str(e)} + ################################ # Compound / CompoundStructure # ################################ class CompoundWrapper(Schema): - compound: List['SimpleCompound'] + compound: List["SimpleCompound"] class CompoundPathwayScenario(Schema): @@ -345,20 +331,20 @@ class CompoundSchema(Schema): externalReferences: Dict[str, List[str]] = Field(None, alias="external_references") id: str = Field(None, alias="url") halflifes: List[Dict[str, str]] = Field([], alias="halflifes") - identifier: str = 'compound' + identifier: str = "compound" imageSize: int = 600 name: str = Field(None, alias="name") pathwayScenarios: List[CompoundPathwayScenario] = Field([], alias="pathway_scenarios") - pathways: List['SimplePathway'] = Field([], alias="related_pathways") + pathways: List["SimplePathway"] = Field([], alias="related_pathways") pubchemCompoundReferences: List[str] = Field([], alias="pubchem_compound_references") - reactions: List['SimpleReaction'] = Field([], alias="related_reactions") + reactions: List["SimpleReaction"] = Field([], alias="related_reactions") reviewStatus: str = Field(False, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") - structures: List['CompoundStructureSchema'] = [] + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") + structures: List["CompoundStructureSchema"] = [] @staticmethod def resolve_review_status(obj: CompoundStructure): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.package.reviewed else "unreviewed" @staticmethod def resolve_external_references(obj: Compound): @@ -381,9 +367,9 @@ class CompoundSchema(Schema): def resolve_pathway_scenarios(obj: Compound): return [ { - 'scenarioId': 'https://envipath.org/package/5882df9c-dae1-4d80-a40e-db4724271456/scenario/cd8350cd-4249-4111-ba9f-4e2209338501', - 'scenarioName': 'Fritz, R. & Brauner, A. (1989) - (00004)', - 'scenarioType': 'Soil' + "scenarioId": "https://envipath.org/package/5882df9c-dae1-4d80-a40e-db4724271456/scenario/cd8350cd-4249-4111-ba9f-4e2209338501", + "scenarioName": "Fritz, R. & Brauner, A. (1989) - (00004)", + "scenarioType": "Soil", } ] @@ -398,22 +384,22 @@ class CompoundStructureSchema(Schema): formula: str = Field(None, alias="formula") halflifes: List[Dict[str, str]] = Field([], alias="halflifes") id: str = Field(None, alias="url") - identifier: str = 'structure' + identifier: str = "structure" imageSize: int = 600 inchikey: str = Field(None, alias="inchikey") isDefaultStructure: bool = Field(None, alias="is_default_structure") mass: float = Field(None, alias="mass") name: str = Field(None, alias="name") - pathways: List['SimplePathway'] = Field([], alias="related_pathways") + pathways: List["SimplePathway"] = Field([], alias="related_pathways") pubchemCompoundReferences: List[str] = Field([], alias="pubchem_compound_references") - reactions: List['SimpleReaction'] = Field([], alias="related_reactions") + reactions: List["SimpleReaction"] = Field([], alias="related_reactions") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") smiles: str = Field(None, alias="smiles") @staticmethod def resolve_review_status(obj: CompoundStructure): - return 'reviewed' if obj.compound.package.reviewed else 'unreviewed' + return "reviewed" if obj.compound.package.reviewed else "unreviewed" @staticmethod def resolve_inchi(obj: CompoundStructure): @@ -448,23 +434,23 @@ class CompoundStructureSchema(Schema): def resolve_pathway_scenarios(obj: CompoundStructure): return [ { - 'scenarioId': 'https://envipath.org/package/5882df9c-dae1-4d80-a40e-db4724271456/scenario/cd8350cd-4249-4111-ba9f-4e2209338501', - 'scenarioName': 'Fritz, R. & Brauner, A. (1989) - (00004)', - 'scenarioType': 'Soil' + "scenarioId": "https://envipath.org/package/5882df9c-dae1-4d80-a40e-db4724271456/scenario/cd8350cd-4249-4111-ba9f-4e2209338501", + "scenarioName": "Fritz, R. & Brauner, A. (1989) - (00004)", + "scenarioType": "Soil", } ] class CompoundStructureWrapper(Schema): - structure: List['SimpleCompoundStructure'] + structure: List["SimpleCompoundStructure"] @router.get("/compound", response={200: CompoundWrapper, 403: Error}) def get_compounds(request): return { - 'compound': Compound.objects.filter( + "compound": Compound.objects.filter( package__in=PackageManager.get_reviewed_packages() - ).prefetch_related('package') + ).prefetch_related("package") } @@ -472,101 +458,113 @@ def get_compounds(request): def get_package_compounds(request, package_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'compound': Compound.objects.filter(package=p).prefetch_related('package')} + return {"compound": Compound.objects.filter(package=p).prefetch_related("package")} except ValueError: return 403, { - 'message': f'Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/compound/{uuid:compound_uuid}", response={200: CompoundSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/compound/{uuid:compound_uuid}", + response={200: CompoundSchema, 403: Error}, +) def get_package_compound(request, package_uuid, compound_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) return Compound.objects.get(package=p, uuid=compound_uuid) except ValueError: return 403, { - 'message': f'Getting Compound with id {compound_uuid} failed due to insufficient rights!'} + "message": f"Getting Compound with id {compound_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/compound/{uuid:compound_uuid}/structure", - response={200: CompoundStructureWrapper, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/compound/{uuid:compound_uuid}/structure", + response={200: CompoundStructureWrapper, 403: Error}, +) def get_package_compound_structures(request, package_uuid, compound_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'structure': Compound.objects.get(package=p, uuid=compound_uuid).structures.all()} + return {"structure": Compound.objects.get(package=p, uuid=compound_uuid).structures.all()} except ValueError: return 403, { - 'message': f'Getting CompoundStructures for Compound with id {compound_uuid} failed due to insufficient rights!'} + "message": f"Getting CompoundStructures for Compound with id {compound_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/compound/{uuid:compound_uuid}/structure/{uuid:structure_uuid}", - response={200: CompoundStructureSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/compound/{uuid:compound_uuid}/structure/{uuid:structure_uuid}", + response={200: CompoundStructureSchema, 403: Error}, +) def get_package_compound_structure(request, package_uuid, compound_uuid, structure_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return CompoundStructure.objects.get(uuid=structure_uuid, - compound=Compound.objects.get(package=p, uuid=compound_uuid)) + return CompoundStructure.objects.get( + uuid=structure_uuid, compound=Compound.objects.get(package=p, uuid=compound_uuid) + ) except ValueError: return 403, { - 'message': f'Getting CompoundStructure with id {structure_uuid} failed due to insufficient rights!'} + "message": f"Getting CompoundStructure with id {structure_uuid} failed due to insufficient rights!" + } ######### # Rules # ######### class RuleWrapper(Schema): - rule: List['SimpleRule'] + rule: List["SimpleRule"] class SimpleRuleSchema(Schema): aliases: List[str] = Field([], alias="aliases") description: str = Field(None, alias="description") ecNumbers: List[Dict[str, str]] = Field([], alias="ec_numbers") - engine: str = 'ambit' + engine: str = "ambit" id: str = Field(None, alias="url") identifier: str = Field(None, alias="identifier") isCompositeRule: bool = False name: str = Field(None, alias="name") - pathways: List['SimplePathway'] = Field([], alias="related_pathways") + pathways: List["SimplePathway"] = Field([], alias="related_pathways") productFilterSmarts: str = Field("", alias="product_filter_smarts") productSmarts: str = Field(None, alias="products_smarts") reactantFilterSmarts: str = Field("", alias="reactant_filter_smarts") reactantSmarts: str = Field(None, alias="reactants_smarts") - reactions: List['SimpleReaction'] = Field([], alias="related_reactions") + reactions: List["SimpleReaction"] = Field([], alias="related_reactions") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") smirks: str = Field("", alias="smirks") # TODO transformations: str = Field("", alias="transformations") @staticmethod def resolve_url(obj: Rule): - return obj.url.replace('-ambit-', '-').replace('-rdkit-', '-') + return obj.url.replace("-ambit-", "-").replace("-rdkit-", "-") @staticmethod def resolve_identifier(obj: Rule): - if 'simple-rule' in obj.url: - return 'simple-rule' - if 'simple-ambit-rule' in obj.url: - return 'simple-rule' - elif 'parallel-rule' in obj.url: - return 'parallel-rule' - elif 'sequential-rule' in obj.url: - return 'sequential-rule' + if "simple-rule" in obj.url: + return "simple-rule" + if "simple-ambit-rule" in obj.url: + return "simple-rule" + elif "parallel-rule" in obj.url: + return "parallel-rule" + elif "sequential-rule" in obj.url: + return "sequential-rule" else: return None @staticmethod def resolve_review_status(obj: Rule): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.package.reviewed else "unreviewed" @staticmethod def resolve_product_filter_smarts(obj: Rule): - return obj.product_filter_smarts if obj.product_filter_smarts else '' + return obj.product_filter_smarts if obj.product_filter_smarts else "" @staticmethod def resolve_reactant_filter_smarts(obj: Rule): - return obj.reactant_filter_smarts if obj.reactant_filter_smarts else '' + return obj.reactant_filter_smarts if obj.reactant_filter_smarts else "" class CompositeRuleSchema(Schema): @@ -577,13 +575,13 @@ class CompositeRuleSchema(Schema): identifier: str = Field(None, alias="identifier") isCompositeRule: bool = True name: str = Field(None, alias="name") - pathways: List['SimplePathway'] = Field([], alias="related_pathways") + pathways: List["SimplePathway"] = Field([], alias="related_pathways") productFilterSmarts: str = Field("", alias="product_filter_smarts") reactantFilterSmarts: str = Field("", alias="reactant_filter_smarts") - reactions: List['SimpleReaction'] = Field([], alias="related_reactions") + reactions: List["SimpleReaction"] = Field([], alias="related_reactions") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") - simpleRules: List['SimpleRule'] = Field([], alias="simple_rules") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") + simpleRules: List["SimpleRule"] = Field([], alias="simple_rules") @staticmethod def resolve_ec_numbers(obj: Rule): @@ -591,40 +589,40 @@ class CompositeRuleSchema(Schema): @staticmethod def resolve_url(obj: Rule): - return obj.url.replace('-ambit-', '-').replace('-rdkit-', '-') + return obj.url.replace("-ambit-", "-").replace("-rdkit-", "-") @staticmethod def resolve_identifier(obj: Rule): - if 'simple-rule' in obj.url: - return 'simple-rule' - if 'simple-ambit-rule' in obj.url: - return 'simple-rule' - elif 'parallel-rule' in obj.url: - return 'parallel-rule' - elif 'sequential-rule' in obj.url: - return 'sequential-rule' + if "simple-rule" in obj.url: + return "simple-rule" + if "simple-ambit-rule" in obj.url: + return "simple-rule" + elif "parallel-rule" in obj.url: + return "parallel-rule" + elif "sequential-rule" in obj.url: + return "sequential-rule" else: return None @staticmethod def resolve_review_status(obj: Rule): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.package.reviewed else "unreviewed" @staticmethod def resolve_product_filter_smarts(obj: Rule): - return obj.product_filter_smarts if obj.product_filter_smarts else '' + return obj.product_filter_smarts if obj.product_filter_smarts else "" @staticmethod def resolve_reactant_filter_smarts(obj: Rule): - return obj.reactant_filter_smarts if obj.reactant_filter_smarts else '' + return obj.reactant_filter_smarts if obj.reactant_filter_smarts else "" @router.get("/rule", response={200: RuleWrapper, 403: Error}) def get_rules(request): return { - 'rule': Rule.objects.filter( + "rule": Rule.objects.filter( package__in=PackageManager.get_reviewed_packages() - ).prefetch_related('package') + ).prefetch_related("package") } @@ -632,26 +630,33 @@ def get_rules(request): def get_package_rules(request, package_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'rule': Rule.objects.filter(package=p).prefetch_related('package')} + return {"rule": Rule.objects.filter(package=p).prefetch_related("package")} except ValueError: return 403, { - 'message': f'Getting Rules for Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Rules for Package with id {package_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/rule/{uuid:rule_uuid}", - response={200: SimpleRuleSchema | CompositeRuleSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/rule/{uuid:rule_uuid}", + response={200: SimpleRuleSchema | CompositeRuleSchema, 403: Error}, +) def get_package_rule(request, package_uuid, rule_uuid): return _get_package_rule(request, package_uuid, rule_uuid) -@router.get("/package/{uuid:package_uuid}/simple-rule/{uuid:rule_uuid}", - response={200: SimpleRuleSchema | CompositeRuleSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/simple-rule/{uuid:rule_uuid}", + response={200: SimpleRuleSchema | CompositeRuleSchema, 403: Error}, +) def get_package_simple_rule(request, package_uuid, rule_uuid): return _get_package_rule(request, package_uuid, rule_uuid) -@router.get("/package/{uuid:package_uuid}/parallel-rule/{uuid:rule_uuid}", - response={200: SimpleRuleSchema | CompositeRuleSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/parallel-rule/{uuid:rule_uuid}", + response={200: SimpleRuleSchema | CompositeRuleSchema, 403: Error}, +) def get_package_parallel_rule(request, package_uuid, rule_uuid): return _get_package_rule(request, package_uuid, rule_uuid) @@ -662,21 +667,30 @@ def _get_package_rule(request, package_uuid, rule_uuid): return Rule.objects.get(package=p, uuid=rule_uuid) except ValueError: return 403, { - 'message': f'Getting Rule with id {rule_uuid} failed due to insufficient rights!'} + "message": f"Getting Rule with id {rule_uuid} failed due to insufficient rights!" + } # POST -@router.post("/package/{uuid:package_uuid}/rule/{uuid:rule_uuid}", response={200: str | Any, 403: Error}) +@router.post( + "/package/{uuid:package_uuid}/rule/{uuid:rule_uuid}", response={200: str | Any, 403: Error} +) def post_package_rule(request, package_uuid, rule_uuid, compound: Form[str] = None): return _post_package_rule(request, package_uuid, rule_uuid, compound=compound) -@router.post("/package/{uuid:package_uuid}/simple-rule/{uuid:rule_uuid}", response={200: str | Any, 403: Error}) +@router.post( + "/package/{uuid:package_uuid}/simple-rule/{uuid:rule_uuid}", + response={200: str | Any, 403: Error}, +) def post_package_simple_rule(request, package_uuid, rule_uuid, compound: Form[str] = None): return _post_package_rule(request, package_uuid, rule_uuid, compound=compound) -@router.post("/package/{uuid:package_uuid}/parallel-rule/{uuid:rule_uuid}", response={200: str | Any, 403: Error}) +@router.post( + "/package/{uuid:package_uuid}/parallel-rule/{uuid:rule_uuid}", + response={200: str | Any, 403: Error}, +) def post_package_parallel_rule(request, package_uuid, rule_uuid, compound: Form[str] = None): return _post_package_rule(request, package_uuid, rule_uuid, compound=compound) @@ -688,7 +702,7 @@ def _post_package_rule(request, package_uuid, rule_uuid, compound: Form[str]): if compound is not None: if not compound.split(): - return 400, {'message': 'Compound is empty'} + return 400, {"message": "Compound is empty"} product_sets = r.apply(compound) @@ -697,20 +711,21 @@ def _post_package_rule(request, package_uuid, rule_uuid, compound: Form[str]): for product in p_set: res.append(product) - return HttpResponse('\n'.join(res), content_type="text/plain") + return HttpResponse("\n".join(res), content_type="text/plain") return r except ValueError: return 403, { - 'message': f'Getting Rule with id {rule_uuid} failed due to insufficient rights!'} + "message": f"Getting Rule with id {rule_uuid} failed due to insufficient rights!" + } ############ # Reaction # ############ class ReactionWrapper(Schema): - reaction: List['SimpleReaction'] + reaction: List["SimpleReaction"] class ReactionCompoundStructure(Schema): @@ -723,17 +738,17 @@ class ReactionSchema(Schema): aliases: List[str] = Field([], alias="aliases") description: str = Field(None, alias="description") ecNumbers: List[Dict[str, str]] = Field([], alias="ec_numbers") - educts: List['ReactionCompoundStructure'] = Field([], alias="educts") + educts: List["ReactionCompoundStructure"] = Field([], alias="educts") id: str = Field(None, alias="url") - identifier: str = 'reaction' + identifier: str = "reaction" medlineRefs: List[str] = Field([], alias="medline_references") multistep: bool = Field(None, alias="multi_step") name: str = Field(None, alias="name") - pathways: List['SimplePathway'] = Field([], alias="related_pathways") - products: List['ReactionCompoundStructure'] = Field([], alias="products") + pathways: List["SimplePathway"] = Field([], alias="related_pathways") + products: List["ReactionCompoundStructure"] = Field([], alias="products") references: List[Dict[str, List[str]]] = Field([], alias="references") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") smirks: str = Field("", alias="smirks") @staticmethod @@ -757,15 +772,15 @@ class ReactionSchema(Schema): @staticmethod def resolve_review_status(obj: Rule): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.package.reviewed else "unreviewed" @router.get("/reaction", response={200: ReactionWrapper, 403: Error}) def get_reactions(request): return { - 'reaction': Reaction.objects.filter( + "reaction": Reaction.objects.filter( package__in=PackageManager.get_reviewed_packages() - ).prefetch_related('package') + ).prefetch_related("package") } @@ -773,42 +788,47 @@ def get_reactions(request): def get_package_reactions(request, package_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'reaction': Reaction.objects.filter(package=p).prefetch_related('package')} + return {"reaction": Reaction.objects.filter(package=p).prefetch_related("package")} except ValueError: return 403, { - 'message': f'Getting Reactions for Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Reactions for Package with id {package_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/reaction/{uuid:reaction_uuid}", response={200: ReactionSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/reaction/{uuid:reaction_uuid}", + response={200: ReactionSchema, 403: Error}, +) def get_package_reaction(request, package_uuid, reaction_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) return Reaction.objects.get(package=p, uuid=reaction_uuid) except ValueError: return 403, { - 'message': f'Getting Reaction with id {reaction_uuid} failed due to insufficient rights!'} + "message": f"Getting Reaction with id {reaction_uuid} failed due to insufficient rights!" + } ############ # Scenario # ############ class ScenarioWrapper(Schema): - scenario: List['SimpleScenario'] + scenario: List["SimpleScenario"] class ScenarioSchema(Schema): aliases: List[str] = Field([], alias="aliases") - collection: Dict['str', List[Dict[str, Any]]] = Field([], alias="collection") + collection: Dict["str", List[Dict[str, Any]]] = Field([], alias="collection") collectionID: Optional[str] = None description: str = Field(None, alias="description") id: str = Field(None, alias="url") - identifier: str = 'scenario' + identifier: str = "scenario" linkedTo: List[Dict[str, str]] = Field({}, alias="linked_to") name: str = Field(None, alias="name") - pathways: List['SimplePathway'] = Field([], alias="related_pathways") + pathways: List["SimplePathway"] = Field([], alias="related_pathways") relatedScenarios: List[Dict[str, str]] = Field([], alias="related_scenarios") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") type: str = Field(None, alias="scenario_type") @staticmethod @@ -817,15 +837,15 @@ class ScenarioSchema(Schema): @staticmethod def resolve_review_status(obj: Rule): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.package.reviewed else "unreviewed" @router.get("/scenario", response={200: ScenarioWrapper, 403: Error}) def get_scenarios(request): return { - 'scenario': Scenario.objects.filter( + "scenario": Scenario.objects.filter( package__in=PackageManager.get_reviewed_packages() - ).prefetch_related('package') + ).prefetch_related("package") } @@ -833,27 +853,33 @@ def get_scenarios(request): def get_package_scenarios(request, package_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'scenario': Scenario.objects.filter(package=p).prefetch_related('package')} + return {"scenario": Scenario.objects.filter(package=p).prefetch_related("package")} except ValueError: return 403, { - 'message': f'Getting Scenarios for Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Scenarios for Package with id {package_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/scenario/{uuid:scenario_uuid}", response={200: ScenarioSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/scenario/{uuid:scenario_uuid}", + response={200: ScenarioSchema, 403: Error}, +) def get_package_scenario(request, package_uuid, scenario_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) return Scenario.objects.get(package=p, uuid=scenario_uuid) except ValueError: return 403, { - 'message': f'Getting Scenario with id {scenario_uuid} failed due to insufficient rights!'} + "message": f"Getting Scenario with id {scenario_uuid} failed due to insufficient rights!" + } ########### # Pathway # ########### class PathwayWrapper(Schema): - pathway: List['SimplePathway'] + pathway: List["SimplePathway"] + class PathwayEdge(Schema): ecNumbers: List[str] = Field([], alias="ec_numbers") @@ -875,6 +901,7 @@ class PathwayEdge(Schema): return r.smirks return None + class PathwayNode(Schema): atomCount: int = Field(None, alias="atom_count") depth: int = Field(None, alias="depth") @@ -892,6 +919,7 @@ class PathwayNode(Schema): @staticmethod def resolve_atom_count(obj: Node): from rdkit import Chem + return Chem.MolFromSmiles(obj.default_node_label.smiles).GetNumAtoms() @staticmethod @@ -931,12 +959,12 @@ class PathwaySchema(Schema): nodes: List[PathwayNode] = Field([], alias="nodes") pathwayName: str = Field(None, alias="name") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") upToDate: bool = Field(None, alias="up_to_date") @staticmethod def resolve_review_status(obj: Pathway): - return 'reviewed' if obj.package.reviewed else 'unreviewed' + return "reviewed" if obj.package.reviewed else "unreviewed" @staticmethod def resolve_completed(obj: Pathway): @@ -954,9 +982,9 @@ class PathwaySchema(Schema): @router.get("/pathway", response={200: PathwayWrapper, 403: Error}) def get_pathways(request): return { - 'pathway': Pathway.objects.filter( + "pathway": Pathway.objects.filter( package__in=PackageManager.get_reviewed_packages() - ).prefetch_related('package') + ).prefetch_related("package") } @@ -964,31 +992,36 @@ def get_pathways(request): def get_package_pathways(request, package_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'pathway': Pathway.objects.filter(package=p).prefetch_related('package')} + return {"pathway": Pathway.objects.filter(package=p).prefetch_related("package")} except ValueError: return 403, { - 'message': f'Getting Pathways for Package with id {package_uuid} failed due to insufficient rights!'} + "message": f"Getting Pathways for Package with id {package_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}", response={200: PathwaySchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}", + response={200: PathwaySchema, 403: Error}, +) def get_package_pathway(request, package_uuid, pathway_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) return Pathway.objects.get(package=p, uuid=pathway_uuid) except ValueError: return 403, { - 'message': f'Getting Pathway with id {pathway_uuid} failed due to insufficient rights!'} + "message": f"Getting Pathway with id {pathway_uuid} failed due to insufficient rights!" + } @router.post("/package/{uuid:package_uuid}/pathway") def create_pathway( - request, - package_uuid, - smilesinput: Form[str], - name: Optional[str] = Form(None), - description: Optional[str] = Form(None), - rootOnly: Optional[str] = Form(None), - selectedSetting: Optional[str] = Form(None) + request, + package_uuid, + smilesinput: Form[str], + name: Optional[str] = Form(None), + description: Optional[str] = Form(None), + rootOnly: Optional[str] = Form(None), + selectedSetting: Optional[str] = Form(None), ): try: p = PackageManager.get_package_by_id(request.user, package_uuid) @@ -997,14 +1030,14 @@ def create_pathway( pw = Pathway.create(p, stand_smiles, name=name, description=description) - pw_mode = 'predict' - if rootOnly and rootOnly == 'true': - pw_mode = 'build' + pw_mode = "predict" + if rootOnly and rootOnly == "true": + pw_mode = "build" - pw.kv.update({'mode': pw_mode}) + pw.kv.update({"mode": pw_mode}) pw.save() - if pw_mode == 'predict': + if pw_mode == "predict": setting = request.user.prediction_settings() if selectedSetting: @@ -1014,6 +1047,7 @@ def create_pathway( pw.save() from .tasks import predict + predict.delay(pw.pk, setting.pk, limit=-1) return redirect(pw.url) @@ -1025,35 +1059,37 @@ def create_pathway( # Node # ######## class NodeWrapper(Schema): - node: List['SimpleNode'] + node: List["SimpleNode"] + class NodeCompoundStructure(Schema): id: str = Field(None, alias="url") image: str = Field(None, alias="image") smiles: str = Field(None, alias="smiles") - name:str =Field(None, alias="name") + name: str = Field(None, alias="name") @staticmethod def resolve_image(obj: CompoundStructure): return f"{obj.url}?image=svg" + class NodeSchema(Schema): aliases: List[str] = Field([], alias="aliases") confidenceScenarios: List[SimpleScenario] = Field([], alias="confidence_scenarios") defaultStructure: NodeCompoundStructure = Field(None, alias="default_node_label") depth: int = Field(None, alias="depth") description: str = Field(None, alias="description") - engineeredIntermediate:bool = Field(None, alias="engineered_intermediate") + engineeredIntermediate: bool = Field(None, alias="engineered_intermediate") halflifes: Dict[str, str] = Field({}, alias="halflife") id: str = Field(None, alias="url") - identifier: str = 'node' + identifier: str = "node" image: str = Field(None, alias="image") name: str = Field(None, alias="name") proposedValues: List[Dict[str, str]] = Field([], alias="proposed_values") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") smiles: str = Field(None, alias="default_node_label.smiles") - structures: List['CompoundStructureSchema'] = Field([], alias="node_labels") + structures: List["CompoundStructureSchema"] = Field([], alias="node_labels") @staticmethod def resolve_engineered_intermediate(obj: Node): @@ -1075,20 +1111,27 @@ class NodeSchema(Schema): @staticmethod def resolve_review_status(obj: Node): - return 'reviewed' if obj.pathway.package.reviewed else 'unreviewed' + return "reviewed" if obj.pathway.package.reviewed else "unreviewed" -@router.get("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/node", response={200: NodeWrapper, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/node", + response={200: NodeWrapper, 403: Error}, +) def get_package_pathway_nodes(request, package_uuid, pathway_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'node': Pathway.objects.get(package=p, uuid=pathway_uuid).nodes} + return {"node": Pathway.objects.get(package=p, uuid=pathway_uuid).nodes} except ValueError: return 403, { - 'message': f'Getting Nodes for Pathway with id {pathway_uuid} failed due to insufficient rights!'} + "message": f"Getting Nodes for Pathway with id {pathway_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/node/{uuid:node_uuid}", response={200: NodeSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/node/{uuid:node_uuid}", + response={200: NodeSchema, 403: Error}, +) def get_package_pathway_node(request, package_uuid, pathway_uuid, node_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) @@ -1096,13 +1139,16 @@ def get_package_pathway_node(request, package_uuid, pathway_uuid, node_uuid): return Node.objects.get(pathway=pw, uuid=node_uuid) except ValueError: return 403, { - 'message': f'Getting Node with id {node_uuid} failed due to insufficient rights!'} + "message": f"Getting Node with id {node_uuid} failed due to insufficient rights!" + } + ######## # Edge # ######## class EdgeWrapper(Schema): - edge: List['SimpleEdge'] + edge: List["SimpleEdge"] + class EdgeNode(SimpleNode): image: str = Field(None, alias="image") @@ -1111,36 +1157,44 @@ class EdgeNode(SimpleNode): def resolve_image(obj: Node): return f"{obj.default_node_label.url}?image=svg" + class EdgeSchema(Schema): aliases: List[str] = Field([], alias="aliases") description: str = Field(None, alias="description") ecNumbers: List[str] = Field([], alias="ec_numbers") - endNodes: List['EdgeNode'] = Field([], alias="end_nodes") + endNodes: List["EdgeNode"] = Field([], alias="end_nodes") id: str = Field(None, alias="url") - identifier: str = 'edge' + identifier: str = "edge" name: str = Field(None, alias="name") reactionName: str = Field(None, alias="edge_label.name") reactionURI: str = Field(None, alias="edge_label.url") reviewStatus: str = Field(None, alias="review_status") - scenarios: List['SimpleScenario'] = Field([], alias="scenarios") - startNodes: List['EdgeNode'] = Field([], alias="start_nodes") + scenarios: List["SimpleScenario"] = Field([], alias="scenarios") + startNodes: List["EdgeNode"] = Field([], alias="start_nodes") @staticmethod def resolve_review_status(obj: Node): - return 'reviewed' if obj.pathway.package.reviewed else 'unreviewed' + return "reviewed" if obj.pathway.package.reviewed else "unreviewed" -@router.get("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/edge", response={200: EdgeWrapper, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/edge", + response={200: EdgeWrapper, 403: Error}, +) def get_package_pathway_edges(request, package_uuid, pathway_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) - return {'edge': Pathway.objects.get(package=p, uuid=pathway_uuid).edges} + return {"edge": Pathway.objects.get(package=p, uuid=pathway_uuid).edges} except ValueError: return 403, { - 'message': f'Getting Edges for Pathway with id {pathway_uuid} failed due to insufficient rights!'} + "message": f"Getting Edges for Pathway with id {pathway_uuid} failed due to insufficient rights!" + } -@router.get("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/edge/{uuid:edge_uuid}", response={200: EdgeSchema, 403: Error}) +@router.get( + "/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/edge/{uuid:edge_uuid}", + response={200: EdgeSchema, 403: Error}, +) def get_package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid): try: p = PackageManager.get_package_by_id(request.user, package_uuid) @@ -1148,39 +1202,39 @@ def get_package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid): return Edge.objects.get(pathway=pw, uuid=edge_uuid) except ValueError: return 403, { - 'message': f'Getting Edge with id {edge_uuid} failed due to insufficient rights!'} + "message": f"Getting Edge with id {edge_uuid} failed due to insufficient rights!" + } ########### # Setting # ########### class SettingWrapper(Schema): - setting: List['SimpleSetting'] + setting: List["SimpleSetting"] class SettingSchema(Schema): duplicationHash: int = -1 id: str = Field(None, alias="url") - identifier: str = 'setting' - includedPackages: List['SimplePackage'] = Field([], alias="rule_packages") + identifier: str = "setting" + includedPackages: List["SimplePackage"] = Field([], alias="rule_packages") isPublic: bool = Field(None, alias="public") name: str = Field(None, alias="name") - normalizationRules: List['SimpleRule']= Field([], alias="normalization_rules") + normalizationRules: List["SimpleRule"] = Field([], alias="normalization_rules") propertyPlugins: List[str] = Field([], alias="property_plugins") truncationstrategy: Optional[str] = Field(None, alias="truncation_strategy") -@router.get('/setting', response={200: SettingWrapper, 403: Error}) +@router.get("/setting", response={200: SettingWrapper, 403: Error}) def get_settings(request): - return { - 'setting': SettingManager.get_all_settings(request.user) - } + return {"setting": SettingManager.get_all_settings(request.user)} -@router.get('/setting/{uuid:setting_uuid}', response={200: SettingSchema, 403: Error}) +@router.get("/setting/{uuid:setting_uuid}", response={200: SettingSchema, 403: Error}) def get_setting(request, setting_uuid): try: return SettingManager.get_setting_by_id(request.user, setting_uuid) except ValueError: return 403, { - 'message': f'Getting Setting with id {setting_uuid} failed due to insufficient rights!'} + "message": f"Getting Setting with id {setting_uuid} failed due to insufficient rights!" + } diff --git a/epdb/logic.py b/epdb/logic.py index ec51f3d3..be7ecdee 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -7,9 +7,26 @@ from uuid import UUID from django.contrib.auth import get_user_model from django.db import transaction from django.conf import settings as s +from pydantic import ValidationError -from epdb.models import User, Package, UserPackagePermission, GroupPackagePermission, Permission, Group, Setting, \ - EPModel, UserSettingPermission, Rule, Pathway, Node, Edge, Compound, Reaction, CompoundStructure +from epdb.models import ( + User, + Package, + UserPackagePermission, + GroupPackagePermission, + Permission, + Group, + Setting, + EPModel, + UserSettingPermission, + Rule, + Pathway, + Node, + Edge, + Compound, + Reaction, + CompoundStructure, +) from utilities.chem import FormatConverter from utilities.misc import PackageImporter, PackageExporter @@ -17,23 +34,30 @@ logger = logging.getLogger(__name__) class EPDBURLParser: - - UUID_PATTERN = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}' + UUID_PATTERN = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" MODEL_PATTERNS = { - 'epdb.User': re.compile(rf'^.*/user/{UUID_PATTERN}'), - 'epdb.Group': re.compile(rf'^.*/group/{UUID_PATTERN}'), - 'epdb.Package': re.compile(rf'^.*/package/{UUID_PATTERN}'), - 'epdb.Compound': re.compile(rf'^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}'), - 'epdb.CompoundStructure': re.compile(rf'^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}/structure/{UUID_PATTERN}'), - 'epdb.Rule': re.compile(rf'^.*/package/{UUID_PATTERN}/(?:simple-ambit-rule|simple-rdkit-rule|parallel-rule|sequential-rule|rule)/{UUID_PATTERN}'), - 'epdb.Reaction': re.compile(rf'^.*/package/{UUID_PATTERN}/reaction/{UUID_PATTERN}$'), - 'epdb.Pathway': re.compile(rf'^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}'), - 'epdb.Node': re.compile(rf'^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/node/{UUID_PATTERN}'), - 'epdb.Edge': re.compile(rf'^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/edge/{UUID_PATTERN}'), - 'epdb.Scenario': re.compile(rf'^.*/package/{UUID_PATTERN}/scenario/{UUID_PATTERN}'), - 'epdb.EPModel': re.compile(rf'^.*/package/{UUID_PATTERN}/model/{UUID_PATTERN}'), - 'epdb.Setting': re.compile(rf'^.*/setting/{UUID_PATTERN}'), + "epdb.User": re.compile(rf"^.*/user/{UUID_PATTERN}"), + "epdb.Group": re.compile(rf"^.*/group/{UUID_PATTERN}"), + "epdb.Package": re.compile(rf"^.*/package/{UUID_PATTERN}"), + "epdb.Compound": re.compile(rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}"), + "epdb.CompoundStructure": re.compile( + rf"^.*/package/{UUID_PATTERN}/compound/{UUID_PATTERN}/structure/{UUID_PATTERN}" + ), + "epdb.Rule": re.compile( + rf"^.*/package/{UUID_PATTERN}/(?:simple-ambit-rule|simple-rdkit-rule|parallel-rule|sequential-rule|rule)/{UUID_PATTERN}" + ), + "epdb.Reaction": re.compile(rf"^.*/package/{UUID_PATTERN}/reaction/{UUID_PATTERN}$"), + "epdb.Pathway": re.compile(rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}"), + "epdb.Node": re.compile( + rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/node/{UUID_PATTERN}" + ), + "epdb.Edge": re.compile( + rf"^.*/package/{UUID_PATTERN}/pathway/{UUID_PATTERN}/edge/{UUID_PATTERN}" + ), + "epdb.Scenario": re.compile(rf"^.*/package/{UUID_PATTERN}/scenario/{UUID_PATTERN}"), + "epdb.EPModel": re.compile(rf"^.*/package/{UUID_PATTERN}/model/{UUID_PATTERN}"), + "epdb.Setting": re.compile(rf"^.*/setting/{UUID_PATTERN}"), } def __init__(self, url: str): @@ -50,7 +74,8 @@ class EPDBURLParser: def _get_model_class(self, model_path: str): try: from django.apps import apps - app_label, model_name = model_path.split('.')[-2:] + + app_label, model_name = model_path.split(".")[-2:] return apps.get_model(app_label, model_name) except (ImportError, LookupError, ValueError): raise ValueError(f"Model {model_path} does not exist!") @@ -60,39 +85,42 @@ class EPDBURLParser: return model_class.objects.get(url=url) def is_package_url(self) -> bool: - return bool(re.compile(rf'^.*/package/{self.UUID_PATTERN}$').findall(self.url)) + return bool(re.compile(rf"^.*/package/{self.UUID_PATTERN}$").findall(self.url)) def contains_package_url(self): - return bool(self.MODEL_PATTERNS['epdb.Package'].findall(self.url)) and not self.is_package_url() + return ( + bool(self.MODEL_PATTERNS["epdb.Package"].findall(self.url)) + and not self.is_package_url() + ) def is_user_url(self) -> bool: - return bool(self.MODEL_PATTERNS['epdb.User'].findall(self.url)) + return bool(self.MODEL_PATTERNS["epdb.User"].findall(self.url)) def is_group_url(self) -> bool: - return bool(self.MODEL_PATTERNS['epdb.Group'].findall(self.url)) + return bool(self.MODEL_PATTERNS["epdb.Group"].findall(self.url)) def is_setting_url(self) -> bool: - return bool(self.MODEL_PATTERNS['epdb.Setting'].findall(self.url)) + return bool(self.MODEL_PATTERNS["epdb.Setting"].findall(self.url)) def get_object(self) -> Optional[Any]: # Define priority order from most specific to least specific priority_order = [ # 3rd level - 'epdb.CompoundStructure', - 'epdb.Node', - 'epdb.Edge', + "epdb.CompoundStructure", + "epdb.Node", + "epdb.Edge", # 2nd level - 'epdb.Compound', - 'epdb.Rule', - 'epdb.Reaction', - 'epdb.Scenario', - 'epdb.EPModel', - 'epdb.Pathway', + "epdb.Compound", + "epdb.Rule", + "epdb.Reaction", + "epdb.Scenario", + "epdb.EPModel", + "epdb.Pathway", # 1st level - 'epdb.Package', - 'epdb.Setting', - 'epdb.Group', - 'epdb.User', + "epdb.Package", + "epdb.Setting", + "epdb.Group", + "epdb.User", ] for model_path in priority_order: @@ -111,21 +139,21 @@ class EPDBURLParser: hierarchy_order = [ # 1st level - 'epdb.Package', - 'epdb.Setting', - 'epdb.Group', - 'epdb.User', + "epdb.Package", + "epdb.Setting", + "epdb.Group", + "epdb.User", # 2nd level - 'epdb.Compound', - 'epdb.Rule', - 'epdb.Reaction', - 'epdb.Scenario', - 'epdb.EPModel', - 'epdb.Pathway', + "epdb.Compound", + "epdb.Rule", + "epdb.Reaction", + "epdb.Scenario", + "epdb.EPModel", + "epdb.Pathway", # 3rd level - 'epdb.CompoundStructure', - 'epdb.Node', - 'epdb.Edge', + "epdb.CompoundStructure", + "epdb.Node", + "epdb.Edge", ] for model_path in hierarchy_order: @@ -143,7 +171,9 @@ class EPDBURLParser: class UserManager(object): - user_pattern = re.compile(r".*/user/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}") + user_pattern = re.compile( + r".*/user/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" + ) @staticmethod def is_user_url(url: str): @@ -151,25 +181,25 @@ class UserManager(object): @staticmethod @transaction.atomic - def create_user(username, email, password, set_setting=True, add_to_group=True, *args, **kwargs): + def create_user( + username, email, password, set_setting=True, add_to_group=True, *args, **kwargs + ): # avoid circular import :S from .tasks import send_registration_mail - extra_fields = { - 'is_active': not s.ADMIN_APPROVAL_REQUIRED - } + extra_fields = {"is_active": not s.ADMIN_APPROVAL_REQUIRED} - if 'is_active' in kwargs: - extra_fields['is_active'] = kwargs['is_active'] + if "is_active" in kwargs: + extra_fields["is_active"] = kwargs["is_active"] - if 'uuid' in kwargs: - extra_fields['uuid'] = kwargs['uuid'] + if "uuid" in kwargs: + extra_fields["uuid"] = kwargs["uuid"] u = get_user_model().objects.create_user(username, email, password, **extra_fields) # Create package package_name = f"{u.username}{'’' if u.username[-1] in 'sxzß' else 's'} Package" - package_description = f"This package was generated during registration." + package_description = "This package was generated during registration." p = PackageManager.create_package(u, package_name, package_description) u.default_package = p u.save() @@ -183,7 +213,7 @@ class UserManager(object): u.save() if add_to_group: - g = Group.objects.get(public=True, name='enviPath Users') + g = Group.objects.get(public=True, name="enviPath Users") g.user_member.add(u) g.save() u.default_group = g @@ -203,7 +233,7 @@ class UserManager(object): @staticmethod def get_user_lp(user_url: str): - uuid = user_url.strip().split('/')[-1] + uuid = user_url.strip().split("/")[-1] return get_user_model().objects.get(uuid=uuid) @staticmethod @@ -218,8 +248,11 @@ class UserManager(object): def writable(current_user, user): return (current_user == user) or user.is_superuser + class GroupManager(object): - group_pattern = re.compile(r".*/group/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}") + group_pattern = re.compile( + r".*/group/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" + ) @staticmethod def is_group_url(url: str): @@ -240,7 +273,7 @@ class GroupManager(object): @staticmethod def get_group_lp(group_url: str): - uuid = group_url.strip().split('/')[-1] + uuid = group_url.strip().split("/")[-1] return Group.objects.get(uuid=uuid) @staticmethod @@ -249,7 +282,7 @@ class GroupManager(object): @staticmethod def get_group_by_url(user, group_url): - return GroupManager.get_group_by_id(user, group_url.split('/')[-1]) + return GroupManager.get_group_by_id(user, group_url.split("/")[-1]) @staticmethod def get_group_by_id(user, group_id): @@ -266,17 +299,16 @@ class GroupManager(object): @staticmethod @transaction.atomic def update_members(caller: User, group: Group, member: Union[User, Group], add_or_remove: str): - if caller != group.owner and not caller.is_superuser: - raise ValueError('Only the group Owner is allowed to add members!') + raise ValueError("Only the group Owner is allowed to add members!") if isinstance(member, Group): - if add_or_remove == 'add': + if add_or_remove == "add": group.group_member.add(member) else: group.group_member.remove(member) else: - if add_or_remove == 'add': + if add_or_remove == "add": group.user_member.add(member) else: group.user_member.remove(member) @@ -289,7 +321,9 @@ class GroupManager(object): class PackageManager(object): - package_pattern = re.compile(r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}") + package_pattern = re.compile( + r".*/package/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" + ) @staticmethod def is_package_url(url: str): @@ -301,51 +335,73 @@ class PackageManager(object): @staticmethod def readable(user, package): - if UserPackagePermission.objects.filter(package=package, user=user).exists() or \ - GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user)) or \ - package.reviewed is True or \ - user.is_superuser: + if ( + UserPackagePermission.objects.filter(package=package, user=user).exists() + or GroupPackagePermission.objects.filter( + package=package, group__in=GroupManager.get_groups(user) + ) + or package.reviewed is True + or user.is_superuser + ): return True return False @staticmethod def writable(user, package): - if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.WRITE[0]).exists() or \ - GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).exists() or \ - UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.ALL[0]).exists() or \ - user.is_superuser: + if ( + UserPackagePermission.objects.filter( + package=package, user=user, permission=Permission.WRITE[0] + ).exists() + or GroupPackagePermission.objects.filter( + package=package, + group__in=GroupManager.get_groups(user), + permission=Permission.WRITE[0], + ).exists() + or UserPackagePermission.objects.filter( + package=package, user=user, permission=Permission.ALL[0] + ).exists() + or user.is_superuser + ): return True return False @staticmethod def administrable(user, package): - if UserPackagePermission.objects.filter(package=package, user=user, permission=Permission.ALL[0]).exists() or \ - GroupPackagePermission.objects.filter(package=package, group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).exists() or \ - user.is_superuser: + if ( + UserPackagePermission.objects.filter( + package=package, user=user, permission=Permission.ALL[0] + ).exists() + or GroupPackagePermission.objects.filter( + package=package, + group__in=GroupManager.get_groups(user), + permission=Permission.ALL[0], + ).exists() + or user.is_superuser + ): return True return False @staticmethod - def has_package_permission(user: 'User', package: Union[str, UUID, 'Package'], permission: str): - + def has_package_permission(user: "User", package: Union[str, UUID, "Package"], permission: str): if isinstance(package, str) or isinstance(package, UUID): package = Package.objects.get(uuid=package) groups = GroupManager.get_groups(user) - perms = { - 'all': ['all'], - 'write': ['all', 'write'], - 'read': ['all', 'write', 'read'] - } + perms = {"all": ["all"], "write": ["all", "write"], "read": ["all", "write", "read"]} valid_perms = perms.get(permission) - if UserPackagePermission.objects.filter(package=package, user=user, permission__in=valid_perms).exists() or \ - GroupPackagePermission.objects.filter(package=package, group__in=groups, - permission__in=valid_perms).exists() or \ - user.is_superuser: + if ( + UserPackagePermission.objects.filter( + package=package, user=user, permission__in=valid_perms + ).exists() + or GroupPackagePermission.objects.filter( + package=package, group__in=groups, permission__in=valid_perms + ).exists() + or user.is_superuser + ): return True return False @@ -354,7 +410,7 @@ class PackageManager(object): def get_package_lp(package_url): match = re.findall(PackageManager.package_pattern, package_url) if match: - package_id = match[0].split('/')[-1] + package_id = match[0].split("/")[-1] return Package.objects.get(uuid=package_id) return None @@ -363,10 +419,12 @@ class PackageManager(object): match = re.findall(PackageManager.package_pattern, package_url) if match: - package_id = match[0].split('/')[-1] + package_id = match[0].split("/")[-1] return PackageManager.get_package_by_id(user, package_id) else: - raise ValueError("Requested URL {} does not contain a valid package identifier!".format(package_url)) + raise ValueError( + "Requested URL {} does not contain a valid package identifier!".format(package_url) + ) @staticmethod def get_package_by_id(user, package_id): @@ -376,7 +434,8 @@ class PackageManager(object): return p else: raise ValueError( - "Insufficient permissions to access Package with ID {}".format(package_id)) + "Insufficient permissions to access Package with ID {}".format(package_id) + ) except Package.DoesNotExist: raise ValueError("Package with ID {} does not exist!".format(package_id)) @@ -387,10 +446,15 @@ class PackageManager(object): qs = Package.objects.all() else: user_package_qs = Package.objects.filter( - id__in=UserPackagePermission.objects.filter(user=user).values('package').distinct()) + id__in=UserPackagePermission.objects.filter(user=user).values("package").distinct() + ) group_package_qs = Package.objects.filter( - id__in=GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user)).values( - 'package').distinct()) + id__in=GroupPackagePermission.objects.filter( + group__in=GroupManager.get_groups(user) + ) + .values("package") + .distinct() + ) qs = user_package_qs | group_package_qs if include_reviewed: @@ -407,14 +471,34 @@ class PackageManager(object): if user.is_superuser: qs = Package.objects.all() else: - write_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]).values('package').distinct() - owner_user_packs = UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]).values('package').distinct() + write_user_packs = ( + UserPackagePermission.objects.filter(user=user, permission=Permission.WRITE[0]) + .values("package") + .distinct() + ) + owner_user_packs = ( + UserPackagePermission.objects.filter(user=user, permission=Permission.ALL[0]) + .values("package") + .distinct() + ) user_packs = write_user_packs | owner_user_packs user_package_qs = Package.objects.filter(id__in=user_packs) - write_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0]).values( 'package').distinct() - owner_group_packs = GroupPackagePermission.objects.filter(group__in=GroupManager.get_groups(user), permission=Permission.ALL[0]).values( 'package').distinct() + write_group_packs = ( + GroupPackagePermission.objects.filter( + group__in=GroupManager.get_groups(user), permission=Permission.WRITE[0] + ) + .values("package") + .distinct() + ) + owner_group_packs = ( + GroupPackagePermission.objects.filter( + group__in=GroupManager.get_groups(user), permission=Permission.ALL[0] + ) + .values("package") + .distinct() + ) group_packs = write_group_packs | owner_group_packs group_package_qs = Package.objects.filter(id__in=group_packs) @@ -447,24 +531,26 @@ class PackageManager(object): @staticmethod @transaction.atomic - def update_permissions(caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str]): + def update_permissions( + caller: User, package: Package, grantee: Union[User, Group], new_perm: Optional[str] + ): caller_perm = None if not caller.is_superuser: caller_perm = UserPackagePermission.objects.get(user=caller, package=package).permission if caller_perm != Permission.ALL[0] and not caller.is_superuser: - raise ValueError(f"Only owner are allowed to modify permissions") + raise ValueError("Only owner are allowed to modify permissions") data = { - 'package': package, + "package": package, } if isinstance(grantee, User): perm_cls = UserPackagePermission - data['user'] = grantee + data["user"] = grantee else: perm_cls = GroupPackagePermission - data['group'] = grantee + data["group"] = grantee if new_perm is None: qs = perm_cls.objects.filter(**data) @@ -476,34 +562,47 @@ class PackageManager(object): else: logger.debug(f"No Permission object for {perm_cls} with filter {data} found!") else: - _ = perm_cls.objects.update_or_create(defaults={'permission': new_perm}, **data) - - + _ = perm_cls.objects.update_or_create(defaults={"permission": new_perm}, **data) @staticmethod @transaction.atomic - def import_legacy_package(data: dict, owner: User, keep_ids=False, add_import_timestamp=True, trust_reviewed=False): + def import_legacy_package( + data: dict, owner: User, keep_ids=False, add_import_timestamp=True, trust_reviewed=False + ): from uuid import UUID, uuid4 from datetime import datetime from collections import defaultdict - from .models import Package, Compound, CompoundStructure, SimpleRule, SimpleAmbitRule, SimpleRDKitRule, \ - ParallelRule, SequentialRule, SequentialRuleOrdering, Reaction, Pathway, Node, Edge, Scenario + from .models import ( + Package, + Compound, + CompoundStructure, + SimpleRule, + SimpleAmbitRule, + ParallelRule, + SequentialRule, + SequentialRuleOrdering, + Reaction, + Pathway, + Node, + Edge, + Scenario, + ) from envipy_additional_information import AdditionalInformationConverter pack = Package() - pack.uuid = UUID(data['id'].split('/')[-1]) if keep_ids else uuid4() + pack.uuid = UUID(data["id"].split("/")[-1]) if keep_ids else uuid4() if add_import_timestamp: - pack.name = '{} - {}'.format(data['name'], datetime.now().strftime('%Y-%m-%d %H:%M')) + pack.name = "{} - {}".format(data["name"], datetime.now().strftime("%Y-%m-%d %H:%M")) else: - pack.name = data['name'] + pack.name = data["name"] if trust_reviewed: - pack.reviewed = True if data['reviewStatus'] == 'reviewed' else False + pack.reviewed = True if data["reviewStatus"] == "reviewed" else False else: pack.reviewed = False - pack.description = data['description'] + pack.description = data["description"] pack.save() up = UserPackagePermission() @@ -520,123 +619,124 @@ class PackageManager(object): scen_mapping = defaultdict(list) # Store Scenarios - for scenario in data['scenarios']: + for scenario in data["scenarios"]: scen = Scenario() scen.package = pack - scen.uuid = UUID(scenario['id'].split('/')[-1]) if keep_ids else uuid4() - scen.name = scenario['name'] - scen.description = scenario['description'] - scen.scenario_type = scenario['type'] - scen.scenario_date = scenario['date'] + scen.uuid = UUID(scenario["id"].split("/")[-1]) if keep_ids else uuid4() + scen.name = scenario["name"] + scen.description = scenario["description"] + scen.scenario_type = scenario["type"] + scen.scenario_date = scenario["date"] scen.additional_information = dict() scen.save() - mapping[scenario['id']] = scen.uuid + mapping[scenario["id"]] = scen.uuid new_add_inf = defaultdict(list) # TODO Store AI... - for ex in scenario.get('additionalInformationCollection', {}).get('additionalInformation', []): - name = ex['name'] - addinf_data = ex['data'] + for ex in scenario.get("additionalInformationCollection", {}).get( + "additionalInformation", [] + ): + name = ex["name"] + addinf_data = ex["data"] # park the parent scen id for now and link it later - if name == 'referringscenario': + if name == "referringscenario": parent_mapping[scen.uuid] = addinf_data continue # Broken eP Data - if name == 'initialmasssediment' and addinf_data == 'missing data': + if name == "initialmasssediment" and addinf_data == "missing data": continue # TODO Enzymes arent ready yet - if name == 'enzyme': + if name == "enzyme": continue try: res = AdditionalInformationConverter.convert(name, addinf_data) res_cls_name = res.__class__.__name__ ai_data = json.loads(res.model_dump_json()) - ai_data['uuid'] = f"{uuid4()}" + ai_data["uuid"] = f"{uuid4()}" new_add_inf[res_cls_name].append(ai_data) - except: + except ValidationError: logger.error(f"Failed to convert {name} with {addinf_data}") scen.additional_information = new_add_inf scen.save() - print('Scenarios imported...') + print("Scenarios imported...") # Store compounds and its structures - for compound in data['compounds']: + for compound in data["compounds"]: comp = Compound() comp.package = pack - comp.uuid = UUID(compound['id'].split('/')[-1]) if keep_ids else uuid4() - comp.name = compound['name'] - comp.description = compound['description'] - comp.aliases = compound['aliases'] + comp.uuid = UUID(compound["id"].split("/")[-1]) if keep_ids else uuid4() + comp.name = compound["name"] + comp.description = compound["description"] + comp.aliases = compound["aliases"] comp.save() - mapping[compound['id']] = comp.uuid + mapping[compound["id"]] = comp.uuid - for scen in compound['scenarios']: - scen_mapping[scen['id']].append(comp) + for scen in compound["scenarios"]: + scen_mapping[scen["id"]].append(comp) default_structure = None - for structure in compound['structures']: + for structure in compound["structures"]: struc = CompoundStructure() # struc.object_url = Command.get_id(structure, keep_ids) struc.compound = comp - struc.uuid = UUID(structure['id'].split('/')[-1]) if keep_ids else uuid4() - struc.name = structure['name'] - struc.description = structure['description'] - struc.smiles = structure['smiles'] + struc.uuid = UUID(structure["id"].split("/")[-1]) if keep_ids else uuid4() + struc.name = structure["name"] + struc.description = structure["description"] + struc.smiles = structure["smiles"] struc.save() - for scen in structure['scenarios']: - scen_mapping[scen['id']].append(struc) + for scen in structure["scenarios"]: + scen_mapping[scen["id"]].append(struc) - mapping[structure['id']] = struc.uuid + mapping[structure["id"]] = struc.uuid - if structure['id'] == compound['defaultStructure']['id']: + if structure["id"] == compound["defaultStructure"]["id"]: default_structure = struc struc.save() - if default_structure is None: - raise ValueError('No default structure set') + raise ValueError("No default structure set") comp.default_structure = default_structure comp.save() - print('Compounds imported...') + print("Compounds imported...") # Store simple and parallel-rules par_rules = [] seq_rules = [] - for rule in data['rules']: - if rule['identifier'] == 'parallel-rule': + for rule in data["rules"]: + if rule["identifier"] == "parallel-rule": par_rules.append(rule) continue - if rule['identifier'] == 'sequential-rule': + if rule["identifier"] == "sequential-rule": seq_rules.append(rule) continue r = SimpleAmbitRule() - r.uuid = UUID(rule['id'].split('/')[-1]) if keep_ids else uuid4() + r.uuid = UUID(rule["id"].split("/")[-1]) if keep_ids else uuid4() r.package = pack - r.name = rule['name'] - r.description = rule['description'] - r.smirks = rule['smirks'] - r.reactant_filter_smarts = rule.get('reactantFilterSmarts', None) - r.product_filter_smarts = rule.get('productFilterSmarts', None) + r.name = rule["name"] + r.description = rule["description"] + r.smirks = rule["smirks"] + r.reactant_filter_smarts = rule.get("reactantFilterSmarts", None) + r.product_filter_smarts = rule.get("productFilterSmarts", None) r.save() - mapping[rule['id']] = r.uuid + mapping[rule["id"]] = r.uuid - for scen in rule['scenarios']: - scen_mapping[scen['id']].append(r) + for scen in rule["scenarios"]: + scen_mapping[scen["id"]].append(r) print("Par: ", len(par_rules)) print("Seq: ", len(seq_rules)) @@ -644,36 +744,36 @@ class PackageManager(object): for par_rule in par_rules: r = ParallelRule() r.package = pack - r.uuid = UUID(par_rule['id'].split('/')[-1]) if keep_ids else uuid4() - r.name = par_rule['name'] - r.description = par_rule['description'] + r.uuid = UUID(par_rule["id"].split("/")[-1]) if keep_ids else uuid4() + r.name = par_rule["name"] + r.description = par_rule["description"] r.save() - mapping[par_rule['id']] = r.uuid + mapping[par_rule["id"]] = r.uuid - for scen in par_rule['scenarios']: - scen_mapping[scen['id']].append(r) + for scen in par_rule["scenarios"]: + scen_mapping[scen["id"]].append(r) - for simple_rule in par_rule['simpleRules']: - if simple_rule['id'] in mapping: - r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule['id']])) + for simple_rule in par_rule["simpleRules"]: + if simple_rule["id"] in mapping: + r.simple_rules.add(SimpleRule.objects.get(uuid=mapping[simple_rule["id"]])) r.save() for seq_rule in seq_rules: r = SequentialRule() r.package = pack - r.uuid = UUID(seq_rule['id'].split('/')[-1]) if keep_ids else uuid4() - r.name = seq_rule['name'] - r.description = seq_rule['description'] + r.uuid = UUID(seq_rule["id"].split("/")[-1]) if keep_ids else uuid4() + r.name = seq_rule["name"] + r.description = seq_rule["description"] r.save() - mapping[seq_rule['id']] = r.uuid + mapping[seq_rule["id"]] = r.uuid - for scen in seq_rule['scenarios']: - scen_mapping[scen['id']].append(r) + for scen in seq_rule["scenarios"]: + scen_mapping[scen["id"]].append(r) - for i, simple_rule in enumerate(seq_rule['simpleRules']): + for i, simple_rule in enumerate(seq_rule["simpleRules"]): sro = SequentialRuleOrdering() sro.simple_rule = simple_rule sro.sequential_rule = r @@ -683,97 +783,97 @@ class PackageManager(object): r.save() - print('Rules imported...') + print("Rules imported...") - for reaction in data['reactions']: + for reaction in data["reactions"]: r = Reaction() r.package = pack - r.uuid = UUID(reaction['id'].split('/')[-1]) if keep_ids else uuid4() - r.name = reaction['name'] - r.description = reaction['description'] - r.medlinereferences = reaction['medlinereferences'], - r.multi_step = True if reaction['multistep'] == 'true' else False + r.uuid = UUID(reaction["id"].split("/")[-1]) if keep_ids else uuid4() + r.name = reaction["name"] + r.description = reaction["description"] + r.medlinereferences = (reaction["medlinereferences"],) + r.multi_step = True if reaction["multistep"] == "true" else False r.save() - mapping[reaction['id']] = r.uuid + mapping[reaction["id"]] = r.uuid - for scen in reaction['scenarios']: - scen_mapping[scen['id']].append(r) + for scen in reaction["scenarios"]: + scen_mapping[scen["id"]].append(r) - for educt in reaction['educts']: - r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt['id']])) + for educt in reaction["educts"]: + r.educts.add(CompoundStructure.objects.get(uuid=mapping[educt["id"]])) - for product in reaction['products']: - r.products.add(CompoundStructure.objects.get(uuid=mapping[product['id']])) + for product in reaction["products"]: + r.products.add(CompoundStructure.objects.get(uuid=mapping[product["id"]])) - if 'rules' in reaction: - for rule in reaction['rules']: + if "rules" in reaction: + for rule in reaction["rules"]: try: - r.rules.add(Rule.objects.get(uuid=mapping[rule['id']])) + r.rules.add(Rule.objects.get(uuid=mapping[rule["id"]])) except Exception as e: print(f"Rule with id {rule['id']} not found!") print(e) r.save() - print('Reactions imported...') + print("Reactions imported...") - for pathway in data['pathways']: + for pathway in data["pathways"]: pw = Pathway() pw.package = pack - pw.uuid = UUID(pathway['id'].split('/')[-1]) if keep_ids else uuid4() - pw.name = pathway['name'] - pw.description = pathway['description'] + pw.uuid = UUID(pathway["id"].split("/")[-1]) if keep_ids else uuid4() + pw.name = pathway["name"] + pw.description = pathway["description"] pw.save() - mapping[pathway['id']] = pw.uuid - for scen in pathway['scenarios']: - scen_mapping[scen['id']].append(pw) + mapping[pathway["id"]] = pw.uuid + for scen in pathway["scenarios"]: + scen_mapping[scen["id"]].append(pw) out_nodes_mapping = defaultdict(set) - root_node = None - - for node in pathway['nodes']: + for node in pathway["nodes"]: n = Node() - n.uuid = UUID(node['id'].split('/')[-1]) if keep_ids else uuid4() - n.name = node['name'] + n.uuid = UUID(node["id"].split("/")[-1]) if keep_ids else uuid4() + n.name = node["name"] n.pathway = pw - n.depth = node['depth'] - n.default_node_label = CompoundStructure.objects.get(uuid=mapping[node['defaultNodeLabel']['id']]) + n.depth = node["depth"] + n.default_node_label = CompoundStructure.objects.get( + uuid=mapping[node["defaultNodeLabel"]["id"]] + ) n.save() - mapping[node['id']] = n.uuid + mapping[node["id"]] = n.uuid - for scen in node['scenarios']: - scen_mapping[scen['id']].append(n) + for scen in node["scenarios"]: + scen_mapping[scen["id"]].append(n) - for node_label in node['nodeLabels']: - n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label['id']])) + for node_label in node["nodeLabels"]: + n.node_labels.add(CompoundStructure.objects.get(uuid=mapping[node_label["id"]])) n.save() - for out_edge in node['outEdges']: + for out_edge in node["outEdges"]: out_nodes_mapping[n.uuid].add(out_edge) - for edge in pathway['edges']: + for edge in pathway["edges"]: e = Edge() - e.uuid = UUID(edge['id'].split('/')[-1]) if keep_ids else uuid4() - e.name = edge['name'] + e.uuid = UUID(edge["id"].split("/")[-1]) if keep_ids else uuid4() + e.name = edge["name"] e.pathway = pw - e.description = edge['description'] - e.edge_label = Reaction.objects.get(uuid=mapping[edge['edgeLabel']['id']]) + e.description = edge["description"] + e.edge_label = Reaction.objects.get(uuid=mapping[edge["edgeLabel"]["id"]]) e.save() - mapping[edge['id']] = e.uuid + mapping[edge["id"]] = e.uuid - for scen in edge['scenarios']: - scen_mapping[scen['id']].append(e) + for scen in edge["scenarios"]: + scen_mapping[scen["id"]].append(e) - for start_node in edge['startNodes']: + for start_node in edge["startNodes"]: e.start_nodes.add(Node.objects.get(uuid=mapping[start_node])) - for end_node in edge['endNodes']: + for end_node in edge["endNodes"]: e.end_nodes.add(Node.objects.get(uuid=mapping[end_node])) e.save() @@ -784,7 +884,7 @@ class PackageManager(object): n.out_edges.add(Edge.objects.get(uuid=mapping[v1])) n.save() - print('Pathways imported...') + print("Pathways imported...") # Linking Phase for child, parent in parent_mapping.items(): @@ -801,13 +901,13 @@ class PackageManager(object): print("Scenarios linked...") - print('Import statistics:') - print('Package {} stored'.format(pack.url)) - print('Imported {} compounds'.format(Compound.objects.filter(package=pack).count())) - print('Imported {} rules'.format(Rule.objects.filter(package=pack).count())) - print('Imported {} reactions'.format(Reaction.objects.filter(package=pack).count())) - print('Imported {} pathways'.format(Pathway.objects.filter(package=pack).count())) - print('Imported {} Scenarios'.format(Scenario.objects.filter(package=pack).count())) + print("Import statistics:") + print("Package {} stored".format(pack.url)) + print("Imported {} compounds".format(Compound.objects.filter(package=pack).count())) + print("Imported {} rules".format(Rule.objects.filter(package=pack).count())) + print("Imported {} reactions".format(Reaction.objects.filter(package=pack).count())) + print("Imported {} pathways".format(Pathway.objects.filter(package=pack).count())) + print("Imported {} Scenarios".format(Scenario.objects.filter(package=pack).count())) print("Fixing Node depths...") total_pws = Pathway.objects.filter(package=pack).count() @@ -848,7 +948,7 @@ class PackageManager(object): if str(prod.uuid) not in seen: old_depth = prod.depth if old_depth != i + 1: - print(f'updating depth from {old_depth} to {i + 1}') + print(f"updating depth from {old_depth} to {i + 1}") prod.depth = i + 1 prod.save() @@ -859,15 +959,19 @@ class PackageManager(object): if new_level: levels.append(new_level) - print(f'{p + 1}/{total_pws} fixed.') + print(f"{p + 1}/{total_pws} fixed.") return pack @staticmethod @transaction.atomic - def import_package(data: Dict[str, Any], owner: User, preserve_uuids=False, add_import_timestamp=True, - trust_reviewed=False) -> Package: - + def import_package( + data: Dict[str, Any], + owner: User, + preserve_uuids=False, + add_import_timestamp=True, + trust_reviewed=False, + ) -> Package: importer = PackageImporter(data, preserve_uuids, add_import_timestamp, trust_reviewed) imported_package = importer.do_import() @@ -880,46 +984,64 @@ class PackageManager(object): return imported_package @staticmethod - def export_package(package: Package, include_models: bool = False, - include_external_identifiers: bool = True) -> Dict[str, Any]: + def export_package( + package: Package, include_models: bool = False, include_external_identifiers: bool = True + ) -> Dict[str, Any]: return PackageExporter(package).do_export() class SettingManager(object): - setting_pattern = re.compile(r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$") + setting_pattern = re.compile( + r".*/setting/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$" + ) @staticmethod def get_setting_by_url(user, setting_url): match = re.findall(SettingManager.setting_pattern, setting_url) if match: - setting_id = match[0].split('/')[-1] + setting_id = match[0].split("/")[-1] return SettingManager.get_setting_by_id(user, setting_id) else: - raise ValueError("Requested URL {} does not contain a valid setting identifier!".format(setting_url)) + raise ValueError( + "Requested URL {} does not contain a valid setting identifier!".format(setting_url) + ) @staticmethod def get_setting_by_id(user, setting_id): s = Setting.objects.get(uuid=setting_id) - if s.global_default or s.public or user.is_superuser or \ - UserSettingPermission.objects.filter(user=user, setting=s).exists(): + if ( + s.global_default + or s.public + or user.is_superuser + or UserSettingPermission.objects.filter(user=user, setting=s).exists() + ): return s - raise ValueError( - "Insufficient permissions to access Setting with ID {}".format(setting_id)) + raise ValueError("Insufficient permissions to access Setting with ID {}".format(setting_id)) @staticmethod def get_all_settings(user): - sp = UserSettingPermission.objects.filter(user=user).values('setting') - return (Setting.objects.filter(id__in=sp) | Setting.objects.filter(public=True) | Setting.objects.filter( - global_default=True)).distinct() + sp = UserSettingPermission.objects.filter(user=user).values("setting") + return ( + Setting.objects.filter(id__in=sp) + | Setting.objects.filter(public=True) + | Setting.objects.filter(global_default=True) + ).distinct() @staticmethod @transaction.atomic - def create_setting(user: User, name: str = None, description: str = None, max_nodes: int = None, - max_depth: int = None, rule_packages: List[Package] = None, model: EPModel = None, - model_threshold: float = None): + def create_setting( + user: User, + name: str = None, + description: str = None, + max_nodes: int = None, + max_depth: int = None, + rule_packages: List[Package] = None, + model: EPModel = None, + model_threshold: float = None, + ): s = Setting() s.name = name s.description = description @@ -945,7 +1067,6 @@ class SettingManager(object): @staticmethod def get_default_setting(user: User): - pass @staticmethod @@ -953,21 +1074,20 @@ class SettingManager(object): def set_default_setting(user: User, setting: Setting): pass + class SearchManager(object): - - @staticmethod def search(packages: Union[Package, List[Package]], searchterm: str, mode: str): match mode: - case 'text': + case "text": return SearchManager._search_text(packages, searchterm) - case 'default': + case "default": return SearchManager._search_default_smiles(packages, searchterm) - case 'exact': + case "exact": return SearchManager._search_exact_smiles(packages, searchterm) - case 'canonical': + case "canonical": return SearchManager._search_canonical_smiles(packages, searchterm) - case 'inchikey': + case "inchikey": return SearchManager._search_inchikey(packages, searchterm) case _: raise ValueError(f"Unknown search mode {mode}!") @@ -977,16 +1097,38 @@ class SearchManager(object): from django.db.models import Q search_cond = Q(inchikey=searchterm) - compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__inchikey=searchterm)).distinct() - compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond) - reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__inchikey=searchterm) | Q(products__inchikey=searchterm))).distinct() - pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__inchikey=searchterm) | Q(edge__edge_label__products__inchikey=searchterm))).distinct() + compound_qs = Compound.objects.filter( + Q(package__in=packages) & Q(compoundstructure__inchikey=searchterm) + ).distinct() + compound_structure_qs = CompoundStructure.objects.filter( + Q(compound__package__in=packages) & search_cond + ) + reactions_qs = Reaction.objects.filter( + Q(package__in=packages) + & (Q(educts__inchikey=searchterm) | Q(products__inchikey=searchterm)) + ).distinct() + pathway_qs = Pathway.objects.filter( + Q(package__in=packages) + & ( + Q(edge__edge_label__educts__inchikey=searchterm) + | Q(edge__edge_label__products__inchikey=searchterm) + ) + ).distinct() return { - 'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs], - 'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs], - 'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs], - 'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs], + "Compounds": [ + {"name": c.name, "description": c.description, "url": c.url} for c in compound_qs + ], + "Compound Structures": [ + {"name": c.name, "description": c.description, "url": c.url} + for c in compound_structure_qs + ], + "Reactions": [ + {"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs + ], + "Pathways": [ + {"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs + ], } @staticmethod @@ -994,16 +1136,38 @@ class SearchManager(object): from django.db.models import Q search_cond = Q(smiles=searchterm) - compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__smiles=searchterm)).distinct() - compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond) - reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__smiles=searchterm) | Q(products__smiles=searchterm))).distinct() - pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__smiles=searchterm) | Q(edge__edge_label__products__smiles=searchterm))).distinct() + compound_qs = Compound.objects.filter( + Q(package__in=packages) & Q(compoundstructure__smiles=searchterm) + ).distinct() + compound_structure_qs = CompoundStructure.objects.filter( + Q(compound__package__in=packages) & search_cond + ) + reactions_qs = Reaction.objects.filter( + Q(package__in=packages) + & (Q(educts__smiles=searchterm) | Q(products__smiles=searchterm)) + ).distinct() + pathway_qs = Pathway.objects.filter( + Q(package__in=packages) + & ( + Q(edge__edge_label__educts__smiles=searchterm) + | Q(edge__edge_label__products__smiles=searchterm) + ) + ).distinct() return { - 'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs], - 'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs], - 'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs], - 'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs], + "Compounds": [ + {"name": c.name, "description": c.description, "url": c.url} for c in compound_qs + ], + "Compound Structures": [ + {"name": c.name, "description": c.description, "url": c.url} + for c in compound_structure_qs + ], + "Reactions": [ + {"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs + ], + "Pathways": [ + {"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs + ], } @staticmethod @@ -1013,16 +1177,41 @@ class SearchManager(object): inchi_front = FormatConverter.InChIKey(searchterm)[:14] search_cond = Q(inchikey__startswith=inchi_front) - compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__inchikey__startswith=inchi_front)).distinct() - compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond) - reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__inchikey__startswith=inchi_front) | Q(products__inchikey__startswith=inchi_front))).distinct() - pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__inchikey__startswith=inchi_front) | Q(edge__edge_label__products__inchikey__startswith=inchi_front))).distinct() + compound_qs = Compound.objects.filter( + Q(package__in=packages) & Q(compoundstructure__inchikey__startswith=inchi_front) + ).distinct() + compound_structure_qs = CompoundStructure.objects.filter( + Q(compound__package__in=packages) & search_cond + ) + reactions_qs = Reaction.objects.filter( + Q(package__in=packages) + & ( + Q(educts__inchikey__startswith=inchi_front) + | Q(products__inchikey__startswith=inchi_front) + ) + ).distinct() + pathway_qs = Pathway.objects.filter( + Q(package__in=packages) + & ( + Q(edge__edge_label__educts__inchikey__startswith=inchi_front) + | Q(edge__edge_label__products__inchikey__startswith=inchi_front) + ) + ).distinct() return { - 'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs], - 'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs], - 'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs], - 'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs], + "Compounds": [ + {"name": c.name, "description": c.description, "url": c.url} for c in compound_qs + ], + "Compound Structures": [ + {"name": c.name, "description": c.description, "url": c.url} + for c in compound_structure_qs + ], + "Reactions": [ + {"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs + ], + "Pathways": [ + {"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs + ], } @staticmethod @@ -1030,42 +1219,76 @@ class SearchManager(object): from django.db.models import Q search_cond = Q(canonical_smiles=searchterm) - compound_qs = Compound.objects.filter(Q(package__in=packages) & Q(compoundstructure__canonical_smiles=searchterm)).distinct() - compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond) - reactions_qs = Reaction.objects.filter(Q(package__in=packages) & (Q(educts__canonical_smiles=searchterm) | Q(products__canonical_smiles=searchterm))).distinct() - pathway_qs = Pathway.objects.filter(Q(package__in=packages) & (Q(edge__edge_label__educts__canonical_smiles=searchterm) | Q(edge__edge_label__products__canonical_smiles=searchterm))).distinct() + compound_qs = Compound.objects.filter( + Q(package__in=packages) & Q(compoundstructure__canonical_smiles=searchterm) + ).distinct() + compound_structure_qs = CompoundStructure.objects.filter( + Q(compound__package__in=packages) & search_cond + ) + reactions_qs = Reaction.objects.filter( + Q(package__in=packages) + & (Q(educts__canonical_smiles=searchterm) | Q(products__canonical_smiles=searchterm)) + ).distinct() + pathway_qs = Pathway.objects.filter( + Q(package__in=packages) + & ( + Q(edge__edge_label__educts__canonical_smiles=searchterm) + | Q(edge__edge_label__products__canonical_smiles=searchterm) + ) + ).distinct() return { - 'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs], - 'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs], - 'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs], - 'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs], + "Compounds": [ + {"name": c.name, "description": c.description, "url": c.url} for c in compound_qs + ], + "Compound Structures": [ + {"name": c.name, "description": c.description, "url": c.url} + for c in compound_structure_qs + ], + "Reactions": [ + {"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs + ], + "Pathways": [ + {"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs + ], } @staticmethod def _search_text(packages: Union[Package, List[Package]], searchterm: str): from django.db.models import Q - search_cond = (Q(name__icontains=searchterm) | Q(description__icontains=searchterm)) + search_cond = Q(name__icontains=searchterm) | Q(description__icontains=searchterm) cond = Q(package__in=packages) & search_cond compound_qs = Compound.objects.filter(cond) - compound_structure_qs = CompoundStructure.objects.filter(Q(compound__package__in=packages) & search_cond) + compound_structure_qs = CompoundStructure.objects.filter( + Q(compound__package__in=packages) & search_cond + ) rule_qs = Rule.objects.filter(cond) reactions_qs = Reaction.objects.filter(cond) pathway_qs = Pathway.objects.filter(cond) res = { - 'Compounds': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_qs], - 'Compound Structures': [{'name': c.name, 'description': c.description, 'url': c.url} for c in compound_structure_qs], - 'Rules': [{'name': r.name, 'description': r.description, 'url': r.url} for r in rule_qs], - 'Reactions': [{'name': r.name, 'description': r.description, 'url': r.url} for r in reactions_qs], - 'Pathways': [{'name': p.name, 'description': p.description, 'url': p.url} for p in pathway_qs], + "Compounds": [ + {"name": c.name, "description": c.description, "url": c.url} for c in compound_qs + ], + "Compound Structures": [ + {"name": c.name, "description": c.description, "url": c.url} + for c in compound_structure_qs + ], + "Rules": [ + {"name": r.name, "description": r.description, "url": r.url} for r in rule_qs + ], + "Reactions": [ + {"name": r.name, "description": r.description, "url": r.url} for r in reactions_qs + ], + "Pathways": [ + {"name": p.name, "description": p.description, "url": p.url} for p in pathway_qs + ], } return res class SNode(object): - def __init__(self, smiles: str, depth: int, app_domain_assessment: dict = None): self.smiles = smiles self.depth = depth @@ -1084,10 +1307,13 @@ class SNode(object): class SEdge(object): - - def __init__(self, educts: Union[SNode, List[SNode]], products: Union[SNode | List[SNode]], - rule: Optional['Rule'] = None, probability: Optional[float] = None): - + def __init__( + self, + educts: Union[SNode, List[SNode]], + products: Union[SNode | List[SNode]], + rule: Optional["Rule"] = None, + probability: Optional[float] = None, + ): if not isinstance(educts, list): educts = [educts] @@ -1114,23 +1340,32 @@ class SEdge(object): if not isinstance(other, SEdge): return False - if self.rule is not None and other.rule is None or \ - self.rule is None and other.rule is not None or \ - self.rule != other.rule: + if ( + self.rule is not None + and other.rule is None + or self.rule is None + and other.rule is not None + or self.rule != other.rule + ): return False if not (len(self.educts) == len(other.educts)): return False - for n1, n2 in zip(sorted(self.educts, key=lambda x: x.smiles), sorted(other.educts, key=lambda x: x.smiles)): + for n1, n2 in zip( + sorted(self.educts, key=lambda x: x.smiles), + sorted(other.educts, key=lambda x: x.smiles), + ): if n1.smiles != n2.smiles: return False if not (len(self.products) == len(other.products)): return False - for n1, n2 in zip(sorted(self.products, key=lambda x: x.smiles), - sorted(other.products, key=lambda x: x.smiles)): + for n1, n2 in zip( + sorted(self.products, key=lambda x: x.smiles), + sorted(other.products, key=lambda x: x.smiles), + ): if n1.smiles != n2.smiles: return False @@ -1141,10 +1376,12 @@ class SEdge(object): class SPathway(object): - - def __init__(self, root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None, - persist: Optional['Pathway'] = None, prediction_setting: Optional[Setting] = None - ): + def __init__( + self, + root_nodes: Optional[Union[str, SNode, List[str | SNode]]] = None, + persist: Optional["Pathway"] = None, + prediction_setting: Optional[Setting] = None, + ): self.root_nodes = [] self.persist = persist @@ -1168,13 +1405,15 @@ class SPathway(object): self.root_nodes.append(n) self.smiles_to_node: Dict[str, SNode] = dict(**{n.smiles: n for n in self.root_nodes}) - self.edges: Set['SEdge'] = set() + self.edges: Set["SEdge"] = set() self.done = False @staticmethod - def from_pathway(pw: 'Pathway', persist: bool = True): - """ Initializes a SPathway with a state given by a Pathway """ - spw = SPathway(root_nodes=pw.root_nodes, persist=pw if persist else None, prediction_setting=pw.setting) + def from_pathway(pw: "Pathway", persist: bool = True): + """Initializes a SPathway with a state given by a Pathway""" + spw = SPathway( + root_nodes=pw.root_nodes, persist=pw if persist else None, prediction_setting=pw.setting + ) # root_nodes are already added in __init__, add remaining nodes for n in pw.nodes: snode = SNode(n.default_node_label.smiles, n.depth) @@ -1197,8 +1436,8 @@ class SPathway(object): rule = e.edge_label.rules.all().first() prob = None - if e.kv.get('probability'): - prob = float(e.kv['probability']) + if e.kv.get("probability"): + prob = float(e.kv["probability"]) sedge = SEdge(sub, prod, rule=rule, probability=prob) spw.edges.add(sedge) @@ -1232,7 +1471,7 @@ class SPathway(object): return sorted(res, key=lambda x: hash(x)) - def predict_step(self, from_depth: int = None, from_node: 'Node' = None): + def predict_step(self, from_depth: int = None, from_node: "Node" = None): substrates: List[SNode] = [] if from_depth is not None: @@ -1248,26 +1487,28 @@ class SPathway(object): new_tp = False if substrates: for sub in substrates: - if sub.app_domain_assessment is None: if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)[0] + app_domain_assessment = self.prediction_setting.model.app_domain.assess( + sub.smiles + )[0] if self.persist is not None: n = self.snode_persist_lookup[sub] - assert n.id is not None, "Node has no id! Should have been saved already... aborting!" + assert n.id is not None, ( + "Node has no id! Should have been saved already... aborting!" + ) node_data = n.simple_json() - node_data['image'] = f"{n.url}?image=svg" - app_domain_assessment['assessment']['node'] = node_data + node_data["image"] = f"{n.url}?image=svg" + app_domain_assessment["assessment"]["node"] = node_data - n.kv['app_domain_assessment'] = app_domain_assessment + n.kv["app_domain_assessment"] = app_domain_assessment n.save() sub.app_domain_assessment = app_domain_assessment - candidates = self.prediction_setting.expand(self, sub) # candidates is a List of PredictionResult. The length of the List is equal to the number of rules for cand_set in candidates: @@ -1283,14 +1524,25 @@ class SPathway(object): app_domain_assessment = None if self.prediction_setting.model: if self.prediction_setting.model.app_domain: - app_domain_assessment = self.prediction_setting.model.app_domain.assess(c)[0] + app_domain_assessment = ( + self.prediction_setting.model.app_domain.assess(c)[ + 0 + ] + ) - self.smiles_to_node[c] = SNode(c, sub.depth + 1, app_domain_assessment) + self.smiles_to_node[c] = SNode( + c, sub.depth + 1, app_domain_assessment + ) node = self.smiles_to_node[c] cand_nodes.append(node) - edge = SEdge(sub, cand_nodes, rule=cand_set.rule, probability=cand_set.probability) + edge = SEdge( + sub, + cand_nodes, + rule=cand_set.rule, + probability=cand_set.probability, + ) self.edges.add(edge) # In case no substrates are found, we're done. @@ -1314,12 +1566,14 @@ class SPathway(object): if snode.app_domain_assessment is not None: app_domain_assessment = snode.app_domain_assessment - assert n.id is not None, "Node has no id! Should have been saved already... aborting!" + assert n.id is not None, ( + "Node has no id! Should have been saved already... aborting!" + ) node_data = n.simple_json() - node_data['image'] = f"{n.url}?image=svg" - app_domain_assessment['assessment']['node'] = node_data + node_data["image"] = f"{n.url}?image=svg" + app_domain_assessment["assessment"]["node"] = node_data - n.kv['app_domain_assessment'] = app_domain_assessment + n.kv["app_domain_assessment"] = app_domain_assessment n.save() self.snode_persist_lookup[snode] = n @@ -1337,7 +1591,7 @@ class SPathway(object): e = Edge.create(self.persist, educt_nodes, product_nodes, sedge.rule) if sedge.probability: - e.kv.update({'probability': sedge.probability}) + e.kv.update({"probability": sedge.probability}) e.save() self.sedge_persist_lookup[sedge] = e @@ -1350,19 +1604,19 @@ class SPathway(object): idx_lookup = {} - for i, s in enumerate(self.smiles_to_node): - n = self.smiles_to_node[s] - idx_lookup[s] = i + for i, smiles in enumerate(self.smiles_to_node): + n = self.smiles_to_node[smiles] + idx_lookup[smiles] = i - nodes.append({'depth': n.depth, 'smiles': n.smiles, 'id': i}) + nodes.append({"depth": n.depth, "smiles": n.smiles, "id": i}) for edge in self.edges: from_idx = idx_lookup[edge.educts[0].smiles] to_indices = [idx_lookup[p.smiles] for p in edge.products] e = { - 'from': from_idx, - 'to': to_indices, + "from": from_idx, + "to": to_indices, } # if edge.rule: @@ -1373,43 +1627,6 @@ class SPathway(object): edges.append(e) return { - 'nodes': nodes, - 'edges': edges, + "nodes": nodes, + "edges": edges, } - - def graph_to_tree_string(self): - graph_json = self.to_json() - nodes = {node['id']: node for node in graph_json['nodes']} - edges = graph_json['edges'] - - children_map = {} - for edge in edges: - src = edge['from'] - for tgt in edge['to']: - children_map.setdefault(src, []).append(tgt) - - visited = set() - - def recurse(node_id, prefix=''): - if node_id in visited: - return prefix + nodes[node_id]['smiles'] + " [loop detected]\n" - visited.add(node_id) - - line = prefix + nodes[node_id]['smiles'] + f" [{node_id}]\n" - kids = children_map.get(node_id, []) - for i, kid in enumerate(kids): - if i == len(kids) - 1: - branch = '└── ' - child_prefix = prefix + ' ' - else: - branch = '├── ' - child_prefix = prefix + '│ ' - line += recurse(kid, prefix=prefix + branch) - return line - - root_nodes = [n['id'] for n in graph_json['nodes'] if n['depth'] == 0] - result = '' - for root in root_nodes: - visited.clear() - result += recurse(root) - return result diff --git a/epdb/management/commands/bootstrap.py b/epdb/management/commands/bootstrap.py index 086fb67b..c01ec5eb 100644 --- a/epdb/management/commands/bootstrap.py +++ b/epdb/management/commands/bootstrap.py @@ -5,32 +5,49 @@ from django.core.management.base import BaseCommand from django.db import transaction from epdb.logic import UserManager, GroupManager, PackageManager, SettingManager -from epdb.models import UserSettingPermission, MLRelativeReasoning, EnviFormer, Permission, User, ExternalDatabase +from epdb.models import ( + UserSettingPermission, + MLRelativeReasoning, + EnviFormer, + Permission, + User, + ExternalDatabase, +) class Command(BaseCommand): - def create_users(self): - # Anonymous User - if not User.objects.filter(email='anon@envipath.com').exists(): - anon = UserManager.create_user("anonymous", "anon@envipath.com", "SuperSafe", - is_active=True, add_to_group=False, set_setting=False) + if not User.objects.filter(email="anon@envipath.com").exists(): + anon = UserManager.create_user( + "anonymous", + "anon@envipath.com", + "SuperSafe", + is_active=True, + add_to_group=False, + set_setting=False, + ) else: - anon = User.objects.get(email='anon@envipath.com') + anon = User.objects.get(email="anon@envipath.com") # Admin User - if not User.objects.filter(email='admin@envipath.com').exists(): - admin = UserManager.create_user("admin", "admin@envipath.com", "SuperSafe", - is_active=True, add_to_group=False, set_setting=False) + if not User.objects.filter(email="admin@envipath.com").exists(): + admin = UserManager.create_user( + "admin", + "admin@envipath.com", + "SuperSafe", + is_active=True, + add_to_group=False, + set_setting=False, + ) admin.is_staff = True admin.is_superuser = True admin.save() else: - admin = User.objects.get(email='admin@envipath.com') + admin = User.objects.get(email="admin@envipath.com") # System Group - g = GroupManager.create_group(admin, 'enviPath Users', 'All enviPath Users') + g = GroupManager.create_group(admin, "enviPath Users", "All enviPath Users") g.public = True g.save() @@ -43,14 +60,20 @@ class Command(BaseCommand): admin.default_group = g admin.save() - if not User.objects.filter(email='user0@envipath.com').exists(): - user0 = UserManager.create_user("user0", "user0@envipath.com", "SuperSafe", - is_active=True, add_to_group=False, set_setting=False) + if not User.objects.filter(email="user0@envipath.com").exists(): + user0 = UserManager.create_user( + "user0", + "user0@envipath.com", + "SuperSafe", + is_active=True, + add_to_group=False, + set_setting=False, + ) user0.is_staff = True user0.is_superuser = True user0.save() else: - user0 = User.objects.get(email='user0@envipath.com') + user0 = User.objects.get(email="user0@envipath.com") g.user_member.add(user0) g.save() @@ -61,18 +84,20 @@ class Command(BaseCommand): return anon, admin, g, user0 def import_package(self, data, owner): - return PackageManager.import_legacy_package(data, owner, keep_ids=True, add_import_timestamp=False, trust_reviewed=True) + return PackageManager.import_legacy_package( + data, owner, keep_ids=True, add_import_timestamp=False, trust_reviewed=True + ) def create_default_setting(self, owner, packages): s = SettingManager.create_setting( owner, - name='Global Default Setting', - description='Global Default Setting containing BBD Rules and Max 30 Nodes and Max Depth of 8', + name="Global Default Setting", + description="Global Default Setting containing BBD Rules and Max 30 Nodes and Max Depth of 8", max_nodes=30, max_depth=5, rule_packages=packages, model=None, - model_threshold=None + model_threshold=None, ) return s @@ -84,54 +109,51 @@ class Command(BaseCommand): """ databases = [ { - 'name': 'PubChem Compound', - 'full_name': 'PubChem Compound Database', - 'description': 'Chemical database of small organic molecules', - 'base_url': 'https://pubchem.ncbi.nlm.nih.gov', - 'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}' + "name": "PubChem Compound", + "full_name": "PubChem Compound Database", + "description": "Chemical database of small organic molecules", + "base_url": "https://pubchem.ncbi.nlm.nih.gov", + "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/compound/{id}", }, { - 'name': 'PubChem Substance', - 'full_name': 'PubChem Substance Database', - 'description': 'Database of chemical substances', - 'base_url': 'https://pubchem.ncbi.nlm.nih.gov', - 'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/substance/{id}' + "name": "PubChem Substance", + "full_name": "PubChem Substance Database", + "description": "Database of chemical substances", + "base_url": "https://pubchem.ncbi.nlm.nih.gov", + "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/substance/{id}", }, { - 'name': 'ChEBI', - 'full_name': 'Chemical Entities of Biological Interest', - 'description': 'Dictionary of molecular entities', - 'base_url': 'https://www.ebi.ac.uk/chebi', - 'url_pattern': 'https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{id}' + "name": "ChEBI", + "full_name": "Chemical Entities of Biological Interest", + "description": "Dictionary of molecular entities", + "base_url": "https://www.ebi.ac.uk/chebi", + "url_pattern": "https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{id}", }, { - 'name': 'RHEA', - 'full_name': 'RHEA Reaction Database', - 'description': 'Comprehensive resource of biochemical reactions', - 'base_url': 'https://www.rhea-db.org', - 'url_pattern': 'https://www.rhea-db.org/rhea/{id}' + "name": "RHEA", + "full_name": "RHEA Reaction Database", + "description": "Comprehensive resource of biochemical reactions", + "base_url": "https://www.rhea-db.org", + "url_pattern": "https://www.rhea-db.org/rhea/{id}", }, { - 'name': 'KEGG Reaction', - 'full_name': 'KEGG Reaction Database', - 'description': 'Database of biochemical reactions', - 'base_url': 'https://www.genome.jp', - 'url_pattern': 'https://www.genome.jp/entry/{id}' + "name": "KEGG Reaction", + "full_name": "KEGG Reaction Database", + "description": "Database of biochemical reactions", + "base_url": "https://www.genome.jp", + "url_pattern": "https://www.genome.jp/entry/{id}", }, { - 'name': 'UniProt', - 'full_name': 'MetaCyc Metabolic Pathway Database', - 'description': 'UniProt is a freely accessible database of protein sequence and functional information', - 'base_url': 'https://www.uniprot.org', - 'url_pattern': 'https://www.uniprot.org/uniprotkb?query="{id}"' - } + "name": "UniProt", + "full_name": "MetaCyc Metabolic Pathway Database", + "description": "UniProt is a freely accessible database of protein sequence and functional information", + "base_url": "https://www.uniprot.org", + "url_pattern": 'https://www.uniprot.org/uniprotkb?query="{id}"', + }, ] for db_info in databases: - ExternalDatabase.objects.get_or_create( - name=db_info['name'], - defaults=db_info - ) + ExternalDatabase.objects.get_or_create(name=db_info["name"], defaults=db_info) @transaction.atomic def handle(self, *args, **options): @@ -142,20 +164,24 @@ class Command(BaseCommand): # Import Packages packages = [ - 'EAWAG-BBD.json', - 'EAWAG-SOIL.json', - 'EAWAG-SLUDGE.json', - 'EAWAG-SEDIMENT.json', + "EAWAG-BBD.json", + "EAWAG-SOIL.json", + "EAWAG-SLUDGE.json", + "EAWAG-SEDIMENT.json", ] mapping = {} for p in packages: print(f"Importing {p}...") - package_data = json.loads(open(s.BASE_DIR / 'fixtures' / 'packages' / '2025-07-18' / p, encoding='utf-8').read()) + package_data = json.loads( + open( + s.BASE_DIR / "fixtures" / "packages" / "2025-07-18" / p, encoding="utf-8" + ).read() + ) imported_package = self.import_package(package_data, admin) - mapping[p.replace('.json', '')] = imported_package + mapping[p.replace(".json", "")] = imported_package - setting = self.create_default_setting(admin, [mapping['EAWAG-BBD']]) + setting = self.create_default_setting(admin, [mapping["EAWAG-BBD"]]) setting.public = True setting.save() setting.make_global_default() @@ -171,26 +197,28 @@ class Command(BaseCommand): usp.save() # Create Model Package - pack = PackageManager.create_package(admin, "Public Prediction Models", - "Package to make Prediction Models publicly available") + pack = PackageManager.create_package( + admin, + "Public Prediction Models", + "Package to make Prediction Models publicly available", + ) pack.reviewed = True pack.save() # Create RR ml_model = MLRelativeReasoning.create( package=pack, - rule_packages=[mapping['EAWAG-BBD']], - data_packages=[mapping['EAWAG-BBD']], + rule_packages=[mapping["EAWAG-BBD"]], + data_packages=[mapping["EAWAG-BBD"]], eval_packages=[], threshold=0.5, - name='ECC - BBD - T0.5', - description='ML Relative Reasoning', + name="ECC - BBD - T0.5", + description="ML Relative Reasoning", ) ml_model.build_dataset() ml_model.build_model() - # ml_model.evaluate_model() # If available, create EnviFormerModel if s.ENVIFORMER_PRESENT: - enviFormer_model = EnviFormer.create(pack, 'EnviFormer - T0.5', 'EnviFormer Model with Threshold 0.5', 0.5) + EnviFormer.create(pack, "EnviFormer - T0.5", "EnviFormer Model with Threshold 0.5", 0.5) diff --git a/epdb/management/commands/create_ml_models.py b/epdb/management/commands/create_ml_models.py index 6042f527..8cf3fd55 100644 --- a/epdb/management/commands/create_ml_models.py +++ b/epdb/management/commands/create_ml_models.py @@ -12,11 +12,28 @@ class Command(BaseCommand): the below command would be used: `python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge """ + def add_arguments(self, parser): - parser.add_argument("model_names", nargs="+", type=str, help="The names of models to train. Options are: enviformer, mlrr") - parser.add_argument("-d", "--data-packages", nargs="+", type=str, help="Packages for training") - parser.add_argument("-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[]) - parser.add_argument("-r", "--rule-packages", nargs="*", type=str, help="Rule Packages mandatory for MLRR", default=[]) + parser.add_argument( + "model_names", + nargs="+", + type=str, + help="The names of models to train. Options are: enviformer, mlrr", + ) + parser.add_argument( + "-d", "--data-packages", nargs="+", type=str, help="Packages for training" + ) + parser.add_argument( + "-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[] + ) + parser.add_argument( + "-r", + "--rule-packages", + nargs="*", + type=str, + help="Rule Packages mandatory for MLRR", + default=[], + ) @transaction.atomic def handle(self, *args, **options): @@ -28,7 +45,9 @@ class Command(BaseCommand): sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0] sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0] except IndexError: - raise IndexError("Can't find correct packages. They should be created with the bootstrap command") + raise IndexError( + "Can't find correct packages. They should be created with the bootstrap command" + ) def decode_packages(package_list): """Decode package strings into their respective packages""" @@ -52,15 +71,27 @@ class Command(BaseCommand): data_packages = decode_packages(options["data_packages"]) eval_packages = decode_packages(options["eval_packages"]) rule_packages = decode_packages(options["rule_packages"]) - for model_name in options['model_names']: + for model_name in options["model_names"]: model_name = model_name.lower() if model_name == "enviformer" and s.ENVIFORMER_PRESENT: - model = EnviFormer.create(pack, data_packages=data_packages, eval_packages=eval_packages, threshold=0.5, - name="EnviFormer - T0.5", description="EnviFormer transformer") + model = EnviFormer.create( + pack, + data_packages=data_packages, + eval_packages=eval_packages, + threshold=0.5, + name="EnviFormer - T0.5", + description="EnviFormer transformer", + ) elif model_name == "mlrr": - model = MLRelativeReasoning.create(package=pack, rule_packages=rule_packages, - data_packages=data_packages, eval_packages=eval_packages, threshold=0.5, - name='ECC - BBD - T0.5', description='ML Relative Reasoning') + model = MLRelativeReasoning.create( + package=pack, + rule_packages=rule_packages, + data_packages=data_packages, + eval_packages=eval_packages, + threshold=0.5, + name="ECC - BBD - T0.5", + description="ML Relative Reasoning", + ) else: raise ValueError(f"Cannot create model of type {model_name}, unknown model type") # Build the dataset for the model, train it, evaluate it and save it diff --git a/epdb/management/commands/import_external_identifiers.py b/epdb/management/commands/import_external_identifiers.py index be51d97a..3c2355b7 100644 --- a/epdb/management/commands/import_external_identifiers.py +++ b/epdb/management/commands/import_external_identifiers.py @@ -1,57 +1,58 @@ from csv import DictReader from django.core.management.base import BaseCommand +from django.db import transaction -from epdb.models import * +from epdb.models import Compound, CompoundStructure, Reaction, ExternalDatabase, ExternalIdentifier class Command(BaseCommand): STR_TO_MODEL = { - 'Compound': Compound, - 'CompoundStructure': CompoundStructure, - 'Reaction': Reaction, + "Compound": Compound, + "CompoundStructure": CompoundStructure, + "Reaction": Reaction, } STR_TO_DATABASE = { - 'ChEBI': ExternalDatabase.objects.get(name='ChEBI'), - 'RHEA': ExternalDatabase.objects.get(name='RHEA'), - 'KEGG Reaction': ExternalDatabase.objects.get(name='KEGG Reaction'), - 'PubChem Compound': ExternalDatabase.objects.get(name='PubChem Compound'), - 'PubChem Substance': ExternalDatabase.objects.get(name='PubChem Substance'), + "ChEBI": ExternalDatabase.objects.get(name="ChEBI"), + "RHEA": ExternalDatabase.objects.get(name="RHEA"), + "KEGG Reaction": ExternalDatabase.objects.get(name="KEGG Reaction"), + "PubChem Compound": ExternalDatabase.objects.get(name="PubChem Compound"), + "PubChem Substance": ExternalDatabase.objects.get(name="PubChem Substance"), } def add_arguments(self, parser): parser.add_argument( - '--data', + "--data", type=str, - help='Path of the ID Mapping file.', + help="Path of the ID Mapping file.", required=True, ) parser.add_argument( - '--replace-host', + "--replace-host", type=str, - help='Replace https://envipath.org/ with this host, e.g. http://localhost:8000/', + help="Replace https://envipath.org/ with this host, e.g. http://localhost:8000/", ) @transaction.atomic def handle(self, *args, **options): - with open(options['data']) as fh: + with open(options["data"]) as fh: reader = DictReader(fh) for row in reader: - clz = self.STR_TO_MODEL[row['model']] + clz = self.STR_TO_MODEL[row["model"]] - url = row['url'] - if options['replace_host']: - url = url.replace('https://envipath.org/', options['replace_host']) + url = row["url"] + if options["replace_host"]: + url = url.replace("https://envipath.org/", options["replace_host"]) instance = clz.objects.get(url=url) - db = self.STR_TO_DATABASE[row['identifier_type']] + db = self.STR_TO_DATABASE[row["identifier_type"]] ExternalIdentifier.objects.create( content_object=instance, database=db, - identifier_value=row['identifier_value'], - url=db.url_pattern.format(id=row['identifier_value']), - is_primary=False + identifier_value=row["identifier_value"], + url=db.url_pattern.format(id=row["identifier_value"]), + is_primary=False, ) diff --git a/epdb/management/commands/import_legacy_package.py b/epdb/management/commands/import_legacy_package.py index ae4b74d6..c05ff0e0 100644 --- a/epdb/management/commands/import_legacy_package.py +++ b/epdb/management/commands/import_legacy_package.py @@ -1,27 +1,29 @@ +import json + from django.core.management.base import BaseCommand +from django.db import transaction from epdb.logic import PackageManager -from epdb.models import * +from epdb.models import User class Command(BaseCommand): - def add_arguments(self, parser): parser.add_argument( - '--data', + "--data", type=str, - help='Path of the Package to import.', + help="Path of the Package to import.", required=True, ) parser.add_argument( - '--owner', + "--owner", type=str, - help='Username of the desired Owner.', + help="Username of the desired Owner.", required=True, ) @transaction.atomic def handle(self, *args, **options): - owner = User.objects.get(username=options['owner']) - package_data = json.load(open(options['data'])) + owner = User.objects.get(username=options["owner"]) + package_data = json.load(open(options["data"])) PackageManager.import_legacy_package(package_data, owner) diff --git a/epdb/management/commands/localize_urls.py b/epdb/management/commands/localize_urls.py index 9e119294..b9f95b11 100644 --- a/epdb/management/commands/localize_urls.py +++ b/epdb/management/commands/localize_urls.py @@ -6,46 +6,45 @@ from django.db.models.functions import Replace class Command(BaseCommand): - def add_arguments(self, parser): parser.add_argument( - '--old', + "--old", type=str, - help='Old Host, most likely https://envipath.org/', + help="Old Host, most likely https://envipath.org/", required=True, ) parser.add_argument( - '--new', + "--new", type=str, - help='New Host, most likely http://localhost:8000/', + help="New Host, most likely http://localhost:8000/", required=True, ) def handle(self, *args, **options): MODELS = [ - 'User', - 'Group', - 'Package', - 'Compound', - 'CompoundStructure', - 'Pathway', - 'Edge', - 'Node', - 'Reaction', - 'SimpleAmbitRule', - 'SimpleRDKitRule', - 'ParallelRule', - 'SequentialRule', - 'Scenario', - 'Setting', - 'MLRelativeReasoning', - 'RuleBasedRelativeReasoning', - 'EnviFormer', - 'ApplicabilityDomain', + "User", + "Group", + "Package", + "Compound", + "CompoundStructure", + "Pathway", + "Edge", + "Node", + "Reaction", + "SimpleAmbitRule", + "SimpleRDKitRule", + "ParallelRule", + "SequentialRule", + "Scenario", + "Setting", + "MLRelativeReasoning", + "RuleBasedRelativeReasoning", + "EnviFormer", + "ApplicabilityDomain", ] for model in MODELS: obj_cls = apps.get_model("epdb", model) print(f"Localizing urls for {model}") obj_cls.objects.update( - url=Replace(F('url'), Value(options['old']), Value(options['new'])) + url=Replace(F("url"), Value(options["old"]), Value(options["new"])) ) diff --git a/epdb/middleware/login_required_middleware.py b/epdb/middleware/login_required_middleware.py index cd3a52ee..a0595e4a 100644 --- a/epdb/middleware/login_required_middleware.py +++ b/epdb/middleware/login_required_middleware.py @@ -3,22 +3,25 @@ from django.shortcuts import redirect from django.urls import reverse from urllib.parse import quote + class LoginRequiredMiddleware: def __init__(self, get_response): self.get_response = get_response self.exempt_urls = [ - reverse('login'), - reverse('logout'), - reverse('admin:login'), - reverse('admin:index'), - ] + getattr(settings, 'LOGIN_EXEMPT_URLS', []) + reverse("login"), + reverse("logout"), + reverse("admin:login"), + reverse("admin:index"), + ] + getattr(settings, "LOGIN_EXEMPT_URLS", []) def __call__(self, request): if not request.user.is_authenticated: path = request.path_info if not any(path.startswith(url) for url in self.exempt_urls): - if request.method == 'GET': - if request.get_full_path() and request.get_full_path() != '/': - return redirect(f"{settings.LOGIN_URL}?next={quote(request.get_full_path())}") + if request.method == "GET": + if request.get_full_path() and request.get_full_path() != "/": + return redirect( + f"{settings.LOGIN_URL}?next={quote(request.get_full_path())}" + ) return redirect(settings.LOGIN_URL) return self.get_response(request) diff --git a/epdb/models.py b/epdb/models.py index 206402a3..83b29925 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -37,19 +37,34 @@ logger = logging.getLogger(__name__) # User/Groups/Permission # ########################## + class User(AbstractUser): email = models.EmailField(unique=True) - uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', unique=True, default=uuid4) - url = models.TextField(blank=False, null=True, verbose_name='URL', unique=True) - default_package = models.ForeignKey('epdb.Package', verbose_name='Default Package', null=True, - on_delete=models.SET_NULL) - default_group = models.ForeignKey('Group', verbose_name='Default Group', null=True, blank=False, - on_delete=models.SET_NULL, related_name='default_group') - default_setting = models.ForeignKey('epdb.Setting', on_delete=models.SET_NULL, - verbose_name='The users default settings', null=True, blank=False) + uuid = models.UUIDField( + null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4 + ) + url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True) + default_package = models.ForeignKey( + "epdb.Package", verbose_name="Default Package", null=True, on_delete=models.SET_NULL + ) + default_group = models.ForeignKey( + "Group", + verbose_name="Default Group", + null=True, + blank=False, + on_delete=models.SET_NULL, + related_name="default_group", + ) + default_setting = models.ForeignKey( + "epdb.Setting", + on_delete=models.SET_NULL, + verbose_name="The users default settings", + null=True, + blank=False, + ) USERNAME_FIELD = "email" - REQUIRED_FIELDS = ['username'] + REQUIRED_FIELDS = ["username"] def save(self, *args, **kwargs): if not self.url: @@ -58,7 +73,7 @@ class User(AbstractUser): super().save(*args, **kwargs) def _url(self): - return '{}/user/{}'.format(s.SERVER_URL, self.uuid) + return "{}/user/{}".format(s.SERVER_URL, self.uuid) def prediction_settings(self): if self.default_setting is None: @@ -73,40 +88,31 @@ class APIToken(TimeStampedModel): Provides secure token-based authentication with expiration support. """ + hashed_key = models.CharField( - max_length=128, - unique=True, - help_text="SHA-256 hash of the token key" + max_length=128, unique=True, help_text="SHA-256 hash of the token key" ) user = models.ForeignKey( User, on_delete=models.CASCADE, - related_name='api_tokens', - help_text="User who owns this token" + related_name="api_tokens", + help_text="User who owns this token", ) expires_at = models.DateTimeField( - null=True, - blank=True, - help_text="Token expiration time (null for no expiration)" + null=True, blank=True, help_text="Token expiration time (null for no expiration)" ) - name = models.CharField( - max_length=100, - help_text="Descriptive name for this token" - ) + name = models.CharField(max_length=100, help_text="Descriptive name for this token") - is_active = models.BooleanField( - default=True, - help_text="Whether this token is active" - ) + is_active = models.BooleanField(default=True, help_text="Whether this token is active") class Meta: - db_table = 'epdb_api_token' - verbose_name = 'API Token' - verbose_name_plural = 'API Tokens' - ordering = ['-created'] + db_table = "epdb_api_token" + verbose_name = "API Token" + verbose_name_plural = "API Tokens" + ordering = ["-created"] def __str__(self) -> str: return f"{self.name} ({self.user.username})" @@ -122,7 +128,9 @@ class APIToken(TimeStampedModel): return True @classmethod - def create_token(cls, user: User, name: str, expires_days: Optional[int] = None) -> Tuple['APIToken', str]: + def create_token( + cls, user: User, name: str, expires_days: Optional[int] = None + ) -> Tuple["APIToken", str]: """ Create a new API token for a user. @@ -142,10 +150,7 @@ class APIToken(TimeStampedModel): expires_at = timezone.now() + timezone.timedelta(days=expires_days) token = cls.objects.create( - user=user, - name=name, - hashed_key=hashed_key, - expires_at=expires_at + user=user, name=name, hashed_key=hashed_key, expires_at=expires_at ) return token, raw_key @@ -164,7 +169,7 @@ class APIToken(TimeStampedModel): hashed_key = hashlib.sha256(raw_key.encode()).hexdigest() try: - token = cls.objects.select_related('user').get(hashed_key=hashed_key) + token = cls.objects.select_related("user").get(hashed_key=hashed_key) if token.is_valid(): return token.user except cls.DoesNotExist: @@ -174,14 +179,22 @@ class APIToken(TimeStampedModel): class Group(TimeStampedModel): - uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', unique=True, default=uuid4) - url = models.TextField(blank=False, null=True, verbose_name='URL', unique=True) - name = models.TextField(blank=False, null=False, verbose_name='Group name') - owner = models.ForeignKey("User", verbose_name='Group Owner', on_delete=models.CASCADE) - public = models.BooleanField(verbose_name='Public Group', default=False) - description = models.TextField(blank=False, null=False, verbose_name='Descriptions', default='no description') - user_member = models.ManyToManyField("User", verbose_name='User members', related_name='users_in_group') - group_member = models.ManyToManyField("Group", verbose_name='Group member', related_name='groups_in_group', blank=True) + uuid = models.UUIDField( + null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4 + ) + url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True) + name = models.TextField(blank=False, null=False, verbose_name="Group name") + owner = models.ForeignKey("User", verbose_name="Group Owner", on_delete=models.CASCADE) + public = models.BooleanField(verbose_name="Public Group", default=False) + description = models.TextField( + blank=False, null=False, verbose_name="Descriptions", default="no description" + ) + user_member = models.ManyToManyField( + "User", verbose_name="User members", related_name="users_in_group" + ) + group_member = models.ManyToManyField( + "Group", verbose_name="Group member", related_name="groups_in_group", blank=True + ) def __str__(self): return f"{self.name} (pk={self.pk})" @@ -193,18 +206,14 @@ class Group(TimeStampedModel): super().save(*args, **kwargs) def _url(self): - return '{}/group/{}'.format(s.SERVER_URL, self.uuid) + return "{}/group/{}".format(s.SERVER_URL, self.uuid) class Permission(TimeStampedModel): - READ = ('read', 'Read') - WRITE = ('write', 'Write') - ALL = ('all', 'All') - PERMS = [ - READ, - WRITE, - ALL - ] + READ = ("read", "Read") + WRITE = ("write", "Write") + ALL = ("all", "All") + PERMS = [READ, WRITE, ALL] permission = models.CharField(max_length=32, choices=PERMS, null=False) def has_read(self): @@ -221,26 +230,32 @@ class Permission(TimeStampedModel): class UserPackagePermission(Permission): - uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', primary_key=True, - default=uuid4) - user = models.ForeignKey('User', verbose_name='Permission to', on_delete=models.CASCADE) - package = models.ForeignKey('epdb.Package', verbose_name='Permission on', on_delete=models.CASCADE) + uuid = models.UUIDField( + null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4 + ) + user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE) + package = models.ForeignKey( + "epdb.Package", verbose_name="Permission on", on_delete=models.CASCADE + ) class Meta: - unique_together = [('package', 'user')] + unique_together = [("package", "user")] def __str__(self): return f"User: {self.user} has Permission: {self.permission} on Package: {self.package}" class GroupPackagePermission(Permission): - uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', primary_key=True, - default=uuid4) - group = models.ForeignKey('Group', verbose_name='Permission to', on_delete=models.CASCADE) - package = models.ForeignKey('epdb.Package', verbose_name='Permission on', on_delete=models.CASCADE) + uuid = models.UUIDField( + null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4 + ) + group = models.ForeignKey("Group", verbose_name="Permission to", on_delete=models.CASCADE) + package = models.ForeignKey( + "epdb.Package", verbose_name="Permission on", on_delete=models.CASCADE + ) class Meta: - unique_together = [('package', 'group')] + unique_together = [("package", "group")] def __str__(self): return f"Group: {self.group} has Permission: {self.permission} on Package: {self.package}" @@ -259,69 +274,77 @@ class ExternalDatabase(TimeStampedModel): max_length=500, blank=True, verbose_name="URL Pattern", - help_text="URL pattern with {id} placeholder, e.g., 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}'" + help_text="URL pattern with {id} placeholder, e.g., 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}'", ) is_active = models.BooleanField(default=True, verbose_name="Is Active") class Meta: - db_table = 'epdb_external_database' - verbose_name = 'External Database' - verbose_name_plural = 'External Databases' - ordering = ['name'] + db_table = "epdb_external_database" + verbose_name = "External Database" + verbose_name_plural = "External Databases" + ordering = ["name"] def __str__(self): return self.full_name or self.name def get_url_for_identifier(self, identifier_value): - if self.url_pattern and '{id}' in self.url_pattern: + if self.url_pattern and "{id}" in self.url_pattern: return self.url_pattern.format(id=identifier_value) return None @staticmethod def get_databases(): return { - 'compound': [ + "compound": [ { - 'database': ExternalDatabase.objects.get(name='PubChem Compound'), - 'placeholder': 'PubChem Compound ID e.g. 12345', - }, { - 'database': ExternalDatabase.objects.get(name='PubChem Substance'), - 'placeholder': 'PubChem Substance ID e.g. 12345', - }, { - 'database': ExternalDatabase.objects.get(name='KEGG Reaction'), - 'placeholder': 'KEGG ID including entity Prefix e.g. C12345', - }, { - 'database': ExternalDatabase.objects.get(name='ChEBI'), - 'placeholder': 'ChEBI ID without prefix e.g. 12345', + "database": ExternalDatabase.objects.get(name="PubChem Compound"), + "placeholder": "PubChem Compound ID e.g. 12345", + }, + { + "database": ExternalDatabase.objects.get(name="PubChem Substance"), + "placeholder": "PubChem Substance ID e.g. 12345", + }, + { + "database": ExternalDatabase.objects.get(name="KEGG Reaction"), + "placeholder": "KEGG ID including entity Prefix e.g. C12345", + }, + { + "database": ExternalDatabase.objects.get(name="ChEBI"), + "placeholder": "ChEBI ID without prefix e.g. 12345", }, ], - 'structure': [ + "structure": [ { - 'database': ExternalDatabase.objects.get(name='PubChem Compound'), - 'placeholder': 'PubChem Compound ID e.g. 12345', - }, { - 'database': ExternalDatabase.objects.get(name='PubChem Substance'), - 'placeholder': 'PubChem Substance ID e.g. 12345', - }, { - 'database': ExternalDatabase.objects.get(name='KEGG Reaction'), - 'placeholder': 'KEGG ID including entity Prefix e.g. C12345', - }, { - 'database': ExternalDatabase.objects.get(name='ChEBI'), - 'placeholder': 'ChEBI ID without prefix e.g. 12345', + "database": ExternalDatabase.objects.get(name="PubChem Compound"), + "placeholder": "PubChem Compound ID e.g. 12345", + }, + { + "database": ExternalDatabase.objects.get(name="PubChem Substance"), + "placeholder": "PubChem Substance ID e.g. 12345", + }, + { + "database": ExternalDatabase.objects.get(name="KEGG Reaction"), + "placeholder": "KEGG ID including entity Prefix e.g. C12345", + }, + { + "database": ExternalDatabase.objects.get(name="ChEBI"), + "placeholder": "ChEBI ID without prefix e.g. 12345", }, ], - 'reaction': [ + "reaction": [ { - 'database': ExternalDatabase.objects.get(name='KEGG Reaction'), - 'placeholder': 'KEGG ID including entity Prefix e.g. C12345', - }, { - 'database': ExternalDatabase.objects.get(name='RHEA'), - 'placeholder': 'RHEA ID without Prefix e.g. 12345', - }, { - 'database': ExternalDatabase.objects.get(name='UniProt'), - 'placeholder': 'Query ID for UniPro e.g. rhea:12345', - } - ] + "database": ExternalDatabase.objects.get(name="KEGG Reaction"), + "placeholder": "KEGG ID including entity Prefix e.g. C12345", + }, + { + "database": ExternalDatabase.objects.get(name="RHEA"), + "placeholder": "RHEA ID without Prefix e.g. 12345", + }, + { + "database": ExternalDatabase.objects.get(name="UniProt"), + "placeholder": "Query ID for UniPro e.g. rhea:12345", + }, + ], } @@ -331,29 +354,27 @@ class ExternalIdentifier(TimeStampedModel): # Generic foreign key to link to any model content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) object_id = models.IntegerField() - content_object = GenericForeignKey('content_type', 'object_id') + content_object = GenericForeignKey("content_type", "object_id") database = models.ForeignKey( - ExternalDatabase, - on_delete=models.CASCADE, - verbose_name="External Database" + ExternalDatabase, on_delete=models.CASCADE, verbose_name="External Database" ) identifier_value = models.CharField(max_length=255, verbose_name="Identifier Value") url = models.URLField(blank=True, null=True, verbose_name="Direct URL") is_primary = models.BooleanField( default=False, verbose_name="Is Primary", - help_text="Mark this as the primary identifier for this database" + help_text="Mark this as the primary identifier for this database", ) class Meta: - db_table = 'epdb_external_identifier' - verbose_name = 'External Identifier' - verbose_name_plural = 'External Identifiers' - unique_together = [('content_type', 'object_id', 'database', 'identifier_value')] + db_table = "epdb_external_identifier" + verbose_name = "External Identifier" + verbose_name_plural = "External Identifiers" + unique_together = [("content_type", "object_id", "database", "identifier_value")] indexes = [ - models.Index(fields=['content_type', 'object_id']), - models.Index(fields=['database', 'identifier_value']), + models.Index(fields=["content_type", "object_id"]), + models.Index(fields=["database", "identifier_value"]), ] def __str__(self): @@ -385,157 +406,154 @@ class ExternalIdentifierMixin(models.Model): database, created = ExternalDatabase.objects.get_or_create(name=database_name) if is_primary: - self.external_identifiers.filter(database=database, is_primary=True).update(is_primary=False) + self.external_identifiers.filter(database=database, is_primary=True).update( + is_primary=False + ) external_id, created = ExternalIdentifier.objects.get_or_create( content_type=ContentType.objects.get_for_model(self), object_id=self.pk, database=database, identifier_value=identifier_value, - defaults={ - 'url': url, - 'is_primary': is_primary - } + defaults={"url": url, "is_primary": is_primary}, ) return external_id def remove_external_identifier(self, database_name, identifier_value): self.external_identifiers.filter( - database__name=database_name, - identifier_value=identifier_value + database__name=database_name, identifier_value=identifier_value ).delete() class ChemicalIdentifierMixin(ExternalIdentifierMixin): - class Meta: abstract = True @property def pubchem_compound_id(self): - identifier = self.get_external_identifier('PubChem Compound') + identifier = self.get_external_identifier("PubChem Compound") return identifier.identifier_value if identifier else None @property def pubchem_substance_id(self): - identifier = self.get_external_identifier('PubChem Substance') + identifier = self.get_external_identifier("PubChem Substance") return identifier.identifier_value if identifier else None @property def chebi_id(self): - identifier = self.get_external_identifier('ChEBI') + identifier = self.get_external_identifier("ChEBI") return identifier.identifier_value if identifier else None @property def cas_number(self): - identifier = self.get_external_identifier('CAS') + identifier = self.get_external_identifier("CAS") return identifier.identifier_value if identifier else None def add_pubchem_compound_id(self, compound_id, is_primary=True): return self.add_external_identifier( - 'PubChem Compound', + "PubChem Compound", compound_id, - f'https://pubchem.ncbi.nlm.nih.gov/compound/{compound_id}', - is_primary + f"https://pubchem.ncbi.nlm.nih.gov/compound/{compound_id}", + is_primary, ) def add_pubchem_substance_id(self, substance_id): return self.add_external_identifier( - 'PubChem Substance', + "PubChem Substance", substance_id, - f'https://pubchem.ncbi.nlm.nih.gov/substance/{substance_id}' + f"https://pubchem.ncbi.nlm.nih.gov/substance/{substance_id}", ) def add_chebi_id(self, chebi_id, is_primary=False): - clean_id = chebi_id.replace('CHEBI:', '') + clean_id = chebi_id.replace("CHEBI:", "") return self.add_external_identifier( - 'ChEBI', + "ChEBI", clean_id, - f'https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{clean_id}', - is_primary + f"https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{clean_id}", + is_primary, ) def add_cas_number(self, cas_number): - return self.add_external_identifier('CAS', cas_number) + return self.add_external_identifier("CAS", cas_number) def get_pubchem_identifiers(self): - return self.get_external_identifier('PubChem Compound') or self.get_external_identifier('PubChem Substance') + return self.get_external_identifier("PubChem Compound") or self.get_external_identifier( + "PubChem Substance" + ) def get_pubchem_compound_identifiers(self): - return self.get_external_identifier('PubChem Compound') + return self.get_external_identifier("PubChem Compound") def get_pubchem_substance_identifiers(self): - return self.get_external_identifier('PubChem Substance') + return self.get_external_identifier("PubChem Substance") def get_chebi_identifiers(self): - return self.get_external_identifier('ChEBI') + return self.get_external_identifier("ChEBI") def get_cas_identifiers(self): - return self.get_external_identifier('CAS') + return self.get_external_identifier("CAS") + class ReactionIdentifierMixin(ExternalIdentifierMixin): - class Meta: abstract = True @property def rhea_id(self): - identifier = self.get_external_identifier('RHEA') + identifier = self.get_external_identifier("RHEA") return identifier.identifier_value if identifier else None @property def kegg_reaction_id(self): - identifier = self.get_external_identifier('KEGG Reaction') + identifier = self.get_external_identifier("KEGG Reaction") return identifier.identifier_value if identifier else None @property def metacyc_reaction_id(self): - identifier = self.get_external_identifier('MetaCyc') + identifier = self.get_external_identifier("MetaCyc") return identifier.identifier_value if identifier else None def add_rhea_id(self, rhea_id, is_primary=True): return self.add_external_identifier( - 'RHEA', - rhea_id, - f'https://www.rhea-db.org/rhea/{rhea_id}', - is_primary + "RHEA", rhea_id, f"https://www.rhea-db.org/rhea/{rhea_id}", is_primary ) def add_uniprot_id(self, uniprot_id, is_primary=True): return self.add_external_identifier( - 'UniProt', + "UniProt", uniprot_id, f'https://www.uniprot.org/uniprotkb?query="{uniprot_id}"', - is_primary + is_primary, ) def add_kegg_reaction_id(self, kegg_id): return self.add_external_identifier( - 'KEGG Reaction', - kegg_id, - f'https://www.genome.jp/entry/reaction+{kegg_id}' + "KEGG Reaction", kegg_id, f"https://www.genome.jp/entry/reaction+{kegg_id}" ) def add_metacyc_reaction_id(self, metacyc_id): - return self.add_external_identifier('MetaCyc', metacyc_id) + return self.add_external_identifier("MetaCyc", metacyc_id) def get_rhea_identifiers(self): - return self.get_external_identifier('RHEA') + return self.get_external_identifier("RHEA") def get_uniprot_identifiers(self): - return self.get_external_identifier('UniProt') + return self.get_external_identifier("UniProt") ############## # EP Objects # ############## class EnviPathModel(TimeStampedModel): - uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', unique=True, - default=uuid4) - name = models.TextField(blank=False, null=False, verbose_name='Name', default='no name') - description = models.TextField(blank=False, null=False, verbose_name='Descriptions', default='no description') + uuid = models.UUIDField( + null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4 + ) + name = models.TextField(blank=False, null=False, verbose_name="Name", default="no name") + description = models.TextField( + blank=False, null=False, verbose_name="Descriptions", default="no description" + ) - url = models.TextField(blank=False, null=True, verbose_name='URL', unique=True) + url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True) kv = JSONField(null=True, blank=True, default=dict) @@ -551,13 +569,13 @@ class EnviPathModel(TimeStampedModel): def simple_json(self, include_description=False): res = { - 'url': self.url, - 'uuid': str(self.uuid), - 'name': self.name, + "url": self.url, + "uuid": str(self.uuid), + "name": self.name, } if include_description: - res['description'] = self.description + res["description"] = self.description return res @@ -575,8 +593,7 @@ class EnviPathModel(TimeStampedModel): class AliasMixin(models.Model): aliases = ArrayField( - models.TextField(blank=False, null=False), - verbose_name='Aliases', default=list + models.TextField(blank=False, null=False), verbose_name="Aliases", default=list ) @transaction.atomic @@ -599,15 +616,15 @@ class AliasMixin(models.Model): class ScenarioMixin(models.Model): - scenarios = models.ManyToManyField("epdb.Scenario", verbose_name='Attached Scenarios') + scenarios = models.ManyToManyField("epdb.Scenario", verbose_name="Attached Scenarios") @transaction.atomic - def set_scenarios(self, scenarios: List['Scenario']): + def set_scenarios(self, scenarios: List["Scenario"]): self.scenarios.clear() self.save() - for s in scenarios: - self.scenarios.add(s) + for scen in scenarios: + self.scenarios.add(scen) self.save() @@ -616,14 +633,15 @@ class ScenarioMixin(models.Model): class License(models.Model): - link = models.URLField(blank=False, null=False, verbose_name='link') - image_link = models.URLField(blank=False, null=False, verbose_name='Image link') + link = models.URLField(blank=False, null=False, verbose_name="link") + image_link = models.URLField(blank=False, null=False, verbose_name="Image link") class Package(EnviPathModel): - reviewed = models.BooleanField(verbose_name='Reviewstatus', default=False) - license = models.ForeignKey('epdb.License', on_delete=models.SET_NULL, blank=True, null=True, - verbose_name='License') + reviewed = models.BooleanField(verbose_name="Reviewstatus", default=False) + license = models.ForeignKey( + "epdb.License", on_delete=models.SET_NULL, blank=True, null=True, verbose_name="License" + ) def delete(self, *args, **kwargs): # explicitly handle related Rules @@ -659,9 +677,9 @@ class Package(EnviPathModel): return self.epmodel_set.all() def _url(self): - return '{}/package/{}'.format(s.SERVER_URL, self.uuid) + return "{}/package/{}".format(s.SERVER_URL, self.uuid) - def get_applicable_rules(self) -> List['Rule']: + def get_applicable_rules(self) -> List["Rule"]: """ Returns a ordered set of rules where the following applies: 1. All Composite will be added to result @@ -689,57 +707,67 @@ class Package(EnviPathModel): class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin): - package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True) - default_structure = models.ForeignKey('CompoundStructure', verbose_name='Default Structure', - related_name='compound_default_structure', - on_delete=models.CASCADE, null=True) + package = models.ForeignKey( + "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) + default_structure = models.ForeignKey( + "CompoundStructure", + verbose_name="Default Structure", + related_name="compound_default_structure", + on_delete=models.CASCADE, + null=True, + ) - external_identifiers = GenericRelation('ExternalIdentifier') + external_identifiers = GenericRelation("ExternalIdentifier") @property def structures(self) -> QuerySet: return CompoundStructure.objects.filter(compound=self) @property - def normalized_structure(self) -> 'CompoundStructure' : + def normalized_structure(self) -> "CompoundStructure": return CompoundStructure.objects.get(compound=self, normalized_structure=True) def _url(self): - return '{}/compound/{}'.format(self.package.url, self.uuid) + return "{}/compound/{}".format(self.package.url, self.uuid) @transaction.atomic - def set_default_structure(self, cs: 'CompoundStructure'): + def set_default_structure(self, cs: "CompoundStructure"): if cs.compound != self: - raise ValueError("Attempt to set a CompoundStructure stored in a different compound as default") + raise ValueError( + "Attempt to set a CompoundStructure stored in a different compound as default" + ) self.default_structure = cs self.save() @property def related_pathways(self): - pathways = Node.objects.filter(node_labels__in=[self.default_structure]).values_list('pathway', flat=True) - return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by('name') + pathways = Node.objects.filter(node_labels__in=[self.default_structure]).values_list( + "pathway", flat=True + ) + return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by("name") @property def related_reactions(self): return ( - Reaction.objects.filter(package=self.package, educts__in=[self.default_structure]) - | - Reaction.objects.filter(package=self.package, products__in=[self.default_structure]) - ).order_by('name') + Reaction.objects.filter(package=self.package, educts__in=[self.default_structure]) + | Reaction.objects.filter(package=self.package, products__in=[self.default_structure]) + ).order_by("name") @staticmethod @transaction.atomic - def create(package: Package, smiles: str, name: str = None, description: str = None, *args, **kwargs) -> 'Compound': - - if smiles is None or smiles.strip() == '': - raise ValueError('SMILES is required') + def create( + package: Package, smiles: str, name: str = None, description: str = None, *args, **kwargs + ) -> "Compound": + if smiles is None or smiles.strip() == "": + raise ValueError("SMILES is required") smiles = smiles.strip() parsed = FormatConverter.from_smiles(smiles) if parsed is None: - raise ValueError('Given SMILES is invalid') + raise ValueError("Given SMILES is invalid") standardized_smiles = FormatConverter.standardize(smiles) @@ -748,21 +776,25 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin return CompoundStructure.objects.get(smiles=smiles, compound__package=package).compound # Check if we can find the standardized one - if CompoundStructure.objects.filter(smiles=standardized_smiles, compound__package=package).exists(): + if CompoundStructure.objects.filter( + smiles=standardized_smiles, compound__package=package + ).exists(): # TODO should we add a structure? - return CompoundStructure.objects.get(smiles=standardized_smiles, compound__package=package).compound + return CompoundStructure.objects.get( + smiles=standardized_smiles, compound__package=package + ).compound # Generate Compound c = Compound() c.package = package - if name is None or name.strip() == '': + if name is None or name.strip() == "": name = f"Compound {Compound.objects.filter(package=package).count() + 1}" c.name = name # We have a default here only set the value if it carries some payload - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": c.description = description.strip() c.save() @@ -770,12 +802,17 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin is_standardized = standardized_smiles == smiles if not is_standardized: - _ = CompoundStructure.create(c, standardized_smiles, name='Normalized structure of {}'.format(name), - description='{} (in its normalized form)'.format(description), - normalized_structure=True) + _ = CompoundStructure.create( + c, + standardized_smiles, + name="Normalized structure of {}".format(name), + description="{} (in its normalized form)".format(description), + normalized_structure=True, + ) - cs = CompoundStructure.create(c, smiles, name=name, description=description, - normalized_structure=is_standardized) + cs = CompoundStructure.create( + c, smiles, name=name, description=description, normalized_structure=is_standardized + ) c.default_structure = cs c.save() @@ -783,34 +820,45 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin return c @transaction.atomic - def add_structure(self, smiles: str, name: str = None, description: str = None, default_structure: bool = False, - *args, **kwargs) -> 'CompoundStructure': - - if smiles is None or smiles == '': - raise ValueError('SMILES is required') + def add_structure( + self, + smiles: str, + name: str = None, + description: str = None, + default_structure: bool = False, + *args, + **kwargs, + ) -> "CompoundStructure": + if smiles is None or smiles == "": + raise ValueError("SMILES is required") smiles = smiles.strip() parsed = FormatConverter.from_smiles(smiles) if parsed is None: - raise ValueError('Given SMILES is invalid') + raise ValueError("Given SMILES is invalid") standardized_smiles = FormatConverter.standardize(smiles) is_standardized = standardized_smiles == smiles if self.normalized_structure.smiles != standardized_smiles: - raise ValueError('The standardized SMILES does not match the compounds standardized one!') + raise ValueError( + "The standardized SMILES does not match the compounds standardized one!" + ) if is_standardized: CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package) # Check if we find a direct match for a given SMILES and/or its standardized SMILES - if CompoundStructure.objects.filter(smiles__in=smiles, compound__package=self.package).exists(): + if CompoundStructure.objects.filter( + smiles__in=smiles, compound__package=self.package + ).exists(): return CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package) - cs = CompoundStructure.create(self, smiles, name=name, description=description, - normalized_structure=is_standardized) + cs = CompoundStructure.create( + self, smiles, name=name, description=description, normalized_structure=is_standardized + ) if default_structure: self.default_structure = cs @@ -819,7 +867,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin return cs @transaction.atomic - def copy(self, target: 'Package', mapping: Dict): + def copy(self, target: "Package", mapping: Dict): if self in mapping: return mapping[self] @@ -827,7 +875,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin package=target, name=self.name, description=self.description, - kv=self.kv.copy() if self.kv else {} + kv=self.kv.copy() if self.kv else {}, ) mapping[self] = new_compound @@ -842,7 +890,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin normalized_structure=structure.normalized_structure, name=structure.name, description=structure.description, - kv=structure.kv.copy() if structure.kv else {} + kv=structure.kv.copy() if structure.kv else {}, ) mapping[structure] = new_structure @@ -853,7 +901,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin database=ext_id.database, identifier_value=ext_id.identifier_value, url=ext_id.url, - is_primary=ext_id.is_primary + is_primary=ext_id.is_primary, ) if self.default_structure: @@ -871,23 +919,23 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin database=ext_id.database, identifier_value=ext_id.identifier_value, url=ext_id.url, - is_primary=ext_id.is_primary + is_primary=ext_id.is_primary, ) return new_compound class Meta: - unique_together = [('uuid', 'package')] + unique_together = [("uuid", "package")] class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin): - compound = models.ForeignKey('epdb.Compound', on_delete=models.CASCADE, db_index=True) - smiles = models.TextField(blank=False, null=False, verbose_name='SMILES') - canonical_smiles = models.TextField(blank=False, null=False, verbose_name='Canonical SMILES') + compound = models.ForeignKey("epdb.Compound", on_delete=models.CASCADE, db_index=True) + smiles = models.TextField(blank=False, null=False, verbose_name="SMILES") + canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES") inchikey = models.TextField(max_length=27, blank=False, null=False, verbose_name="InChIKey") normalized_structure = models.BooleanField(null=False, blank=False, default=False) - external_identifiers = GenericRelation('ExternalIdentifier') + external_identifiers = GenericRelation("ExternalIdentifier") def save(self, *args, **kwargs): # Compute these fields only on initial save call @@ -898,16 +946,20 @@ class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdenti # Generate InChIKey self.inchikey = FormatConverter.InChIKey(self.smiles) except Exception as e: - logger.error(f"Could compute canonical SMILES/InChIKey from {self.smiles}, error: {e}") + logger.error( + f"Could compute canonical SMILES/InChIKey from {self.smiles}, error: {e}" + ) super().save(*args, **kwargs) def _url(self): - return '{}/structure/{}'.format(self.compound.url, self.uuid) + return "{}/structure/{}".format(self.compound.url, self.uuid) @staticmethod @transaction.atomic - def create(compound: Compound, smiles: str, name: str = None, description: str = None, *args, **kwargs): + def create( + compound: Compound, smiles: str, name: str = None, description: str = None, *args, **kwargs + ): if CompoundStructure.objects.filter(compound=compound, smiles=smiles).exists(): return CompoundStructure.objects.get(compound=compound, smiles=smiles) @@ -924,15 +976,15 @@ class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdenti cs.smiles = smiles cs.compound = compound - if 'normalized_structure' in kwargs: - cs.normalized_structure = kwargs['normalized_structure'] + if "normalized_structure" in kwargs: + cs.normalized_structure = kwargs["normalized_structure"] cs.save() return cs @transaction.atomic - def copy(self, target: 'Package', mapping: Dict): + def copy(self, target: "Package", mapping: Dict): if self in mapping: return mapping[self] @@ -945,16 +997,17 @@ class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdenti @property def related_pathways(self): - pathways = Node.objects.filter(node_labels__in=[self]).values_list('pathway', flat=True) - return Pathway.objects.filter(package=self.compound.package, id__in=set(pathways)).order_by('name') + pathways = Node.objects.filter(node_labels__in=[self]).values_list("pathway", flat=True) + return Pathway.objects.filter(package=self.compound.package, id__in=set(pathways)).order_by( + "name" + ) @property def related_reactions(self): return ( - Reaction.objects.filter(package=self.compound.package, educts__in=[self]) - | - Reaction.objects.filter(package=self.compound.package, products__in=[self]) - ).order_by('name') + Reaction.objects.filter(package=self.compound.package, educts__in=[self]) + | Reaction.objects.filter(package=self.compound.package, products__in=[self]) + ).order_by("name") @property def is_default_structure(self): @@ -962,7 +1015,9 @@ class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdenti class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): - package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True) + package = models.ForeignKey( + "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) # # https://github.com/django-polymorphic/django-polymorphic/issues/229 # _non_polymorphic = models.Manager() @@ -976,16 +1031,16 @@ class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): @staticmethod def cls_for_type(rule_type: str): - if rule_type == 'SimpleAmbitRule': + if rule_type == "SimpleAmbitRule": return SimpleAmbitRule - elif rule_type == 'SimpleRDKitRule': + elif rule_type == "SimpleRDKitRule": return SimpleRDKitRule - elif rule_type == 'ParallelRule': + elif rule_type == "ParallelRule": return ParallelRule - elif rule_type == 'SequentialRule': + elif rule_type == "SequentialRule": return SequentialRule else: - raise ValueError(f'{rule_type} is unknown!') + raise ValueError(f"{rule_type} is unknown!") @staticmethod @transaction.atomic @@ -993,9 +1048,8 @@ class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): cls = Rule.cls_for_type(rule_type) return cls.create(*args, **kwargs) - @transaction.atomic - def copy(self, target: 'Package', mapping: Dict): + def copy(self, target: "Package", mapping: Dict): """Copy a rule to the target package.""" if self in mapping: return mapping[self] @@ -1011,7 +1065,7 @@ class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): smirks=self.smirks, reactant_filter_smarts=self.reactant_filter_smarts, product_filter_smarts=self.product_filter_smarts, - kv=self.kv.copy() if self.kv else {} + kv=self.kv.copy() if self.kv else {}, ) elif rule_type == SimpleRDKitRule: new_rule = SimpleRDKitRule.objects.create( @@ -1019,14 +1073,14 @@ class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): name=self.name, description=self.description, reaction_smarts=self.reaction_smarts, - kv=self.kv.copy() if self.kv else {} + kv=self.kv.copy() if self.kv else {}, ) elif rule_type == ParallelRule: new_rule = ParallelRule.objects.create( package=target, name=self.name, description=self.description, - kv=self.kv.copy() if self.kv else {} + kv=self.kv.copy() if self.kv else {}, ) # Copy simple rules relationships for simple_rule in self.simple_rules.all(): @@ -1049,17 +1103,22 @@ class SimpleRule(Rule): # # class SimpleAmbitRule(SimpleRule): - smirks = models.TextField(blank=False, null=False, verbose_name='SMIRKS') - reactant_filter_smarts = models.TextField(null=True, verbose_name='Reactant Filter SMARTS') - product_filter_smarts = models.TextField(null=True, verbose_name='Product Filter SMARTS') + smirks = models.TextField(blank=False, null=False, verbose_name="SMIRKS") + reactant_filter_smarts = models.TextField(null=True, verbose_name="Reactant Filter SMARTS") + product_filter_smarts = models.TextField(null=True, verbose_name="Product Filter SMARTS") @staticmethod @transaction.atomic - def create(package: Package, name: str = None, description: str = None, smirks: str = None, - reactant_filter_smarts: str = None, product_filter_smarts: str = None): - - if smirks is None or smirks.strip() == '': - raise ValueError('SMIRKS is required!') + def create( + package: Package, + name: str = None, + description: str = None, + smirks: str = None, + reactant_filter_smarts: str = None, + product_filter_smarts: str = None, + ): + if smirks is None or smirks.strip() == "": + raise ValueError("SMIRKS is required!") smirks = smirks.strip() @@ -1068,62 +1127,63 @@ class SimpleAmbitRule(SimpleRule): query = SimpleAmbitRule.objects.filter(package=package, smirks=smirks) - if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != '': + if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "": query = query.filter(reactant_filter_smarts=reactant_filter_smarts) - if product_filter_smarts is not None and product_filter_smarts.strip() != '': + if product_filter_smarts is not None and product_filter_smarts.strip() != "": query = query.filter(product_filter_smarts=product_filter_smarts) if query.exists(): if query.count() > 1: - logger.error(f'More than one rule matched this one! {query}') + logger.error(f"More than one rule matched this one! {query}") return query.first() r = SimpleAmbitRule() r.package = package - if name is None or name.strip() == '': - name = f'Rule {Rule.objects.filter(package=package).count() + 1}' + if name is None or name.strip() == "": + name = f"Rule {Rule.objects.filter(package=package).count() + 1}" r.name = name - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": r.description = description r.smirks = smirks - if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != '': + if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "": r.reactant_filter_smarts = reactant_filter_smarts - if product_filter_smarts is not None and product_filter_smarts.strip() != '': + if product_filter_smarts is not None and product_filter_smarts.strip() != "": r.product_filter_smarts = product_filter_smarts r.save() return r def _url(self): - return '{}/simple-ambit-rule/{}'.format(self.package.url, self.uuid) + return "{}/simple-ambit-rule/{}".format(self.package.url, self.uuid) def apply(self, smiles): return FormatConverter.apply(smiles, self.smirks) @property def reactants_smarts(self): - return self.smirks.split('>>')[0] + return self.smirks.split(">>")[0] @property def products_smarts(self): - return self.smirks.split('>>')[1] + return self.smirks.split(">>")[1] @property def related_reactions(self): qs = Package.objects.filter(reviewed=True) - return self.reaction_rule.filter(package__in=qs).order_by('name') + return self.reaction_rule.filter(package__in=qs).order_by("name") @property def related_pathways(self): return Pathway.objects.filter( - id__in=Edge.objects.filter(edge_label__in=self.related_reactions).values('pathway_id')).order_by('name') + id__in=Edge.objects.filter(edge_label__in=self.related_reactions).values("pathway_id") + ).order_by("name") @property def as_svg(self): @@ -1131,22 +1191,22 @@ class SimpleAmbitRule(SimpleRule): class SimpleRDKitRule(SimpleRule): - reaction_smarts = models.TextField(blank=False, null=False, verbose_name='SMIRKS') + reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS") def apply(self, smiles): return FormatConverter.apply(smiles, self.reaction_smarts) def _url(self): - return '{}/simple-rdkit-rule/{}'.format(self.package.url, self.uuid) + return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid) # # class ParallelRule(Rule): - simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules') + simple_rules = models.ManyToManyField("epdb.SimpleRule", verbose_name="Simple rules") def _url(self): - return '{}/parallel-rule/{}'.format(self.package.url, self.uuid) + return "{}/parallel-rule/{}".format(self.package.url, self.uuid) @cached_property def srs(self) -> QuerySet: @@ -1164,7 +1224,7 @@ class ParallelRule(Rule): res = set() for sr in self.srs: - for part in sr.reactants_smarts.split('.'): + for part in sr.reactants_smarts.split("."): res.add(part) return res @@ -1174,19 +1234,19 @@ class ParallelRule(Rule): res = set() for sr in self.srs: - for part in sr.products_smarts.split('.'): + for part in sr.products_smarts.split("."): res.add(part) return res - class SequentialRule(Rule): - simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules', - through='SequentialRuleOrdering') + simple_rules = models.ManyToManyField( + "epdb.SimpleRule", verbose_name="Simple rules", through="SequentialRuleOrdering" + ) def _url(self): - return '{}/sequential-rule/{}'.format(self.compound.url, self.uuid) + return "{}/sequential-rule/{}".format(self.compound.url, self.uuid) @property def srs(self): @@ -1207,29 +1267,37 @@ class SequentialRuleOrdering(models.Model): class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin): - package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True) - educts = models.ManyToManyField('epdb.CompoundStructure', verbose_name='Educts', related_name='reaction_educts') - products = models.ManyToManyField('epdb.CompoundStructure', verbose_name='Products', - related_name='reaction_products') - rules = models.ManyToManyField('epdb.Rule', verbose_name='Rule', related_name='reaction_rule') - multi_step = models.BooleanField(verbose_name='Multistep Reaction') + package = models.ForeignKey( + "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) + educts = models.ManyToManyField( + "epdb.CompoundStructure", verbose_name="Educts", related_name="reaction_educts" + ) + products = models.ManyToManyField( + "epdb.CompoundStructure", verbose_name="Products", related_name="reaction_products" + ) + rules = models.ManyToManyField("epdb.Rule", verbose_name="Rule", related_name="reaction_rule") + multi_step = models.BooleanField(verbose_name="Multistep Reaction") medline_references = ArrayField( - models.TextField(blank=False, null=False), null=True, - verbose_name='Medline References' + models.TextField(blank=False, null=False), null=True, verbose_name="Medline References" ) - external_identifiers = GenericRelation('ExternalIdentifier') + external_identifiers = GenericRelation("ExternalIdentifier") def _url(self): - return '{}/reaction/{}'.format(self.package.url, self.uuid) + return "{}/reaction/{}".format(self.package.url, self.uuid) @staticmethod @transaction.atomic - def create(package: Package, name: str = None, description: str = None, - educts: Union[List[str], List[CompoundStructure]] = None, - products: Union[List[str], List[CompoundStructure]] = None, - rules: Union[Rule | List[Rule]] = None, multi_step: bool = True): - + def create( + package: Package, + name: str = None, + description: str = None, + educts: Union[List[str], List[CompoundStructure]] = None, + products: Union[List[str], List[CompoundStructure]] = None, + rules: Union[Rule | List[Rule]] = None, + multi_step: bool = True, + ): _educts = [] _products = [] @@ -1260,39 +1328,39 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin rules = [rules] query = Reaction.objects.annotate( - educt_count=Count('educts', filter=Q(educts__in=_educts), distinct=True), - product_count=Count('products', filter=Q(products__in=_products), distinct=True), + educt_count=Count("educts", filter=Q(educts__in=_educts), distinct=True), + product_count=Count("products", filter=Q(products__in=_products), distinct=True), ) # The annotate/filter wont work if rules is an empty list if rules: query = query.annotate( - rule_count=Count('rules', filter=Q(rules__in=rules), distinct=True) + rule_count=Count("rules", filter=Q(rules__in=rules), distinct=True) ).filter(rule_count=len(rules)) else: - query = query.annotate( - rule_count=Count('rules', distinct=True) - ).filter(rule_count=0) + query = query.annotate(rule_count=Count("rules", distinct=True)).filter(rule_count=0) existing_reaction_qs = query.filter( educt_count=len(_educts), product_count=len(_products), multi_step=multi_step, - package=package + package=package, ) if existing_reaction_qs.exists(): if existing_reaction_qs.count() > 1: - logger.error(f'Found more than one reaction for given input! {existing_reaction_qs}') + logger.error( + f"Found more than one reaction for given input! {existing_reaction_qs}" + ) return existing_reaction_qs.first() r = Reaction() r.package = package - if name is not None and name.strip() != '': + if name is not None and name.strip() != "": r.name = name - if description is not None and name.strip() != '': + if description is not None and name.strip() != "": r.description = description r.multi_step = multi_step @@ -1313,7 +1381,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin return r @transaction.atomic - def copy(self, target: 'Package', mapping: Dict ) -> 'Reaction': + def copy(self, target: "Package", mapping: Dict) -> "Reaction": """Copy a reaction to the target package.""" if self in mapping: return mapping[self] @@ -1325,7 +1393,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin description=self.description, multi_step=self.multi_step, medline_references=self.medline_references, - kv=self.kv.copy() if self.kv else {} + kv=self.kv.copy() if self.kv else {}, ) mapping[self] = new_reaction @@ -1351,7 +1419,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin database=ext_id.database, identifier_value=ext_id.identifier_value, url=ext_id.url, - is_primary=ext_id.is_primary + is_primary=ext_id.is_primary, ) return new_reaction @@ -1366,12 +1434,17 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin @property def related_pathways(self): return Pathway.objects.filter( - id__in=Edge.objects.filter(edge_label=self).values('pathway_id')).order_by('name') + id__in=Edge.objects.filter(edge_label=self).values("pathway_id") + ).order_by("name") class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): - package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True) - setting = models.ForeignKey('epdb.Setting', verbose_name='Setting', on_delete=models.CASCADE, null=True, blank=True) + package = models.ForeignKey( + "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) + setting = models.ForeignKey( + "epdb.Setting", verbose_name="Setting", on_delete=models.CASCADE, null=True, blank=True + ) @property def root_nodes(self): @@ -1398,30 +1471,30 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): return self.edge_set.all() def _url(self): - return '{}/pathway/{}'.format(self.package.url, self.uuid) + return "{}/pathway/{}".format(self.package.url, self.uuid) # Mode def is_built(self): - return self.kv.get('mode', 'build') == 'build' + return self.kv.get("mode", "build") == "build" def is_predicted(self): - return self.kv.get('mode', 'build') == 'predicted' + return self.kv.get("mode", "build") == "predicted" def is_incremental(self): - return self.kv.get('mode', 'build') == 'incremental' + return self.kv.get("mode", "build") == "incremental" # Status def status(self): - return self.kv.get('status', 'completed') + return self.kv.get("status", "completed") def completed(self): - return self.status() == 'completed' + return self.status() == "completed" def running(self): - return self.status() == 'running' + return self.status() == "running" def failed(self): - return self.status() == 'failed' + return self.status() == "failed" def d3_json(self): # Ideally it would be something like this but @@ -1462,14 +1535,14 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): # D3 links Nodes based on indices in nodes array node_url_to_idx = dict() for i, n in enumerate(nodes): - n['id'] = i - node_url_to_idx[n['url']] = i + n["id"] = i + node_url_to_idx[n["url"]] = i adjusted_links = [] for link in links: # Check if we'll need pseudo nodes - if len(link['end_node_urls']) > 1: - start_depth = nodes[node_url_to_idx[link['start_node_urls'][0]]]['depth'] + if len(link["end_node_urls"]) > 1: + start_depth = nodes[node_url_to_idx[link["start_node_urls"][0]]]["depth"] pseudo_idx = len(nodes) pseudo_node = { "depth": start_depth + 0.5, @@ -1480,38 +1553,38 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): # add links start -> pseudo new_link = { - 'name': link['name'], - 'id': link['id'], - 'url': link['url'], - 'image': link['image'], - 'reaction': link['reaction'], - 'reaction_probability': link['reaction_probability'], - 'scenarios': link['scenarios'], - 'source': node_url_to_idx[link['start_node_urls'][0]], - 'target': pseudo_idx, - 'app_domain': link.get('app_domain', None) + "name": link["name"], + "id": link["id"], + "url": link["url"], + "image": link["image"], + "reaction": link["reaction"], + "reaction_probability": link["reaction_probability"], + "scenarios": link["scenarios"], + "source": node_url_to_idx[link["start_node_urls"][0]], + "target": pseudo_idx, + "app_domain": link.get("app_domain", None), } adjusted_links.append(new_link) # add n links pseudo -> end - for target in link['end_node_urls']: + for target in link["end_node_urls"]: new_link = { - 'name': link['name'], - 'id': link['id'], - 'url': link['url'], - 'image': link['image'], - 'reaction': link['reaction'], - 'reaction_probability': link['reaction_probability'], - 'scenarios': link['scenarios'], - 'source': pseudo_idx, - 'target': node_url_to_idx[target], - 'app_domain': link.get('app_domain', None) + "name": link["name"], + "id": link["id"], + "url": link["url"], + "image": link["image"], + "reaction": link["reaction"], + "reaction_probability": link["reaction_probability"], + "scenarios": link["scenarios"], + "source": pseudo_idx, + "target": node_url_to_idx[target], + "app_domain": link.get("app_domain", None), } adjusted_links.append(new_link) else: - link['source'] = node_url_to_idx[link['start_node_urls'][0]] - link['target'] = node_url_to_idx[link['end_node_urls'][0]] + link["source"] = node_url_to_idx[link["start_node_urls"][0]] + link["target"] = node_url_to_idx[link["end_node_urls"][0]] adjusted_links.append(link) res = { @@ -1519,16 +1592,16 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): "completed": "true", "description": self.description, "id": self.url, - "isIncremental": self.kv.get('mode') == 'incremental', - "isPredicted": self.kv.get('mode') == 'predicted', - "lastModified": self.modified.strftime('%Y-%m-%d %H:%M:%S'), + "isIncremental": self.kv.get("mode") == "incremental", + "isPredicted": self.kv.get("mode") == "predicted", + "lastModified": self.modified.strftime("%Y-%m-%d %H:%M:%S"), "pathwayName": self.name, - "reviewStatus": "reviewed" if self.package.reviewed else 'unreviewed', + "reviewStatus": "reviewed" if self.package.reviewed else "unreviewed", "scenarios": [], "upToDate": True, "links": adjusted_links, "nodes": nodes, - "modified": self.modified.strftime('%Y-%m-%d %H:%M:%S'), + "modified": self.modified.strftime("%Y-%m-%d %H:%M:%S"), "status": self.status(), } @@ -1539,16 +1612,18 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): import io rows = [] - rows.append([ - 'SMILES', - 'name', - 'depth', - 'probability', - 'rule_names', - 'rule_ids', - 'parent_smiles', - ]) - for n in self.nodes.order_by('depth'): + rows.append( + [ + "SMILES", + "name", + "depth", + "probability", + "rule_names", + "rule_ids", + "parent_smiles", + ] + ) + for n in self.nodes.order_by("depth"): cs = n.default_node_label row = [cs.smiles, cs.name, n.depth] @@ -1556,9 +1631,9 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): if len(edges): for e in edges: _row = row.copy() - _row.append(e.kv.get('probability')) - _row.append(','.join([r.name for r in e.edge_label.rules.all()])) - _row.append(','.join([r.url for r in e.edge_label.rules.all()])) + _row.append(e.kv.get("probability")) + _row.append(",".join([r.name for r in e.edge_label.rules.all()])) + _row.append(",".join([r.url for r in e.edge_label.rules.all()])) _row.append(e.start_nodes.all()[0].default_node_label.smiles) rows.append(_row) else: @@ -1576,16 +1651,21 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): @staticmethod @transaction.atomic - def create(package: 'Package', smiles: str, name: Optional[str] = None, description: Optional[str] = None): + def create( + package: "Package", + smiles: str, + name: Optional[str] = None, + description: Optional[str] = None, + ): pw = Pathway() pw.package = package - if name is None or name.strip() == '': + if name is None or name.strip() == "": name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}" pw.name = name - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": pw.description = description pw.save() @@ -1601,8 +1681,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): return pw @transaction.atomic - def copy(self, target: 'Package', mapping: Dict) -> 'Pathway': - + def copy(self, target: "Package", mapping: Dict) -> "Pathway": if self in mapping: return mapping[self] @@ -1611,8 +1690,8 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): package=target, name=self.name, description=self.description, - setting=self.setting, # TODO copy settings? - kv=self.kv.copy() if self.kv else {} + setting=self.setting, # TODO copy settings? + kv=self.kv.copy() if self.kv else {}, ) # # Copy aliases if they exist @@ -1641,7 +1720,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): depth=node.depth, name=node.name, description=node.description, - kv=node.kv.copy() if node.kv else {} + kv=node.kv.copy() if node.kv else {}, ) mapping[node] = new_node @@ -1665,7 +1744,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): edge_label=copied_reaction, name=edge.name, description=edge.description, - kv=edge.kv.copy() if edge.kv else {} + kv=edge.kv.copy() if edge.kv else {}, ) # Copy start and end nodes relationships @@ -1682,26 +1761,45 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): return new_pathway @transaction.atomic - def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None, depth: Optional[int] = 0): + def add_node( + self, + smiles: str, + name: Optional[str] = None, + description: Optional[str] = None, + depth: Optional[int] = 0, + ): return Node.create(self, smiles, depth, name=name, description=description) @transaction.atomic - def add_edge(self, start_nodes: List['Node'], end_nodes: List['Node'], rule: Optional['Rule'] = None, - name: Optional[str] = None, description: Optional[str] = None): + def add_edge( + self, + start_nodes: List["Node"], + end_nodes: List["Node"], + rule: Optional["Rule"] = None, + name: Optional[str] = None, + description: Optional[str] = None, + ): return Edge.create(self, start_nodes, end_nodes, rule, name=name, description=description) class Node(EnviPathModel, AliasMixin, ScenarioMixin): - pathway = models.ForeignKey('epdb.Pathway', verbose_name='belongs to', on_delete=models.CASCADE, db_index=True) - default_node_label = models.ForeignKey('epdb.CompoundStructure', verbose_name='Default Node Label', - on_delete=models.CASCADE, related_name='default_node_structure') - node_labels = models.ManyToManyField('epdb.CompoundStructure', verbose_name='All Node Labels', - related_name='node_structures') - out_edges = models.ManyToManyField('epdb.Edge', verbose_name='Outgoing Edges') - depth = models.IntegerField(verbose_name='Node depth', null=False, blank=False) + pathway = models.ForeignKey( + "epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True + ) + default_node_label = models.ForeignKey( + "epdb.CompoundStructure", + verbose_name="Default Node Label", + on_delete=models.CASCADE, + related_name="default_node_structure", + ) + node_labels = models.ManyToManyField( + "epdb.CompoundStructure", verbose_name="All Node Labels", related_name="node_structures" + ) + out_edges = models.ManyToManyField("epdb.Edge", verbose_name="Outgoing Edges") + depth = models.IntegerField(verbose_name="Node depth", null=False, blank=False) def _url(self): - return '{}/node/{}'.format(self.pathway.url, self.uuid) + return "{}/node/{}".format(self.pathway.url, self.uuid) def d3_json(self): app_domain_data = self.get_app_domain_assessment_data() @@ -1711,19 +1809,28 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): "url": self.url, "node_label_id": self.default_node_label.url, "image": f"{self.url}?image=svg", - "image_svg": IndigoUtils.mol_to_svg(self.default_node_label.smiles, width=40, height=40), + "image_svg": IndigoUtils.mol_to_svg( + self.default_node_label.smiles, width=40, height=40 + ), "name": self.default_node_label.name, "smiles": self.default_node_label.smiles, - "scenarios": [{'name': s.name, 'url': s.url} for s in self.scenarios.all()], + "scenarios": [{"name": s.name, "url": s.url} for s in self.scenarios.all()], "app_domain": { - 'inside_app_domain': app_domain_data['assessment']['inside_app_domain'] if app_domain_data else None, - 'uncovered_functional_groups': False, - } + "inside_app_domain": app_domain_data["assessment"]["inside_app_domain"] + if app_domain_data + else None, + "uncovered_functional_groups": False, + }, } @staticmethod - def create(pathway: 'Pathway', smiles: str, depth: int, name: Optional[str] = None, - description: Optional[str] = None): + def create( + pathway: "Pathway", + smiles: str, + depth: int, + name: Optional[str] = None, + description: Optional[str] = None, + ): c = Compound.create(pathway.package, smiles, name=name, description=description) if Node.objects.filter(pathway=pathway, default_node_label=c.default_structure).exists(): @@ -1746,7 +1853,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): return IndigoUtils.mol_to_svg(self.default_node_label.smiles) def get_app_domain_assessment_data(self): - data = self.kv.get('app_domain_assessment', None) + data = self.kv.get("app_domain_assessment", None) if data: rule_ids = defaultdict(list) @@ -1754,68 +1861,81 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): for r in e.edge_label.rules.all(): rule_ids[str(r.uuid)].append(e.simple_json()) - for t in data['assessment']['transformations']: - if t['rule']['uuid'] in rule_ids: - t['is_predicted'] = True - t['edges'] = rule_ids[t['rule']['uuid']] + for t in data["assessment"]["transformations"]: + if t["rule"]["uuid"] in rule_ids: + t["is_predicted"] = True + t["edges"] = rule_ids[t["rule"]["uuid"]] return data def simple_json(self, include_description=False): res = super().simple_json() - name = res.get('name', None) - if name == 'no name': - res['name'] = self.default_node_label.name + name = res.get("name", None) + if name == "no name": + res["name"] = self.default_node_label.name return res class Edge(EnviPathModel, AliasMixin, ScenarioMixin): - pathway = models.ForeignKey('epdb.Pathway', verbose_name='belongs to', on_delete=models.CASCADE, db_index=True) - edge_label = models.ForeignKey('epdb.Reaction', verbose_name='Edge label', null=True, on_delete=models.SET_NULL) - start_nodes = models.ManyToManyField('epdb.Node', verbose_name='Start Nodes', related_name='edge_educts') - end_nodes = models.ManyToManyField('epdb.Node', verbose_name='End Nodes', related_name='edge_products') + pathway = models.ForeignKey( + "epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True + ) + edge_label = models.ForeignKey( + "epdb.Reaction", verbose_name="Edge label", null=True, on_delete=models.SET_NULL + ) + start_nodes = models.ManyToManyField( + "epdb.Node", verbose_name="Start Nodes", related_name="edge_educts" + ) + end_nodes = models.ManyToManyField( + "epdb.Node", verbose_name="End Nodes", related_name="edge_products" + ) def _url(self): - return '{}/edge/{}'.format(self.pathway.url, self.uuid) + return "{}/edge/{}".format(self.pathway.url, self.uuid) def d3_json(self): edge_json = { - 'name': self.name, - 'id': self.url, - 'url': self.url, - 'image': self.url + '?image=svg', - 'reaction': {'name': self.edge_label.name, 'url': self.edge_label.url} if self.edge_label else None, - 'reaction_probability': self.kv.get('probability'), - 'start_node_urls': [x.url for x in self.start_nodes.all()], - 'end_node_urls': [x.url for x in self.end_nodes.all()], - "scenarios": [{'name': s.name, 'url': s.url} for s in self.scenarios.all()], + "name": self.name, + "id": self.url, + "url": self.url, + "image": self.url + "?image=svg", + "reaction": {"name": self.edge_label.name, "url": self.edge_label.url} + if self.edge_label + else None, + "reaction_probability": self.kv.get("probability"), + "start_node_urls": [x.url for x in self.start_nodes.all()], + "end_node_urls": [x.url for x in self.end_nodes.all()], + "scenarios": [{"name": s.name, "url": s.url} for s in self.scenarios.all()], } for n in self.start_nodes.all(): app_domain_data = n.get_app_domain_assessment_data() if app_domain_data: - for t in app_domain_data['assessment']['transformations']: - if 'edges' in t: - for e in t['edges']: - if e['uuid'] == str(self.uuid): + for t in app_domain_data["assessment"]["transformations"]: + if "edges" in t: + for e in t["edges"]: + if e["uuid"] == str(self.uuid): passes_app_domain = ( - t['local_compatibility'] >= app_domain_data['ad_params'][ - 'local_compatibility_threshold'] - ) and ( - t['reliability'] >= app_domain_data['ad_params'][ - 'reliability_threshold'] - ) + t["local_compatibility"] + >= app_domain_data["ad_params"]["local_compatibility_threshold"] + ) and ( + t["reliability"] + >= app_domain_data["ad_params"]["reliability_threshold"] + ) - edge_json['app_domain'] = { - 'passes_app_domain': passes_app_domain, - 'local_compatibility': t['local_compatibility'], - 'local_compatibility_threshold': app_domain_data['ad_params'][ - 'local_compatibility_threshold'], - 'reliability': t['reliability'], - 'reliability_threshold': app_domain_data['ad_params']['reliability_threshold'], - 'times_triggered': t['times_triggered'], + edge_json["app_domain"] = { + "passes_app_domain": passes_app_domain, + "local_compatibility": t["local_compatibility"], + "local_compatibility_threshold": app_domain_data["ad_params"][ + "local_compatibility_threshold" + ], + "reliability": t["reliability"], + "reliability_threshold": app_domain_data["ad_params"][ + "reliability_threshold" + ], + "times_triggered": t["times_triggered"], } break @@ -1823,9 +1943,14 @@ class Edge(EnviPathModel, AliasMixin, ScenarioMixin): return edge_json @staticmethod - def create(pathway, start_nodes: List[Node], end_nodes: List[Node], rule: Optional[Rule] = None, - name: Optional[str] = None, - description: Optional[str] = None): + def create( + pathway, + start_nodes: List[Node], + end_nodes: List[Node], + rule: Optional[Rule] = None, + name: Optional[str] = None, + description: Optional[str] = None, + ): e = Edge() e.pathway = pathway e.save() @@ -1837,16 +1962,20 @@ class Edge(EnviPathModel, AliasMixin, ScenarioMixin): e.end_nodes.add(node) if name is None: - name = f'Reaction {pathway.package.reactions.count() + 1}' + name = f"Reaction {pathway.package.reactions.count() + 1}" if description is None: - description = s.DEFAULT_VALUES['description'] + description = s.DEFAULT_VALUES["description"] - r = Reaction.create(pathway.package, name=name, description=description, - educts=[n.default_node_label for n in e.start_nodes.all()], - products=[n.default_node_label for n in e.end_nodes.all()], - rules=rule, multi_step=False - ) + r = Reaction.create( + pathway.package, + name=name, + description=description, + educts=[n.default_node_label for n in e.start_nodes.all()], + products=[n.default_node_label for n in e.end_nodes.all()], + rules=rule, + multi_step=False, + ) e.edge_label = r e.save() @@ -1858,31 +1987,43 @@ class Edge(EnviPathModel, AliasMixin, ScenarioMixin): def simple_json(self, include_description=False): res = super().simple_json() - name = res.get('name', None) - if name == 'no name': - res['name'] = self.edge_label.name + name = res.get("name", None) + if name == "no name": + res["name"] = self.edge_label.name return res class EPModel(PolymorphicModel, EnviPathModel): - package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True) + package = models.ForeignKey( + "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) def _url(self): - return '{}/model/{}'.format(self.package.url, self.uuid) + return "{}/model/{}".format(self.package.url, self.uuid) class PackageBasedModel(EPModel): - rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", - related_name="%(app_label)s_%(class)s_rule_packages") - data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", - related_name="%(app_label)s_%(class)s_data_packages") - eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", - related_name="%(app_label)s_%(class)s_eval_packages") + rule_packages = models.ManyToManyField( + "Package", + verbose_name="Rule Packages", + related_name="%(app_label)s_%(class)s_rule_packages", + ) + data_packages = models.ManyToManyField( + "Package", + verbose_name="Data Packages", + related_name="%(app_label)s_%(class)s_data_packages", + ) + eval_packages = models.ManyToManyField( + "Package", + verbose_name="Evaluation Packages", + related_name="%(app_label)s_%(class)s_eval_packages", + ) threshold = models.FloatField(null=False, blank=False, default=0.5) eval_results = JSONField(null=True, blank=True, default=dict) - app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True, - default=None) + app_domain = models.ForeignKey( + "epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None + ) multigen_eval = models.BooleanField(null=False, blank=False, default=False) INITIAL = "INITIAL" @@ -1899,9 +2040,11 @@ class PackageBasedModel(EPModel): BUILT_NOT_EVALUATED: "Model is built and can be used for predictions, Model is not evaluated yet.", EVALUATING: "Model is evaluating", FINISHED: "Model has finished building and evaluation.", - ERROR: "Model has failed." + ERROR: "Model has failed.", } - model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL) + model_status = models.CharField( + blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL + ) def status(self): return self.PROGRESS_STATUS_CHOICES[self.model_status] @@ -1916,19 +2059,21 @@ class PackageBasedModel(EPModel): res = [] - thresholds = self.eval_results['average_precision_per_threshold'].keys() + thresholds = self.eval_results["average_precision_per_threshold"].keys() for t in thresholds: - res.append({ - 'precision': self.eval_results['average_precision_per_threshold'][t], - 'recall': self.eval_results['average_recall_per_threshold'][t], - 'threshold': float(t) - }) + res.append( + { + "precision": self.eval_results["average_precision_per_threshold"][t], + "recall": self.eval_results["average_recall_per_threshold"][t], + "threshold": float(t), + } + ) return res @cached_property - def applicable_rules(self) -> List['Rule']: + def applicable_rules(self) -> List["Rule"]: """ Returns a ordered set of rules where the following applies: 1. All Composite will be added to result @@ -1983,7 +2128,7 @@ class PackageBasedModel(EPModel): ds.save(f) return ds - def load_dataset(self) -> 'Dataset': + def load_dataset(self) -> "Dataset": ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") return Dataset.load(ds_path) @@ -2025,7 +2170,6 @@ class PackageBasedModel(EPModel): self.save() def evaluate_model(self): - if self.model_status != self.BUILT_NOT_EVALUATED: raise ValueError(f"Can't evaluate a model in state {self.model_status}!") @@ -2033,15 +2177,11 @@ class PackageBasedModel(EPModel): self.save() def train_func(X, y, train_index, model_kwargs): - clz = model_kwargs.pop('clz') - if clz == 'RuleBaseRelativeReasoning': - mod = RelativeReasoning( - **model_kwargs - ) + clz = model_kwargs.pop("clz") + if clz == "RuleBaseRelativeReasoning": + mod = RelativeReasoning(**model_kwargs) else: - mod = EnsembleClassifierChain( - **model_kwargs - ) + mod = EnsembleClassifierChain(**model_kwargs) if train_index is not None: X, y = X[train_index], y[train_index] @@ -2072,19 +2212,23 @@ class PackageBasedModel(EPModel): for t in np.arange(0, 1.05, 0.05): temp_thresholded = (y_pred_filtered >= t).astype(int) - prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0) + prec[f"{t:.2f}"] = precision_score( + y_test_filtered, temp_thresholded, zero_division=0 + ) rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0) return acc, prec, rec - def evaluate_mg(model, pathways: Union[QuerySet['Pathway']| List['Pathway']], threshold): + def evaluate_mg(model, pathways: Union[QuerySet["Pathway"] | List["Pathway"]], threshold): thresholds = np.arange(0.1, 1.1, 0.1) precision = {f"{t:.2f}": [] for t in thresholds} recall = {f"{t:.2f}": [] for t in thresholds} # Note: only one root compound supported at this time - root_compounds = [[p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways] + root_compounds = [ + [p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways + ] # As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and # pass it to the setting used in prediction @@ -2132,7 +2276,9 @@ class PackageBasedModel(EPModel): # If there are eval packages perform single generation evaluation on them instead of random splits if self.eval_packages.count() > 0: - eval_reactions = list(Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()) + eval_reactions = list( + Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() + ) ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) if isinstance(self, RuleBasedRelativeReasoning): X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) @@ -2158,9 +2304,15 @@ class PackageBasedModel(EPModel): splits = list(shuff.split(X)) from joblib import Parallel, delayed - models = Parallel(n_jobs=10)(delayed(train_func)(X, y, train_index, self._model_args()) for train_index, _ in splits) - evaluations = Parallel(n_jobs=10)(delayed(evaluate_sg)(model, X, y, test_index, self.threshold) - for model, (_, test_index) in zip(models, splits)) + + models = Parallel(n_jobs=10)( + delayed(train_func)(X, y, train_index, self._model_args()) + for train_index, _ in splits + ) + evaluations = Parallel(n_jobs=10)( + delayed(evaluate_sg)(model, X, y, test_index, self.threshold) + for model, (_, test_index) in zip(models, splits) + ) self.eval_results = self.compute_averages(evaluations) @@ -2169,19 +2321,28 @@ class PackageBasedModel(EPModel): if self.eval_packages.count() > 0: pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct() multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold) - self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()}) + self.eval_results.update( + { + f"multigen_{k}": v + for k, v in self.compute_averages([multi_eval_result]).items() + } + ) else: - pathway_qs = Pathway.objects.prefetch_related( - 'node_set', - 'node_set__out_edges', - 'node_set__default_node_label', - 'node_set__scenarios', - 'edge_set', - 'edge_set__start_nodes', - 'edge_set__end_nodes', - 'edge_set__edge_label', - 'edge_set__scenarios' - ).filter(package__in=self.data_packages.all()).distinct() + pathway_qs = ( + Pathway.objects.prefetch_related( + "node_set", + "node_set__out_edges", + "node_set__default_node_label", + "node_set__scenarios", + "edge_set", + "edge_set__start_nodes", + "edge_set__end_nodes", + "edge_set__edge_label", + "edge_set__scenarios", + ) + .filter(package__in=self.data_packages.all()) + .distinct() + ) pathways = [] for pathway in pathway_qs: @@ -2189,7 +2350,9 @@ class PackageBasedModel(EPModel): if len(pathway.root_nodes) > 0: pathways.append(pathway) else: - logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") + logging.warning( + f"No root compound in pathway {pathway.name}, excluding from multigen evaluation" + ) # build lookup reaction -> {uuid1, uuid2} for overlap check reaction_to_educts = defaultdict(set) @@ -2203,7 +2366,9 @@ class PackageBasedModel(EPModel): # Compute splits of the collected pathway splits = [] - for train, test in ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways): + for train, test in ShuffleSplit( + n_splits=n_splits, test_size=0.25, random_state=42 + ).split(pathways): train_pathways = [pathways[i] for i in train] test_pathways = [pathways[i] for i in test] @@ -2225,30 +2390,37 @@ class PackageBasedModel(EPModel): if educt in id_to_index: split_ids.append(id_to_index[educt]) else: - logger.debug(f"Couldn't find features in X for compound {educt}") + logger.debug( + f"Couldn't find features in X for compound {educt}" + ) else: overlap += 1 logging.debug( - f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") + f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" + ) # Get the rows from the dataset corresponding to compounds in the training set pathways split_x, split_y = X[split_ids], y[split_ids] splits.append([(split_x, split_y), test_pathways]) - # Build model on subsets obtained by pathway split trained_models = Parallel(n_jobs=10)( - delayed(train_func)(split_x, split_y, np.arange(split_x.shape[0]), self._model_args()) for (split_x, split_y), _ in splits + delayed(train_func)( + split_x, split_y, np.arange(split_x.shape[0]), self._model_args() + ) + for (split_x, split_y), _ in splits ) # Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work multi_ret_vals = Parallel(n_jobs=1)( - delayed(evaluate_mg)(model, test_pathways, self.threshold) for model, (_, test_pathways) in - zip(trained_models, splits) + delayed(evaluate_mg)(model, test_pathways, self.threshold) + for model, (_, test_pathways) in zip(trained_models, splits) ) - self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()}) + self.eval_results.update( + {f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()} + ) self.model_status = self.FINISHED self.save() @@ -2273,11 +2445,11 @@ class PackageBasedModel(EPModel): return { "average_accuracy": float(avg_first_item), "average_precision_per_threshold": avg_dict2, - "average_recall_per_threshold": avg_dict3 + "average_recall_per_threshold": avg_dict3, } @staticmethod - def combine_products_and_probs(rules: List['Rule'], probabilities, products): + def combine_products_and_probs(rules: List["Rule"], probabilities, products): res = [] for rule, p, smis in zip(rules, probabilities, products): res.append(PredictionResult(smis, p, rule)) @@ -2293,19 +2465,26 @@ class RuleBasedRelativeReasoning(PackageBasedModel): @staticmethod @transaction.atomic - def create(package: 'Package', rule_packages: List['Package'], data_packages: List['Package'], - eval_packages: List['Package'], threshold: float = 0.5, min_count: int = 10, max_count: int = 0, - name: 'str' = None, description: str = None): - + def create( + package: "Package", + rule_packages: List["Package"], + data_packages: List["Package"], + eval_packages: List["Package"], + threshold: float = 0.5, + min_count: int = 10, + max_count: int = 0, + name: "str" = None, + description: str = None, + ): rbrr = RuleBasedRelativeReasoning() rbrr.package = package - if name is None or name.strip() == '': + if name is None or name.strip() == "": name = f"RuleBasedRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}" rbrr.name = name - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": rbrr.description = description if threshold is None or (threshold <= 0 or 1 <= threshold): @@ -2350,8 +2529,8 @@ class RuleBasedRelativeReasoning(PackageBasedModel): def _fit_model(self, ds: Dataset): X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None) model = RelativeReasoning( - start_index= ds.triggered()[0], - end_index= ds.triggered()[1], + start_index=ds.triggered()[0], + end_index=ds.triggered()[1], ) model.fit(X, y) return model @@ -2359,9 +2538,9 @@ class RuleBasedRelativeReasoning(PackageBasedModel): def _model_args(self): ds = self.load_dataset() return { - 'clz': 'RuleBaseRelativeReasoning', - 'start_index': ds.triggered()[0], - 'end_index': ds.triggered()[1], + "clz": "RuleBaseRelativeReasoning", + "start_index": ds.triggered()[0], + "end_index": ds.triggered()[1], } def _save_model(self, model): @@ -2369,11 +2548,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel): joblib.dump(model, f) @cached_property - def model(self) -> 'RelativeReasoning': - mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl')) + def model(self) -> "RelativeReasoning": + mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")) return mod - def predict(self, smiles) -> List['PredictionResult']: + def predict(self, smiles) -> List["PredictionResult"]: start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) @@ -2381,7 +2560,9 @@ class RuleBasedRelativeReasoning(PackageBasedModel): mod = self.model pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None)) - res = RuleBasedRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0]) + res = RuleBasedRelativeReasoning.combine_products_and_probs( + self.applicable_rules, pred[0], classify_prods[0] + ) end = datetime.now() logger.info(f"Full predict took {(end - start).total_seconds()}s") @@ -2389,24 +2570,30 @@ class RuleBasedRelativeReasoning(PackageBasedModel): class MLRelativeReasoning(PackageBasedModel): - @staticmethod @transaction.atomic - def create(package: 'Package', rule_packages: List['Package'], - data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5, - name: 'str' = None, description: str = None, build_app_domain: bool = False, - app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, - app_domain_local_compatibility_threshold: float = None): - + def create( + package: "Package", + rule_packages: List["Package"], + data_packages: List["Package"], + eval_packages: List["Package"], + threshold: float = 0.5, + name: "str" = None, + description: str = None, + build_app_domain: bool = False, + app_domain_num_neighbours: int = None, + app_domain_reliability_threshold: float = None, + app_domain_local_compatibility_threshold: float = None, + ): mlrr = MLRelativeReasoning() mlrr.package = package - if name is None or name.strip() == '': + if name is None or name.strip() == "": name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}" mlrr.name = name - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": mlrr.description = description if threshold is None or (threshold <= 0 or 1 <= threshold): @@ -2434,8 +2621,12 @@ class MLRelativeReasoning(PackageBasedModel): mlrr.eval_packages.add(p) if build_app_domain: - ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold, - app_domain_local_compatibility_threshold) + ad = ApplicabilityDomain.create( + mlrr, + app_domain_num_neighbours, + app_domain_reliability_threshold, + app_domain_local_compatibility_threshold, + ) mlrr.app_domain = ad mlrr.save() @@ -2445,15 +2636,13 @@ class MLRelativeReasoning(PackageBasedModel): def _fit_model(self, ds: Dataset): X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) - model = EnsembleClassifierChain( - **s.DEFAULT_MODEL_PARAMS - ) + model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) model.fit(X, y) return model def _model_args(self): return { - 'clz': 'MLRelativeReasoning', + "clz": "MLRelativeReasoning", **s.DEFAULT_MODEL_PARAMS, } @@ -2462,18 +2651,20 @@ class MLRelativeReasoning(PackageBasedModel): joblib.dump(model, f) @cached_property - def model(self) -> 'EnsembleClassifierChain': - mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl')) + def model(self) -> "EnsembleClassifierChain": + mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")) mod.base_clf.n_jobs = -1 return mod - def predict(self, smiles) -> List['PredictionResult']: + def predict(self, smiles) -> List["PredictionResult"]: start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) pred = self.model.predict_proba(classify_ds.X()) - res = MLRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0]) + res = MLRelativeReasoning.combine_products_and_probs( + self.applicable_rules, pred[0], classify_prods[0] + ) end = datetime.now() logger.info(f"Full predict took {(end - start).total_seconds()}s") @@ -2491,8 +2682,12 @@ class ApplicabilityDomain(EnviPathModel): @staticmethod @transaction.atomic - def create(mlrr: MLRelativeReasoning, num_neighbours: int = 5, reliability_threshold: float = 0.5, - local_compatibility_threshold: float = 0.5): + def create( + mlrr: MLRelativeReasoning, + num_neighbours: int = 5, + reliability_threshold: float = 0.5, + local_compatibility_threshold: float = 0.5, + ): ad = ApplicabilityDomain() ad.model = mlrr # ad.uuid = mlrr.uuid @@ -2505,7 +2700,7 @@ class ApplicabilityDomain(EnviPathModel): @cached_property def pca(self) -> ApplicabilityDomainPCA: - pca = joblib.load(os.path.join(s.MODEL_DIR, f'{self.model.uuid}_pca.pkl')) + pca = joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")) return pca @cached_property @@ -2527,7 +2722,9 @@ class ApplicabilityDomain(EnviPathModel): # Collect functional Groups together with their counts for reactivity center highlighting functional_groups_counts = defaultdict(int) - for cs in CompoundStructure.objects.filter(compound__package__in=self.model.data_packages.all()): + for cs in CompoundStructure.objects.filter( + compound__package__in=self.model.data_packages.all() + ): for fg in FormatConverter.get_functional_groups(cs.smiles): functional_groups_counts[fg] += 1 @@ -2540,7 +2737,7 @@ class ApplicabilityDomain(EnviPathModel): f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl") joblib.dump(ad, f) - def assess(self, structure: Union[str, 'CompoundStructure']): + def assess(self, structure: Union[str, "CompoundStructure"]): ds = self.model.load_dataset() if isinstance(structure, CompoundStructure): @@ -2548,7 +2745,9 @@ class ApplicabilityDomain(EnviPathModel): else: smiles = structure - assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules) + assessment_ds, assessment_prods = ds.classification_dataset( + [structure], self.model.applicable_rules + ) # qualified_neighbours_per_rule is a nested dictionary structured as: # { @@ -2561,11 +2760,13 @@ class ApplicabilityDomain(EnviPathModel): # it identifies all training structures that have the same trigger reaction activated (i.e., value 1). # This is used to find "qualified neighbours" — training examples that share the same triggered feature # with a given assessment structure under a particular rule. - qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(lambda: defaultdict(list)) + qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict( + lambda: defaultdict(list) + ) for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())): feature = ds.columns[feature_index] - if feature.startswith('trig_'): + if feature.startswith("trig_"): # TODO unroll loop for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)): if int(cx[feature_index]) == 1: @@ -2575,15 +2776,16 @@ class ApplicabilityDomain(EnviPathModel): probs = self.training_set_probs # preds = self.model.model.predict_proba(assessment_ds.X()) - preds = self.model.combine_products_and_probs(self.model.applicable_rules, - self.model.model.predict_proba(assessment_ds.X())[0], - assessment_prods[0]) + preds = self.model.combine_products_and_probs( + self.model.applicable_rules, + self.model.model.predict_proba(assessment_ds.X())[0], + assessment_prods[0], + ) assessments = list() # loop through our assessment dataset for i, instance in enumerate(assessment_ds): - rule_reliabilities = dict() local_compatibilities = dict() neighbours_per_rule = dict() @@ -2591,7 +2793,6 @@ class ApplicabilityDomain(EnviPathModel): # loop through rule indices together with the collected neighbours indices from train dataset for rule_idx, vals in qualified_neighbours_per_rule[i].items(): - # collect the train dataset instances and store it along with the index (a.k.a. row number) of the # train dataset train_instances = [] @@ -2604,8 +2805,8 @@ class ApplicabilityDomain(EnviPathModel): # compute tanimoto distance for all neighbours # result ist a list of tuples with train index and computed distance dists = self._compute_distances( - instance.X()[0][sf[0]:sf[1]], - [ti[1].X()[0][sf[0]:sf[1]] for ti in train_instances] + instance.X()[0][sf[0] : sf[1]], + [ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances], ) dists_with_index = list() @@ -2614,71 +2815,86 @@ class ApplicabilityDomain(EnviPathModel): # sort them in a descending way and take at most `self.num_neighbours` dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True) - dists_with_index = dists_with_index[:self.num_neighbours] + dists_with_index = dists_with_index[: self.num_neighbours] # compute average distance - rule_reliabilities[rule_idx] = sum([d[1] for d in dists_with_index]) / len(dists_with_index) if len( - dists_with_index) > 0 else 0.0 + rule_reliabilities[rule_idx] = ( + sum([d[1] for d in dists_with_index]) / len(dists_with_index) + if len(dists_with_index) > 0 + else 0.0 + ) # for local_compatibility we'll need the datasets for the indices having the highest similarity neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index] - local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets) - neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in - neighbour_datasets] - neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index] + local_compatibilities[rule_idx] = self._compute_compatibility( + rule_idx, probs, neighbour_datasets + ) + neighbours_per_rule[rule_idx] = [ + CompoundStructure.objects.get(uuid=ds[1].structure_id()) + for ds in neighbour_datasets + ] + neighbor_probs_per_rule[rule_idx] = [ + probs[d[0]][rule_idx] for d in dists_with_index + ] ad_res = { - 'ad_params': { - 'uuid': str(self.uuid), - 'model': self.model.simple_json(), - 'num_neighbours': self.num_neighbours, - 'reliability_threshold': self.reliability_threshold, - 'local_compatibility_threshold': self.local_compatibilty_threshold, + "ad_params": { + "uuid": str(self.uuid), + "model": self.model.simple_json(), + "num_neighbours": self.num_neighbours, + "reliability_threshold": self.reliability_threshold, + "local_compatibility_threshold": self.local_compatibilty_threshold, + }, + "assessment": { + "smiles": smiles, + "inside_app_domain": self.pca.is_applicable(instance)[0], }, - 'assessment': { - 'smiles': smiles, - 'inside_app_domain': self.pca.is_applicable(instance)[0], - } } transformations = list() for rule_idx in rule_reliabilities.keys(): - rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', '')) + rule = Rule.objects.get( + uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "") + ) rule_data = rule.simple_json() - rule_data['image'] = f"{rule.url}?image=svg" + rule_data["image"] = f"{rule.url}?image=svg" neighbors = [] - for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]): + for n, n_prob in zip( + neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx] + ): neighbor = n.simple_json() - neighbor['image'] = f"{n.url}?image=svg" - neighbor['smiles'] = n.smiles - neighbor['related_pathways'] = [ - pw.simple_json() for pw in Pathway.objects.filter( - node__default_node_label=n, - package__in=self.model.data_packages.all() + neighbor["image"] = f"{n.url}?image=svg" + neighbor["smiles"] = n.smiles + neighbor["related_pathways"] = [ + pw.simple_json() + for pw in Pathway.objects.filter( + node__default_node_label=n, package__in=self.model.data_packages.all() ).distinct() ] - neighbor['probability'] = n_prob + neighbor["probability"] = n_prob neighbors.append(neighbor) transformation = { - 'rule': rule_data, - 'reliability': rule_reliabilities[rule_idx], + "rule": rule_data, + "reliability": rule_reliabilities[rule_idx], # We're setting it here to False, as we don't know whether "assess" is called during pathway # prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime - 'is_predicted': False, - 'local_compatibility': local_compatibilities[rule_idx], - 'probability': preds[rule_idx].probability, - 'transformation_products': [x.product_set for x in preds[rule_idx].product_sets], - 'times_triggered': ds.times_triggered(str(rule.uuid)), - 'neighbors': neighbors, + "is_predicted": False, + "local_compatibility": local_compatibilities[rule_idx], + "probability": preds[rule_idx].probability, + "transformation_products": [ + x.product_set for x in preds[rule_idx].product_sets + ], + "times_triggered": ds.times_triggered(str(rule.uuid)), + "neighbors": neighbors, } transformations.append(transformation) - ad_res['assessment']['transformations'] = transformations + ad_res["assessment"]["transformations"] = transformations assessments.append(ad_res) @@ -2687,12 +2903,15 @@ class ApplicabilityDomain(EnviPathModel): @staticmethod def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]): from utilities.ml import tanimoto_distance - distances = [(i, tanimoto_distance(classify_instance, train)) for i, train in - enumerate(train_instances)] + + distances = [ + (i, tanimoto_distance(classify_instance, train)) + for i, train in enumerate(train_instances) + ] return distances @staticmethod - def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, 'Dataset']]): + def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]): tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0 accuracy = 0.0 @@ -2715,22 +2934,29 @@ class ApplicabilityDomain(EnviPathModel): class EnviFormer(PackageBasedModel): - @staticmethod @transaction.atomic - def create(package: 'Package', data_packages: List['Package'], eval_packages: List['Package'], - threshold: float = 0.5, name: 'str' = None, description: str = None, build_app_domain: bool = False, - app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None, - app_domain_local_compatibility_threshold: float = None): + def create( + package: "Package", + data_packages: List["Package"], + eval_packages: List["Package"], + threshold: float = 0.5, + name: "str" = None, + description: str = None, + build_app_domain: bool = False, + app_domain_num_neighbours: int = None, + app_domain_reliability_threshold: float = None, + app_domain_local_compatibility_threshold: float = None, + ): mod = EnviFormer() mod.package = package - if name is None or name.strip() == '': + if name is None or name.strip() == "": name = f"EnviFormer {EnviFormer.objects.filter(package=package).count() + 1}" mod.name = name - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": mod.description = description if threshold is None or (threshold <= 0 or 1 <= threshold): @@ -2761,15 +2987,21 @@ class EnviFormer(PackageBasedModel): @cached_property def model(self): from enviformer import load + ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt") return load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt) - def predict(self, smiles) -> List['PredictionResult']: + def predict(self, smiles) -> List["PredictionResult"]: return self.predict_batch([smiles])[0] def predict_batch(self, smiles_list): # Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately - canon_smiles = [".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]) for smiles in smiles_list] + canon_smiles = [ + ".".join( + [FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")] + ) + for smiles in smiles_list + ] logger.info(f"Submitting {canon_smiles} to {self.name}") products_list = self.model.predict_batch(canon_smiles) logger.info(f"Got results {products_list}") @@ -2779,7 +3011,12 @@ class EnviFormer(PackageBasedModel): res = [] for smi, prob in products.items(): try: - smi = ".".join([FormatConverter.standardize(smile, remove_stereo=True) for smile in smi.split(".")]) + smi = ".".join( + [ + FormatConverter.standardize(smile, remove_stereo=True) + for smile in smi.split(".") + ] + ) except ValueError: # This occurs when the predicted string is an invalid SMILES logging.info(f"EnviFormer predicted an invalid SMILES: {smi}") continue @@ -2796,8 +3033,18 @@ class EnviFormer(PackageBasedModel): # Standardise reactions for the training data, EnviFormer ignores stereochemistry currently ds = [] for reaction in self._get_reactions(): - educts = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()]) - products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()]) + educts = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.educts.all() + ] + ) + products = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.products.all() + ] + ) ds.append(f"{educts}>>{products}") end = datetime.now() @@ -2807,7 +3054,7 @@ class EnviFormer(PackageBasedModel): json.dump(ds, d_file) return ds - def load_dataset(self) -> 'Dataset': + def load_dataset(self) -> "Dataset": ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") with open(ds_path) as d_file: ds = json.load(d_file) @@ -2816,6 +3063,7 @@ class EnviFormer(PackageBasedModel): def _fit_model(self, ds): # Call to enviFormer's fine_tune function and return the model from enviformer.finetune import fine_tune + start = datetime.now() model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) end = datetime.now() @@ -2824,7 +3072,10 @@ class EnviFormer(PackageBasedModel): def _save_model(self, model): from enviformer.utils import save_model - save_model(model, os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")) + + save_model( + model, os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt") + ) def _model_args(self) -> Dict[str, Any]: args = {"clz": "EnviFormer"} @@ -2876,7 +3127,9 @@ class EnviFormer(PackageBasedModel): for true_set in product_sets: for threshold in correct: - pred_s = [s for i, s in enumerate(pred_smiles) if pred_scores[i] > threshold] + pred_s = [ + s for i, s in enumerate(pred_smiles) if pred_scores[i] > threshold + ] predicted[threshold] += len(pred_s) for pred_set in pred_s: if len(true_set - pred_set) == 0: @@ -2886,7 +3139,9 @@ class EnviFormer(PackageBasedModel): # Recall is TP (correct) / TP + FN (len(test_reactions)) rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()} # Precision is TP (correct) / TP + FP (predicted) - prec = {f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()} + prec = { + f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items() + } # Accuracy for EnviFormer is just recall return rec[f"{model_thresh:.2f}"], prec, rec @@ -2907,8 +3162,15 @@ class EnviFormer(PackageBasedModel): for p in pathways: root_node = p.root_nodes if len(root_node) > 1: - logging.warning(f"Pathway {p.name} has more than one root compound, only {root_node[0]} will be used") - root_node = ".".join([FormatConverter.standardize(smile) for smile in root_node[0].default_node_label.smiles.split(".")]) + logging.warning( + f"Pathway {p.name} has more than one root compound, only {root_node[0]} will be used" + ) + root_node = ".".join( + [ + FormatConverter.standardize(smile) + for smile in root_node[0].default_node_label.smiles.split(".") + ] + ) root_compounds.append(root_node) # As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and # pass it to the setting used in prediction @@ -2955,17 +3217,28 @@ class EnviFormer(PackageBasedModel): # If there are eval packages perform single generation evaluation on them instead of random splits if self.eval_packages.count() > 0: ds = [] - for reaction in Reaction.objects.filter(package__in=self.eval_packages.all()).distinct(): + for reaction in Reaction.objects.filter( + package__in=self.eval_packages.all() + ).distinct(): educts = ".".join( - [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()]) - products = ".".join([FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in - reaction.products.all()]) + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.educts.all() + ] + ) + products = ".".join( + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.products.all() + ] + ) ds.append(f"{educts}>>{products}") test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds]) single_gen_result = evaluate_sg(ds, test_result, self.threshold) self.eval_results = self.compute_averages([single_gen_result]) else: from enviformer.finetune import fine_tune + ds = self.load_dataset() n_splits = 20 shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) @@ -2979,7 +3252,9 @@ class EnviFormer(PackageBasedModel): start = datetime.now() model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) end = datetime.now() - logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") + logger.debug( + f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" + ) model.to(s.ENVIFORMER_DEVICE) test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) @@ -2990,19 +3265,28 @@ class EnviFormer(PackageBasedModel): if self.eval_packages.count() > 0: pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct() multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold) - self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages([multi_eval_result]).items()}) + self.eval_results.update( + { + f"multigen_{k}": v + for k, v in self.compute_averages([multi_eval_result]).items() + } + ) else: - pathway_qs = Pathway.objects.prefetch_related( - 'node_set', - 'node_set__out_edges', - 'node_set__default_node_label', - 'node_set__scenarios', - 'edge_set', - 'edge_set__start_nodes', - 'edge_set__end_nodes', - 'edge_set__edge_label', - 'edge_set__scenarios' - ).filter(package__in=self.data_packages.all()).distinct() + pathway_qs = ( + Pathway.objects.prefetch_related( + "node_set", + "node_set__out_edges", + "node_set__default_node_label", + "node_set__scenarios", + "edge_set", + "edge_set__start_nodes", + "edge_set__end_nodes", + "edge_set__edge_label", + "edge_set__scenarios", + ) + .filter(package__in=self.data_packages.all()) + .distinct() + ) pathways = [] for pathway in pathway_qs: @@ -3010,7 +3294,9 @@ class EnviFormer(PackageBasedModel): if len(pathway.root_nodes) > 0: pathways.append(pathway) else: - logging.warning(f"No root compound in pathway {pathway.name}, excluding from multigen evaluation") + logging.warning( + f"No root compound in pathway {pathway.name}, excluding from multigen evaluation" + ) # build lookup reaction -> {uuid1, uuid2} for overlap check reaction_to_educts = defaultdict(set) @@ -3022,7 +3308,9 @@ class EnviFormer(PackageBasedModel): multi_gen_results = [] # Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each # iteration instead of storing all trained models. - for split_id, (train, test) in enumerate(ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)): + for split_id, (train, test) in enumerate( + ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways) + ): train_pathways = [pathways[i] for i in train] test_pathways = [pathways[i] for i in test] @@ -3039,20 +3327,39 @@ class EnviFormer(PackageBasedModel): for pathway in train_pathways: for reaction in pathway.edges: reaction = reaction.edge_label - if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]): + if any( + [ + educt in test_educts + for educt in reaction_to_educts[str(reaction.uuid)] + ] + ): overlap += 1 continue educts = ".".join( - [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.educts.all()]) + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.educts.all() + ] + ) products = ".".join( - [FormatConverter.standardize(smile.smiles, remove_stereo=True) for smile in reaction.products.all()]) + [ + FormatConverter.standardize(smile.smiles, remove_stereo=True) + for smile in reaction.products.all() + ] + ) train_reactions.append(f"{educts}>>{products}") logging.debug( - f"{overlap} compounds had to be removed from multigen split due to overlap within pathways") + f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" + ) model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) - self.eval_results.update({f"multigen_{k}": v for k, v in self.compute_averages(multi_gen_results).items()}) + self.eval_results.update( + { + f"multigen_{k}": v + for k, v in self.compute_averages(multi_gen_results).items() + } + ) self.model_status = self.FINISHED self.save() @@ -3067,36 +3374,47 @@ class PluginModel(EPModel): class Scenario(EnviPathModel): - package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True) - scenario_date = models.CharField(max_length=256, null=False, blank=False, default='No date') - scenario_type = models.CharField(max_length=256, null=False, blank=False, default='Not specified') + package = models.ForeignKey( + "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) + scenario_date = models.CharField(max_length=256, null=False, blank=False, default="No date") + scenario_type = models.CharField( + max_length=256, null=False, blank=False, default="Not specified" + ) # for Referring Scenarios this property will be filled - parent = models.ForeignKey('self', on_delete=models.CASCADE, default=None, null=True) + parent = models.ForeignKey("self", on_delete=models.CASCADE, default=None, null=True) - additional_information = models.JSONField(verbose_name='Additional Information') + additional_information = models.JSONField(verbose_name="Additional Information") def _url(self): - return '{}/scenario/{}'.format(self.package.url, self.uuid) + return "{}/scenario/{}".format(self.package.url, self.uuid) @staticmethod @transaction.atomic - def create(package: 'Package', name:str, description:str, scenario_date:str, scenario_type:str, additional_information: List['EnviPyModel']): + def create( + package: "Package", + name: str, + description: str, + scenario_date: str, + scenario_type: str, + additional_information: List["EnviPyModel"], + ): s = Scenario() s.package = package - if name is None or name.strip() == '': + if name is None or name.strip() == "": name = f"Scenario {Scenario.objects.filter(package=package).count() + 1}" s.name = name - if description is not None and description.strip() != '': + if description is not None and description.strip() != "": s.description = description - if scenario_date is not None and scenario_date.strip() != '': + if scenario_date is not None and scenario_date.strip() != "": s.scenario_date = scenario_date - if scenario_type is not None and scenario_type.strip() != '': + if scenario_type is not None and scenario_type.strip() != "": s.scenario_type = scenario_type add_inf = defaultdict(list) @@ -3104,10 +3422,9 @@ class Scenario(EnviPathModel): for info in additional_information: cls_name = info.__class__.__name__ ai_data = json.loads(info.model_dump_json()) - ai_data['uuid'] = f"{uuid4()}" + ai_data["uuid"] = f"{uuid4()}" add_inf[cls_name].append(ai_data) - s.additional_information = add_inf s.save() @@ -3115,10 +3432,10 @@ class Scenario(EnviPathModel): return s @transaction.atomic - def add_additional_information(self, data: 'EnviPyModel'): + def add_additional_information(self, data: "EnviPyModel"): cls_name = data.__class__.__name__ ai_data = json.loads(data.model_dump_json()) - ai_data['uuid'] = f"{uuid4()}" + ai_data["uuid"] = f"{uuid4()}" if cls_name not in self.additional_information: self.additional_information[cls_name] = [] @@ -3126,7 +3443,6 @@ class Scenario(EnviPathModel): self.additional_information[cls_name].append(ai_data) self.save() - @transaction.atomic def remove_additional_information(self, ai_uuid): found_type = None @@ -3134,7 +3450,7 @@ class Scenario(EnviPathModel): for k, vals in self.additional_information.items(): for i, v in enumerate(vals): - if v['uuid'] == ai_uuid: + if v["uuid"] == ai_uuid: found_type = k found_idx = i break @@ -3149,15 +3465,15 @@ class Scenario(EnviPathModel): raise ValueError(f"Could not find additional information with uuid {ai_uuid}") @transaction.atomic - def set_additional_information(self, data: Dict[str, 'EnviPyModel']): + def set_additional_information(self, data: Dict[str, "EnviPyModel"]): new_ais = defaultdict(list) for k, vals in data.items(): for v in vals: ai_data = json.loads(v.model_dump_json()) - if hasattr(v, 'uuid'): - ai_data['uuid'] = str(v.uuid) + if hasattr(v, "uuid"): + ai_data["uuid"] = str(v.uuid) else: - ai_data['uuid'] = str(uuid4()) + ai_data["uuid"] = str(uuid4()) new_ais[k].append(ai_data) @@ -3168,7 +3484,7 @@ class Scenario(EnviPathModel): from envipy_additional_information import NAME_MAPPING for k, vals in self.additional_information.items(): - if k == 'enzyme': + if k == "enzyme": continue for v in vals: @@ -3176,20 +3492,23 @@ class Scenario(EnviPathModel): MAPPING = {c.__name__: c for c in NAME_MAPPING.values()} inst = MAPPING[k](**v) # Add uuid to uniquely identify objects for manipulation - if 'uuid' in v: - inst.__dict__['uuid'] = v['uuid'] + if "uuid" in v: + inst.__dict__["uuid"] = v["uuid"] yield inst class UserSettingPermission(Permission): - uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', primary_key=True, - default=uuid4) - user = models.ForeignKey('User', verbose_name='Permission to', on_delete=models.CASCADE) - setting = models.ForeignKey('epdb.Setting', verbose_name='Permission on', on_delete=models.CASCADE) + uuid = models.UUIDField( + null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4 + ) + user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE) + setting = models.ForeignKey( + "epdb.Setting", verbose_name="Permission on", on_delete=models.CASCADE + ) class Meta: - unique_together = [('setting', 'user')] + unique_together = [("setting", "user")] def __str__(self): return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}" @@ -3199,17 +3518,28 @@ class Setting(EnviPathModel): public = models.BooleanField(null=False, blank=False, default=False) global_default = models.BooleanField(null=False, blank=False, default=False) - max_depth = models.IntegerField(null=False, blank=False, verbose_name='Setting Max Depth', default=5) - max_nodes = models.IntegerField(null=False, blank=False, verbose_name='Setting Max Number of Nodes', default=30) + max_depth = models.IntegerField( + null=False, blank=False, verbose_name="Setting Max Depth", default=5 + ) + max_nodes = models.IntegerField( + null=False, blank=False, verbose_name="Setting Max Number of Nodes", default=30 + ) - rule_packages = models.ManyToManyField("Package", verbose_name="Setting Rule Packages", - related_name="setting_rule_packages", blank=True) - model = models.ForeignKey('EPModel', verbose_name='Setting EPModel', on_delete=models.SET_NULL, null=True, - blank=True) - model_threshold = models.FloatField(null=True, blank=True, verbose_name='Setting Model Threshold', default=0.25) + rule_packages = models.ManyToManyField( + "Package", + verbose_name="Setting Rule Packages", + related_name="setting_rule_packages", + blank=True, + ) + model = models.ForeignKey( + "EPModel", verbose_name="Setting EPModel", on_delete=models.SET_NULL, null=True, blank=True + ) + model_threshold = models.FloatField( + null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25 + ) def _url(self): - return '{}/setting/{}'.format(s.SERVER_URL, self.uuid) + return "{}/setting/{}".format(s.SERVER_URL, self.uuid) @cached_property def applicable_rules(self): @@ -3245,11 +3575,15 @@ class Setting(EnviPathModel): def expand(self, pathway, current_node): """Decision Method whether to expand on a certain Node or not""" if pathway.num_nodes() >= self.max_nodes: - logger.info(f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}") + logger.info( + f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}" + ) return [] if pathway.depth() >= self.max_depth: - logger.info(f"Pathway has reached depth {pathway.depth()} which exceeds the limit of {self.max_depth}") + logger.info( + f"Pathway has reached depth {pathway.depth()} which exceeds the limit of {self.max_depth}" + ) return [] transformations = [] diff --git a/epdb/tasks.py b/epdb/tasks.py index b0916f8a..aabaf8d1 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -2,55 +2,57 @@ import logging from typing import Optional from celery import shared_task -from epdb.models import Pathway, Node, Edge, EPModel, Setting +from epdb.models import Pathway, Node, EPModel, Setting from epdb.logic import SPathway logger = logging.getLogger(__name__) -@shared_task(queue='background') +@shared_task(queue="background") def mul(a, b): return a * b -@shared_task(queue='predict') +@shared_task(queue="predict") def predict_simple(model_pk: int, smiles: str): mod = EPModel.objects.get(id=model_pk) res = mod.predict(smiles) return res -@shared_task(queue='background') +@shared_task(queue="background") def send_registration_mail(user_pk: int): pass -@shared_task(queue='model') +@shared_task(queue="model") def build_model(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.build_dataset() mod.build_model() -@shared_task(queue='model') +@shared_task(queue="model") def evaluate_model(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.evaluate_model() -@shared_task(queue='model') +@shared_task(queue="model") def retrain(model_pk: int): mod = EPModel.objects.get(id=model_pk) mod.retrain() -@shared_task(queue='predict') -def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway: +@shared_task(queue="predict") +def predict( + pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None +) -> Pathway: pw = Pathway.objects.get(id=pw_pk) setting = Setting.objects.get(id=pred_setting_pk) - pw.kv.update(**{'status': 'running'}) + pw.kv.update(**{"status": "running"}) pw.save() try: @@ -74,12 +76,10 @@ def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_ else: raise ValueError("Neither limit nor node_pk given!") - - except Exception as e: - pw.kv.update({'status': 'failed'}) + pw.kv.update({"status": "failed"}) pw.save() raise e - pw.kv.update(**{'status': 'completed'}) - pw.save() \ No newline at end of file + pw.kv.update(**{"status": "completed"}) + pw.save() diff --git a/epdb/templatetags/envipytags.py b/epdb/templatetags/envipytags.py index ce2fa9d3..c8c92fef 100644 --- a/epdb/templatetags/envipytags.py +++ b/epdb/templatetags/envipytags.py @@ -2,6 +2,7 @@ from django import template register = template.Library() + @register.filter def classname(obj): - return obj.__class__.__name__ \ No newline at end of file + return obj.__class__.__name__ diff --git a/epdb/tests.py b/epdb/tests.py deleted file mode 100644 index 7ce503c2..00000000 --- a/epdb/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/epdb/urls.py b/epdb/urls.py index b10ddad5..16f0f2ba 100644 --- a/epdb/urls.py +++ b/epdb/urls.py @@ -3,97 +3,177 @@ from django.contrib.auth import views as auth_views from . import views as v -UUID = '[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}' +UUID = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" urlpatterns = [ # Home - re_path(r'^$', v.index, name='index'), - + re_path(r"^$", v.index, name="index"), # Login - re_path(r'^login', v.login, name='login'), - re_path(r'^logout', v.logout, name='logout'), - re_path(r'^register', v.register, name='register'), - - # Built In views - path('password_reset/', auth_views.PasswordResetView.as_view( - template_name='static/password_reset_form.html' - ), name='password_reset'), - - path('password_reset/done/', auth_views.PasswordResetDoneView.as_view( - template_name='static/password_reset_done.html' - ), name='password_reset_done'), - - path('reset///', auth_views.PasswordResetConfirmView.as_view( - template_name='static/password_reset_confirm.html' - ), name='password_reset_confirm'), - - path('reset/done/', auth_views.PasswordResetCompleteView.as_view( - template_name='static/password_reset_complete.html' - ), name='password_reset_complete'), - - + re_path(r"^login", v.login, name="login"), + re_path(r"^logout", v.logout, name="logout"), + re_path(r"^register", v.register, name="register"), + # Built-In views + path( + "password_reset/", + auth_views.PasswordResetView.as_view(template_name="static/password_reset_form.html"), + name="password_reset", + ), + path( + "password_reset/done/", + auth_views.PasswordResetDoneView.as_view(template_name="static/password_reset_done.html"), + name="password_reset_done", + ), + path( + "reset///", + auth_views.PasswordResetConfirmView.as_view( + template_name="static/password_reset_confirm.html" + ), + name="password_reset_confirm", + ), + path( + "reset/done/", + auth_views.PasswordResetCompleteView.as_view( + template_name="static/password_reset_complete.html" + ), + name="password_reset_complete", + ), # Top level urls - re_path(r'^package$', v.packages, name='packages'), - re_path(r'^compound$', v.compounds, name='compounds'), - re_path(r'^rule$', v.rules, name='rules'), - re_path(r'^reaction$', v.reactions, name='reactions'), - re_path(r'^pathway$', v.pathways, name='pathways'), - re_path(r'^scenario$', v.scenarios, name='scenarios'), - re_path(r'^model$', v.models, name='model'), - re_path(r'^user$', v.users, name='users'), - re_path(r'^group$', v.groups, name='groups'), - re_path(r'^search$', v.search, name='search'), - - + re_path(r"^package$", v.packages, name="packages"), + re_path(r"^compound$", v.compounds, name="compounds"), + re_path(r"^rule$", v.rules, name="rules"), + re_path(r"^reaction$", v.reactions, name="reactions"), + re_path(r"^pathway$", v.pathways, name="pathways"), + re_path(r"^scenario$", v.scenarios, name="scenarios"), + re_path(r"^model$", v.models, name="model"), + re_path(r"^user$", v.users, name="users"), + re_path(r"^group$", v.groups, name="groups"), + re_path(r"^search$", v.search, name="search"), # User Detail - re_path(rf'^user/(?P{UUID})', v.user, name='user'), + re_path(rf"^user/(?P{UUID})", v.user, name="user"), # Group Detail - re_path(rf'^group/(?P{UUID})$', v.group, name='group detail'), - + re_path(rf"^group/(?P{UUID})$", v.group, name="group detail"), # "in package" urls - re_path(rf'^package/(?P{UUID})$', v.package, name='package detail'), + re_path(rf"^package/(?P{UUID})$", v.package, name="package detail"), # Compound - re_path(rf'^package/(?P{UUID})/compound$', v.package_compounds, name='package compound list'), - re_path(rf'^package/(?P{UUID})/compound/(?P{UUID})$', v.package_compound, name='package compound detail'), + re_path( + rf"^package/(?P{UUID})/compound$", + v.package_compounds, + name="package compound list", + ), + re_path( + rf"^package/(?P{UUID})/compound/(?P{UUID})$", + v.package_compound, + name="package compound detail", + ), # Compound Structure - re_path(rf'^package/(?P{UUID})/compound/(?P{UUID})/structure$', v.package_compound_structures, name='package compound structure list'), - re_path(rf'^package/(?P{UUID})/compound/(?P{UUID})/structure/(?P{UUID})$', v.package_compound_structure, name='package compound structure detail'), + re_path( + rf"^package/(?P{UUID})/compound/(?P{UUID})/structure$", + v.package_compound_structures, + name="package compound structure list", + ), + re_path( + rf"^package/(?P{UUID})/compound/(?P{UUID})/structure/(?P{UUID})$", + v.package_compound_structure, + name="package compound structure detail", + ), # Rule - re_path(rf'^package/(?P{UUID})/rule$', v.package_rules, name='package rule list'), - re_path(rf'^package/(?P{UUID})/rule/(?P{UUID})$', v.package_rule, name='package rule detail'), - re_path(rf'^package/(?P{UUID})/simple-ambit-rule/(?P{UUID})$', v.package_rule, name='package rule detail'), - re_path(rf'^package/(?P{UUID})/simple-rdkit-rule/(?P{UUID})$', v.package_rule, name='package rule detail'), - re_path(rf'^package/(?P{UUID})/parallel-rule/(?P{UUID})$', v.package_rule, name='package rule detail'), - re_path(rf'^package/(?P{UUID})/sequential-rule/(?P{UUID})$', v.package_rule, name='package rule detail'), + re_path(rf"^package/(?P{UUID})/rule$", v.package_rules, name="package rule list"), + re_path( + rf"^package/(?P{UUID})/rule/(?P{UUID})$", + v.package_rule, + name="package rule detail", + ), + re_path( + rf"^package/(?P{UUID})/simple-ambit-rule/(?P{UUID})$", + v.package_rule, + name="package rule detail", + ), + re_path( + rf"^package/(?P{UUID})/simple-rdkit-rule/(?P{UUID})$", + v.package_rule, + name="package rule detail", + ), + re_path( + rf"^package/(?P{UUID})/parallel-rule/(?P{UUID})$", + v.package_rule, + name="package rule detail", + ), + re_path( + rf"^package/(?P{UUID})/sequential-rule/(?P{UUID})$", + v.package_rule, + name="package rule detail", + ), # Reaction - re_path(rf'^package/(?P{UUID})/reaction$', v.package_reactions, name='package reaction list'), - re_path(rf'^package/(?P{UUID})/reaction/(?P{UUID})$', v.package_reaction, name='package reaction detail'), + re_path( + rf"^package/(?P{UUID})/reaction$", + v.package_reactions, + name="package reaction list", + ), + re_path( + rf"^package/(?P{UUID})/reaction/(?P{UUID})$", + v.package_reaction, + name="package reaction detail", + ), # # Pathway - re_path(rf'^package/(?P{UUID})/pathway$', v.package_pathways, name='package pathway list'), - re_path(rf'^package/(?P{UUID})/pathway/(?P{UUID})$', v.package_pathway, name='package pathway detail'), + re_path( + rf"^package/(?P{UUID})/pathway$", + v.package_pathways, + name="package pathway list", + ), + re_path( + rf"^package/(?P{UUID})/pathway/(?P{UUID})$", + v.package_pathway, + name="package pathway detail", + ), # Pathway Nodes - re_path(rf'^package/(?P{UUID})/pathway/(?P{UUID})/node$', v.package_pathway_nodes, name='package pathway node list'), - re_path(rf'^package/(?P{UUID})/pathway/(?P{UUID})/node/(?P{UUID})$', v.package_pathway_node, name='package pathway node detail'), + re_path( + rf"^package/(?P{UUID})/pathway/(?P{UUID})/node$", + v.package_pathway_nodes, + name="package pathway node list", + ), + re_path( + rf"^package/(?P{UUID})/pathway/(?P{UUID})/node/(?P{UUID})$", + v.package_pathway_node, + name="package pathway node detail", + ), # Pathway Edges - re_path(rf'^package/(?P{UUID})/pathway/(?P{UUID})/edge$', v.package_pathway_edges, name='package pathway edge list'), - re_path(rf'^package/(?P{UUID})/pathway/(?P{UUID})/edge/(?P{UUID})$', v.package_pathway_edge, name='package pathway edge detail'), + re_path( + rf"^package/(?P{UUID})/pathway/(?P{UUID})/edge$", + v.package_pathway_edges, + name="package pathway edge list", + ), + re_path( + rf"^package/(?P{UUID})/pathway/(?P{UUID})/edge/(?P{UUID})$", + v.package_pathway_edge, + name="package pathway edge detail", + ), # Scenario - re_path(rf'^package/(?P{UUID})/scenario$', v.package_scenarios, name='package scenario list'), - re_path(rf'^package/(?P{UUID})/scenario/(?P{UUID})$', v.package_scenario, name='package scenario detail'), + re_path( + rf"^package/(?P{UUID})/scenario$", + v.package_scenarios, + name="package scenario list", + ), + re_path( + rf"^package/(?P{UUID})/scenario/(?P{UUID})$", + v.package_scenario, + name="package scenario detail", + ), # Model - re_path(rf'^package/(?P{UUID})/model$', v.package_models, name='package model list'), - re_path(rf'^package/(?P{UUID})/model/(?P{UUID})$', v.package_model,name='package model detail'), - - re_path(r'^setting$', v.settings, name='settings'), - re_path(rf'^setting/(?P{UUID})', v.setting, name='setting'), - - re_path(r'^indigo/info$', v.indigo, name='indigo_info'), - re_path(r'^indigo/aromatize$', v.aromatize, name='indigo_aromatize'), - re_path(r'^indigo/dearomatize$', v.dearomatize, name='indigo_dearomatize'), - re_path(r'^indigo/layout$', v.layout, name='indigo_layout'), - - re_path(r'^depict$', v.depict, name='depict'), - + re_path( + rf"^package/(?P{UUID})/model$", v.package_models, name="package model list" + ), + re_path( + rf"^package/(?P{UUID})/model/(?P{UUID})$", + v.package_model, + name="package model detail", + ), + re_path(r"^setting$", v.settings, name="settings"), + re_path(rf"^setting/(?P{UUID})", v.setting, name="setting"), + re_path(r"^indigo/info$", v.indigo, name="indigo_info"), + re_path(r"^indigo/aromatize$", v.aromatize, name="indigo_aromatize"), + re_path(r"^indigo/dearomatize$", v.dearomatize, name="indigo_dearomatize"), + re_path(r"^indigo/layout$", v.layout, name="indigo_layout"), + re_path(r"^depict$", v.depict, name="depict"), # OAuth Stuff path("o/userinfo/", v.userinfo, name="oauth_userinfo"), ] diff --git a/epdb/views.py b/epdb/views.py index b9010014..b13c0b39 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -14,10 +14,39 @@ from oauth2_provider.decorators import protected_resource from utilities.chem import FormatConverter, IndigoUtils from utilities.decorators import package_permission_required from utilities.misc import HTMLGenerator -from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser -from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \ - EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \ - UserPackagePermission, Permission, License, User, Edge, ExternalDatabase, ExternalIdentifier +from .logic import ( + GroupManager, + PackageManager, + UserManager, + SettingManager, + SearchManager, + EPDBURLParser, +) +from .models import ( + Package, + GroupPackagePermission, + Group, + CompoundStructure, + Compound, + Reaction, + Rule, + Pathway, + Node, + EPModel, + EnviFormer, + MLRelativeReasoning, + RuleBasedRelativeReasoning, + Scenario, + SimpleAmbitRule, + APIToken, + UserPackagePermission, + Permission, + License, + User, + Edge, + ExternalDatabase, + ExternalIdentifier, +) logger = logging.getLogger(__name__) @@ -31,11 +60,11 @@ def log_post_params(request): def error(request, message: str, detail: str, code: int = 400): context = get_base_context(request) error_context = { - 'error_message': message, - 'error_detail': detail, + "error_message": message, + "error_detail": detail, } - if request.headers.get('Accept') == 'application/json': + if request.headers.get("Accept") == "application/json": return JsonResponse(error_context, status=500) context.update(**error_context) @@ -43,59 +72,59 @@ def error(request, message: str, detail: str, code: int = 400): def login(request): - current_user = _anonymous_or_real(request) context = get_base_context(request) - if request.method == 'GET': - context['title'] = 'enviPath' - context['next'] = request.GET.get('next', '') - return render(request, 'static/login.html', context) + if request.method == "GET": + context["title"] = "enviPath" + context["next"] = request.GET.get("next", "") + return render(request, "static/login.html", context) - elif request.method == 'POST': + elif request.method == "POST": from django.contrib.auth import authenticate from django.contrib.auth import login - username = request.POST.get('username') - password = request.POST.get('password') + username = request.POST.get("username") + password = request.POST.get("password") # Get email for username and check if the account is active try: temp_user = get_user_model().objects.get(username=username) if not temp_user.is_active: - context['message'] = "User account is not activated yet!" - return render(request, 'static/login.html', context) + context["message"] = "User account is not activated yet!" + return render(request, "static/login.html", context) email = temp_user.email except get_user_model().DoesNotExist: - context['message'] = "Login failed!" - return render(request, 'static/login.html', context) + context["message"] = "Login failed!" + return render(request, "static/login.html", context) try: user = authenticate(username=email, password=password) - except Exception as e: - context['message'] = "Login failed!" - return render(request, 'static/login.html', context) + except Exception: + context["message"] = "Login failed!" + return render(request, "static/login.html", context) if user is not None: login(request, user) - if next := request.POST.get('next'): + if next := request.POST.get("next"): return redirect(next) return redirect(reverse("index")) else: - context['message'] = "Login failed!" - return render(request, 'static/login.html', context) + context["message"] = "Login failed!" + return render(request, "static/login.html", context) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def logout(request): - if request.method == 'POST': - is_logout = bool(request.POST.get('logout', False)) + if request.method == "POST": + is_logout = bool(request.POST.get("logout", False)) if is_logout: from django.contrib.auth import logout + logout(request) return redirect(s.SERVER_URL) @@ -103,49 +132,52 @@ def logout(request): def register(request): - current_user = _anonymous_or_real(request) context = get_base_context(request) - if request.method == 'GET': - context['title'] = 'enviPath' - context['next'] = request.GET.get('next', '') - return render(request, 'static/register.html', context) - elif request.method == 'POST': - context['title'] = 'enviPath' - if next := request.POST.get('next'): - context['next'] = next + if request.method == "GET": + context["title"] = "enviPath" + context["next"] = request.GET.get("next", "") + return render(request, "static/register.html", context) + elif request.method == "POST": + context["title"] = "enviPath" + if next := request.POST.get("next"): + context["next"] = next - username = request.POST.get('username', '').strip() - email = request.POST.get('email', '').strip() - password = request.POST.get('password', '').strip() - rpassword = request.POST.get('rpassword', '').strip() + username = request.POST.get("username", "").strip() + email = request.POST.get("email", "").strip() + password = request.POST.get("password", "").strip() + rpassword = request.POST.get("rpassword", "").strip() if not (username and email and password): context["message"] = "Invalid username/email/password" - return render(request, 'static/register.html', context) + return render(request, "static/register.html", context) - if password != rpassword or password == '': - context['message'] = "Registration failed, provided passwords differ!" - return render(request, 'static/register.html', context) + if password != rpassword or password == "": + context["message"] = "Registration failed, provided passwords differ!" + return render(request, "static/register.html", context) try: u = UserManager.create_user(username, email, password) + logger.info(f"Created user {u.username} ({u.pk})") except Exception: - context['message'] = "Registration failed! Couldn't create User Account." - return render(request, 'static/register.html', context) + context["message"] = "Registration failed! Couldn't create User Account." + return render(request, "static/register.html", context) if s.ADMIN_APPROVAL_REQUIRED: - context['success_message'] = "Your account has been created! An admin will activate it soon!" + context["success_message"] = ( + "Your account has been created! An admin will activate it soon!" + ) else: - context['success_message'] = "Account has been created! You'll receive a mail to activate your account shortly." + context["success_message"] = ( + "Account has been created! You'll receive a mail to activate your account shortly." + ) - return render(request, 'static/login.html', context) + return render(request, "static/login.html", context) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def editable(request, user): - if user.is_superuser: return True @@ -159,8 +191,14 @@ def editable(request, user): elif UserManager.is_user_url(url): _user = UserManager.get_user_lp(request.build_absolute_uri()) return UserManager.writable(user, _user) - elif url in [s.SERVER_URL, f"{s.SERVER_URL}/", f"{s.SERVER_URL}/package", f"{s.SERVER_URL}/user", - f"{s.SERVER_URL}/group", f"{s.SERVER_URL}/search"]: + elif url in [ + s.SERVER_URL, + f"{s.SERVER_URL}/", + f"{s.SERVER_URL}/package", + f"{s.SERVER_URL}/user", + f"{s.SERVER_URL}/group", + f"{s.SERVER_URL}/search", + ]: return True else: logger.debug(f"Unknown url: {url}") @@ -181,20 +219,22 @@ def get_base_context(request, for_user=None) -> Dict[str, Any]: current_user = for_user ctx = { - 'title': 'enviPath', - 'meta': { - 'version': '0.0.1', - 'server_url': s.SERVER_URL, - 'user': current_user, - 'can_edit': can_edit, - 'url_contains_package': url_contains_package, - 'readable_packages': PackageManager.get_all_readable_packages(current_user, include_reviewed=True), - 'writeable_packages': PackageManager.get_all_writeable_packages(current_user), - 'available_groups': GroupManager.get_groups(current_user), - 'available_settings': SettingManager.get_all_settings(current_user), - 'enabled_features': s.FLAGS, - 'debug': s.DEBUG, - 'external_databases': ExternalDatabase.get_databases(), + "title": "enviPath", + "meta": { + "version": "0.0.1", + "server_url": s.SERVER_URL, + "user": current_user, + "can_edit": can_edit, + "url_contains_package": url_contains_package, + "readable_packages": PackageManager.get_all_readable_packages( + current_user, include_reviewed=True + ), + "writeable_packages": PackageManager.get_all_writeable_packages(current_user), + "available_groups": GroupManager.get_groups(current_user), + "available_settings": SettingManager.get_all_settings(current_user), + "enabled_features": s.FLAGS, + "debug": s.DEBUG, + "external_databases": ExternalDatabase.get_databases(), }, } @@ -204,26 +244,41 @@ def get_base_context(request, for_user=None) -> Dict[str, Any]: def _anonymous_or_real(request): if request.user.is_authenticated and not request.user.is_anonymous: return request.user - return get_user_model().objects.get(username='anonymous') + return get_user_model().objects.get(username="anonymous") -def breadcrumbs(first_level_object=None, second_level_namespace=None, second_level_object=None, - third_level_namespace=None, third_level_object=None) -> List[Dict[str, str]]: +def breadcrumbs( + first_level_object=None, + second_level_namespace=None, + second_level_object=None, + third_level_namespace=None, + third_level_object=None, +) -> List[Dict[str, str]]: bread = [ - {'Home': s.SERVER_URL}, - {'Package': s.SERVER_URL + '/package'}, + {"Home": s.SERVER_URL}, + {"Package": s.SERVER_URL + "/package"}, ] if first_level_object is not None: bread.append({first_level_object.name: first_level_object.url}) if second_level_namespace is not None: - bread.append({f'{second_level_namespace}'.capitalize(): first_level_object.url + f'/{second_level_namespace}'}) + bread.append( + { + f"{second_level_namespace}".capitalize(): first_level_object.url + + f"/{second_level_namespace}" + } + ) if second_level_object is not None: bread.append({second_level_object.name: second_level_object.url}) if third_level_namespace is not None: - bread.append({f'{third_level_namespace}'.capitalize(): second_level_object.url + f'/{third_level_namespace}'}) + bread.append( + { + f"{third_level_namespace}".capitalize(): second_level_object.url + + f"/{third_level_namespace}" + } + ) if third_level_object is not None: bread.append({third_level_object.name: third_level_object.url}) @@ -234,19 +289,18 @@ def breadcrumbs(first_level_object=None, second_level_namespace=None, second_lev def set_scenarios(current_user, attach_object, scenario_urls: List[str]): scens = [] for scenario_url in scenario_urls: - # As empty lists will be removed in POST request well send [''] - if scenario_url == '': + if scenario_url == "": continue package = PackageManager.get_package_by_url(current_user, scenario_url) - scen = Scenario.objects.get(package=package, uuid=scenario_url.split('/')[-1]) + scen = Scenario.objects.get(package=package, uuid=scenario_url.split("/")[-1]) scens.append(scen) attach_object.set_scenarios(scens) -def copy_object(current_user, target_package: 'Package', source_object_url: str): +def copy_object(current_user, target_package: "Package", source_object_url: str): # Ensures that source is readable source_package = PackageManager.get_package_by_url(current_user, source_object_url) @@ -262,7 +316,7 @@ def copy_object(current_user, target_package: 'Package', source_object_url: str) # Gets the most specific object source_object = parser.get_object() - if hasattr(source_object, 'copy'): + if hasattr(source_object, "copy"): mapping = dict() copy = source_object.copy(target_package, mapping) @@ -274,327 +328,346 @@ def copy_object(current_user, target_package: 'Package', source_object_url: str) raise ValueError(f"Object {source_object} can't be copied!") + def index(request): context = get_base_context(request) - context['title'] = 'enviPath - Home' - context['meta']['current_package'] = context['meta']['user'].default_package + context["title"] = "enviPath - Home" + context["meta"]["current_package"] = context["meta"]["user"].default_package - if request.GET.get('getMLServerPath', False): + if request.GET.get("getMLServerPath", False): return JsonResponse({"mlServerPath": s.SERVER_URL}) - return render(request, 'index/index.html', context) + return render(request, "index/index.html", context) def packages(request): current_user = _anonymous_or_real(request) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Packages' + context["title"] = "enviPath - Packages" - context['object_type'] = 'package' - context['meta']['current_package'] = context['meta']['user'].default_package - context['meta']['can_edit'] = True + context["object_type"] = "package" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["meta"]["can_edit"] = True - reviewed_package_qs = Package.objects.filter(reviewed=True).order_by('created') - unreviewed_package_qs = PackageManager.get_all_readable_packages(current_user).order_by('name') + reviewed_package_qs = Package.objects.filter(reviewed=True).order_by("created") + unreviewed_package_qs = PackageManager.get_all_readable_packages(current_user).order_by( + "name" + ) - context['reviewed_objects'] = reviewed_package_qs - context['unreviewed_objects'] = unreviewed_package_qs + context["reviewed_objects"] = reviewed_package_qs + context["unreviewed_objects"] = unreviewed_package_qs - return render(request, 'collections/objects_list.html', context) + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': - - - if hidden := request.POST.get('hidden', None): - - if hidden in ['import-legacy-package-json', 'import-package-json']: - f = request.FILES['file'] + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden in ["import-legacy-package-json", "import-package-json"]: + f = request.FILES["file"] try: file_data = f.read().decode("utf-8") data = json.loads(file_data) - if hidden == 'import-legacy-package-json': + if hidden == "import-legacy-package-json": pack = PackageManager.import_legacy_package(data, current_user) else: pack = PackageManager.import_package(data, current_user) return redirect(pack.url) except UnicodeDecodeError: - return error(request, 'Invalid encoding.', f'Invalid encoding, must be UTF-8') + return error(request, "Invalid encoding.", "Invalid encoding, must be UTF-8") else: return HttpResponseBadRequest() else: - package_name = request.POST.get('package-name') - package_description = request.POST.get('package-description', s.DEFAULT_VALUES['description']) + package_name = request.POST.get("package-name") + package_description = request.POST.get( + "package-description", s.DEFAULT_VALUES["description"] + ) - created_package = PackageManager.create_package(current_user, package_name, package_description) + created_package = PackageManager.create_package( + current_user, package_name, package_description + ) return redirect(created_package.url) - elif request.method == 'OPTIONS': + elif request.method == "OPTIONS": response = HttpResponse() - response['allow'] = ','.join(['GET', 'POST']) + response["allow"] = ",".join(["GET", "POST"]) return response else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def compounds(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Compounds' + context["title"] = "enviPath - Compounds" - context['object_type'] = 'compound' - context['meta']['current_package'] = context['meta']['user'].default_package + context["object_type"] = "compound" + context["meta"]["current_package"] = context["meta"]["user"].default_package reviewed_compound_qs = Compound.objects.none() for p in PackageManager.get_reviewed_packages(): reviewed_compound_qs |= Compound.objects.filter(package=p) - reviewed_compound_qs = reviewed_compound_qs.order_by('name') + reviewed_compound_qs = reviewed_compound_qs.order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": True} - for pw in reviewed_compound_qs - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": True} + for pw in reviewed_compound_qs + ] + } + ) - context['reviewed_objects'] = reviewed_compound_qs - return render(request, 'collections/objects_list.html', context) + context["reviewed_objects"] = reviewed_compound_qs + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': + elif request.method == "POST": # delegate to default package current_user = _anonymous_or_real(request) default_package = current_user.default_package return package_compounds(request, default_package.uuid) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def rules(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Rules' + context["title"] = "enviPath - Rules" - context['object_type'] = 'rule' - context['meta']['current_package'] = context['meta']['user'].default_package - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Rule': s.SERVER_URL + '/rule'}, + context["object_type"] = "rule" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Rule": s.SERVER_URL + "/rule"}, ] reviewed_rule_qs = Rule.objects.none() for p in PackageManager.get_reviewed_packages(): reviewed_rule_qs |= Rule.objects.filter(package=p) - reviewed_rule_qs = reviewed_rule_qs.order_by('name') + reviewed_rule_qs = reviewed_rule_qs.order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": True} - for pw in reviewed_rule_qs - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": True} + for pw in reviewed_rule_qs + ] + } + ) - context['reviewed_objects'] = reviewed_rule_qs - return render(request, 'collections/objects_list.html', context) + context["reviewed_objects"] = reviewed_rule_qs + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': + elif request.method == "POST": # delegate to default package current_user = _anonymous_or_real(request) default_package = current_user.default_package return package_rules(request, default_package.uuid) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def reactions(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Reactions' + context["title"] = "enviPath - Reactions" - context['object_type'] = 'reaction' - context['meta']['current_package'] = context['meta']['user'].default_package - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Reaction': s.SERVER_URL + '/reaction'}, + context["object_type"] = "reaction" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Reaction": s.SERVER_URL + "/reaction"}, ] reviewed_reaction_qs = Reaction.objects.none() for p in PackageManager.get_reviewed_packages(): - reviewed_reaction_qs |= Reaction.objects.filter(package=p).order_by('name') + reviewed_reaction_qs |= Reaction.objects.filter(package=p).order_by("name") - reviewed_reaction_qs = reviewed_reaction_qs.order_by('name') + reviewed_reaction_qs = reviewed_reaction_qs.order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": True} - for pw in reviewed_reaction_qs - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": True} + for pw in reviewed_reaction_qs + ] + } + ) - context['reviewed_objects'] = reviewed_reaction_qs - return render(request, 'collections/objects_list.html', context) + context["reviewed_objects"] = reviewed_reaction_qs + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': + elif request.method == "POST": # delegate to default package current_user = _anonymous_or_real(request) default_package = current_user.default_package return package_reactions(request, default_package.uuid) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def pathways(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Pathways' + context["title"] = "enviPath - Pathways" - context['object_type'] = 'pathway' - context['meta']['current_package'] = context['meta']['user'].default_package - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Pathway': s.SERVER_URL + '/pathway'}, + context["object_type"] = "pathway" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Pathway": s.SERVER_URL + "/pathway"}, ] reviewed_pathway_qs = Pathway.objects.none() for p in PackageManager.get_reviewed_packages(): - reviewed_pathway_qs |= Pathway.objects.filter(package=p).order_by('name') + reviewed_pathway_qs |= Pathway.objects.filter(package=p).order_by("name") - reviewed_pathway_qs = reviewed_pathway_qs.order_by('name') + reviewed_pathway_qs = reviewed_pathway_qs.order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": True} - for pw in reviewed_pathway_qs - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": True} + for pw in reviewed_pathway_qs + ] + } + ) - context['reviewed_objects'] = reviewed_pathway_qs - return render(request, 'collections/objects_list.html', context) + context["reviewed_objects"] = reviewed_pathway_qs + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': + elif request.method == "POST": # delegate to default package current_user = _anonymous_or_real(request) default_package = current_user.default_package return package_pathways(request, default_package.uuid) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def scenarios(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Scenarios' + context["title"] = "enviPath - Scenarios" - context['object_type'] = 'scenario' - context['meta']['current_package'] = context['meta']['user'].default_package - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Scenario': s.SERVER_URL + '/scenario'}, + context["object_type"] = "scenario" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Scenario": s.SERVER_URL + "/scenario"}, ] reviewed_scenario_qs = Scenario.objects.none() for p in PackageManager.get_reviewed_packages(): - reviewed_scenario_qs |= Scenario.objects.filter(package=p).order_by('name') + reviewed_scenario_qs |= Scenario.objects.filter(package=p).order_by("name") - reviewed_scenario_qs = reviewed_scenario_qs.order_by('name') + reviewed_scenario_qs = reviewed_scenario_qs.order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": s.name, "url": s.url, "reviewed": True} for s in reviewed_scenario_qs - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": s.name, "url": s.url, "reviewed": True} + for s in reviewed_scenario_qs + ] + } + ) - context['reviewed_objects'] = reviewed_scenario_qs - return render(request, 'collections/objects_list.html', context) + context["reviewed_objects"] = reviewed_scenario_qs + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': + elif request.method == "POST": # delegate to default package default_package = request.user.default_package return package_scenarios(request, default_package.uuid) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def models(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = 'enviPath - Models' + context["title"] = "enviPath - Models" - context['object_type'] = 'model' - context['meta']['current_package'] = context['meta']['user'].default_package - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Model': s.SERVER_URL + '/model'}, + context["object_type"] = "model" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Model": s.SERVER_URL + "/model"}, ] - context['model_types'] = { - 'ML Relative Reasoning': 'ml-relative-reasoning', - 'Rule Based Relative Reasoning': 'rule-based-relative-reasoning', - 'EnviFormer': 'enviformer', + context["model_types"] = { + "ML Relative Reasoning": "ml-relative-reasoning", + "Rule Based Relative Reasoning": "rule-based-relative-reasoning", + "EnviFormer": "enviformer", } for k, v in s.CLASSIFIER_PLUGINS.items(): - context['model_types'][v.display()] = k + context["model_types"][v.display()] = k reviewed_model_qs = EPModel.objects.none() for p in PackageManager.get_reviewed_packages(): - reviewed_model_qs |= EPModel.objects.filter(package=p).order_by('name') + reviewed_model_qs |= EPModel.objects.filter(package=p).order_by("name") - reviewed_model_qs = reviewed_model_qs.order_by('name') + reviewed_model_qs = reviewed_model_qs.order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": True} - for pw in reviewed_model_qs - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": True} + for pw in reviewed_model_qs + ] + } + ) - context['reviewed_objects'] = reviewed_model_qs - return render(request, 'collections/objects_list.html', context) + context["reviewed_objects"] = reviewed_model_qs + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': + elif request.method == "POST": current_user = _anonymous_or_real(request) default_package = current_user.default_package return package_models(request, default_package.uuid) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def search(request): current_user = _anonymous_or_real(request) - if request.method == 'GET': - package_urls = request.GET.getlist('packages') - searchterm = request.GET.get('search') - mode = request.GET.get('mode') + if request.method == "GET": + package_urls = request.GET.getlist("packages") + searchterm = request.GET.get("search") + mode = request.GET.get("mode") # add HTTP_ACCEPT check to differentiate between index and ajax call - if 'application/json' in request.META.get('HTTP_ACCEPT') and all([searchterm, mode]): + if "application/json" in request.META.get("HTTP_ACCEPT") and all([searchterm, mode]): if package_urls: - packages = [PackageManager.get_package_by_url(current_user, p) for p in package_urls] + packages = [ + PackageManager.get_package_by_url(current_user, p) for p in package_urls + ] else: packages = PackageManager.get_reviewed_packages() @@ -603,34 +676,36 @@ def search(request): return JsonResponse(search_result, safe=False) context = get_base_context(request) - context['title'] = 'enviPath - Search' + context["title"] = "enviPath - Search" - context['object_type'] = 'model' - context['meta']['current_package'] = context['meta']['user'].default_package - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Search': s.SERVER_URL + '/search'}, + context["object_type"] = "model" + context["meta"]["current_package"] = context["meta"]["user"].default_package + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Search": s.SERVER_URL + "/search"}, ] reviewed_package_qs = PackageManager.get_reviewed_packages() unreviewed_package_qs = PackageManager.get_all_readable_packages(current_user) - context['reviewed_objects'] = reviewed_package_qs - context['unreviewed_objects'] = unreviewed_package_qs + context["reviewed_objects"] = reviewed_package_qs + context["unreviewed_objects"] = unreviewed_package_qs if all([searchterm, mode]): if package_urls: - packages = [PackageManager.get_package_by_url(current_user, p) for p in package_urls] + packages = [ + PackageManager.get_package_by_url(current_user, p) for p in package_urls + ] else: packages = PackageManager.get_reviewed_packages() - context['search_result'] = SearchManager.search(packages, searchterm, mode) - context['search_result']['searchterm'] = searchterm + context["search_result"] = SearchManager.search(packages, searchterm, mode) + context["search_result"]["searchterm"] = searchterm - return render(request, 'search.html', context) + return render(request, "search.html", context) else: - return HttpResponseNotAllowed(['GET']) + return HttpResponseNotAllowed(["GET"]) @package_permission_required() @@ -638,111 +713,124 @@ def package_models(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - Models' + context["title"] = f"enviPath - {current_package.name} - Models" - context['meta']['current_package'] = current_package - context['object_type'] = 'model' - context['breadcrumbs'] = breadcrumbs(current_package, 'model') + context["meta"]["current_package"] = current_package + context["object_type"] = "model" + context["breadcrumbs"] = breadcrumbs(current_package, "model") reviewed_model_qs = EPModel.objects.none() unreviewed_model_qs = EPModel.objects.none() if current_package.reviewed: - reviewed_model_qs = EPModel.objects.filter(package=current_package).order_by('name') + reviewed_model_qs = EPModel.objects.filter(package=current_package).order_by("name") else: - unreviewed_model_qs = EPModel.objects.filter(package=current_package).order_by('name') + unreviewed_model_qs = EPModel.objects.filter(package=current_package).order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_model_qs if current_package.reviewed else unreviewed_model_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_model_qs if current_package.reviewed else unreviewed_model_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_model_qs - context['unreviewed_objects'] = unreviewed_model_qs + context["reviewed_objects"] = reviewed_model_qs + context["unreviewed_objects"] = unreviewed_model_qs - context['model_types'] = { - 'ML Relative Reasoning': 'ml-relative-reasoning', - 'Rule Based Relative Reasoning': 'rule-based-relative-reasoning', + context["model_types"] = { + "ML Relative Reasoning": "ml-relative-reasoning", + "Rule Based Relative Reasoning": "rule-based-relative-reasoning", } - if s.FLAGS.get('ENVIFORMER', False): - context['model_types']['EnviFormer'] = 'enviformer' + if s.FLAGS.get("ENVIFORMER", False): + context["model_types"]["EnviFormer"] = "enviformer" - if s.FLAGS.get('PLUGINS', False): + if s.FLAGS.get("PLUGINS", False): for k, v in s.CLASSIFIER_PLUGINS.items(): - context['model_types'][v.display()] = k + context["model_types"][v.display()] = k - return render(request, 'collections/objects_list.html', context) - - elif request.method == 'POST': + return render(request, "collections/objects_list.html", context) + elif request.method == "POST": log_post_params(request) - name = request.POST.get('model-name') - description = request.POST.get('model-description') + name = request.POST.get("model-name") + description = request.POST.get("model-description") - model_type = request.POST.get('model-type') + model_type = request.POST.get("model-type") - if model_type == 'enviformer': - threshold = float(request.POST.get(f'{model_type}-threshold', 0.5)) + if model_type == "enviformer": + threshold = float(request.POST.get(f"{model_type}-threshold", 0.5)) mod = EnviFormer.create(current_package, name, description, threshold) - elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning': + elif model_type == "ml-relative-reasoning" or model_type == "rule-based-relative-reasoning": # Generic fields for ML and Rule Based - rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages') - data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages') - eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', []) + rule_packages = request.POST.getlist("package-based-relative-reasoning-rule-packages") + data_packages = request.POST.getlist("package-based-relative-reasoning-data-packages") + eval_packages = request.POST.getlist( + "package-based-relative-reasoning-evaluation-packages", [] + ) # Generic params params = { - 'package' : current_package, - 'name' : name, - 'description' : description, - 'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages], - 'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages], - 'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages], + "package": current_package, + "name": name, + "description": description, + "rule_packages": [ + PackageManager.get_package_by_url(current_user, p) for p in rule_packages + ], + "data_packages": [ + PackageManager.get_package_by_url(current_user, p) for p in data_packages + ], + "eval_packages": [ + PackageManager.get_package_by_url(current_user, p) for p in eval_packages + ], } - if model_type == 'ml-relative-reasoning': + if model_type == "ml-relative-reasoning": # ML Specific - threshold = float(request.POST.get(f'{model_type}-threshold', 0.5)) - fingerprinter = request.POST.get(f'{model_type}-fingerprinter') + threshold = float(request.POST.get(f"{model_type}-threshold", 0.5)) + # TODO handle additional fingerprinter + # fingerprinter = request.POST.get(f"{model_type}-fingerprinter") # App Domain related parameters - build_ad = request.POST.get('build-app-domain', False) == 'on' - num_neighbors = request.POST.get('num-neighbors', 5) - reliability_threshold = request.POST.get('reliability-threshold', 0.5) - local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5) + build_ad = request.POST.get("build-app-domain", False) == "on" + num_neighbors = request.POST.get("num-neighbors", 5) + reliability_threshold = request.POST.get("reliability-threshold", 0.5) + local_compatibility_threshold = request.POST.get( + "local-compatibility-threshold", 0.5 + ) - params['threshold'] = threshold + params["threshold"] = threshold # params['fingerprinter'] = fingerprinter - params['build_app_domain'] = build_ad - params['app_domain_num_neighbours'] = num_neighbors - params['app_domain_reliability_threshold'] = reliability_threshold - params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold + params["build_app_domain"] = build_ad + params["app_domain_num_neighbours"] = num_neighbors + params["app_domain_reliability_threshold"] = reliability_threshold + params["app_domain_local_compatibility_threshold"] = local_compatibility_threshold - mod = MLRelativeReasoning.create( - **params - ) + mod = MLRelativeReasoning.create(**params) else: - mod = RuleBasedRelativeReasoning.create( - **params - ) + mod = RuleBasedRelativeReasoning.create(**params) from .tasks import build_model + build_model.delay(mod.pk) else: - return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."') + return error( + request, "Invalid model type.", f'Model type "{model_type}" is not supported."' + ) return redirect(mod.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -751,21 +839,21 @@ def package_model(request, package_uuid, model_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_model = EPModel.objects.get(package=current_package, uuid=model_uuid) - if request.method == 'GET': - classify = request.GET.get('classify', False) - ad_assessment = request.GET.get('app-domain-assessment', False) + if request.method == "GET": + classify = request.GET.get("classify", False) + ad_assessment = request.GET.get("app-domain-assessment", False) if classify or ad_assessment: - smiles = request.GET.get('smiles', '').strip() + smiles = request.GET.get("smiles", "").strip() # Check if smiles is non empty and valid - if smiles == '': - return JsonResponse({'error': 'Received empty SMILES'}, status=400) + if smiles == "": + return JsonResponse({"error": "Received empty SMILES"}, status=400) try: stand_smiles = FormatConverter.standardize(smiles) - except ValueError as e: - return JsonResponse({'error': f'"{smiles}" is not a valid SMILES'}, status=400) + except ValueError: + return JsonResponse({"error": f'"{smiles}" is not a valid SMILES'}, status=400) if classify: pred_res = current_model.predict(stand_smiles) @@ -778,11 +866,15 @@ def package_model(request, package_uuid, model_uuid): logger.debug(f"Checking {prod_set}") products.append(tuple([x for x in prod_set])) - res.append({ - 'products': list(set(products)), - 'probability': pr.probability, - 'btrule': {k: getattr(pr.rule, k) for k in ['url', 'name']} if pr.rule is not None else None - }) + res.append( + { + "products": list(set(products)), + "probability": pr.probability, + "btrule": {k: getattr(pr.rule, k) for k in ["url", "name"]} + if pr.rule is not None + else None, + } + ) return JsonResponse(res, safe=False) @@ -791,31 +883,32 @@ def package_model(request, package_uuid, model_uuid): return JsonResponse(app_domain_assessment, safe=False) context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_model.name}' + context["title"] = f"enviPath - {current_package.name} - {current_model.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'model' - context['breadcrumbs'] = breadcrumbs(current_package, 'model', current_model) + context["meta"]["current_package"] = current_package + context["object_type"] = "model" + context["breadcrumbs"] = breadcrumbs(current_package, "model", current_model) - context['model'] = current_model - context['current_object'] = current_model - - return render(request, 'objects/model.html', context) + context["model"] = current_model + context["current_object"] = current_model - elif request.method == 'POST': - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + return render(request, "objects/model.html", context) + + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_model.delete() - return redirect(current_package.url + '/model') - elif hidden == 'evaluate': + return redirect(current_package.url + "/model") + elif hidden == "evaluate": from .tasks import evaluate_model + evaluate_model.delay(current_model.pk) return redirect(current_model.url) else: return HttpResponseBadRequest() else: - name = request.POST.get('model-name', '').strip() - description = request.POST.get('model-description', '').strip() + name = request.POST.get("model-name", "").strip() + description = request.POST.get("model-description", "").strip() if any([name, description]): if name: @@ -830,7 +923,7 @@ def package_model(request, package_uuid, model_uuid): return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -838,84 +931,96 @@ def package(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': - + if request.method == "GET": if request.GET.get("export", False) == "true": filename = f"{current_package.name.replace(' ', '_')}_{current_package.uuid}.json" - pack_json = PackageManager.export_package(current_package, include_models=False, - include_external_identifiers=False) - response = JsonResponse(pack_json, content_type='application/json') - response['Content-Disposition'] = f'attachment; filename="{filename}"' + pack_json = PackageManager.export_package( + current_package, include_models=False, include_external_identifiers=False + ) + response = JsonResponse(pack_json, content_type="application/json") + response["Content-Disposition"] = f'attachment; filename="{filename}"' return response context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name}' + context["title"] = f"enviPath - {current_package.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'package' - context['breadcrumbs'] = breadcrumbs(current_package) + context["meta"]["current_package"] = current_package + context["object_type"] = "package" + context["breadcrumbs"] = breadcrumbs(current_package) - context['package'] = current_package + context["package"] = current_package user_perms = UserPackagePermission.objects.filter(package=current_package) users = get_user_model().objects.exclude( - id__in=UserPackagePermission.objects.filter(package=current_package).values_list('user_id', flat=True)) + id__in=UserPackagePermission.objects.filter(package=current_package).values_list( + "user_id", flat=True + ) + ) group_perms = GroupPackagePermission.objects.filter(package=current_package) groups = Group.objects.exclude( - id__in=GroupPackagePermission.objects.filter(package=current_package).values_list('group_id', flat=True)) + id__in=GroupPackagePermission.objects.filter(package=current_package).values_list( + "group_id", flat=True + ) + ) - context['users'] = users - context['groups'] = groups - context['user_permissions'] = user_perms - context['group_permissions'] = group_perms + context["users"] = users + context["groups"] = groups + context["user_permissions"] = user_perms + context["group_permissions"] = group_perms - return render(request, 'objects/package.html', context) - - elif request.method == 'POST': + return render(request, "objects/package.html", context) + elif request.method == "POST": log_post_params(request) - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': - + if hidden := request.POST.get("hidden", None): + if hidden == "delete": if current_user.default_package == current_package: - return error(request, f'Package "{current_package.name}" is the default and cannot be deleted!', - 'You cannot delete the default package. If you want to delete this package you have to set another default package first.') + return error( + request, + f'Package "{current_package.name}" is the default and cannot be deleted!', + "You cannot delete the default package. If you want to delete this package you have to set another default package first.", + ) logger.debug(current_package.delete()) - return redirect(s.SERVER_URL + '/package') - elif hidden == 'publish-package': + return redirect(s.SERVER_URL + "/package") + elif hidden == "publish-package": for g in Group.objects.filter(public=True): - PackageManager.update_permissions(current_user, current_package, g, Permission.READ[0]) + PackageManager.update_permissions( + current_user, current_package, g, Permission.READ[0] + ) return redirect(current_package.url) - elif hidden == 'copy': - object_to_copy = request.POST.get('object_to_copy') + elif hidden == "copy": + object_to_copy = request.POST.get("object_to_copy") if not object_to_copy: - return error(request, 'No object to copy', 'There was no object to copy.') + return error(request, "No object to copy", "There was no object to copy.") try: copied_object = copy_object(current_user, current_package, object_to_copy) - except ValueError as e: - return JsonResponse({'error': f"Can't copy object {object_to_copy} to the same package!"}, status=400) + except ValueError: + return JsonResponse( + {"error": f"Can't copy object {object_to_copy} to the same package!"}, + status=400, + ) - return JsonResponse({'success': copied_object.url}) + return JsonResponse({"success": copied_object.url}) else: return HttpResponseBadRequest() - new_package_name = request.POST.get('package-name') - new_package_description = request.POST.get('package-description') + new_package_name = request.POST.get("package-name") + new_package_description = request.POST.get("package-description") - grantee_url = request.POST.get('grantee') - read = request.POST.get('read') == 'on' - write = request.POST.get('write') == 'on' - owner = request.POST.get('owner') == 'on' + grantee_url = request.POST.get("grantee") + read = request.POST.get("read") == "on" + write = request.POST.get("write") == "on" + owner = request.POST.get("owner") == "on" - license = request.POST.get('license') - license_link = request.POST.get('license-link') - license_image_link = request.POST.get('license-image-link') + license = request.POST.get("license") + license_link = request.POST.get("license-link") + license_image_link = request.POST.get("license-image-link") if new_package_name: current_package.name = new_package_name @@ -928,7 +1033,7 @@ def package(request, package_uuid): return redirect(current_package.url) elif any([grantee_url, read, write, owner]): - if 'user' in grantee_url: + if "user" in grantee_url: grantee = UserManager.get_user_lp(grantee_url) else: grantee = GroupManager.get_group_lp(grantee_url) @@ -944,7 +1049,7 @@ def package(request, package_uuid): PackageManager.update_permissions(current_user, current_package, grantee, max_perm) return redirect(current_package.url) elif license is not None: - if license == 'no-license': + if license == "no-license": if current_package.license is not None: current_package.license.delete() @@ -955,12 +1060,12 @@ def package(request, package_uuid): if current_package.license is not None: current_package.license.delete() - l = License() - l.link = license_link - l.image_link = license_image_link - l.save() + license = License() + license.link = license_link + license.image_link = license_image_link + license.save() - current_package.license = l + current_package.license = license current_package.save() return redirect(current_package.url) @@ -968,7 +1073,7 @@ def package(request, package_uuid): return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -976,46 +1081,54 @@ def package_compounds(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - Compounds' + context["title"] = f"enviPath - {current_package.name} - Compounds" - context['meta']['current_package'] = current_package - context['object_type'] = 'compound' - context['breadcrumbs'] = breadcrumbs(current_package, 'compound') + context["meta"]["current_package"] = current_package + context["object_type"] = "compound" + context["breadcrumbs"] = breadcrumbs(current_package, "compound") reviewed_compound_qs = Compound.objects.none() unreviewed_compound_qs = Compound.objects.none() if current_package.reviewed: - reviewed_compound_qs = Compound.objects.filter(package=current_package).order_by('name') + reviewed_compound_qs = Compound.objects.filter(package=current_package).order_by("name") else: - unreviewed_compound_qs = Compound.objects.filter(package=current_package).order_by('name') + unreviewed_compound_qs = Compound.objects.filter(package=current_package).order_by( + "name" + ) - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_compound_qs if current_package.reviewed else unreviewed_compound_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_compound_qs + if current_package.reviewed + else unreviewed_compound_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_compound_qs - context['unreviewed_objects'] = unreviewed_compound_qs + context["reviewed_objects"] = reviewed_compound_qs + context["unreviewed_objects"] = unreviewed_compound_qs - return render(request, 'collections/objects_list.html', context) + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': - compound_name = request.POST.get('compound-name') - compound_smiles = request.POST.get('compound-smiles') - compound_description = request.POST.get('compound-description') + elif request.method == "POST": + compound_name = request.POST.get("compound-name") + compound_smiles = request.POST.get("compound-smiles") + compound_description = request.POST.get("compound-description") c = Compound.create(current_package, compound_smiles, compound_name, compound_description) return redirect(c.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1024,35 +1137,35 @@ def package_compound(request, package_uuid, compound_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_compound.name}' + context["title"] = f"enviPath - {current_package.name} - {current_compound.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'compound' - context['breadcrumbs'] = breadcrumbs(current_package, 'compound', current_compound) + context["meta"]["current_package"] = current_package + context["object_type"] = "compound" + context["breadcrumbs"] = breadcrumbs(current_package, "compound", current_compound) - context['compound'] = current_compound - context['current_object'] = current_compound + context["compound"] = current_compound + context["current_object"] = current_compound - return render(request, 'objects/compound.html', context) + return render(request, "objects/compound.html", context) - elif request.method == 'POST': - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_compound.delete() - return redirect(current_package.url + '/compound') + return redirect(current_package.url + "/compound") else: return HttpResponseBadRequest() - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_compound, selected_scenarios) return redirect(current_compound.url) - new_compound_name = request.POST.get('compound-name', '').strip() - new_compound_description = request.POST.get('compound-description', '').strip() + new_compound_name = request.POST.get("compound-name", "").strip() + new_compound_description = request.POST.get("compound-description", "").strip() if new_compound_name: current_compound.name = new_compound_name @@ -1064,8 +1177,8 @@ def package_compound(request, package_uuid, compound_uuid): current_compound.save() return redirect(current_compound.url) - selected_database = request.POST.get('selected-database', '').strip() - external_identifier = request.POST.get('identifier', '').strip() + selected_database = request.POST.get("selected-database", "").strip() + external_identifier = request.POST.get("identifier", "").strip() if selected_database and external_identifier: db = ExternalDatabase.objects.get(id=int(selected_database)) @@ -1074,14 +1187,14 @@ def package_compound(request, package_uuid, compound_uuid): database=db, identifier_value=external_identifier, url=db.url_pattern.format(id=external_identifier), - is_primary=False + is_primary=False, ) return redirect(current_compound.url) else: return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1090,38 +1203,42 @@ def package_compound_structures(request, package_uuid, compound_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_compound.name} - Structures' + context["title"] = ( + f"enviPath - {current_package.name} - {current_compound.name} - Structures" + ) - context['meta']['current_package'] = current_package - context['object_type'] = 'structure' - context['breadcrumbs'] = breadcrumbs(current_package, 'compound', current_compound, 'structure') + context["meta"]["current_package"] = current_package + context["object_type"] = "structure" + context["breadcrumbs"] = breadcrumbs( + current_package, "compound", current_compound, "structure" + ) reviewed_compound_structure_qs = CompoundStructure.objects.none() unreviewed_compound_structure_qs = CompoundStructure.objects.none() if current_package.reviewed: - reviewed_compound_structure_qs = current_compound.structures.order_by('name') + reviewed_compound_structure_qs = current_compound.structures.order_by("name") else: - unreviewed_compound_structure_qs = current_compound.structures.order_by('name') + unreviewed_compound_structure_qs = current_compound.structures.order_by("name") - context['reviewed_objects'] = reviewed_compound_structure_qs - context['unreviewed_objects'] = unreviewed_compound_structure_qs + context["reviewed_objects"] = reviewed_compound_structure_qs + context["unreviewed_objects"] = unreviewed_compound_structure_qs - return render(request, 'collections/objects_list.html', context) + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': - structure_name = request.POST.get('structure-name') - structure_smiles = request.POST.get('structure-smiles') - structure_description = request.POST.get('structure-description') + elif request.method == "POST": + structure_name = request.POST.get("structure-name") + structure_smiles = request.POST.get("structure-smiles") + structure_description = request.POST.get("structure-description") cs = current_compound.add_structure(structure_smiles, structure_name, structure_description) return redirect(cs.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1129,47 +1246,54 @@ def package_compound_structure(request, package_uuid, compound_uuid, structure_u current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_compound = Compound.objects.get(package=current_package, uuid=compound_uuid) - current_structure = CompoundStructure.objects.get(compound=current_compound, uuid=structure_uuid) + current_structure = CompoundStructure.objects.get( + compound=current_compound, uuid=structure_uuid + ) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_compound.name} - {current_structure.name}' + context["title"] = ( + f"enviPath - {current_package.name} - {current_compound.name} - {current_structure.name}" + ) - context['meta']['current_package'] = current_package - context['object_type'] = 'structure' + context["meta"]["current_package"] = current_package + context["object_type"] = "structure" - context['compound_structure'] = current_structure - context['current_object'] = current_structure - context['breadcrumbs'] = breadcrumbs(current_package, 'compound', current_compound, 'structure', current_structure) + context["compound_structure"] = current_structure + context["current_object"] = current_structure + context["breadcrumbs"] = breadcrumbs( + current_package, "compound", current_compound, "structure", current_structure + ) - return render(request, 'objects/compound_structure.html', context) + return render(request, "objects/compound_structure.html", context) - elif request.method == 'POST': - - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden == "delete": # Check if we have to delete the compound as no structure is left if len(current_structure.compound.structures.all()) == 1: # This will delete the structure as well current_compound.delete() - return redirect(current_package.url + '/compound') + return redirect(current_package.url + "/compound") else: if current_structure.normalized_structure: current_compound.delete() - return redirect(current_package.url + '/compound') + return redirect(current_package.url + "/compound") else: if current_compound.default_structure == current_structure: current_structure.delete() - current_compound.default_structure = current_compound.structures.all().first() - return redirect(current_compound.url + '/structure') + current_compound.default_structure = ( + current_compound.structures.all().first() + ) + return redirect(current_compound.url + "/structure") else: current_structure.delete() - return redirect(current_compound.url + '/structure') + return redirect(current_compound.url + "/structure") else: return HttpResponseBadRequest() - new_structure_name = request.POST.get('compound-structure-name', '').strip() - new_structure_description = request.POST.get('compound-structure-description', '').strip() + new_structure_name = request.POST.get("compound-structure-name", "").strip() + new_structure_description = request.POST.get("compound-structure-description", "").strip() if new_structure_name: current_structure.name = new_structure_name @@ -1181,14 +1305,14 @@ def package_compound_structure(request, package_uuid, compound_uuid, structure_u current_structure.save() return redirect(current_structure.url) - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_structure, selected_scenarios) return redirect(current_structure.url) - selected_database = request.POST.get('selected-database', '').strip() - external_identifier = request.POST.get('identifier', '').strip() + selected_database = request.POST.get("selected-database", "").strip() + external_identifier = request.POST.get("identifier", "").strip() if selected_database and external_identifier: db = ExternalDatabase.objects.get(id=int(selected_database)) @@ -1197,13 +1321,17 @@ def package_compound_structure(request, package_uuid, compound_uuid, structure_u database=db, identifier_value=external_identifier, url=db.url_pattern.format(id=external_identifier), - is_primary=False + is_primary=False, ) return redirect(current_structure.url) return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', ]) + return HttpResponseNotAllowed( + [ + "GET", + ] + ) @package_permission_required() @@ -1211,67 +1339,75 @@ def package_rules(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - Rules' + context["title"] = f"enviPath - {current_package.name} - Rules" - context['meta']['current_package'] = current_package - context['object_type'] = 'rule' - context['breadcrumbs'] = breadcrumbs(current_package, 'rule') + context["meta"]["current_package"] = current_package + context["object_type"] = "rule" + context["breadcrumbs"] = breadcrumbs(current_package, "rule") reviewed_rule_qs = Rule.objects.none() unreviewed_rule_qs = Rule.objects.none() if current_package.reviewed: - reviewed_rule_qs = Rule.objects.filter(package=current_package).order_by('name') + reviewed_rule_qs = Rule.objects.filter(package=current_package).order_by("name") else: - unreviewed_rule_qs = Rule.objects.filter(package=current_package).order_by('name') + unreviewed_rule_qs = Rule.objects.filter(package=current_package).order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_rule_qs if current_package.reviewed else unreviewed_rule_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_rule_qs if current_package.reviewed else unreviewed_rule_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_rule_qs - context['unreviewed_objects'] = unreviewed_rule_qs + context["reviewed_objects"] = reviewed_rule_qs + context["unreviewed_objects"] = unreviewed_rule_qs - return render(request, 'collections/objects_list.html', context) - - elif request.method == 'POST': + return render(request, "collections/objects_list.html", context) + elif request.method == "POST": log_post_params(request) # Generic params - rule_name = request.POST.get('rule-name') - rule_description = request.POST.get('rule-description') + rule_name = request.POST.get("rule-name") + rule_description = request.POST.get("rule-description") - rule_type = request.POST.get('rule-type') + rule_type = request.POST.get("rule-type") params = {} # Obtain parameters as required by rule type - if rule_type == 'SimpleAmbitRule': - params['smirks'] = request.POST.get('rule-smirks') - params['reactant_filter_smarts'] = request.POST.get('rule-reactant-smarts') - params['product_filter_smarts'] = request.POST.get('rule-product-smarts') - elif rule_type == 'SimpleRDKitRule': - params['reaction_smarts'] = request.POST.get('rule-reaction-smarts') - elif rule_type == 'ParallelRule': + if rule_type == "SimpleAmbitRule": + params["smirks"] = request.POST.get("rule-smirks") + params["reactant_filter_smarts"] = request.POST.get("rule-reactant-smarts") + params["product_filter_smarts"] = request.POST.get("rule-product-smarts") + elif rule_type == "SimpleRDKitRule": + params["reaction_smarts"] = request.POST.get("rule-reaction-smarts") + elif rule_type == "ParallelRule": pass - elif rule_type == 'SequentialRule': + elif rule_type == "SequentialRule": pass else: return HttpResponseBadRequest() - r = Rule.create(rule_type=rule_type, package=current_package, name=rule_name, description=rule_description, - **params) + r = Rule.create( + rule_type=rule_type, + package=current_package, + name=rule_name, + description=rule_description, + **params, + ) return redirect(r.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1280,14 +1416,16 @@ def package_rule(request, package_uuid, rule_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_rule = Rule.objects.get(package=current_package, uuid=rule_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - if smiles := request.GET.get('smiles', False): + if smiles := request.GET.get("smiles", False): stand_smiles = FormatConverter.standardize(smiles) res = current_rule.apply(stand_smiles) if len(res) > 1: - logger.info(f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one.") + logger.info( + f"Rule {current_rule.uuid} returned multiple product sets on {smiles}, picking the first one." + ) smirks = f"{stand_smiles}>>{'.'.join(sorted(res[0]))}" # Usually the functional groups are a mapping of fg -> count @@ -1295,41 +1433,47 @@ def package_rule(request, package_uuid, rule_uuid): educt_functional_groups = {x: 1000 for x in current_rule.reactants_smarts} product_functional_groups = {x: 1000 for x in current_rule.products_smarts} return HttpResponse( - IndigoUtils.smirks_to_svg(smirks, False, 0, 0, - educt_functional_groups=educt_functional_groups, - product_functional_groups=product_functional_groups), - content_type='image/svg+xml') + IndigoUtils.smirks_to_svg( + smirks, + False, + 0, + 0, + educt_functional_groups=educt_functional_groups, + product_functional_groups=product_functional_groups, + ), + content_type="image/svg+xml", + ) - context['title'] = f'enviPath - {current_package.name} - {current_rule.name}' + context["title"] = f"enviPath - {current_package.name} - {current_rule.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'rule' - context['breadcrumbs'] = breadcrumbs(current_package, 'rule', current_rule) + context["meta"]["current_package"] = current_package + context["object_type"] = "rule" + context["breadcrumbs"] = breadcrumbs(current_package, "rule", current_rule) - context['rule'] = current_rule - context['current_object'] = current_rule + context["rule"] = current_rule + context["current_object"] = current_rule if isinstance(current_rule, SimpleAmbitRule): - return render(request, 'objects/simple_rule.html', context) + return render(request, "objects/simple_rule.html", context) else: # isinstance(current_rule, ParallelRule) or isinstance(current_rule, SequentialRule): - return render(request, 'objects/composite_rule.html', context) + return render(request, "objects/composite_rule.html", context) - elif request.method == 'POST': - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_rule.delete() - return redirect(current_package.url + '/rule') + return redirect(current_package.url + "/rule") else: return HttpResponseBadRequest() - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_rule, selected_scenarios) return redirect(current_rule.url) - rule_name = request.POST.get('rule-name', '').strip() - rule_description = request.POST.get('rule-description', '').strip() + rule_name = request.POST.get("rule-name", "").strip() + rule_description = request.POST.get("rule-description", "").strip() if rule_name: current_rule.name = rule_name @@ -1344,7 +1488,7 @@ def package_rule(request, package_uuid, rule_uuid): return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1352,50 +1496,63 @@ def package_reactions(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_package.name} - Reactions' + context["title"] = f"enviPath - {current_package.name} - {current_package.name} - Reactions" - context['meta']['current_package'] = current_package - context['object_type'] = 'reaction' - context['breadcrumbs'] = breadcrumbs(current_package, 'reaction') + context["meta"]["current_package"] = current_package + context["object_type"] = "reaction" + context["breadcrumbs"] = breadcrumbs(current_package, "reaction") reviewed_reaction_qs = Reaction.objects.none() unreviewed_reaction_qs = Reaction.objects.none() if current_package.reviewed: - reviewed_reaction_qs = Reaction.objects.filter(package=current_package).order_by('name') + reviewed_reaction_qs = Reaction.objects.filter(package=current_package).order_by("name") else: - unreviewed_reaction_qs = Reaction.objects.filter(package=current_package).order_by('name') + unreviewed_reaction_qs = Reaction.objects.filter(package=current_package).order_by( + "name" + ) - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_reaction_qs if current_package.reviewed else unreviewed_reaction_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_reaction_qs + if current_package.reviewed + else unreviewed_reaction_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_reaction_qs - context['unreviewed_objects'] = unreviewed_reaction_qs + context["reviewed_objects"] = reviewed_reaction_qs + context["unreviewed_objects"] = unreviewed_reaction_qs - return render(request, 'collections/objects_list.html', context) + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': - reaction_name = request.POST.get('reaction-name') - reaction_description = request.POST.get('reaction-description') - reactions_smirks = request.POST.get('reaction-smirks') + elif request.method == "POST": + reaction_name = request.POST.get("reaction-name") + reaction_description = request.POST.get("reaction-description") + reactions_smirks = request.POST.get("reaction-smirks") - educts = reactions_smirks.split('>>')[0].split('.') - products = reactions_smirks.split('>>')[1].split('.') + educts = reactions_smirks.split(">>")[0].split(".") + products = reactions_smirks.split(">>")[1].split(".") - r = Reaction.create(current_package, name=reaction_name, description=reaction_description, educts=educts, - products=products) + r = Reaction.create( + current_package, + name=reaction_name, + description=reaction_description, + educts=educts, + products=products, + ) return redirect(r.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1404,35 +1561,35 @@ def package_reaction(request, package_uuid, reaction_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_reaction = Reaction.objects.get(package=current_package, uuid=reaction_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_reaction.name}' + context["title"] = f"enviPath - {current_package.name} - {current_reaction.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'reaction' - context['breadcrumbs'] = breadcrumbs(current_package, 'reaction', current_reaction) + context["meta"]["current_package"] = current_package + context["object_type"] = "reaction" + context["breadcrumbs"] = breadcrumbs(current_package, "reaction", current_reaction) - context['reaction'] = current_reaction - context['current_object'] = current_reaction + context["reaction"] = current_reaction + context["current_object"] = current_reaction - return render(request, 'objects/reaction.html', context) + return render(request, "objects/reaction.html", context) - elif request.method == 'POST': - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_reaction.delete() - return redirect(current_package.url + '/reaction') + return redirect(current_package.url + "/reaction") else: return HttpResponseBadRequest() - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_reaction, selected_scenarios) return redirect(current_reaction.url) - new_reaction_name = request.POST.get('reaction-name', '').strip() - new_reaction_description = request.POST.get('reaction-description', '').strip() + new_reaction_name = request.POST.get("reaction-name", "").strip() + new_reaction_description = request.POST.get("reaction-description", "").strip() if new_reaction_name: current_reaction.name = new_reaction_name @@ -1444,8 +1601,8 @@ def package_reaction(request, package_uuid, reaction_uuid): current_reaction.save() return redirect(current_reaction.url) - selected_database = request.POST.get('selected-database', '').strip() - external_identifier = request.POST.get('identifier', '').strip() + selected_database = request.POST.get("selected-database", "").strip() + external_identifier = request.POST.get("identifier", "").strip() if selected_database and external_identifier: db = ExternalDatabase.objects.get(id=int(selected_database)) @@ -1454,14 +1611,14 @@ def package_reaction(request, package_uuid, reaction_uuid): database=db, identifier_value=external_identifier, url=db.url_pattern.format(id=external_identifier), - is_primary=False + is_primary=False, ) return redirect(current_reaction.url) return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1469,62 +1626,76 @@ def package_pathways(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - Pathways' + context["title"] = f"enviPath - {current_package.name} - Pathways" - context['meta']['current_package'] = current_package - context['object_type'] = 'pathway' - context['breadcrumbs'] = breadcrumbs(current_package, 'pathway') + context["meta"]["current_package"] = current_package + context["object_type"] = "pathway" + context["breadcrumbs"] = breadcrumbs(current_package, "pathway") reviewed_pathway_qs = Pathway.objects.none() unreviewed_pathway_qs = Pathway.objects.none() if current_package.reviewed: - reviewed_pathway_qs = Pathway.objects.filter(package=current_package).order_by('name') + reviewed_pathway_qs = Pathway.objects.filter(package=current_package).order_by("name") else: - unreviewed_pathway_qs = Pathway.objects.filter(package=current_package).order_by('name') + unreviewed_pathway_qs = Pathway.objects.filter(package=current_package).order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_pathway_qs if current_package.reviewed else unreviewed_pathway_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_pathway_qs + if current_package.reviewed + else unreviewed_pathway_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_pathway_qs - context['unreviewed_objects'] = unreviewed_pathway_qs + context["reviewed_objects"] = reviewed_pathway_qs + context["unreviewed_objects"] = unreviewed_pathway_qs - return render(request, 'collections/objects_list.html', context) - - elif request.method == 'POST': + return render(request, "collections/objects_list.html", context) + elif request.method == "POST": log_post_params(request) - name = request.POST.get('name') - description = request.POST.get('description') - pw_mode = request.POST.get('predict', 'predict').strip() - smiles = request.POST.get('smiles', '').strip() + name = request.POST.get("name") + description = request.POST.get("description") + pw_mode = request.POST.get("predict", "predict").strip() + smiles = request.POST.get("smiles", "").strip() - if 'smiles' in request.POST and smiles == '': - return error(request, "Pathway prediction failed!", - "Pathway prediction failed due to missing or empty SMILES") + if "smiles" in request.POST and smiles == "": + return error( + request, + "Pathway prediction failed!", + "Pathway prediction failed due to missing or empty SMILES", + ) smiles = smiles.strip() try: stand_smiles = FormatConverter.standardize(smiles) except ValueError: - return error(request, "Pathway prediction failed!", - f'Pathway prediction failed as standardization of SMILES "{smiles}" failed!') + return error( + request, + "Pathway prediction failed!", + f'Pathway prediction failed as standardization of SMILES "{smiles}" failed!', + ) - modes = ['predict', 'build', 'incremental'] + modes = ["predict", "build", "incremental"] if pw_mode not in modes: - return error(request, "Pathway prediction failed!", - f'Pathway prediction failed as received mode "{pw_mode}" is none of {modes}') + return error( + request, + "Pathway prediction failed!", + f'Pathway prediction failed as received mode "{pw_mode}" is none of {modes}', + ) - prediction_setting = request.POST.get('prediction-setting', None) + prediction_setting = request.POST.get("prediction-setting", None) if prediction_setting: prediction_setting = SettingManager.get_setting_by_url(current_user, prediction_setting) else: @@ -1533,27 +1704,28 @@ def package_pathways(request, package_uuid): pw = Pathway.create(current_package, stand_smiles, name=name, description=description) # set mode - pw.kv.update({'mode': pw_mode}) + pw.kv.update({"mode": pw_mode}) pw.save() - if pw_mode == 'predict' or pw_mode == 'incremental': + if pw_mode == "predict" or pw_mode == "incremental": # unlimited pred (will be handled by setting) limit = -1 # For incremental predict first level and return - if pw_mode == 'incremental': + if pw_mode == "incremental": limit = 1 pw.setting = prediction_setting pw.save() from .tasks import predict + predict.delay(pw.pk, prediction_setting.pk, limit=limit) return redirect(pw.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1562,22 +1734,25 @@ def package_pathway(request, package_uuid, pathway_uuid): current_package: Package = PackageManager.get_package_by_id(current_user, package_uuid) current_pathway: Pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) - if request.method == 'GET': - + if request.method == "GET": if request.GET.get("last_modified", False): - return JsonResponse({'modified': current_pathway.modified.strftime('%Y-%m-%d %H:%M:%S')}) + return JsonResponse( + {"modified": current_pathway.modified.strftime("%Y-%m-%d %H:%M:%S")} + ) - if request.GET.get('status', False): - return JsonResponse({ - 'status': current_pathway.status(), - 'modified': current_pathway.modified.strftime('%Y-%m-%d %H:%M:%S'), - }) + if request.GET.get("status", False): + return JsonResponse( + { + "status": current_pathway.status(), + "modified": current_pathway.modified.strftime("%Y-%m-%d %H:%M:%S"), + } + ) if request.GET.get("download", False) == "true": filename = f"{current_pathway.name.replace(' ', '_')}_{current_pathway.uuid}.csv" csv_pw = current_pathway.to_csv() - response = HttpResponse(csv_pw, content_type='text/csv') - response['Content-Disposition'] = f'attachment; filename="{filename}"' + response = HttpResponse(csv_pw, content_type="text/csv") + response["Content-Disposition"] = f'attachment; filename="{filename}"' return response @@ -1586,62 +1761,62 @@ def package_pathway(request, package_uuid, pathway_uuid): # related objects current_pathway = Pathway.objects.prefetch_related( - 'node_set', - 'node_set__out_edges', - 'node_set__default_node_label', - 'node_set__scenarios', - 'edge_set', - 'edge_set__start_nodes', - 'edge_set__end_nodes', - 'edge_set__edge_label', - 'edge_set__scenarios' + "node_set", + "node_set__out_edges", + "node_set__default_node_label", + "node_set__scenarios", + "edge_set", + "edge_set__start_nodes", + "edge_set__end_nodes", + "edge_set__edge_label", + "edge_set__scenarios", ).get(uuid=pathway_uuid) context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_pathway.name}' + context["title"] = f"enviPath - {current_package.name} - {current_pathway.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'pathway' - context['breadcrumbs'] = breadcrumbs(current_package, 'pathway', current_pathway) + context["meta"]["current_package"] = current_package + context["object_type"] = "pathway" + context["breadcrumbs"] = breadcrumbs(current_package, "pathway", current_pathway) - context['pathway'] = current_pathway - context['current_object'] = current_pathway + context["pathway"] = current_pathway + context["current_object"] = current_pathway - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Package': s.SERVER_URL + '/package'}, + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Package": s.SERVER_URL + "/package"}, {current_package.name: current_package.url}, - {'Pathway': current_package.url + '/pathway'}, + {"Pathway": current_package.url + "/pathway"}, {current_pathway.name: current_pathway.url}, ] - return render(request, 'objects/pathway.html', context) + return render(request, "objects/pathway.html", context) # return render(request, 'pathway_playground2.html', context) - elif request.method == 'POST': - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + elif request.method == "POST": + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_pathway.delete() - return redirect(current_package.url + '/pathway') + return redirect(current_package.url + "/pathway") else: return HttpResponseBadRequest() - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_pathway, selected_scenarios) return redirect(current_pathway.url) - pathway_name = request.POST.get('pathway-name') - pathway_description = request.POST.get('pathway-description') + pathway_name = request.POST.get("pathway-name") + pathway_description = request.POST.get("pathway-description") if any([pathway_name, pathway_description]): - if pathway_name is not None and pathway_name.strip() != '': + if pathway_name is not None and pathway_name.strip() != "": pathway_name = pathway_name.strip() current_pathway.name = pathway_name - if pathway_description is not None and pathway_description.strip() != '': + if pathway_description is not None and pathway_description.strip() != "": pathway_description = pathway_description.strip() current_pathway.description = pathway_description @@ -1649,20 +1824,21 @@ def package_pathway(request, package_uuid, pathway_uuid): current_pathway.save() return redirect(current_pathway.url) - node_url = request.POST.get('node') + node_url = request.POST.get("node") if node_url: n = current_pathway.get_node(node_url) from .tasks import predict + # Dont delay? predict(current_pathway.pk, current_pathway.setting.pk, node_pk=n.pk) - return JsonResponse({'success': current_pathway.url}) + return JsonResponse({"success": current_pathway.url}) return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1671,54 +1847,57 @@ def package_pathway_nodes(request, package_uuid, pathway_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_pathway.name} - Nodes' + context["title"] = f"enviPath - {current_package.name} - {current_pathway.name} - Nodes" - context['meta']['current_package'] = current_package - context['object_type'] = 'node' - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Package': s.SERVER_URL + '/package'}, + context["meta"]["current_package"] = current_package + context["object_type"] = "node" + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Package": s.SERVER_URL + "/package"}, {current_package.name: current_package.url}, - {'Pathway': current_package.url + '/pathway'}, + {"Pathway": current_package.url + "/pathway"}, {current_pathway.name: current_pathway.url}, - {'Node': current_pathway.url + '/node'}, + {"Node": current_pathway.url + "/node"}, ] reviewed_node_qs = Node.objects.none() unreviewed_node_qs = Node.objects.none() if current_package.reviewed: - reviewed_node_qs = Node.objects.filter(pathway=current_pathway).order_by('name') + reviewed_node_qs = Node.objects.filter(pathway=current_pathway).order_by("name") else: - unreviewed_node_qs = Node.objects.filter(pathway=current_pathway).order_by('name') + unreviewed_node_qs = Node.objects.filter(pathway=current_pathway).order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_node_qs if current_package.reviewed else unreviewed_node_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_node_qs if current_package.reviewed else unreviewed_node_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_node_qs - context['unreviewed_objects'] = unreviewed_node_qs + context["reviewed_objects"] = reviewed_node_qs + context["unreviewed_objects"] = unreviewed_node_qs - return render(request, 'collections/objects_list.html', context) + return render(request, "collections/objects_list.html", context) - elif request.method == 'POST': - - node_name = request.POST.get('node-name') - node_description = request.POST.get('node-description') - node_smiles = request.POST.get('node-smiles') + elif request.method == "POST": + node_name = request.POST.get("node-name") + node_description = request.POST.get("node-description") + node_smiles = request.POST.get("node-smiles") current_pathway.add_node(node_smiles, name=node_name, description=node_description) return redirect(current_pathway.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1728,43 +1907,43 @@ def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid): current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_node = Node.objects.get(pathway=current_pathway, uuid=node_uuid) - if request.method == 'GET': - is_image_request = request.GET.get('image') + if request.method == "GET": + is_image_request = request.GET.get("image") if is_image_request: - if is_image_request == 'svg': + if is_image_request == "svg": svg_data = current_node.as_svg return HttpResponse(svg_data, content_type="image/svg+xml") context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_pathway.name}' + context["title"] = f"enviPath - {current_package.name} - {current_pathway.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'pathway' + context["meta"]["current_package"] = current_package + context["object_type"] = "pathway" - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Package': s.SERVER_URL + '/package'}, + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Package": s.SERVER_URL + "/package"}, {current_package.name: current_package.url}, - {'Pathway': current_package.url + '/pathway'}, + {"Pathway": current_package.url + "/pathway"}, {current_pathway.name: current_pathway.url}, - {'Node': current_pathway.url + '/node'}, + {"Node": current_pathway.url + "/node"}, {current_node.name: current_node.url}, ] - context['node'] = current_node - context['current_object'] = current_node + context["node"] = current_node + context["current_object"] = current_node - context['app_domain_assessment_data'] = json.dumps(current_node.get_app_domain_assessment_data()) + context["app_domain_assessment_data"] = json.dumps( + current_node.get_app_domain_assessment_data() + ) - return render(request, 'objects/node.html', context) - - elif request.method == 'POST': + return render(request, "objects/node.html", context) + elif request.method == "POST": log_post_params(request) - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': - + if hidden := request.POST.get("hidden", None): + if hidden == "delete": # pre_delete signal will take care of edge deletion current_node.delete() @@ -1772,15 +1951,15 @@ def package_pathway_node(request, package_uuid, pathway_uuid, node_uuid): else: return HttpResponseBadRequest() - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_node, selected_scenarios) return redirect(current_node.url) return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1789,61 +1968,66 @@ def package_pathway_edges(request, package_uuid, pathway_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_pathway.name} - Edges' + context["title"] = f"enviPath - {current_package.name} - {current_pathway.name} - Edges" - context['meta']['current_package'] = current_package - context['object_type'] = 'edge' - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Package': s.SERVER_URL + '/package'}, + context["meta"]["current_package"] = current_package + context["object_type"] = "edge" + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Package": s.SERVER_URL + "/package"}, {current_package.name: current_package.url}, - {'Pathway': current_package.url + '/pathway'}, + {"Pathway": current_package.url + "/pathway"}, {current_pathway.name: current_pathway.url}, - {'Edge': current_pathway.url + '/edge'}, + {"Edge": current_pathway.url + "/edge"}, ] reviewed_edge_qs = Edge.objects.none() unreviewed_edge_qs = Edge.objects.none() if current_package.reviewed: - reviewed_edge_qs = Edge.objects.filter(pathway=current_pathway).order_by('name') + reviewed_edge_qs = Edge.objects.filter(pathway=current_pathway).order_by("name") else: - unreviewed_edge_qs = Edge.objects.filter(pathway=current_pathway).order_by('name') + unreviewed_edge_qs = Edge.objects.filter(pathway=current_pathway).order_by("name") - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_edge_qs if current_package.reviewed else unreviewed_edge_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_edge_qs if current_package.reviewed else unreviewed_edge_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_edge_qs - context['unreviewed_objects'] = unreviewed_edge_qs + context["reviewed_objects"] = reviewed_edge_qs + context["unreviewed_objects"] = unreviewed_edge_qs - return render(request, 'collections/objects_list.html', context) - - elif request.method == 'POST': + return render(request, "collections/objects_list.html", context) + elif request.method == "POST": log_post_params(request) - edge_name = request.POST.get('edge-name') - edge_description = request.POST.get('edge-description') - edge_substrates = request.POST.getlist('edge-substrates') - edge_products = request.POST.getlist('edge-products') + edge_name = request.POST.get("edge-name") + edge_description = request.POST.get("edge-description") + edge_substrates = request.POST.getlist("edge-substrates") + edge_products = request.POST.getlist("edge-products") substrate_nodes = [current_pathway.get_node(url) for url in edge_substrates] product_nodes = [current_pathway.get_node(url) for url in edge_products] # TODO in the future consider Rules here? - current_pathway.add_edge(substrate_nodes, product_nodes, name=edge_name, description=edge_description) + current_pathway.add_edge( + substrate_nodes, product_nodes, name=edge_name, description=edge_description + ) return redirect(current_pathway.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1853,36 +2037,38 @@ def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid): current_pathway = Pathway.objects.get(package=current_package, uuid=pathway_uuid) current_edge = Edge.objects.get(pathway=current_pathway, uuid=edge_uuid) - if request.method == 'GET': - is_image_request = request.GET.get('image') + if request.method == "GET": + is_image_request = request.GET.get("image") if is_image_request: - if is_image_request == 'svg': + if is_image_request == "svg": svg_data = current_edge.as_svg return HttpResponse(svg_data, content_type="image/svg+xml") context = get_base_context(request) - context[ - 'title'] = f'enviPath - {current_package.name} - {current_pathway.name} - {current_edge.edge_label.name}' + context["title"] = ( + f"enviPath - {current_package.name} - {current_pathway.name} - {current_edge.edge_label.name}" + ) - context['meta']['current_package'] = current_package - context['object_type'] = 'edge' - context['breadcrumbs'] = breadcrumbs(current_package, 'pathway', current_pathway, 'edge', current_edge) - context['edge'] = current_edge - context['current_object'] = current_edge + context["meta"]["current_package"] = current_package + context["object_type"] = "edge" + context["breadcrumbs"] = breadcrumbs( + current_package, "pathway", current_pathway, "edge", current_edge + ) + context["edge"] = current_edge + context["current_object"] = current_edge - return render(request, 'objects/edge.html', context) - - elif request.method == 'POST': + return render(request, "objects/edge.html", context) + elif request.method == "POST": log_post_params(request) - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_edge.delete() return redirect(current_pathway.url) - if 'selected-scenarios' in request.POST: - selected_scenarios = request.POST.getlist('selected-scenarios') + if "selected-scenarios" in request.POST: + selected_scenarios = request.POST.getlist("selected-scenarios") set_scenarios(current_user, current_edge, selected_scenarios) return redirect(current_edge.url) @@ -1890,7 +2076,7 @@ def package_pathway_edge(request, package_uuid, pathway_uuid, edge_uuid): return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) @package_permission_required() @@ -1898,92 +2084,119 @@ def package_scenarios(request, package_uuid): current_user = _anonymous_or_real(request) current_package = PackageManager.get_package_by_id(current_user, package_uuid) - if request.method == 'GET': - - if 'application/json' in request.META.get('HTTP_ACCEPT') and not request.GET.get('all', False): - scens = Scenario.objects.filter(package=current_package).order_by('name') - res = [{'name': s.name, 'url': s.url, 'uuid': s.uuid} for s in scens] + if request.method == "GET": + if "application/json" in request.META.get("HTTP_ACCEPT") and not request.GET.get( + "all", False + ): + scens = Scenario.objects.filter(package=current_package).order_by("name") + res = [{"name": s.name, "url": s.url, "uuid": s.uuid} for s in scens] return JsonResponse(res, safe=False) context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - Scenarios' + context["title"] = f"enviPath - {current_package.name} - Scenarios" - context['meta']['current_package'] = current_package - context['object_type'] = 'scenario' - context['breadcrumbs'] = breadcrumbs(current_package, 'scenario') + context["meta"]["current_package"] = current_package + context["object_type"] = "scenario" + context["breadcrumbs"] = breadcrumbs(current_package, "scenario") reviewed_scenario_qs = Scenario.objects.none() unreviewed_scenario_qs = Scenario.objects.none() if current_package.reviewed: - reviewed_scenario_qs = Scenario.objects.filter(package=current_package).order_by('name') + reviewed_scenario_qs = Scenario.objects.filter(package=current_package).order_by("name") else: - unreviewed_scenario_qs = Scenario.objects.filter(package=current_package).order_by('name') + unreviewed_scenario_qs = Scenario.objects.filter(package=current_package).order_by( + "name" + ) - if request.GET.get('all'): - return JsonResponse({ - "objects": [ - {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} - for pw in (reviewed_scenario_qs if current_package.reviewed else unreviewed_scenario_qs) - ] - }) + if request.GET.get("all"): + return JsonResponse( + { + "objects": [ + {"name": pw.name, "url": pw.url, "reviewed": current_package.reviewed} + for pw in ( + reviewed_scenario_qs + if current_package.reviewed + else unreviewed_scenario_qs + ) + ] + } + ) - context['reviewed_objects'] = reviewed_scenario_qs - context['unreviewed_objects'] = unreviewed_scenario_qs + context["reviewed_objects"] = reviewed_scenario_qs + context["unreviewed_objects"] = unreviewed_scenario_qs - from envipy_additional_information import SLUDGE_ADDITIONAL_INFORMATION, SOIL_ADDITIONAL_INFORMATION, \ - SEDIMENT_ADDITIONAL_INFORMATION - context['scenario_types'] = { - 'Soil Data': { - 'name': 'soil', - 'widgets': [HTMLGenerator.generate_html(ai, prefix=f'soil_{0}') for ai in - [x for s in SOIL_ADDITIONAL_INFORMATION.values() for x in s]] + from envipy_additional_information import ( + SLUDGE_ADDITIONAL_INFORMATION, + SOIL_ADDITIONAL_INFORMATION, + SEDIMENT_ADDITIONAL_INFORMATION, + ) + + context["scenario_types"] = { + "Soil Data": { + "name": "soil", + "widgets": [ + HTMLGenerator.generate_html(ai, prefix=f"soil_{0}") + for ai in [x for s in SOIL_ADDITIONAL_INFORMATION.values() for x in s] + ], }, - 'Sludge Data': { - 'name': 'sludge', - 'widgets': [HTMLGenerator.generate_html(ai, prefix=f'sludge_{0}') for ai in - [x for s in SLUDGE_ADDITIONAL_INFORMATION.values() for x in s]] + "Sludge Data": { + "name": "sludge", + "widgets": [ + HTMLGenerator.generate_html(ai, prefix=f"sludge_{0}") + for ai in [x for s in SLUDGE_ADDITIONAL_INFORMATION.values() for x in s] + ], + }, + "Water-Sediment System Data": { + "name": "sediment", + "widgets": [ + HTMLGenerator.generate_html(ai, prefix=f"sediment_{0}") + for ai in [x for s in SEDIMENT_ADDITIONAL_INFORMATION.values() for x in s] + ], }, - 'Water-Sediment System Data': { - 'name': 'sediment', - 'widgets': [HTMLGenerator.generate_html(ai, prefix=f'sediment_{0}') for ai in - [x for s in SEDIMENT_ADDITIONAL_INFORMATION.values() for x in s]] - } } - context['sludge_additional_information'] = SLUDGE_ADDITIONAL_INFORMATION - context['soil_additional_information'] = SOIL_ADDITIONAL_INFORMATION - context['sediment_additional_information'] = SEDIMENT_ADDITIONAL_INFORMATION - - return render(request, 'collections/objects_list.html', context) - elif request.method == 'POST': + context["sludge_additional_information"] = SLUDGE_ADDITIONAL_INFORMATION + context["soil_additional_information"] = SOIL_ADDITIONAL_INFORMATION + context["sediment_additional_information"] = SEDIMENT_ADDITIONAL_INFORMATION + return render(request, "collections/objects_list.html", context) + elif request.method == "POST": log_post_params(request) - scenario_name = request.POST.get('scenario-name') - scenario_description = request.POST.get('scenario-description') - scenario_date_year = request.POST.get('scenario-date-year') - scenario_date_month = request.POST.get('scenario-date-month') - scenario_date_day = request.POST.get('scenario-date-day') + scenario_name = request.POST.get("scenario-name") + scenario_description = request.POST.get("scenario-description") + scenario_date_year = request.POST.get("scenario-date-year") + scenario_date_month = request.POST.get("scenario-date-month") + scenario_date_day = request.POST.get("scenario-date-day") scenario_date = scenario_date_year - if scenario_date_month is not None and scenario_date_month.strip() != '': - scenario_date += f'-{int(scenario_date_month):02d}' - if scenario_date_day is not None and scenario_date_day.strip() != '': - scenario_date += f'-{int(scenario_date_day):02d}' + if scenario_date_month is not None and scenario_date_month.strip() != "": + scenario_date += f"-{int(scenario_date_month):02d}" + if scenario_date_day is not None and scenario_date_day.strip() != "": + scenario_date += f"-{int(scenario_date_day):02d}" - scenario_type = request.POST.get('scenario-type') + scenario_type = request.POST.get("scenario-type") additional_information = HTMLGenerator.build_models(request.POST.dict()) additional_information = [x for s in additional_information.values() for x in s] - s = Scenario.create(current_package, name=scenario_name, description=scenario_description, - scenario_date=scenario_date, scenario_type=scenario_type, - additional_information=additional_information) + s = Scenario.create( + current_package, + name=scenario_name, + description=scenario_description, + scenario_date=scenario_date, + scenario_type=scenario_type, + additional_information=additional_information, + ) return redirect(s.url) else: - return HttpResponseNotAllowed(['GET', ]) + return HttpResponseNotAllowed( + [ + "GET", + ] + ) @package_permission_required() @@ -1992,47 +2205,50 @@ def package_scenario(request, package_uuid, scenario_uuid): current_package = PackageManager.get_package_by_id(current_user, package_uuid) current_scenario = Scenario.objects.get(package=current_package, uuid=scenario_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_package.name} - {current_scenario.name}' + context["title"] = f"enviPath - {current_package.name} - {current_scenario.name}" - context['meta']['current_package'] = current_package - context['object_type'] = 'scenario' - context['breadcrumbs'] = breadcrumbs(current_package, 'scenario', current_scenario) + context["meta"]["current_package"] = current_package + context["object_type"] = "scenario" + context["breadcrumbs"] = breadcrumbs(current_package, "scenario", current_scenario) - context['scenario'] = current_scenario + context["scenario"] = current_scenario available_add_infs = [] for add_inf in NAME_MAPPING.values(): - available_add_infs.append({ - 'display_name': add_inf.property_name(None), - 'name': add_inf.__name__, - 'widget': HTMLGenerator.generate_html(add_inf, prefix=f'{0}') - }) - context['available_additional_information'] = available_add_infs + available_add_infs.append( + { + "display_name": add_inf.property_name(None), + "name": add_inf.__name__, + "widget": HTMLGenerator.generate_html(add_inf, prefix=f"{0}"), + } + ) + context["available_additional_information"] = available_add_infs - context['update_widgets'] = [HTMLGenerator.generate_html(ai, prefix=f'{i}') for i, ai in enumerate(current_scenario.get_additional_information())] + context["update_widgets"] = [ + HTMLGenerator.generate_html(ai, prefix=f"{i}") + for i, ai in enumerate(current_scenario.get_additional_information()) + ] - return render(request, 'objects/scenario.html', context) - - elif request.method == 'POST': + return render(request, "objects/scenario.html", context) + elif request.method == "POST": log_post_params(request) - if hidden := request.POST.get('hidden', None): - - if hidden == 'delete': + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_scenario.delete() - return redirect(current_package.url + '/scenario') - elif hidden == 'delete-additional-information': - uuid = request.POST.get('uuid') + return redirect(current_package.url + "/scenario") + elif hidden == "delete-additional-information": + uuid = request.POST.get("uuid") current_scenario.remove_additional_information(uuid) return redirect(current_scenario.url) - elif hidden == 'delete-all-additional-information': + elif hidden == "delete-all-additional-information": current_scenario.additional_information = dict() current_scenario.save() return redirect(current_scenario.url) - elif hidden == 'set-additional-information': + elif hidden == "set-additional-information": ais = HTMLGenerator.build_models(request.POST.dict()) if s.DEBUG: @@ -2040,11 +2256,13 @@ def package_scenario(request, package_uuid, scenario_uuid): current_scenario.set_additional_information(ais) return redirect(current_scenario.url) - elif hidden == 'add-additional-information': + elif hidden == "add-additional-information": ais = HTMLGenerator.build_models(request.POST.dict()) if len(ais.keys()) != 1: - raise ValueError('Only one additional information field can be added at a time.') + raise ValueError( + "Only one additional information field can be added at a time." + ) ai = list(ais.values())[0][0] @@ -2058,36 +2276,35 @@ def package_scenario(request, package_uuid, scenario_uuid): else: return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) ############## # User/Group # ############## def users(request): - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - Users' + context["title"] = "enviPath - Users" - context['object_type'] = 'user' - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'User': s.SERVER_URL + '/user'}, + context["object_type"] = "user" + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"User": s.SERVER_URL + "/user"}, ] - context['objects'] = get_user_model().objects.all() + context["objects"] = get_user_model().objects.all() - return render(request, 'collections/objects_list.html', context) + return render(request, "collections/objects_list.html", context) else: - return HttpResponseNotAllowed(['GET']) + return HttpResponseNotAllowed(["GET"]) def user(request, user_uuid): current_user = _anonymous_or_real(request) - if request.method == 'GET': - + if request.method == "GET": # Check if current user is the one matching to the url if str(current_user.uuid) != user_uuid and not current_user.is_superuser: return HttpResponseBadRequest() @@ -2095,41 +2312,42 @@ def user(request, user_uuid): requested_user = UserManager.get_user_by_id(current_user, user_uuid) context = get_base_context(request, for_user=requested_user) - context['title'] = f'enviPath - User' + context["title"] = "enviPath - User" - context['object_type'] = 'user' - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'User': s.SERVER_URL + '/user'}, - {current_user.username: requested_user.url} + context["object_type"] = "user" + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"User": s.SERVER_URL + "/user"}, + {current_user.username: requested_user.url}, ] - context['user'] = requested_user + context["user"] = requested_user model_qs = EPModel.objects.none() for p in PackageManager.get_all_readable_packages(requested_user, include_reviewed=True): model_qs |= p.models - context['models'] = model_qs + context["models"] = model_qs - context['tokens'] = APIToken.objects.filter(user=requested_user) + context["tokens"] = APIToken.objects.filter(user=requested_user) - return render(request, 'objects/user.html', context) + return render(request, "objects/user.html", context) - elif request.method == 'POST': + elif request.method == "POST": + is_hidden_method = bool(request.POST.get("hidden", False)) - is_hidden_method = bool(request.POST.get('hidden', False)) - - if is_hidden_method and request.POST['hidden'] == 'request-api-token': - name = request.POST.get('name', 'No Name') - valid_for = min(max(int(request.POST.get('valid-for', 90)), 1), 90) + if is_hidden_method and request.POST["hidden"] == "request-api-token": + name = request.POST.get("name", "No Name") + valid_for = min(max(int(request.POST.get("valid-for", 90)), 1), 90) token, raw_token = APIToken.create_token(request.user, name=name, valid_for=valid_for) - return JsonResponse({"raw_token": raw_token, 'token': {'id': token.id, 'name': token.name}}) + return JsonResponse( + {"raw_token": raw_token, "token": {"id": token.id, "name": token.name}} + ) - if is_hidden_method and request.POST['hidden'] == 'delete': - token_id = request.POST.get('token-id') + if is_hidden_method and request.POST["hidden"] == "delete": + token_id = request.POST.get("token-id") if token_id is None: return HttpResponseBadRequest("Token ID missing!") @@ -2140,122 +2358,105 @@ def user(request, user_uuid): return HttpResponse("success") - default_package = request.POST.get('default-package') - default_group = request.POST.get('default-group') - default_prediction_setting = request.POST.get('default-prediction-setting') + default_package = request.POST.get("default-package") + default_group = request.POST.get("default-group") + default_prediction_setting = request.POST.get("default-prediction-setting") if any([default_package, default_group, default_prediction_setting]): - current_user.default_package = PackageManager.get_package_by_url(current_user, default_package) + current_user.default_package = PackageManager.get_package_by_url( + current_user, default_package + ) current_user.default_group = GroupManager.get_group_by_url(current_user, default_group) - current_user.default_setting = SettingManager.get_setting_by_url(current_user, default_prediction_setting) + current_user.default_setting = SettingManager.get_setting_by_url( + current_user, default_prediction_setting + ) current_user.save() return redirect(current_user.url) - prediction_model_pk = request.POST.get('model') - prediction_threshold = request.POST.get('threshold') - prediction_max_nodes = request.POST.get('max_nodes') - prediction_max_depth = request.POST.get('max_depth') - - if all([prediction_model_pk, prediction_threshold, prediction_max_nodes, prediction_max_depth]): - # validate input.. - mod = EPModel.objects.get(id=prediction_model_pk) - if not PackageManager.readable(current_user, mod.package): - return HttpResponseBadRequest() - - threshold = float(prediction_threshold) - if threshold < 0 or threshold > 1: - return HttpResponseBadRequest() - - max_nodes = min(max(int(prediction_max_nodes), 1), 50) - max_depth = min(max(int(prediction_max_depth), 1), 8) - - setting = { - 'model': mod, - 'model_parameters': { - 'threshold': threshold - }, - 'truncator': { - 'max_nodes': max_nodes, - 'max_depth': max_depth, - } - } - return HttpResponseBadRequest() else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def groups(request): current_user = _anonymous_or_real(request) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - Groups' + context["title"] = "enviPath - Groups" - context['object_type'] = 'group' - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Group': s.SERVER_URL + '/group'}, + context["object_type"] = "group" + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Group": s.SERVER_URL + "/group"}, ] - context['objects'] = Group.objects.all() + context["objects"] = Group.objects.all() - return render(request, 'collections/objects_list.html', context) - elif request.method == 'POST': - group_name = request.POST.get('group-name') - group_description = request.POST.get('group-description', s.DEFAULT_VALUES['description']) + return render(request, "collections/objects_list.html", context) + elif request.method == "POST": + group_name = request.POST.get("group-name") + group_description = request.POST.get("group-description", s.DEFAULT_VALUES["description"]) g = GroupManager.create_group(current_user, group_name, group_description) return redirect(g.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def group(request, group_uuid): current_user = _anonymous_or_real(request) current_group = GroupManager.get_group_by_id(current_user, group_uuid) - if request.method == 'GET': + if request.method == "GET": context = get_base_context(request) - context['title'] = f'enviPath - {current_group.name}' + context["title"] = f"enviPath - {current_group.name}" - context['object_type'] = 'group' - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Group': s.SERVER_URL + '/group'}, - {current_group.name: current_group.url} + context["object_type"] = "group" + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Group": s.SERVER_URL + "/group"}, + {current_group.name: current_group.url}, ] - context['group'] = current_group + context["group"] = current_group - context['users'] = UserManager.get_users_lp().exclude(id__in=current_group.user_member.all()) - context['groups'] = GroupManager.get_groups_lp().exclude(id__in=current_group.group_member.all()).exclude(id=current_group.pk) + context["users"] = UserManager.get_users_lp().exclude( + id__in=current_group.user_member.all() + ) + context["groups"] = ( + GroupManager.get_groups_lp() + .exclude(id__in=current_group.group_member.all()) + .exclude(id=current_group.pk) + ) - context['packages'] = Package.objects.filter( - id__in=GroupPackagePermission.objects.filter(group=current_group).values('package').distinct()) + context["packages"] = Package.objects.filter( + id__in=GroupPackagePermission.objects.filter(group=current_group) + .values("package") + .distinct() + ) - return render(request, 'objects/group.html', context) - - elif request.method == 'POST': + return render(request, "objects/group.html", context) + elif request.method == "POST": log_post_params(request) - if hidden := request.POST.get('hidden', None): - if hidden == 'delete': + if hidden := request.POST.get("hidden", None): + if hidden == "delete": current_group.delete() - return redirect(s.SERVER_URL + '/group') + return redirect(s.SERVER_URL + "/group") else: return HttpResponseBadRequest() - member_url = request.POST.get('member') - action = request.POST.get('action') + member_url = request.POST.get("member") + action = request.POST.get("action") - if all([member_url, action]) and action in ['add', 'remove']: - if 'user' in member_url: + if all([member_url, action]) and action in ["add", "remove"]: + if "user" in member_url: member = UserManager.get_user_lp(member_url) else: member = GroupManager.get_group_lp(member_url) @@ -2265,60 +2466,76 @@ def group(request, group_uuid): return redirect(current_group.url) else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def settings(request): current_user = _anonymous_or_real(request) context = get_base_context(request) - if request.method == 'GET': - context['object_type'] = 'setting' + if request.method == "GET": + context["object_type"] = "setting" # Even if settings are aready in "meta", for consistency add it on root level - context['settings'] = SettingManager.get_all_settings(current_user) - context['breadcrumbs'] = [ - {'Home': s.SERVER_URL}, - {'Group': s.SERVER_URL + '/setting'}, + context["settings"] = SettingManager.get_all_settings(current_user) + context["breadcrumbs"] = [ + {"Home": s.SERVER_URL}, + {"Group": s.SERVER_URL + "/setting"}, ] return - elif request.method == 'POST': + elif request.method == "POST": if s.DEBUG: for k, v in request.POST.items(): logger.info("Parameters received:") logger.info(f"{k}\t{v}") - name = request.POST.get('prediction-setting-name') - description = request.POST.get('prediction-setting-description') - new_default = request.POST.get('prediction-setting-new-default', 'off') == 'on' + name = request.POST.get("prediction-setting-name") + description = request.POST.get("prediction-setting-description") + new_default = request.POST.get("prediction-setting-new-default", "off") == "on" - max_nodes = min(max(int(request.POST.get('prediction-setting-max-nodes', 1)), s.DEFAULT_MAX_NUMBER_OF_NODES), - s.DEFAULT_MAX_NUMBER_OF_NODES) - max_depth = min(max(int(request.POST.get('prediction-setting-max-depth', 1)), s.DEFAULT_MAX_DEPTH), - s.DEFAULT_MAX_DEPTH) + max_nodes = min( + max( + int(request.POST.get("prediction-setting-max-nodes", 1)), + s.DEFAULT_MAX_NUMBER_OF_NODES, + ), + s.DEFAULT_MAX_NUMBER_OF_NODES, + ) + max_depth = min( + max(int(request.POST.get("prediction-setting-max-depth", 1)), s.DEFAULT_MAX_DEPTH), + s.DEFAULT_MAX_DEPTH, + ) - tp_gen_method = request.POST.get('tp-generation-method') + tp_gen_method = request.POST.get("tp-generation-method") params = {} - if tp_gen_method == 'model-based-prediction-setting': - model_url = request.POST.get('model-based-prediction-setting-model') + if tp_gen_method == "model-based-prediction-setting": + model_url = request.POST.get("model-based-prediction-setting-model") - model_uuid = model_url.split('/')[-1] - params['model'] = EPModel.objects.get(uuid=model_uuid) - params['model_threshold'] = request.POST.get('model-based-prediction-setting-threshold', - s.DEFAULT_MODEL_THRESHOLD) + model_uuid = model_url.split("/")[-1] + params["model"] = EPModel.objects.get(uuid=model_uuid) + params["model_threshold"] = request.POST.get( + "model-based-prediction-setting-threshold", s.DEFAULT_MODEL_THRESHOLD + ) - if not PackageManager.readable(current_user, params['model'].package): + if not PackageManager.readable(current_user, params["model"].package): raise ValueError("") - elif tp_gen_method == 'rule-based-prediction-setting': - rule_packages = request.POST.getlist('rule-based-prediction-setting-packages') - params['rule_packages'] = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages] + elif tp_gen_method == "rule-based-prediction-setting": + rule_packages = request.POST.getlist("rule-based-prediction-setting-packages") + params["rule_packages"] = [ + PackageManager.get_package_by_url(current_user, p) for p in rule_packages + ] else: raise ValueError("") - created_setting = SettingManager.create_setting(current_user, name=name, description=description, - max_nodes=max_nodes, max_depth=max_depth, **params) + created_setting = SettingManager.create_setting( + current_user, + name=name, + description=description, + max_nodes=max_nodes, + max_depth=max_depth, + **params, + ) if new_default: current_user.default_setting = created_setting @@ -2327,7 +2544,7 @@ def settings(request): return HttpResponse("Success!") else: - return HttpResponseNotAllowed(['GET', 'POST']) + return HttpResponseNotAllowed(["GET", "POST"]) def setting(request, setting_uuid): @@ -2338,16 +2555,18 @@ def setting(request, setting_uuid): # KETCHER # ########### + def indigo(request): from indigo import Indigo - return JsonResponse({'Indigo': {'version': Indigo().version()}}) + + return JsonResponse({"Indigo": {"version": Indigo().version()}}) @csrf_exempt def aromatize(request): - if request.method == 'POST': + if request.method == "POST": data = json.loads(request.body) - mol_data = data.get('struct') + mol_data = data.get("struct") aromatized = IndigoUtils.aromatize(mol_data, False) return JsonResponse({"struct": aromatized}) else: @@ -2356,9 +2575,9 @@ def aromatize(request): @csrf_exempt def dearomatize(request): - if request.method == 'POST': + if request.method == "POST": data = json.loads(request.body) - mol_data = data.get('struct') + mol_data = data.get("struct") dearomatized = IndigoUtils.dearomatize(mol_data, False) return JsonResponse({"struct": dearomatized}) else: @@ -2367,9 +2586,9 @@ def dearomatize(request): @csrf_exempt def layout(request): - if request.method == 'POST': + if request.method == "POST": data = json.loads(request.body) - mol_data = data.get('struct') + mol_data = data.get("struct") lay = IndigoUtils.layout(mol_data) return JsonResponse({"struct": lay}) else: @@ -2380,12 +2599,14 @@ def layout(request): # Generic/Non-Persistent # ########################## def depict(request): - if smiles := request.GET.get('smiles'): - return HttpResponse(IndigoUtils.mol_to_svg(smiles), content_type='image/svg+xml') + if smiles := request.GET.get("smiles"): + return HttpResponse(IndigoUtils.mol_to_svg(smiles), content_type="image/svg+xml") - elif smirks := request.GET.get('smirks'): - query_smirks = request.GET.get('is_query_smirks', False) == 'true' - return HttpResponse(IndigoUtils.smirks_to_svg(smirks, query_smirks), content_type='image/svg+xml') + elif smirks := request.GET.get("smirks"): + query_smirks = request.GET.get("is_query_smirks", False) == "true" + return HttpResponse( + IndigoUtils.smirks_to_svg(smirks, query_smirks), content_type="image/svg+xml" + ) else: return HttpResponseBadRequest() diff --git a/tests/test_compound_model.py b/tests/test_compound_model.py index 2d499c1d..6be2a005 100644 --- a/tests/test_compound_model.py +++ b/tests/test_compound_model.py @@ -13,91 +13,80 @@ class CompoundTest(TestCase): @classmethod def setUpClass(cls): super(CompoundTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") def test_smoke(self): c = Compound.create( self.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - name='Afoxolaner', - description='No Desc' + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + name="Afoxolaner", + description="No Desc", ) - self.assertEqual(c.default_structure.smiles, - 'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F') - self.assertEqual(c.name, 'Afoxolaner') - self.assertEqual(c.description, 'No Desc') + self.assertEqual( + c.default_structure.smiles, + "C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + ) + self.assertEqual(c.name, "Afoxolaner") + self.assertEqual(c.description, "No Desc") def test_missing_smiles(self): with self.assertRaises(ValueError): - _ = Compound.create( - self.package, - smiles=None, - name='Afoxolaner', - description='No Desc' - ) + _ = Compound.create(self.package, smiles=None, name="Afoxolaner", description="No Desc") with self.assertRaises(ValueError): - _ = Compound.create( - self.package, - smiles='', - name='Afoxolaner', - description='No Desc' - ) + _ = Compound.create(self.package, smiles="", name="Afoxolaner", description="No Desc") with self.assertRaises(ValueError): - _ = Compound.create( - self.package, - smiles=' ', - name='Afoxolaner', - description='No Desc' - ) + _ = Compound.create(self.package, smiles=" ", name="Afoxolaner", description="No Desc") def test_smiles_are_trimmed(self): c = Compound.create( self.package, - smiles=' C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F ', - name='Afoxolaner', - description='No Desc' + smiles=" C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F ", + name="Afoxolaner", + description="No Desc", ) - self.assertEqual(c.default_structure.smiles, - 'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F') + self.assertEqual( + c.default_structure.smiles, + "C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + ) def test_name_and_description_optional(self): c = Compound.create( self.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", ) - self.assertEqual(c.name, 'Compound 1') - self.assertEqual(c.description, 'no description') + self.assertEqual(c.name, "Compound 1") + self.assertEqual(c.description, "no description") def test_empty_name_and_description_are_ignored(self): c = Compound.create( self.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - name='', - description='', + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + name="", + description="", ) - self.assertEqual(c.name, 'Compound 1') - self.assertEqual(c.description, 'no description') + self.assertEqual(c.name, "Compound 1") + self.assertEqual(c.description, "no description") def test_deduplication(self): c1 = Compound.create( self.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - name='Afoxolaner', - description='No Desc' + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + name="Afoxolaner", + description="No Desc", ) c2 = Compound.create( self.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - name='Afoxolaner', - description='No Desc' + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + name="Afoxolaner", + description="No Desc", ) # Check if create detects that this Compound already exist @@ -109,36 +98,36 @@ class CompoundTest(TestCase): with self.assertRaises(ValueError): _ = Compound.create( self.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - name='Afoxolaner', - description='No Desc' + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + name="Afoxolaner", + description="No Desc", ) def test_create_with_standardized_smiles(self): c = Compound.create( self.package, - smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', - name='Standardized SMILES', - description='No Desc' + smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1", + name="Standardized SMILES", + description="No Desc", ) self.assertEqual(len(c.structures.all()), 1) cs = c.structures.all()[0] self.assertEqual(cs.normalized_structure, True) - self.assertEqual(cs.smiles, 'O=C(O)C1=CC=C([N+](=O)[O-])C=C1') + self.assertEqual(cs.smiles, "O=C(O)C1=CC=C([N+](=O)[O-])C=C1") def test_create_with_non_standardized_smiles(self): c = Compound.create( self.package, - smiles='[O-][N+](=O)c1ccc(C(=O)[O-])cc1', - name='Non Standardized SMILES', - description='No Desc' + smiles="[O-][N+](=O)c1ccc(C(=O)[O-])cc1", + name="Non Standardized SMILES", + description="No Desc", ) self.assertEqual(len(c.structures.all()), 2) for cs in c.structures.all(): if cs.normalized_structure: - self.assertEqual(cs.smiles, 'O=C(O)C1=CC=C([N+](=O)[O-])C=C1') + self.assertEqual(cs.smiles, "O=C(O)C1=CC=C([N+](=O)[O-])C=C1") break else: # Loop finished without break, lets fail... @@ -147,51 +136,54 @@ class CompoundTest(TestCase): def test_add_structure_smoke(self): c = Compound.create( self.package, - smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', - name='Standardized SMILES', - description='No Desc' + smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1", + name="Standardized SMILES", + description="No Desc", ) - c.add_structure('[O-][N+](=O)c1ccc(C(=O)[O-])cc1', 'Non Standardized SMILES') + c.add_structure("[O-][N+](=O)c1ccc(C(=O)[O-])cc1", "Non Standardized SMILES") self.assertEqual(len(c.structures.all()), 2) def test_add_structure_with_different_normalized_smiles(self): c = Compound.create( self.package, - smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', - name='Standardized SMILES', - description='No Desc' + smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1", + name="Standardized SMILES", + description="No Desc", ) with self.assertRaises(ValueError): c.add_structure( - 'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - 'Different Standardized SMILES') + "C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + "Different Standardized SMILES", + ) def test_delete(self): c = Compound.create( self.package, - smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', - name='Standardization Test', - description='No Desc' + smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1", + name="Standardization Test", + description="No Desc", ) c.delete() self.assertEqual(Compound.objects.filter(package=self.package).count(), 0) - self.assertEqual(CompoundStructure.objects.filter(compound__package=self.package).count(), 0) + self.assertEqual( + CompoundStructure.objects.filter(compound__package=self.package).count(), 0 + ) def test_set_as_default_structure(self): c1 = Compound.create( self.package, - smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', - name='Standardized SMILES', - description='No Desc' + smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1", + name="Standardized SMILES", + description="No Desc", ) default_structure = c1.default_structure - c2 = c1.add_structure('[O-][N+](=O)c1ccc(C(=O)[O-])cc1', 'Non Standardized SMILES') + c2 = c1.add_structure("[O-][N+](=O)c1ccc(C(=O)[O-])cc1", "Non Standardized SMILES") c1.set_default_structure(c2) self.assertNotEqual(default_structure, c2) diff --git a/tests/test_copy_objects.py b/tests/test_copy_objects.py index 09d388c9..9bad1dd9 100644 --- a/tests/test_copy_objects.py +++ b/tests/test_copy_objects.py @@ -1,6 +1,5 @@ from django.test import TestCase -from django.test import TestCase from epdb.logic import PackageManager from epdb.models import Compound, User, Reaction @@ -12,50 +11,47 @@ class CopyTest(TestCase): @classmethod def setUpClass(cls): super(CopyTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Source Package', 'No Desc') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Source Package", "No Desc") cls.AFOXOLANER = Compound.create( cls.package, - smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', - name='Afoxolaner', - description='Test compound for copying' + smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F", + name="Afoxolaner", + description="Test compound for copying", ) cls.FOUR_NITROBENZOIC_ACID = Compound.create( cls.package, - smiles='[O-][N+](=O)c1ccc(C(=O)[O-])cc1', # Normalized: O=C(O)C1=CC=C([N+](=O)[O-])C=C1', - name='Test Compound', - description='Compound with multiple structures' + smiles="[O-][N+](=O)c1ccc(C(=O)[O-])cc1", # Normalized: O=C(O)C1=CC=C([N+](=O)[O-])C=C1', + name="Test Compound", + description="Compound with multiple structures", ) cls.ETHANOL = Compound.create( - cls.package, - smiles='CCO', - name='Ethanol', - description='Simple alcohol' + cls.package, smiles="CCO", name="Ethanol", description="Simple alcohol" ) - cls.target_package = PackageManager.create_package(cls.user, 'Target Package', 'No Desc') + cls.target_package = PackageManager.create_package(cls.user, "Target Package", "No Desc") cls.reaction_educt = Compound.create( cls.package, - smiles='C(CCl)Cl', - name='1,2-Dichloroethane', - description='Eawag BBD compound c0001' + smiles="C(CCl)Cl", + name="1,2-Dichloroethane", + description="Eawag BBD compound c0001", ).default_structure cls.reaction_product = Compound.create( cls.package, - smiles='C(CO)Cl', - name='2-Chloroethanol', - description='Eawag BBD compound c0005' + smiles="C(CO)Cl", + name="2-Chloroethanol", + description="Eawag BBD compound c0005", ).default_structure cls.REACTION = Reaction.create( package=cls.package, - name='Eawag BBD reaction r0001', + name="Eawag BBD reaction r0001", educts=[cls.reaction_educt], products=[cls.reaction_product], - multi_step=False + multi_step=False, ) def test_compound_copy_basic(self): @@ -68,7 +64,9 @@ class CopyTest(TestCase): self.assertEqual(self.AFOXOLANER.description, copied_compound.description) self.assertEqual(copied_compound.package, self.target_package) self.assertEqual(self.AFOXOLANER.package, self.package) - self.assertEqual(self.AFOXOLANER.default_structure.smiles, copied_compound.default_structure.smiles) + self.assertEqual( + self.AFOXOLANER.default_structure.smiles, copied_compound.default_structure.smiles + ) def test_compound_copy_with_multiple_structures(self): """Test copying a compound with multiple structures""" @@ -86,7 +84,7 @@ class CopyTest(TestCase): self.assertIsNotNone(copied_compound.default_structure) self.assertEqual( copied_compound.default_structure.smiles, - self.FOUR_NITROBENZOIC_ACID.default_structure.smiles + self.FOUR_NITROBENZOIC_ACID.default_structure.smiles, ) def test_compound_copy_preserves_aliases(self): @@ -95,15 +93,15 @@ class CopyTest(TestCase): original_compound = self.ETHANOL # Add aliases if the method exists - if hasattr(original_compound, 'add_alias'): - original_compound.add_alias('Ethyl alcohol') - original_compound.add_alias('Grain alcohol') + if hasattr(original_compound, "add_alias"): + original_compound.add_alias("Ethyl alcohol") + original_compound.add_alias("Grain alcohol") mapping = dict() copied_compound = original_compound.copy(self.target_package, mapping) # Verify aliases were copied if they exist - if hasattr(original_compound, 'aliases') and hasattr(copied_compound, 'aliases'): + if hasattr(original_compound, "aliases") and hasattr(copied_compound, "aliases"): original_aliases = original_compound.aliases copied_aliases = copied_compound.aliases self.assertEqual(original_aliases, copied_aliases) @@ -113,10 +111,10 @@ class CopyTest(TestCase): original_compound = self.ETHANOL # Add external identifiers if the methods exist - if hasattr(original_compound, 'add_cas_number'): - original_compound.add_cas_number('64-17-5') - if hasattr(original_compound, 'add_pubchem_compound_id'): - original_compound.add_pubchem_compound_id('702') + if hasattr(original_compound, "add_cas_number"): + original_compound.add_cas_number("64-17-5") + if hasattr(original_compound, "add_pubchem_compound_id"): + original_compound.add_pubchem_compound_id("702") mapping = dict() copied_compound = original_compound.copy(self.target_package, mapping) @@ -146,7 +144,9 @@ class CopyTest(TestCase): self.assertEqual(original_structure.smiles, copied_structure.smiles) self.assertEqual(original_structure.canonical_smiles, copied_structure.canonical_smiles) self.assertEqual(original_structure.inchikey, copied_structure.inchikey) - self.assertEqual(original_structure.normalized_structure, copied_structure.normalized_structure) + self.assertEqual( + original_structure.normalized_structure, copied_structure.normalized_structure + ) # Verify they are different objects self.assertNotEqual(original_structure.uuid, copied_structure.uuid) @@ -177,7 +177,9 @@ class CopyTest(TestCase): self.assertEqual(orig_educt.compound.package, self.package) self.assertEqual(orig_educt.smiles, copy_educt.smiles) - for orig_product, copy_product in zip(self.REACTION.products.all(), copied_reaction.products.all()): + for orig_product, copy_product in zip( + self.REACTION.products.all(), copied_reaction.products.all() + ): self.assertNotEqual(orig_product.uuid, copy_product.uuid) self.assertEqual(orig_product.name, copy_product.name) self.assertEqual(orig_product.description, copy_product.description) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9549d124..eb5a7924 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,21 +11,21 @@ class DatasetTest(TestCase): def setUp(self): self.cs1 = Compound.create( self.package, - name='2,6-Dibromohydroquinone', - description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b', - smiles='C1=C(C(=C(C=C1O)Br)O)Br', + name="2,6-Dibromohydroquinone", + description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b", + smiles="C1=C(C(=C(C=C1O)Br)O)Br", ).default_structure self.cs2 = Compound.create( self.package, - smiles='O=C(O)CC(=O)/C=C(/Br)C(=O)O', + smiles="O=C(O)CC(=O)/C=C(/Br)C(=O)O", ).default_structure self.rule1 = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - smirks='[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\[#6:3]=[#6:2](\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]', - description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6' + smirks="[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\\[#6:3]=[#6:2](\\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]", + description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6", ) self.reaction1 = Reaction.create( @@ -33,14 +33,14 @@ class DatasetTest(TestCase): educts=[self.cs1], products=[self.cs2], rules=[self.rule1], - multi_step=False + multi_step=False, ) @classmethod def setUpClass(cls): super(DatasetTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") def test_smoke(self): reactions = [r for r in Reaction.objects.filter(package=self.package)] diff --git a/tests/test_enviformer.py b/tests/test_enviformer.py index 536046ad..1a688cb1 100644 --- a/tests/test_enviformer.py +++ b/tests/test_enviformer.py @@ -1,18 +1,19 @@ from tempfile import TemporaryDirectory -from django.test import TestCase +from django.test import TestCase, tag from epdb.logic import PackageManager from epdb.models import User, EnviFormer, Package +@tag("slow") class EnviFormerTest(TestCase): fixtures = ["test_fixtures.jsonl.gz"] @classmethod def setUpClass(cls): super(EnviFormerTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') - cls.BBD_SUBSET = Package.objects.get(name='Fixtures') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") + cls.BBD_SUBSET = Package.objects.get(name="Fixtures") def test_model_flow(self): """Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference""" @@ -21,11 +22,14 @@ class EnviFormerTest(TestCase): threshold = float(0.5) data_package_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET] - mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold) + mod = EnviFormer.create( + self.package, data_package_objs, eval_packages_objs, threshold=threshold + ) mod.build_dataset() mod.build_model() mod.multigen_eval = True mod.save() mod.evaluate_model() - results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C') + + mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") diff --git a/tests/test_formatconverter.py b/tests/test_formatconverter.py index 7d871cee..006d267a 100644 --- a/tests/test_formatconverter.py +++ b/tests/test_formatconverter.py @@ -4,8 +4,7 @@ from utilities.chem import FormatConverter class FormatConverterTestCase(TestCase): - def test_standardization(self): - smiles = 'C[n+]1c([N-](C))cccc1' + smiles = "C[n+]1c([N-](C))cccc1" standardized_smiles = FormatConverter.standardize(smiles) - self.assertEqual(standardized_smiles, 'CN=C1C=CC=CN1C') + self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C") diff --git a/tests/test_model.py b/tests/test_model.py index 36a1fd39..e46046ec 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,7 +4,7 @@ import numpy as np from django.test import TestCase from epdb.logic import PackageManager -from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package +from epdb.models import User, MLRelativeReasoning, Package class ModelTest(TestCase): @@ -13,9 +13,9 @@ class ModelTest(TestCase): @classmethod def setUpClass(cls): super(ModelTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') - cls.BBD_SUBSET = Package.objects.get(name='Fixtures') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") + cls.BBD_SUBSET = Package.objects.get(name="Fixtures") def test_smoke(self): with TemporaryDirectory() as tmpdir: @@ -32,8 +32,8 @@ class ModelTest(TestCase): data_package_objs, eval_packages_objs, threshold=threshold, - name='ECC - BBD - 0.5', - description='Created MLRelativeReasoning in Testcase', + name="ECC - BBD - 0.5", + description="Created MLRelativeReasoning in Testcase", ) # mod = RuleBasedRelativeReasoning.create( @@ -54,7 +54,7 @@ class ModelTest(TestCase): mod.save() mod.evaluate_model() - results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C') + results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") products = dict() for r in results: @@ -62,8 +62,11 @@ class ModelTest(TestCase): products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability) expected = { - ('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)), - ('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)), + ("CC=O", "CCNC(=O)C1=CC(C)=CC=C1"): ( + "bt0243-4301", + np.float64(0.33333333333333337), + ), + ("CC1=CC=CC(C(=O)O)=C1", "CCNCC"): ("bt0430-4011", np.float64(0.25)), } self.assertEqual(products, expected) diff --git a/tests/test_multigen_eval.py b/tests/test_multigen_eval.py index 018bb92d..7959d81c 100644 --- a/tests/test_multigen_eval.py +++ b/tests/test_multigen_eval.py @@ -1,4 +1,3 @@ -import json from django.test import TestCase from networkx.utils.misc import graphs_equal from epdb.logic import PackageManager, SPathway @@ -12,9 +11,11 @@ class MultiGenTest(TestCase): @classmethod def setUpClass(cls): super(MultiGenTest, cls).setUpClass() - cls.user: 'User' = User.objects.get(username='anonymous') - cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') - cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures') + cls.user: "User" = User.objects.get(username="anonymous") + cls.package: "Package" = PackageManager.create_package( + cls.user, "Anon Test Package", "No Desc" + ) + cls.BBD_SUBSET: "Package" = Package.objects.get(name="Fixtures") def test_equal_pathways(self): """Test that two identical pathways return a precision and recall of 1.0""" @@ -23,14 +24,23 @@ class MultiGenTest(TestCase): if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges continue score, precision, recall = multigen_eval(pathway, pathway) - self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. " - f"Failed on pathway: {pathway.name}") - self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. " - f"Failed on pathway: {pathway.name}") + self.assertEqual( + precision, + 1.0, + f"Precision should be one for identical pathways. " + f"Failed on pathway: {pathway.name}", + ) + self.assertEqual( + recall, + 1.0, + f"Recall should be one for identical pathways. Failed on pathway: {pathway.name}", + ) def test_intermediates(self): """Test that an intermediate can be correctly identified and the metrics are correctly adjusted""" - score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True) + score, precision, recall, intermediates = multigen_eval( + *self.intermediate_case(), return_intermediates=True + ) self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate") self.assertEqual(precision, 1, "Precision should be 1") self.assertEqual(recall, 1, "Recall should be 1") @@ -49,7 +59,9 @@ class MultiGenTest(TestCase): def test_all(self): """Test an intermediate, false-positive and false-negative together""" - score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True) + score, precision, recall, intermediates = multigen_eval( + *self.all_case(), return_intermediates=True + ) self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate") self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6") self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75") @@ -57,19 +69,22 @@ class MultiGenTest(TestCase): def test_shallow_pathway(self): pathways = self.BBD_SUBSET.pathways.all() for pathway in pathways: - pathway_name = pathway.name if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges continue + shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway)) pathway = graph_from_pathway(pathway) if not graphs_equal(shallow_pathway, pathway): - print('\n\nS', shallow_pathway.adj) - print('\n\nPW', pathway.adj) + print("\n\nS", shallow_pathway.adj) + print("\n\nPW", pathway.adj) # print(shallow_pathway.nodes, pathway.nodes) # print(shallow_pathway.graph, pathway.graph) - self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not " - f"equal to pathway for pathway {pathway.name}") + self.assertTrue( + graphs_equal(shallow_pathway, pathway), + f"Networkx graph from shallow pathway not " + f"equal to pathway for pathway {pathway.name}", + ) def test_graph_edit_eval(self): """Performs all the previous tests but with graph_edit_eval @@ -79,10 +94,16 @@ class MultiGenTest(TestCase): if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges continue score = pathway_edit_eval(pathway, pathway) - self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. " - f"Failed on pathway: {pathway.name}") + self.assertEqual( + score, + 0.0, + "Pathway edit distance should be zero for identical pathways. " + f"Failed on pathway: {pathway.name}", + ) inter_score = pathway_edit_eval(*self.intermediate_case()) - self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case") + self.assertAlmostEqual( + inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case" + ) fp_score = pathway_edit_eval(*self.fp_case()) self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case") fn_score = pathway_edit_eval(*self.fn_case()) @@ -93,22 +114,30 @@ class MultiGenTest(TestCase): def intermediate_case(self): """Create an example with an intermediate in the predicted pathway""" true_pathway = Pathway.create(self.package, "CCO") - true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)]) + true_pathway.add_edge( + [true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)] + ) pred_pathway = Pathway.create(self.package, "CCO") - pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], - [acetaldehyde := pred_pathway.add_node("CC=O", depth=1)]) + pred_pathway.add_edge( + [pred_pathway.root_nodes.all()[0]], + [acetaldehyde := pred_pathway.add_node("CC=O", depth=1)], + ) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)]) return true_pathway, pred_pathway def fp_case(self): """Create an example with an extra compound in the predicted pathway""" true_pathway = Pathway.create(self.package, "CCO") - true_pathway.add_edge([true_pathway.root_nodes.all()[0]], - [acetaldehyde := true_pathway.add_node("CC=O", depth=1)]) + true_pathway.add_edge( + [true_pathway.root_nodes.all()[0]], + [acetaldehyde := true_pathway.add_node("CC=O", depth=1)], + ) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway = Pathway.create(self.package, "CCO") - pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], - [acetaldehyde := pred_pathway.add_node("CC=O", depth=1)]) + pred_pathway.add_edge( + [pred_pathway.root_nodes.all()[0]], + [acetaldehyde := pred_pathway.add_node("CC=O", depth=1)], + ) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)]) return true_pathway, pred_pathway @@ -116,22 +145,30 @@ class MultiGenTest(TestCase): def fn_case(self): """Create an example with a missing compound in the predicted pathway""" true_pathway = Pathway.create(self.package, "CCO") - true_pathway.add_edge([true_pathway.root_nodes.all()[0]], - [acetaldehyde := true_pathway.add_node("CC=O", depth=1)]) + true_pathway.add_edge( + [true_pathway.root_nodes.all()[0]], + [acetaldehyde := true_pathway.add_node("CC=O", depth=1)], + ) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway = Pathway.create(self.package, "CCO") - pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)]) + pred_pathway.add_edge( + [pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)] + ) return true_pathway, pred_pathway def all_case(self): """Create an example with an intermediate, extra compound and missing compound""" true_pathway = Pathway.create(self.package, "CCO") - true_pathway.add_edge([true_pathway.root_nodes.all()[0]], - [acetaldehyde := true_pathway.add_node("CC=O", depth=1)]) + true_pathway.add_edge( + [true_pathway.root_nodes.all()[0]], + [acetaldehyde := true_pathway.add_node("CC=O", depth=1)], + ) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)]) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway = Pathway.create(self.package, "CCO") - pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)]) + pred_pathway.add_edge( + [pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)] + ) pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)]) pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)]) return true_pathway, pred_pathway diff --git a/tests/test_reaction_model.py b/tests/test_reaction_model.py index 15d1562d..ef8c93c0 100644 --- a/tests/test_reaction_model.py +++ b/tests/test_reaction_model.py @@ -10,127 +10,127 @@ class ReactionTest(TestCase): @classmethod def setUpClass(cls): super(ReactionTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") def test_smoke(self): educt = Compound.create( self.package, - smiles='C(CCl)Cl', - name='1,2-Dichloroethane', - description='Eawag BBD compound c0001' + smiles="C(CCl)Cl", + name="1,2-Dichloroethane", + description="Eawag BBD compound c0001", ).default_structure product = Compound.create( self.package, - smiles='C(CO)Cl', - name='2-Chloroethanol', - description='Eawag BBD compound c0005' + smiles="C(CO)Cl", + name="2-Chloroethanol", + description="Eawag BBD compound c0005", ).default_structure r = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', + name="Eawag BBD reaction r0001", educts=[educt], products=[product], - multi_step=False + multi_step=False, ) - self.assertEqual(r.smirks(), 'C(CCl)Cl>>C(CO)Cl') - self.assertEqual(r.name, 'Eawag BBD reaction r0001') - self.assertEqual(r.description, 'no description') + self.assertEqual(r.smirks(), "C(CCl)Cl>>C(CO)Cl") + self.assertEqual(r.name, "Eawag BBD reaction r0001") + self.assertEqual(r.description, "no description") def test_string_educts_and_products(self): r = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], - multi_step=False + name="Eawag BBD reaction r0001", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], + multi_step=False, ) - self.assertEqual(r.smirks(), 'C(CCl)Cl>>C(CO)Cl') + self.assertEqual(r.smirks(), "C(CCl)Cl>>C(CO)Cl") def test_missing_smiles(self): educt = Compound.create( self.package, - smiles='C(CCl)Cl', - name='1,2-Dichloroethane', - description='Eawag BBD compound c0001' + smiles="C(CCl)Cl", + name="1,2-Dichloroethane", + description="Eawag BBD compound c0001", ).default_structure product = Compound.create( self.package, - smiles='C(CO)Cl', - name='2-Chloroethanol', - description='Eawag BBD compound c0005' + smiles="C(CO)Cl", + name="2-Chloroethanol", + description="Eawag BBD compound c0005", ).default_structure with self.assertRaises(ValueError): _ = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', + name="Eawag BBD reaction r0001", educts=[educt], products=[], - multi_step=False + multi_step=False, ) with self.assertRaises(ValueError): _ = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', + name="Eawag BBD reaction r0001", educts=[], products=[product], - multi_step=False + multi_step=False, ) with self.assertRaises(ValueError): _ = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', + name="Eawag BBD reaction r0001", educts=[], products=[], - multi_step=False + multi_step=False, ) def test_empty_name_and_description_are_ignored(self): r = Reaction.create( package=self.package, - name='', - description='', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], + name="", + description="", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], multi_step=False, ) - self.assertEqual(r.name, 'no name') - self.assertEqual(r.description, 'no description') + self.assertEqual(r.name, "no name") + self.assertEqual(r.description, "no description") def test_deduplication(self): rule = Rule.create( package=self.package, - rule_type='SimpleAmbitRule', - name='bt0022-2833', - description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative', - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', + rule_type="SimpleAmbitRule", + name="bt0022-2833", + description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative", + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", ) r1 = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], + name="Eawag BBD reaction r0001", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], rules=[rule], - multi_step=False + multi_step=False, ) r2 = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], + name="Eawag BBD reaction r0001", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], rules=[rule], - multi_step=False + multi_step=False, ) # Check if create detects that this Compound already exist @@ -141,18 +141,18 @@ class ReactionTest(TestCase): def test_deduplication_without_rules(self): r1 = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], - multi_step=False + name="Eawag BBD reaction r0001", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], + multi_step=False, ) r2 = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], - multi_step=False + name="Eawag BBD reaction r0001", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], + multi_step=False, ) # Check if create detects that this Compound already exist @@ -164,19 +164,19 @@ class ReactionTest(TestCase): with self.assertRaises(ValueError): _ = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['ASDF'], - products=['C(CO)Cl'], - multi_step=False + name="Eawag BBD reaction r0001", + educts=["ASDF"], + products=["C(CO)Cl"], + multi_step=False, ) def test_delete(self): r = Reaction.create( package=self.package, - name='Eawag BBD reaction r0001', - educts=['C(CCl)Cl'], - products=['C(CO)Cl'], - multi_step=False + name="Eawag BBD reaction r0001", + educts=["C(CCl)Cl"], + products=["C(CO)Cl"], + multi_step=False, ) r.delete() diff --git a/tests/test_rule_model.py b/tests/test_rule_model.py index 694dfeda..520e049f 100644 --- a/tests/test_rule_model.py +++ b/tests/test_rule_model.py @@ -10,73 +10,79 @@ class RuleTest(TestCase): @classmethod def setUpClass(cls): super(RuleTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") def test_smoke(self): r = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - name='bt0022-2833', - description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative', - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', + name="bt0022-2833", + description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative", + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", ) - self.assertEqual(r.smirks, - '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]') - self.assertEqual(r.name, 'bt0022-2833') - self.assertEqual(r.description, - 'Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative') + self.assertEqual( + r.smirks, + "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + ) + self.assertEqual(r.name, "bt0022-2833") + self.assertEqual( + r.description, + "Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative", + ) def test_smirks_are_trimmed(self): r = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - name='bt0022-2833', - description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative', - smirks=' [H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4] ', + name="bt0022-2833", + description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative", + smirks=" [H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4] ", ) - self.assertEqual(r.smirks, - '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]') + self.assertEqual( + r.smirks, + "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + ) def test_name_and_description_optional(self): r = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", ) - self.assertRegex(r.name, 'Rule \\d+') - self.assertEqual(r.description, 'no description') + self.assertRegex(r.name, "Rule \\d+") + self.assertEqual(r.description, "no description") def test_empty_name_and_description_are_ignored(self): r = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', - name='', - description='', + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + name="", + description="", ) - self.assertRegex(r.name, 'Rule \\d+') - self.assertEqual(r.description, 'no description') + self.assertRegex(r.name, "Rule \\d+") + self.assertEqual(r.description, "no description") def test_deduplication(self): r1 = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', - name='', - description='', + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + name="", + description="", ) r2 = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', - name='', - description='', + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + name="", + description="", ) self.assertEqual(r1.pk, r2.pk) @@ -84,21 +90,21 @@ class RuleTest(TestCase): def test_valid_smirks(self): with self.assertRaises(ValueError): - r = Rule.create( - rule_type='SimpleAmbitRule', + Rule.create( + rule_type="SimpleAmbitRule", package=self.package, - smirks='This is not a valid SMIRKS', - name='', - description='', + smirks="This is not a valid SMIRKS", + name="", + description="", ) def test_delete(self): r = Rule.create( - rule_type='SimpleAmbitRule', + rule_type="SimpleAmbitRule", package=self.package, - smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', - name='', - description='', + smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + name="", + description="", ) r.delete() diff --git a/tests/test_ruleapplication.py b/tests/test_ruleapplication.py index cd05d6c2..690f73d8 100644 --- a/tests/test_ruleapplication.py +++ b/tests/test_ruleapplication.py @@ -8,18 +8,19 @@ from envipy_ambit import apply from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize + @tag("slow") class RuleApplicationTest(TestCase): - def setUp(self): self.total_errors = 0 @classmethod def setUpClass(cls): super(RuleApplicationTest, cls).setUpClass() - cls.data = json.load(gzip.open(s.BASE_DIR / 'fixtures' / 'ambit_rules.json.gz', 'rb')) + cls.data = json.load(gzip.open(s.BASE_DIR / "fixtures" / "ambit_rules.json.gz", "rb")) cls.error_smiles = list() from collections import defaultdict + cls.triggered = defaultdict(lambda: defaultdict(lambda: 0)) @classmethod @@ -29,6 +30,7 @@ class RuleApplicationTest(TestCase): # print(cls.error_smiles) from pprint import pprint from collections import Counter + pprint(Counter(cls.error_smiles)) # import json # pprint(json.loads(json.dumps(cls.triggered))) @@ -53,7 +55,13 @@ class RuleApplicationTest(TestCase): ambit_res = apply(smirks, smiles) ambit_res = list( - set([RuleApplicationTest.normalize_smiles(str(x)) for x in FormatConverter.sanitize_smiles([str(s) for s in ambit_res])[0]])) + set( + [ + RuleApplicationTest.normalize_smiles(str(x)) + for x in FormatConverter.sanitize_smiles([str(s) for s in ambit_res])[0] + ] + ) + ) products = FormatConverter.apply(smiles, smirks) @@ -64,23 +72,30 @@ class RuleApplicationTest(TestCase): all_rdkit_prods = list(set(all_rdkit_prods)) - all_rdkit_res = list(set([RuleApplicationTest.normalize_smiles(str(x)) for x in - FormatConverter.sanitize_smiles([str(s) for s in all_rdkit_prods])[0]])) + all_rdkit_res = list( + set( + [ + RuleApplicationTest.normalize_smiles(str(x)) + for x in FormatConverter.sanitize_smiles([str(s) for s in all_rdkit_prods])[0] + ] + ) + ) return ambit_res, 0, all_rdkit_res, 0 def run_bt_test(self, bt_rule_name): bt_rule = self.data[bt_rule_name] - smirks = bt_rule['smirks'] + smirks = bt_rule["smirks"] res = True all_prods = set() - for comp in bt_rule['compounds']: + for comp in bt_rule["compounds"]: + smi = comp["smiles"] - smi = comp['smiles'] - - ambit_smiles, ambit_errors, rdkit_smiles, rdkit_errors = self.run_both_engines(smi, smirks) + ambit_smiles, ambit_errors, rdkit_smiles, rdkit_errors = self.run_both_engines( + smi, smirks + ) for x in ambit_smiles: all_prods.add(x) @@ -100,8 +115,8 @@ class RuleApplicationTest(TestCase): if len(ambit_smiles) and not partial_res: print(f""" BT {bt_rule_name} - SMIRKS {bt_rule['smirks']} - Compound {comp['smiles']} + SMIRKS {bt_rule["smirks"]} + Compound {comp["smiles"]} Num ambit {len(set(ambit_smiles))} Num rdkit {len(set(rdkit_smiles))} Num Intersection A {len(set(ambit_smiles).intersection(set(rdkit_smiles)))} @@ -115,7 +130,7 @@ class RuleApplicationTest(TestCase): """) if not partial_res: - self.error_smiles.append(comp['smiles']) + self.error_smiles.append(comp["smiles"]) self.total_errors += 1 res &= partial_res @@ -123,871 +138,871 @@ class RuleApplicationTest(TestCase): self.assertTrue(res) def test_bt0349_3023(self): - self.run_bt_test('bt0349-3023') + self.run_bt_test("bt0349-3023") def test_bt0306_3442(self): - self.run_bt_test('bt0306-3442') + self.run_bt_test("bt0306-3442") def test_bt0342_4298_1(self): - self.run_bt_test('bt0342-4298.1') + self.run_bt_test("bt0342-4298.1") def test_bt0064_3707(self): - self.run_bt_test('bt0064-3707') + self.run_bt_test("bt0064-3707") def test_bt0108_470(self): - self.run_bt_test('bt0108-470') + self.run_bt_test("bt0108-470") def test_bt0231_1871_2(self): - self.run_bt_test('bt0231-1871.2') + self.run_bt_test("bt0231-1871.2") def test_bt0374_4148(self): - self.run_bt_test('bt0374-4148') + self.run_bt_test("bt0374-4148") def test_bt0417_3777(self): - self.run_bt_test('bt0417-3777') + self.run_bt_test("bt0417-3777") def test_bt0153_3077(self): - self.run_bt_test('bt0153-3077') + self.run_bt_test("bt0153-3077") def test_bt0213_3524(self): - self.run_bt_test('bt0213-3524') + self.run_bt_test("bt0213-3524") def test_bt0257_3855_2(self): - self.run_bt_test('bt0257-3855.2') + self.run_bt_test("bt0257-3855.2") def test_bt0037_3717(self): - self.run_bt_test('bt0037-3717') + self.run_bt_test("bt0037-3717") def test_bt0102_4062(self): - self.run_bt_test('bt0102-4062') + self.run_bt_test("bt0102-4062") def test_bt0431_4039(self): - self.run_bt_test('bt0431-4039') + self.run_bt_test("bt0431-4039") def test_bt0444_4310(self): - self.run_bt_test('bt0444-4310') + self.run_bt_test("bt0444-4310") def test_bt0242_3803(self): - self.run_bt_test('bt0242-3803') + self.run_bt_test("bt0242-3803") def test_bt0231_1871_3(self): - self.run_bt_test('bt0231-1871.3') + self.run_bt_test("bt0231-1871.3") def test_bt0388_4159(self): - self.run_bt_test('bt0388-4159') + self.run_bt_test("bt0388-4159") def test_bt0022_2833(self): - self.run_bt_test('bt0022-2833') + self.run_bt_test("bt0022-2833") def test_bt0393_3367(self): - self.run_bt_test('bt0393-3367') + self.run_bt_test("bt0393-3367") def test_bt0282_3656(self): - self.run_bt_test('bt0282-3656') + self.run_bt_test("bt0282-3656") def test_bt0399_3488(self): - self.run_bt_test('bt0399-3488') + self.run_bt_test("bt0399-3488") def test_bt0330_3930(self): - self.run_bt_test('bt0330-3930') + self.run_bt_test("bt0330-3930") def test_bt0363_4185(self): - self.run_bt_test('bt0363-4185') + self.run_bt_test("bt0363-4185") def test_bt0243_4301(self): - self.run_bt_test('bt0243-4301') + self.run_bt_test("bt0243-4301") def test_bt0407_3651(self): - self.run_bt_test('bt0407-3651') + self.run_bt_test("bt0407-3651") def test_bt0055_3469_3(self): - self.run_bt_test('bt0055-3469.3') + self.run_bt_test("bt0055-3469.3") def test_bt0230_3525(self): - self.run_bt_test('bt0230-3525') + self.run_bt_test("bt0230-3525") def test_bt0051_3151(self): - self.run_bt_test('bt0051-3151') + self.run_bt_test("bt0051-3151") def test_bt0212_3523(self): - self.run_bt_test('bt0212-3523') + self.run_bt_test("bt0212-3523") def test_bt0005_4282(self): - self.run_bt_test('bt0005-4282') + self.run_bt_test("bt0005-4282") def test_bt0037_3718(self): - self.run_bt_test('bt0037-3718') + self.run_bt_test("bt0037-3718") def test_bt0418_3842(self): - self.run_bt_test('bt0418-3842') + self.run_bt_test("bt0418-3842") def test_bt0062_925(self): - self.run_bt_test('bt0062-925') + self.run_bt_test("bt0062-925") def test_bt0428_3946(self): - self.run_bt_test('bt0428-3946') + self.run_bt_test("bt0428-3946") def test_bt0420_3811(self): - self.run_bt_test('bt0420-3811') + self.run_bt_test("bt0420-3811") def test_bt0351_3769(self): - self.run_bt_test('bt0351-3769') + self.run_bt_test("bt0351-3769") def test_bt0383_3210(self): - self.run_bt_test('bt0383-3210') + self.run_bt_test("bt0383-3210") def test_bt0421_3907(self): - self.run_bt_test('bt0421-3907') + self.run_bt_test("bt0421-3907") def test_bt0079_1087(self): - self.run_bt_test('bt0079-1087') + self.run_bt_test("bt0079-1087") def test_bt0013_4165_2(self): - self.run_bt_test('bt0013-4165.2') + self.run_bt_test("bt0013-4165.2") def test_bt0337_3542(self): - self.run_bt_test('bt0337-3542') + self.run_bt_test("bt0337-3542") def test_bt0325_3638(self): - self.run_bt_test('bt0325-3638') + self.run_bt_test("bt0325-3638") def test_bt0435_4212(self): - self.run_bt_test('bt0435-4212') + self.run_bt_test("bt0435-4212") def test_bt0071_4150(self): - self.run_bt_test('bt0071-4150') + self.run_bt_test("bt0071-4150") def test_bt0351_3944(self): - self.run_bt_test('bt0351-3944') + self.run_bt_test("bt0351-3944") def test_bt0270_3919(self): - self.run_bt_test('bt0270-3919') + self.run_bt_test("bt0270-3919") def test_bt0349_2798(self): - self.run_bt_test('bt0349-2798') + self.run_bt_test("bt0349-2798") def test_bt0154_1367(self): - self.run_bt_test('bt0154-1367') + self.run_bt_test("bt0154-1367") def test_bt0401_3575(self): - self.run_bt_test('bt0401-3575') + self.run_bt_test("bt0401-3575") def test_bt0430_4011(self): - self.run_bt_test('bt0430-4011') + self.run_bt_test("bt0430-4011") def test_bt0337_3545(self): - self.run_bt_test('bt0337-3545') + self.run_bt_test("bt0337-3545") def test_bt0389_3302(self): - self.run_bt_test('bt0389-3302') + self.run_bt_test("bt0389-3302") def test_bt0346_2639(self): - self.run_bt_test('bt0346-2639') + self.run_bt_test("bt0346-2639") def test_bt0268_3530(self): - self.run_bt_test('bt0268-3530') + self.run_bt_test("bt0268-3530") def test_bt0379_3190(self): - self.run_bt_test('bt0379-3190') + self.run_bt_test("bt0379-3190") def test_bt0013_4165(self): - self.run_bt_test('bt0013-4165') + self.run_bt_test("bt0013-4165") def test_bt0351_2780(self): - self.run_bt_test('bt0351-2780') + self.run_bt_test("bt0351-2780") def test_bt0353_4167(self): - self.run_bt_test('bt0353-4167') + self.run_bt_test("bt0353-4167") def test_bt0291_1129(self): - self.run_bt_test('bt0291-1129') + self.run_bt_test("bt0291-1129") def test_bt0103_3648(self): - self.run_bt_test('bt0103-3648') + self.run_bt_test("bt0103-3648") def test_bt0044_3232(self): - self.run_bt_test('bt0044-3232') + self.run_bt_test("bt0044-3232") def test_bt0110_3663(self): - self.run_bt_test('bt0110-3663') + self.run_bt_test("bt0110-3663") def test_bt0107_3557(self): - self.run_bt_test('bt0107-3557') + self.run_bt_test("bt0107-3557") def test_bt0034_2448(self): - self.run_bt_test('bt0034-2448') + self.run_bt_test("bt0034-2448") def test_bt0073_3591(self): - self.run_bt_test('bt0073-3591') + self.run_bt_test("bt0073-3591") def test_bt0219_4295(self): - self.run_bt_test('bt0219-4295') + self.run_bt_test("bt0219-4295") def test_bt0066_3867(self): - self.run_bt_test('bt0066-3867') + self.run_bt_test("bt0066-3867") def test_bt0295_3520(self): - self.run_bt_test('bt0295-3520') + self.run_bt_test("bt0295-3520") def test_bt0021_3858(self): - self.run_bt_test('bt0021-3858') + self.run_bt_test("bt0021-3858") def test_bt0177_3159(self): - self.run_bt_test('bt0177-3159') + self.run_bt_test("bt0177-3159") def test_bt0318_3664(self): - self.run_bt_test('bt0318-3664') + self.run_bt_test("bt0318-3664") def test_bt0080_4217(self): - self.run_bt_test('bt0080-4217') + self.run_bt_test("bt0080-4217") def test_bt0181_1278(self): - self.run_bt_test('bt0181-1278') + self.run_bt_test("bt0181-1278") def test_bt0254_4224_2(self): - self.run_bt_test('bt0254-4224.2') + self.run_bt_test("bt0254-4224.2") def test_bt0237_2957(self): - self.run_bt_test('bt0237-2957') + self.run_bt_test("bt0237-2957") def test_bt0342_4298_2(self): - self.run_bt_test('bt0342-4298.2') + self.run_bt_test("bt0342-4298.2") def test_bt0280_2426(self): - self.run_bt_test('bt0280-2426') + self.run_bt_test("bt0280-2426") def test_bt0438_4230(self): - self.run_bt_test('bt0438-4230') + self.run_bt_test("bt0438-4230") def test_bt0270_3922(self): - self.run_bt_test('bt0270-3922') + self.run_bt_test("bt0270-3922") def test_bt0021_3859(self): - self.run_bt_test('bt0021-3859') + self.run_bt_test("bt0021-3859") def test_bt0323_3394(self): - self.run_bt_test('bt0323-3394') + self.run_bt_test("bt0323-3394") def test_bt0408_3666(self): - self.run_bt_test('bt0408-3666') + self.run_bt_test("bt0408-3666") def test_bt0429_4043(self): - self.run_bt_test('bt0429-4043') + self.run_bt_test("bt0429-4043") def test_bt0198_546(self): - self.run_bt_test('bt0198-546') + self.run_bt_test("bt0198-546") def test_bt0312_3818(self): - self.run_bt_test('bt0312-3818') + self.run_bt_test("bt0312-3818") def test_bt0348_4121(self): - self.run_bt_test('bt0348-4121') + self.run_bt_test("bt0348-4121") def test_bt0153_3078(self): - self.run_bt_test('bt0153-3078') + self.run_bt_test("bt0153-3078") def test_bt0031_1217(self): - self.run_bt_test('bt0031-1217') + self.run_bt_test("bt0031-1217") def test_bt0184_4187(self): - self.run_bt_test('bt0184-4187') + self.run_bt_test("bt0184-4187") def test_bt0055_3469_4(self): - self.run_bt_test('bt0055-3469.4') + self.run_bt_test("bt0055-3469.4") def test_bt0257_3855_1(self): - self.run_bt_test('bt0257-3855.1') + self.run_bt_test("bt0257-3855.1") def test_bt0242_3804(self): - self.run_bt_test('bt0242-3804') + self.run_bt_test("bt0242-3804") def test_bt0077_441(self): - self.run_bt_test('bt0077-441') + self.run_bt_test("bt0077-441") def test_bt0011_4163(self): - self.run_bt_test('bt0011-4163') + self.run_bt_test("bt0011-4163") def test_bt0270_3921(self): - self.run_bt_test('bt0270-3921') + self.run_bt_test("bt0270-3921") def test_bt0376_4266(self): - self.run_bt_test('bt0376-4266') + self.run_bt_test("bt0376-4266") def test_bt0036_3571(self): - self.run_bt_test('bt0036-3571') + self.run_bt_test("bt0036-3571") def test_bt0352_4297_1(self): - self.run_bt_test('bt0352-4297.1') + self.run_bt_test("bt0352-4297.1") def test_bt0199_3639(self): - self.run_bt_test('bt0199-3639') + self.run_bt_test("bt0199-3639") def test_bt0143_3211(self): - self.run_bt_test('bt0143-3211') + self.run_bt_test("bt0143-3211") def test_bt0020_1610(self): - self.run_bt_test('bt0020-1610') + self.run_bt_test("bt0020-1610") def test_bt0440_4255(self): - self.run_bt_test('bt0440-4255') + self.run_bt_test("bt0440-4255") def test_bt0286_846(self): - self.run_bt_test('bt0286-846') + self.run_bt_test("bt0286-846") def test_bt0337_3543(self): - self.run_bt_test('bt0337-3543') + self.run_bt_test("bt0337-3543") def test_bt0416_4269(self): - self.run_bt_test('bt0416-4269') + self.run_bt_test("bt0416-4269") def test_bt0195_3744(self): - self.run_bt_test('bt0195-3744') + self.run_bt_test("bt0195-3744") def test_bt0334_3582(self): - self.run_bt_test('bt0334-3582') + self.run_bt_test("bt0334-3582") def test_bt0327_3585(self): - self.run_bt_test('bt0327-3585') + self.run_bt_test("bt0327-3585") def test_bt0384_4048(self): - self.run_bt_test('bt0384-4048') + self.run_bt_test("bt0384-4048") def test_bt0056_2685(self): - self.run_bt_test('bt0056-2685') + self.run_bt_test("bt0056-2685") def test_bt0337_4117(self): - self.run_bt_test('bt0337-4117') + self.run_bt_test("bt0337-4117") def test_bt0405_3633(self): - self.run_bt_test('bt0405-3633') + self.run_bt_test("bt0405-3633") def test_bt0439_4270(self): - self.run_bt_test('bt0439-4270') + self.run_bt_test("bt0439-4270") def test_bt0332_3924(self): - self.run_bt_test('bt0332-3924') + self.run_bt_test("bt0332-3924") def test_bt0423_3876(self): - self.run_bt_test('bt0423-3876') + self.run_bt_test("bt0423-3876") def test_bt0351_3138(self): - self.run_bt_test('bt0351-3138') + self.run_bt_test("bt0351-3138") def test_bt0351_4118(self): - self.run_bt_test('bt0351-4118') + self.run_bt_test("bt0351-4118") def test_bt0147_3336(self): - self.run_bt_test('bt0147-3336') + self.run_bt_test("bt0147-3336") def test_bt0067_4013(self): - self.run_bt_test('bt0067-4013') + self.run_bt_test("bt0067-4013") def test_bt0063_3938(self): - self.run_bt_test('bt0063-3938') + self.run_bt_test("bt0063-3938") def test_bt0166_3560(self): - self.run_bt_test('bt0166-3560') + self.run_bt_test("bt0166-3560") def test_bt0156_3659(self): - self.run_bt_test('bt0156-3659') + self.run_bt_test("bt0156-3659") def test_bt0372_3657(self): - self.run_bt_test('bt0372-3657') + self.run_bt_test("bt0372-3657") def test_bt0012_4164_2(self): - self.run_bt_test('bt0012-4164.2') + self.run_bt_test("bt0012-4164.2") def test_bt0359_3668(self): - self.run_bt_test('bt0359-3668') + self.run_bt_test("bt0359-3668") def test_bt0385_3220(self): - self.run_bt_test('bt0385-3220') + self.run_bt_test("bt0385-3220") def test_bt0403_3595(self): - self.run_bt_test('bt0403-3595') + self.run_bt_test("bt0403-3595") def test_bt0231_1871_4(self): - self.run_bt_test('bt0231-1871.4') + self.run_bt_test("bt0231-1871.4") def test_bt0003_1196(self): - self.run_bt_test('bt0003-1196') + self.run_bt_test("bt0003-1196") def test_bt0397_3475(self): - self.run_bt_test('bt0397-3475') + self.run_bt_test("bt0397-3475") def test_bt0066_3856(self): - self.run_bt_test('bt0066-3856') + self.run_bt_test("bt0066-3856") def test_bt0423_3824(self): - self.run_bt_test('bt0423-3824') + self.run_bt_test("bt0423-3824") def test_bt0374_3801(self): - self.run_bt_test('bt0374-3801') + self.run_bt_test("bt0374-3801") def test_bt0357_2817(self): - self.run_bt_test('bt0357-2817') + self.run_bt_test("bt0357-2817") def test_bt0281_676(self): - self.run_bt_test('bt0281-676') + self.run_bt_test("bt0281-676") def test_bt0192_3861(self): - self.run_bt_test('bt0192-3861') + self.run_bt_test("bt0192-3861") def test_bt0218_3579(self): - self.run_bt_test('bt0218-3579') + self.run_bt_test("bt0218-3579") def test_bt0055_3469_2(self): - self.run_bt_test('bt0055-3469.2') + self.run_bt_test("bt0055-3469.2") def test_bt0349_4276(self): - self.run_bt_test('bt0349-4276') + self.run_bt_test("bt0349-4276") def test_bt0055_3469_1(self): - self.run_bt_test('bt0055-3469.1') + self.run_bt_test("bt0055-3469.1") def test_bt0420_3794(self): - self.run_bt_test('bt0420-3794') + self.run_bt_test("bt0420-3794") def test_bt0352_2748(self): - self.run_bt_test('bt0352-2748') + self.run_bt_test("bt0352-2748") def test_bt0078_232(self): - self.run_bt_test('bt0078-232') + self.run_bt_test("bt0078-232") def test_bt0051_3093(self): - self.run_bt_test('bt0051-3093') + self.run_bt_test("bt0051-3093") def test_bt0035_1206(self): - self.run_bt_test('bt0035-1206') + self.run_bt_test("bt0035-1206") def test_bt0070_3850(self): - self.run_bt_test('bt0070-3850') + self.run_bt_test("bt0070-3850") def test_bt0343_2675(self): - self.run_bt_test('bt0343-2675') + self.run_bt_test("bt0343-2675") def test_bt0392_3341(self): - self.run_bt_test('bt0392-3341') + self.run_bt_test("bt0392-3341") def test_bt0068_3564(self): - self.run_bt_test('bt0068-3564') + self.run_bt_test("bt0068-3564") def test_bt0005_3776(self): - self.run_bt_test('bt0005-3776') + self.run_bt_test("bt0005-3776") def test_bt0324_3864(self): - self.run_bt_test('bt0324-3864') + self.run_bt_test("bt0324-3864") def test_bt0002_3673(self): - self.run_bt_test('bt0002-3673') + self.run_bt_test("bt0002-3673") def test_bt0366_2884(self): - self.run_bt_test('bt0366-2884') + self.run_bt_test("bt0366-2884") def test_bt0288_2641(self): - self.run_bt_test('bt0288-2641') + self.run_bt_test("bt0288-2641") def test_bt0210_3411(self): - self.run_bt_test('bt0210-3411') + self.run_bt_test("bt0210-3411") def test_bt0180_2844(self): - self.run_bt_test('bt0180-2844') + self.run_bt_test("bt0180-2844") def test_bt0255_2690_4(self): - self.run_bt_test('bt0255-2690.4') + self.run_bt_test("bt0255-2690.4") def test_bt0337_3544(self): - self.run_bt_test('bt0337-3544') + self.run_bt_test("bt0337-3544") def test_bt0217_3026(self): - self.run_bt_test('bt0217-3026') + self.run_bt_test("bt0217-3026") def test_bt0065_4171(self): - self.run_bt_test('bt0065-4171') + self.run_bt_test("bt0065-4171") def test_bt0443_4291(self): - self.run_bt_test('bt0443-4291') + self.run_bt_test("bt0443-4291") def test_bt0386_3347(self): - self.run_bt_test('bt0386-3347') + self.run_bt_test("bt0386-3347") def test_bt0104_2853(self): - self.run_bt_test('bt0104-2853') + self.run_bt_test("bt0104-2853") def test_bt0104_2854(self): - self.run_bt_test('bt0104-2854') + self.run_bt_test("bt0104-2854") def test_bt0349_3022(self): - self.run_bt_test('bt0349-3022') + self.run_bt_test("bt0349-3022") def test_bt0427_3943(self): - self.run_bt_test('bt0427-3943') + self.run_bt_test("bt0427-3943") def test_bt0350_3319(self): - self.run_bt_test('bt0350-3319') + self.run_bt_test("bt0350-3319") def test_bt0166_3562(self): - self.run_bt_test('bt0166-3562') + self.run_bt_test("bt0166-3562") def test_bt0318_4289(self): - self.run_bt_test('bt0318-4289') + self.run_bt_test("bt0318-4289") def test_bt0352_2746(self): - self.run_bt_test('bt0352-2746') + self.run_bt_test("bt0352-2746") def test_bt0014_4215(self): - self.run_bt_test('bt0014-4215') + self.run_bt_test("bt0014-4215") def test_bt0190_1386(self): - self.run_bt_test('bt0190-1386') + self.run_bt_test("bt0190-1386") def test_bt0345_2623(self): - self.run_bt_test('bt0345-2623') + self.run_bt_test("bt0345-2623") def test_bt0370_2992(self): - self.run_bt_test('bt0370-2992') + self.run_bt_test("bt0370-2992") def test_bt0156_3660(self): - self.run_bt_test('bt0156-3660') + self.run_bt_test("bt0156-3660") def test_bt0387_3298(self): - self.run_bt_test('bt0387-3298') + self.run_bt_test("bt0387-3298") def test_bt0030_4292(self): - self.run_bt_test('bt0030-4292') + self.run_bt_test("bt0030-4292") def test_bt0388_3311(self): - self.run_bt_test('bt0388-3311') + self.run_bt_test("bt0388-3311") def test_bt0060_4170(self): - self.run_bt_test('bt0060-4170') + self.run_bt_test("bt0060-4170") def test_bt0144_4271(self): - self.run_bt_test('bt0144-4271') + self.run_bt_test("bt0144-4271") def test_bt0188_1382(self): - self.run_bt_test('bt0188-1382') + self.run_bt_test("bt0188-1382") def test_bt0082_2982(self): - self.run_bt_test('bt0082-2982') + self.run_bt_test("bt0082-2982") def test_bt0425_3892(self): - self.run_bt_test('bt0425-3892') + self.run_bt_test("bt0425-3892") def test_bt0254_4224_1(self): - self.run_bt_test('bt0254-4224.1') + self.run_bt_test("bt0254-4224.1") def test_bt0202_3925(self): - self.run_bt_test('bt0202-3925') + self.run_bt_test("bt0202-3925") def test_bt0333_3583(self): - self.run_bt_test('bt0333-3583') + self.run_bt_test("bt0333-3583") def test_bt0058_2811(self): - self.run_bt_test('bt0058-2811') + self.run_bt_test("bt0058-2811") def test_bt0352_2745(self): - self.run_bt_test('bt0352-2745') + self.run_bt_test("bt0352-2745") def test_bt0034_4082(self): - self.run_bt_test('bt0034-4082') + self.run_bt_test("bt0034-4082") def test_bt0375_3152(self): - self.run_bt_test('bt0375-3152') + self.run_bt_test("bt0375-3152") def test_bt0231_1871_1(self): - self.run_bt_test('bt0231-1871.1') + self.run_bt_test("bt0231-1871.1") def test_bt0362_3080(self): - self.run_bt_test('bt0362-3080') + self.run_bt_test("bt0362-3080") def test_bt0350_3318(self): - self.run_bt_test('bt0350-3318') + self.run_bt_test("bt0350-3318") def test_bt0337_3901(self): - self.run_bt_test('bt0337-3901') + self.run_bt_test("bt0337-3901") def test_bt0001_3568(self): - self.run_bt_test('bt0001-3568') + self.run_bt_test("bt0001-3568") def test_bt0391_4285(self): - self.run_bt_test('bt0391-4285') + self.run_bt_test("bt0391-4285") def test_bt0434_4149(self): - self.run_bt_test('bt0434-4149') + self.run_bt_test("bt0434-4149") def test_bt0156_3760(self): - self.run_bt_test('bt0156-3760') + self.run_bt_test("bt0156-3760") def test_bt0216_3640(self): - self.run_bt_test('bt0216-3640') + self.run_bt_test("bt0216-3640") def test_bt0330_3931(self): - self.run_bt_test('bt0330-3931') + self.run_bt_test("bt0330-3931") def test_bt0320_3863(self): - self.run_bt_test('bt0320-3863') + self.run_bt_test("bt0320-3863") def test_bt0352_2744(self): - self.run_bt_test('bt0352-2744') + self.run_bt_test("bt0352-2744") def test_bt0348_4120(self): - self.run_bt_test('bt0348-4120') + self.run_bt_test("bt0348-4120") def test_bt0255_2690_2(self): - self.run_bt_test('bt0255-2690.2') + self.run_bt_test("bt0255-2690.2") def test_bt0024_2218(self): - self.run_bt_test('bt0024-2218') + self.run_bt_test("bt0024-2218") def test_bt0033_1219(self): - self.run_bt_test('bt0033-1219') + self.run_bt_test("bt0033-1219") def test_bt0418_3806(self): - self.run_bt_test('bt0418-3806') + self.run_bt_test("bt0418-3806") def test_bt0208_3256(self): - self.run_bt_test('bt0208-3256') + self.run_bt_test("bt0208-3256") def test_bt0072_4172(self): - self.run_bt_test('bt0072-4172') + self.run_bt_test("bt0072-4172") def test_bt0362_2882(self): - self.run_bt_test('bt0362-2882') + self.run_bt_test("bt0362-2882") def test_bt0374_4081(self): - self.run_bt_test('bt0374-4081') + self.run_bt_test("bt0374-4081") def test_bt0023_3819(self): - self.run_bt_test('bt0023-3819') + self.run_bt_test("bt0023-3819") def test_bt0404_3928(self): - self.run_bt_test('bt0404-3928') + self.run_bt_test("bt0404-3928") def test_bt0260_2032(self): - self.run_bt_test('bt0260-2032') + self.run_bt_test("bt0260-2032") def test_bt0042_4168(self): - self.run_bt_test('bt0042-4168') + self.run_bt_test("bt0042-4168") def test_bt0416_4253(self): - self.run_bt_test('bt0416-4253') + self.run_bt_test("bt0416-4253") def test_bt0173_1376(self): - self.run_bt_test('bt0173-1376') + self.run_bt_test("bt0173-1376") def test_bt0214_1877(self): - self.run_bt_test('bt0214-1877') + self.run_bt_test("bt0214-1877") def test_bt0353_4312(self): - self.run_bt_test('bt0353-4312') + self.run_bt_test("bt0353-4312") def test_bt0436_4247(self): - self.run_bt_test('bt0436-4247') + self.run_bt_test("bt0436-4247") def test_bt0124_3980(self): - self.run_bt_test('bt0124-3980') + self.run_bt_test("bt0124-3980") def test_bt0275_899(self): - self.run_bt_test('bt0275-899') + self.run_bt_test("bt0275-899") def test_bt0390_3346(self): - self.run_bt_test('bt0390-3346') + self.run_bt_test("bt0390-3346") def test_bt0255_3670(self): - self.run_bt_test('bt0255-3670') + self.run_bt_test("bt0255-3670") def test_bt0298_3335(self): - self.run_bt_test('bt0298-3335') + self.run_bt_test("bt0298-3335") def test_bt0366_2930(self): - self.run_bt_test('bt0366-2930') + self.run_bt_test("bt0366-2930") def test_bt0270_3920(self): - self.run_bt_test('bt0270-3920') + self.run_bt_test("bt0270-3920") def test_bt0241_3580(self): - self.run_bt_test('bt0241-3580') + self.run_bt_test("bt0241-3580") def test_bt0028_3647(self): - self.run_bt_test('bt0028-3647') + self.run_bt_test("bt0028-3647") def test_bt0255_2690_3(self): - self.run_bt_test('bt0255-2690.3') + self.run_bt_test("bt0255-2690.3") def test_bt0257_3855_4(self): - self.run_bt_test('bt0257-3855.4') + self.run_bt_test("bt0257-3855.4") def test_bt0373_3577(self): - self.run_bt_test('bt0373-3577') + self.run_bt_test("bt0373-3577") def test_bt0255_2690_1(self): - self.run_bt_test('bt0255-2690.1') + self.run_bt_test("bt0255-2690.1") def test_bt0364_2870(self): - self.run_bt_test('bt0364-2870') + self.run_bt_test("bt0364-2870") def test_bt0086_4153(self): - self.run_bt_test('bt0086-4153') + self.run_bt_test("bt0086-4153") def test_bt0398_3470(self): - self.run_bt_test('bt0398-3470') + self.run_bt_test("bt0398-3470") def test_bt0086_4152(self): - self.run_bt_test('bt0086-4152') + self.run_bt_test("bt0086-4152") def test_bt0339_3800(self): - self.run_bt_test('bt0339-3800') + self.run_bt_test("bt0339-3800") def test_bt0362_3079(self): - self.run_bt_test('bt0362-3079') + self.run_bt_test("bt0362-3079") def test_bt0340_4250(self): - self.run_bt_test('bt0340-4250') + self.run_bt_test("bt0340-4250") def test_bt0050_2173(self): - self.run_bt_test('bt0050-2173') + self.run_bt_test("bt0050-2173") def test_bt0388_3799(self): - self.run_bt_test('bt0388-3799') + self.run_bt_test("bt0388-3799") def test_bt0348_4122(self): - self.run_bt_test('bt0348-4122') + self.run_bt_test("bt0348-4122") def test_bt0269_3646(self): - self.run_bt_test('bt0269-3646') + self.run_bt_test("bt0269-3646") def test_bt0103_3816(self): - self.run_bt_test('bt0103-3816') + self.run_bt_test("bt0103-3816") def test_bt0158_3361(self): - self.run_bt_test('bt0158-3361') + self.run_bt_test("bt0158-3361") def test_bt0128_4173(self): - self.run_bt_test('bt0128-4173') + self.run_bt_test("bt0128-4173") def test_bt0051_3573(self): - self.run_bt_test('bt0051-3573') + self.run_bt_test("bt0051-3573") def test_bt0348_3840(self): - self.run_bt_test('bt0348-3840') + self.run_bt_test("bt0348-3840") def test_bt0184_3607(self): - self.run_bt_test('bt0184-3607') + self.run_bt_test("bt0184-3607") def test_bt0063_2552(self): - self.run_bt_test('bt0063-2552') + self.run_bt_test("bt0063-2552") def test_bt0330_3929(self): - self.run_bt_test('bt0330-3929') + self.run_bt_test("bt0330-3929") def test_bt0146_3853(self): - self.run_bt_test('bt0146-3853') + self.run_bt_test("bt0146-3853") def test_bt0397_3474(self): - self.run_bt_test('bt0397-3474') + self.run_bt_test("bt0397-3474") def test_bt0257_3855_3(self): - self.run_bt_test('bt0257-3855.3') + self.run_bt_test("bt0257-3855.3") def test_bt0402_3576(self): - self.run_bt_test('bt0402-3576') + self.run_bt_test("bt0402-3576") def test_bt0259_3814(self): - self.run_bt_test('bt0259-3814') + self.run_bt_test("bt0259-3814") def test_bt0023_3854(self): - self.run_bt_test('bt0023-3854') + self.run_bt_test("bt0023-3854") def test_bt0027_3456(self): - self.run_bt_test('bt0027-3456') + self.run_bt_test("bt0027-3456") def test_bt0055_4169(self): - self.run_bt_test('bt0055-4169') + self.run_bt_test("bt0055-4169") def test_bt0051_3501(self): - self.run_bt_test('bt0051-3501') + self.run_bt_test("bt0051-3501") def test_bt0011_4163_2(self): - self.run_bt_test('bt0011-4163.2') + self.run_bt_test("bt0011-4163.2") def test_bt0352_4297_2(self): - self.run_bt_test('bt0352-4297.2') + self.run_bt_test("bt0352-4297.2") def test_bt0012_4164(self): - self.run_bt_test('bt0012-4164') + self.run_bt_test("bt0012-4164") def test_bt0316_4136(self): - self.run_bt_test('bt0316-4136') + self.run_bt_test("bt0316-4136") def test_bt0026_4218(self): - self.run_bt_test('bt0026-4218') + self.run_bt_test("bt0026-4218") def test_bt0234_3381(self): - self.run_bt_test('bt0234-3381') + self.run_bt_test("bt0234-3381") def test_bt0377_4300(self): - self.run_bt_test('bt0377-4300') + self.run_bt_test("bt0377-4300") def test_bt0061_2451(self): - self.run_bt_test('bt0061-2451') + self.run_bt_test("bt0061-2451") def test_bt0313_1587(self): - self.run_bt_test('bt0313-1587') + self.run_bt_test("bt0313-1587") def test_bt0348_4119(self): - self.run_bt_test('bt0348-4119') + self.run_bt_test("bt0348-4119") def test_bt0193_4263(self): - self.run_bt_test('bt0193-4263') + self.run_bt_test("bt0193-4263") def test_bt0378_3188(self): - self.run_bt_test('bt0378-3188') + self.run_bt_test("bt0378-3188") def test_bt0242_3805(self): - self.run_bt_test('bt0242-3805') + self.run_bt_test("bt0242-3805") def test_bt0432_4254(self): - self.run_bt_test('bt0432-4254') + self.run_bt_test("bt0432-4254") def test_bt0189_1410(self): - self.run_bt_test('bt0189-1410') + self.run_bt_test("bt0189-1410") def test_bt0284_3687(self): - self.run_bt_test('bt0284-3687') + self.run_bt_test("bt0284-3687") def test_bt0351_3908(self): - self.run_bt_test('bt0351-3908') + self.run_bt_test("bt0351-3908") def test_bt0029_3674(self): - self.run_bt_test('bt0029-3674') + self.run_bt_test("bt0029-3674") def test_bt0158_3397(self): - self.run_bt_test('bt0158-3397') + self.run_bt_test("bt0158-3397") def test_bt0361_4141(self): - self.run_bt_test('bt0361-4141') + self.run_bt_test("bt0361-4141") def test_bt0225_3862(self): - self.run_bt_test('bt0225-3862') + self.run_bt_test("bt0225-3862") def test_bt0209_3257(self): - self.run_bt_test('bt0209-3257') + self.run_bt_test("bt0209-3257") def test_bt0049_3745(self): - self.run_bt_test('bt0049-3745') + self.run_bt_test("bt0049-3745") def test_bt0358_3553(self): - self.run_bt_test('bt0358-3553') + self.run_bt_test("bt0358-3553") def test_bt0162_4180(self): - self.run_bt_test('bt0162-4180') + self.run_bt_test("bt0162-4180") def test_bt0227_1872(self): - self.run_bt_test('bt0227-1872') + self.run_bt_test("bt0227-1872") def test_bt0226_141(self): - self.run_bt_test('bt0226-141') + self.run_bt_test("bt0226-141") def test_bt0322_3393(self): - self.run_bt_test('bt0322-3393') + self.run_bt_test("bt0322-3393") diff --git a/tests/test_simpleambitrule.py b/tests/test_simpleambitrule.py index ac183714..b2e7fdd7 100644 --- a/tests/test_simpleambitrule.py +++ b/tests/test_simpleambitrule.py @@ -12,34 +12,32 @@ class SimpleAmbitRuleTest(TestCase): @classmethod def setUpClass(cls): super(SimpleAmbitRuleTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Simple Ambit Rule Test Package', - 'Test Package for SimpleAmbitRule') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package( + cls.user, "Simple Ambit Rule Test Package", "Test Package for SimpleAmbitRule" + ) def test_create_basic_rule(self): """Test creating a basic SimpleAmbitRule with minimal parameters.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" - rule = SimpleAmbitRule.create( - package=self.package, - smirks=smirks - ) + rule = SimpleAmbitRule.create(package=self.package, smirks=smirks) self.assertIsInstance(rule, SimpleAmbitRule) self.assertEqual(rule.smirks, smirks) self.assertEqual(rule.package, self.package) - self.assertRegex(rule.name, r'Rule \d+') - self.assertEqual(rule.description, 'no description') + self.assertRegex(rule.name, r"Rule \d+") + self.assertEqual(rule.description, "no description") self.assertIsNone(rule.reactant_filter_smarts) self.assertIsNone(rule.product_filter_smarts) def test_create_with_all_parameters(self): """Test creating SimpleAmbitRule with all parameters.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' - name = 'Test Rule' - description = 'A test biotransformation rule' - reactant_filter = '[CH2X4]' - product_filter = '[OH]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" + name = "Test Rule" + description = "A test biotransformation rule" + reactant_filter = "[CH2X4]" + product_filter = "[OH]" rule = SimpleAmbitRule.create( package=self.package, @@ -47,7 +45,7 @@ class SimpleAmbitRuleTest(TestCase): description=description, smirks=smirks, reactant_filter_smarts=reactant_filter, - product_filter_smarts=product_filter + product_filter_smarts=product_filter, ) self.assertEqual(rule.name, name) @@ -60,127 +58,114 @@ class SimpleAmbitRuleTest(TestCase): """Test that SMIRKS is required for rule creation.""" with self.assertRaises(ValueError) as cm: SimpleAmbitRule.create(package=self.package, smirks=None) - self.assertIn('SMIRKS is required', str(cm.exception)) + self.assertIn("SMIRKS is required", str(cm.exception)) with self.assertRaises(ValueError) as cm: - SimpleAmbitRule.create(package=self.package, smirks='') - self.assertIn('SMIRKS is required', str(cm.exception)) + SimpleAmbitRule.create(package=self.package, smirks="") + self.assertIn("SMIRKS is required", str(cm.exception)) with self.assertRaises(ValueError) as cm: - SimpleAmbitRule.create(package=self.package, smirks=' ') - self.assertIn('SMIRKS is required', str(cm.exception)) + SimpleAmbitRule.create(package=self.package, smirks=" ") + self.assertIn("SMIRKS is required", str(cm.exception)) - @patch('epdb.models.FormatConverter.is_valid_smirks') + @patch("epdb.models.FormatConverter.is_valid_smirks") def test_invalid_smirks_validation(self, mock_is_valid): """Test validation of SMIRKS format.""" mock_is_valid.return_value = False - invalid_smirks = 'invalid_smirks_string' + invalid_smirks = "invalid_smirks_string" with self.assertRaises(ValueError) as cm: - SimpleAmbitRule.create( - package=self.package, - smirks=invalid_smirks - ) + SimpleAmbitRule.create(package=self.package, smirks=invalid_smirks) self.assertIn(f'SMIRKS "{invalid_smirks}" is invalid', str(cm.exception)) mock_is_valid.assert_called_once_with(invalid_smirks) def test_smirks_trimming(self): """Test that SMIRKS strings are trimmed during creation.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' - smirks_with_whitespace = f' {smirks} ' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" + smirks_with_whitespace = f" {smirks} " - rule = SimpleAmbitRule.create( - package=self.package, - smirks=smirks_with_whitespace - ) + rule = SimpleAmbitRule.create(package=self.package, smirks=smirks_with_whitespace) self.assertEqual(rule.smirks, smirks) def test_empty_name_and_description_handling(self): """Test that empty name and description are handled appropriately.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" rule = SimpleAmbitRule.create( - package=self.package, - smirks=smirks, - name='', - description=' ' + package=self.package, smirks=smirks, name="", description=" " ) - self.assertRegex(rule.name, r'Rule \d+') - self.assertEqual(rule.description, 'no description') + self.assertRegex(rule.name, r"Rule \d+") + self.assertEqual(rule.description, "no description") def test_deduplication_basic(self): """Test that identical rules are deduplicated.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" - rule1 = SimpleAmbitRule.create( - package=self.package, - smirks=smirks, - name='Rule 1' - ) + rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks, name="Rule 1") rule2 = SimpleAmbitRule.create( package=self.package, smirks=smirks, - name='Rule 2' # Different name, but same SMIRKS + name="Rule 2", # Different name, but same SMIRKS ) self.assertEqual(rule1.pk, rule2.pk) - self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 1) + self.assertEqual( + SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 1 + ) def test_deduplication_with_filters(self): """Test deduplication with filter SMARTS.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' - reactant_filter = '[CH2X4]' - product_filter = '[OH]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" + reactant_filter = "[CH2X4]" + product_filter = "[OH]" rule1 = SimpleAmbitRule.create( package=self.package, smirks=smirks, reactant_filter_smarts=reactant_filter, - product_filter_smarts=product_filter + product_filter_smarts=product_filter, ) rule2 = SimpleAmbitRule.create( package=self.package, smirks=smirks, reactant_filter_smarts=reactant_filter, - product_filter_smarts=product_filter + product_filter_smarts=product_filter, ) self.assertEqual(rule1.pk, rule2.pk) def test_no_deduplication_different_filters(self): """Test that rules with different filters are not deduplicated.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" rule1 = SimpleAmbitRule.create( - package=self.package, - smirks=smirks, - reactant_filter_smarts='[CH2X4]' + package=self.package, smirks=smirks, reactant_filter_smarts="[CH2X4]" ) rule2 = SimpleAmbitRule.create( - package=self.package, - smirks=smirks, - reactant_filter_smarts='[CH3X4]' + package=self.package, smirks=smirks, reactant_filter_smarts="[CH3X4]" ) self.assertNotEqual(rule1.pk, rule2.pk) - self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 2) + self.assertEqual( + SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 2 + ) def test_filter_smarts_trimming(self): """Test that filter SMARTS are trimmed and handled correctly.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" # Test with whitespace-only filters (should be treated as None) rule = SimpleAmbitRule.create( package=self.package, smirks=smirks, - reactant_filter_smarts=' ', - product_filter_smarts=' ' + reactant_filter_smarts=" ", + product_filter_smarts=" ", ) self.assertIsNone(rule.reactant_filter_smarts) @@ -188,94 +173,85 @@ class SimpleAmbitRuleTest(TestCase): def test_url_property(self): """Test the URL property generation.""" - rule = SimpleAmbitRule.create( - package=self.package, - smirks='[H:1][C:2]>>[H:1][O:2]' - ) + rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]") - expected_url = f'{self.package.url}/simple-ambit-rule/{rule.uuid}' + expected_url = f"{self.package.url}/simple-ambit-rule/{rule.uuid}" self.assertEqual(rule.url, expected_url) - @patch('epdb.models.FormatConverter.apply') + @patch("epdb.models.FormatConverter.apply") def test_apply_method(self, mock_apply): """Test the apply method delegates to FormatConverter.""" - mock_apply.return_value = ['product1', 'product2'] + mock_apply.return_value = ["product1", "product2"] - rule = SimpleAmbitRule.create( - package=self.package, - smirks='[H:1][C:2]>>[H:1][O:2]' - ) + rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]") - test_smiles = 'CCO' + test_smiles = "CCO" result = rule.apply(test_smiles) mock_apply.assert_called_once_with(test_smiles, rule.smirks) - self.assertEqual(result, ['product1', 'product2']) + self.assertEqual(result, ["product1", "product2"]) def test_reactants_smarts_property(self): """Test reactants_smarts property extracts correct part of SMIRKS.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' - expected_reactants = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" + expected_reactants = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]" - rule = SimpleAmbitRule.create( - package=self.package, - smirks=smirks - ) + rule = SimpleAmbitRule.create(package=self.package, smirks=smirks) self.assertEqual(rule.reactants_smarts, expected_reactants) def test_products_smarts_property(self): """Test products_smarts property extracts correct part of SMIRKS.""" - smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' - expected_products = '[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' + smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" + expected_products = "[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]" - rule = SimpleAmbitRule.create( - package=self.package, - smirks=smirks - ) + rule = SimpleAmbitRule.create(package=self.package, smirks=smirks) self.assertEqual(rule.products_smarts, expected_products) - @patch('epdb.models.Package.objects') + @patch("epdb.models.Package.objects") def test_related_reactions_property(self, mock_package_objects): """Test related_reactions property returns correct queryset.""" mock_qs = MagicMock() mock_package_objects.filter.return_value = mock_qs - rule = SimpleAmbitRule.create( - package=self.package, - smirks='[H:1][C:2]>>[H:1][O:2]' - ) + rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]") # Instead of directly assigning, patch the property or use with patch.object - with patch.object(type(rule), 'reaction_rule', new_callable=PropertyMock) as mock_reaction_rule: - mock_reaction_rule.return_value.filter.return_value.order_by.return_value = ['reaction1', 'reaction2'] + with patch.object( + type(rule), "reaction_rule", new_callable=PropertyMock + ) as mock_reaction_rule: + mock_reaction_rule.return_value.filter.return_value.order_by.return_value = [ + "reaction1", + "reaction2", + ] result = rule.related_reactions mock_package_objects.filter.assert_called_once_with(reviewed=True) mock_reaction_rule.return_value.filter.assert_called_once_with(package__in=mock_qs) - mock_reaction_rule.return_value.filter.return_value.order_by.assert_called_once_with('name') - self.assertEqual(result, ['reaction1', 'reaction2']) + mock_reaction_rule.return_value.filter.return_value.order_by.assert_called_once_with( + "name" + ) + self.assertEqual(result, ["reaction1", "reaction2"]) - @patch('epdb.models.Pathway.objects') - @patch('epdb.models.Edge.objects') + @patch("epdb.models.Pathway.objects") + @patch("epdb.models.Edge.objects") def test_related_pathways_property(self, mock_edge_objects, mock_pathway_objects): """Test related_pathways property returns correct queryset.""" - mock_related_reactions = ['reaction1', 'reaction2'] + mock_related_reactions = ["reaction1", "reaction2"] - with patch.object(SimpleAmbitRule, "related_reactions", new_callable=PropertyMock) as mock_prop: + with patch.object( + SimpleAmbitRule, "related_reactions", new_callable=PropertyMock + ) as mock_prop: mock_prop.return_value = mock_related_reactions - rule = SimpleAmbitRule.create( - package=self.package, - smirks='[H:1][C:2]>>[H:1][O:2]' - ) + rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]") # Mock Edge objects query mock_edge_values = MagicMock() - mock_edge_values.values.return_value = ['pathway_id1', 'pathway_id2'] + mock_edge_values.values.return_value = ["pathway_id1", "pathway_id2"] mock_edge_objects.filter.return_value = mock_edge_values # Mock Pathway objects query @@ -285,52 +261,49 @@ class SimpleAmbitRuleTest(TestCase): result = rule.related_pathways mock_edge_objects.filter.assert_called_once_with(edge_label__in=mock_related_reactions) - mock_edge_values.values.assert_called_once_with('pathway_id') + mock_edge_values.values.assert_called_once_with("pathway_id") mock_pathway_objects.filter.assert_called_once() self.assertEqual(result, mock_pathway_qs) - @patch('epdb.models.IndigoUtils.smirks_to_svg') + @patch("epdb.models.IndigoUtils.smirks_to_svg") def test_as_svg_property(self, mock_smirks_to_svg): """Test as_svg property calls IndigoUtils correctly.""" - mock_smirks_to_svg.return_value = 'test_svg' + mock_smirks_to_svg.return_value = "test_svg" - rule = SimpleAmbitRule.create( - package=self.package, - smirks='[H:1][C:2]>>[H:1][O:2]' - ) + rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]") result = rule.as_svg mock_smirks_to_svg.assert_called_once_with(rule.smirks, True, width=800, height=400) - self.assertEqual(result, 'test_svg') + self.assertEqual(result, "test_svg") def test_atomic_transaction(self): """Test that rule creation is atomic.""" - smirks = '[H:1][C:2]>>[H:1][O:2]' + smirks = "[H:1][C:2]>>[H:1][O:2]" # This should work normally rule = SimpleAmbitRule.create(package=self.package, smirks=smirks) self.assertIsInstance(rule, SimpleAmbitRule) # Test transaction rollback on error - with patch('epdb.models.SimpleAmbitRule.save', side_effect=Exception('Database error')): + with patch("epdb.models.SimpleAmbitRule.save", side_effect=Exception("Database error")): with self.assertRaises(Exception): - SimpleAmbitRule.create(package=self.package, smirks='[H:3][C:4]>>[H:3][O:4]') + SimpleAmbitRule.create(package=self.package, smirks="[H:3][C:4]>>[H:3][O:4]") # Verify no partial data was saved self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package).count(), 1) def test_multiple_duplicate_warning(self): """Test logging when multiple duplicates are found.""" - smirks = '[H:1][C:2]>>[H:1][O:2]' + smirks = "[H:1][C:2]>>[H:1][O:2]" # Create first rule rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks) # Manually create a duplicate to simulate the error condition - rule2 = SimpleAmbitRule(package=self.package, smirks=smirks, name='Manual Rule') + rule2 = SimpleAmbitRule(package=self.package, smirks=smirks, name="Manual Rule") rule2.save() - with patch('epdb.models.logger') as mock_logger: + with patch("epdb.models.logger") as mock_logger: # This should find the existing rule and log an error about multiple matches result = SimpleAmbitRule.create(package=self.package, smirks=smirks) @@ -339,24 +312,28 @@ class SimpleAmbitRuleTest(TestCase): # Should log an error about multiple matches mock_logger.error.assert_called() - self.assertIn('More than one rule matched', mock_logger.error.call_args[0][0]) + self.assertIn("More than one rule matched", mock_logger.error.call_args[0][0]) def test_model_fields(self): """Test model field properties.""" rule = SimpleAmbitRule.create( package=self.package, - smirks='[H:1][C:2]>>[H:1][O:2]', - reactant_filter_smarts='[CH3]', - product_filter_smarts='[OH]' + smirks="[H:1][C:2]>>[H:1][O:2]", + reactant_filter_smarts="[CH3]", + product_filter_smarts="[OH]", ) # Test field properties - self.assertFalse(rule._meta.get_field('smirks').blank) - self.assertFalse(rule._meta.get_field('smirks').null) - self.assertTrue(rule._meta.get_field('reactant_filter_smarts').null) - self.assertTrue(rule._meta.get_field('product_filter_smarts').null) + self.assertFalse(rule._meta.get_field("smirks").blank) + self.assertFalse(rule._meta.get_field("smirks").null) + self.assertTrue(rule._meta.get_field("reactant_filter_smarts").null) + self.assertTrue(rule._meta.get_field("product_filter_smarts").null) # Test verbose names - self.assertEqual(rule._meta.get_field('smirks').verbose_name, 'SMIRKS') - self.assertEqual(rule._meta.get_field('reactant_filter_smarts').verbose_name, 'Reactant Filter SMARTS') - self.assertEqual(rule._meta.get_field('product_filter_smarts').verbose_name, 'Product Filter SMARTS') + self.assertEqual(rule._meta.get_field("smirks").verbose_name, "SMIRKS") + self.assertEqual( + rule._meta.get_field("reactant_filter_smarts").verbose_name, "Reactant Filter SMARTS" + ) + self.assertEqual( + rule._meta.get_field("product_filter_smarts").verbose_name, "Product Filter SMARTS" + ) diff --git a/tests/test_sobjects.py b/tests/test_sobjects.py index 5bd3261d..0212466e 100644 --- a/tests/test_sobjects.py +++ b/tests/test_sobjects.py @@ -1,32 +1,29 @@ from django.test import TestCase -from epdb.logic import SNode, SEdge, SPathway +from epdb.logic import SNode, SEdge class SObjectTest(TestCase): - def setUp(self): pass def test_snode_eq(self): - snode1 = SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0) - snode2 = SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0) + snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0) + snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0) assert snode1 == snode2 def test_snode_hash(self): pass def test_sedge_eq(self): - sedge1 = SEdge([SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)], - [SNode('CN1C(=O)NC2=C(C1=O)N(C)C=N2', 1), SNode('C=O', 1)], - rule=None) - sedge2 = SEdge([SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)], - [SNode('CN1C(=O)NC2=C(C1=O)N(C)C=N2', 1), SNode('C=O', 1)], - rule=None) + sedge1 = SEdge( + [SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)], + [SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)], + rule=None, + ) + sedge2 = SEdge( + [SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)], + [SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)], + rule=None, + ) assert sedge1 == sedge2 - - def test_sedge_hash(self): - pass - - def test_spathway(self): - pw = SPathway() diff --git a/tests/views/test_compound_views.py b/tests/views/test_compound_views.py index bd089144..731524e2 100644 --- a/tests/views/test_compound_views.py +++ b/tests/views/test_compound_views.py @@ -3,7 +3,7 @@ from django.urls import reverse from envipy_additional_information import Temperature, Interval from epdb.logic import UserManager, PackageManager -from epdb.models import Compound, Scenario, ExternalIdentifier, ExternalDatabase +from epdb.models import Compound, Scenario, ExternalDatabase class CompoundViewTest(TestCase): @@ -12,21 +12,28 @@ class CompoundViewTest(TestCase): @classmethod def setUpClass(cls): super(CompoundViewTest, cls).setUpClass() - cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", - set_setting=False, add_to_group=True, is_active=True) + cls.user1 = UserManager.create_user( + "user1", + "user1@envipath.com", + "SuperSafe", + set_setting=False, + add_to_group=True, + is_active=True, + ) cls.user1_default_package = cls.user1.default_package - cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') + cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack") def setUp(self): self.client.force_login(self.user1) def test_create_compound(self): response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -38,17 +45,18 @@ class CompoundViewTest(TestCase): self.assertEqual(c.name, "1,2-Dichloroethane") self.assertEqual(c.description, "Eawag BBD compound c0001") self.assertEqual(c.default_structure.smiles, "C(CCl)Cl") - self.assertEqual(c.default_structure.canonical_smiles, 'ClCCCl') + self.assertEqual(c.default_structure.canonical_smiles, "ClCCCl") self.assertEqual(c.structures.all().count(), 2) self.assertEqual(self.user1_default_package.compounds.count(), 1) # Adding the same rule again should return the existing one, hence not increasing the number of rules response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.url, compound_url) @@ -57,11 +65,12 @@ class CompoundViewTest(TestCase): # Adding the same rule in a different package should create a new rule response = self.client.post( - reverse("package compound list", kwargs={'package_uuid': self.package.uuid}), { + reverse("package compound list", kwargs={"package_uuid": self.package.uuid}), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -69,11 +78,12 @@ class CompoundViewTest(TestCase): # adding another reaction should increase count response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "2-Chloroethanol", "compound-description": "Eawag BBD compound c0005", "compound-smiles": "C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -82,11 +92,12 @@ class CompoundViewTest(TestCase): # Edit def test_edit_rule(self): response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -95,13 +106,17 @@ class CompoundViewTest(TestCase): c = Compound.objects.get(url=compound_url) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(self.user1_default_package.uuid), - 'compound_uuid': str(c.uuid) - }), { + reverse( + "package compound detail", + kwargs={ + "package_uuid": str(self.user1_default_package.uuid), + "compound_uuid": str(c.uuid), + }, + ), + { "compound-name": "Test Compound Adjusted", "compound-description": "New Description", - } + }, ) self.assertEqual(response.status_code, 302) @@ -121,7 +136,7 @@ class CompoundViewTest(TestCase): "Test Desc", "2025-10", "soil", - [Temperature(interval=Interval(start=20, end=30))] + [Temperature(interval=Interval(start=20, end=30))], ) s2 = Scenario.create( @@ -130,15 +145,16 @@ class CompoundViewTest(TestCase): "Test Desc2", "2025-10", "soil", - [Temperature(interval=Interval(start=10, end=20))] + [Temperature(interval=Interval(start=10, end=20))], ) response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -147,36 +163,35 @@ class CompoundViewTest(TestCase): c = Compound.objects.get(url=compound_url) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(c.package.uuid), - 'compound_uuid': str(c.uuid) - }), { - "selected-scenarios": [s1.url, s2.url] - } + reverse( + "package compound detail", + kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)}, + ), + {"selected-scenarios": [s1.url, s2.url]}, ) self.assertEqual(len(c.scenarios.all()), 2) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(c.package.uuid), - 'compound_uuid': str(c.uuid) - }), { - "selected-scenarios": [s1.url] - } + reverse( + "package compound detail", + kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)}, + ), + {"selected-scenarios": [s1.url]}, ) self.assertEqual(len(c.scenarios.all()), 1) self.assertEqual(c.scenarios.first().url, s1.url) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(c.package.uuid), - 'compound_uuid': str(c.uuid) - }), { + reverse( + "package compound detail", + kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)}, + ), + { # We have to set an empty string to avoid that the parameter is removed "selected-scenarios": "" - } + }, ) self.assertEqual(len(c.scenarios.all()), 0) @@ -184,11 +199,12 @@ class CompoundViewTest(TestCase): # def test_copy(self): response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -196,12 +212,13 @@ class CompoundViewTest(TestCase): c = Compound.objects.get(url=compound_url) response = self.client.post( - reverse("package detail", kwargs={ - 'package_uuid': str(self.package.uuid), - }), { - "hidden": "copy", - "object_to_copy": c.url - } + reverse( + "package detail", + kwargs={ + "package_uuid": str(self.package.uuid), + }, + ), + {"hidden": "copy", "object_to_copy": c.url}, ) self.assertEqual(response.status_code, 200) @@ -215,44 +232,48 @@ class CompoundViewTest(TestCase): # Copy to the same package should fail response = self.client.post( - reverse("package detail", kwargs={ - 'package_uuid': str(c.package.uuid), - }), { - "hidden": "copy", - "object_to_copy": c.url - } + reverse( + "package detail", + kwargs={ + "package_uuid": str(c.package.uuid), + }, + ), + {"hidden": "copy", "object_to_copy": c.url}, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json()['error'], f"Can't copy object {compound_url} to the same package!") + self.assertEqual( + response.json()["error"], f"Can't copy object {compound_url} to the same package!" + ) def test_references(self): ext_db, _ = ExternalDatabase.objects.get_or_create( - name='PubChem Compound', + name="PubChem Compound", defaults={ - 'full_name': 'PubChem Compound Database', - 'description': 'Chemical database of small organic molecules', - 'base_url': 'https://pubchem.ncbi.nlm.nih.gov', - 'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}' - } + "full_name": "PubChem Compound Database", + "description": "Chemical database of small organic molecules", + "base_url": "https://pubchem.ncbi.nlm.nih.gov", + "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/compound/{id}", + }, ) ext_db2, _ = ExternalDatabase.objects.get_or_create( - name='PubChem Substance', + name="PubChem Substance", defaults={ - 'full_name': 'PubChem Substance Database', - 'description': 'Database of chemical substances', - 'base_url': 'https://pubchem.ncbi.nlm.nih.gov', - 'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/substance/{id}' - } + "full_name": "PubChem Substance Database", + "description": "Database of chemical substances", + "base_url": "https://pubchem.ncbi.nlm.nih.gov", + "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/substance/{id}", + }, ) response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -260,42 +281,49 @@ class CompoundViewTest(TestCase): c = Compound.objects.get(url=compound_url) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(c.package.uuid), - 'compound_uuid': str(c.uuid), - }), { - 'selected-database': ext_db.pk, - 'identifier': '25154249' - } + reverse( + "package compound detail", + kwargs={ + "package_uuid": str(c.package.uuid), + "compound_uuid": str(c.uuid), + }, + ), + {"selected-database": ext_db.pk, "identifier": "25154249"}, ) self.assertEqual(c.external_identifiers.count(), 1) self.assertEqual(c.external_identifiers.first().database, ext_db) - self.assertEqual(c.external_identifiers.first().identifier_value, '25154249') - self.assertEqual(c.external_identifiers.first().url, 'https://pubchem.ncbi.nlm.nih.gov/compound/25154249') + self.assertEqual(c.external_identifiers.first().identifier_value, "25154249") + self.assertEqual( + c.external_identifiers.first().url, "https://pubchem.ncbi.nlm.nih.gov/compound/25154249" + ) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(c.package.uuid), - 'compound_uuid': str(c.uuid), - }), { - 'selected-database': ext_db2.pk, - 'identifier': '25154249' - } + reverse( + "package compound detail", + kwargs={ + "package_uuid": str(c.package.uuid), + "compound_uuid": str(c.uuid), + }, + ), + {"selected-database": ext_db2.pk, "identifier": "25154249"}, ) self.assertEqual(c.external_identifiers.count(), 2) self.assertEqual(c.external_identifiers.last().database, ext_db2) - self.assertEqual(c.external_identifiers.last().identifier_value, '25154249') - self.assertEqual(c.external_identifiers.last().url, 'https://pubchem.ncbi.nlm.nih.gov/substance/25154249') + self.assertEqual(c.external_identifiers.last().identifier_value, "25154249") + self.assertEqual( + c.external_identifiers.last().url, "https://pubchem.ncbi.nlm.nih.gov/substance/25154249" + ) def test_delete(self): response = self.client.post( - reverse("compounds"), { + reverse("compounds"), + { "compound-name": "1,2-Dichloroethane", "compound-description": "Eawag BBD compound c0001", "compound-smiles": "C(CCl)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -304,12 +332,11 @@ class CompoundViewTest(TestCase): c = Compound.objects.get(url=compound_url) response = self.client.post( - reverse("package compound detail", kwargs={ - 'package_uuid': str(c.package.uuid), - 'compound_uuid': str(c.uuid) - }), { - "hidden": "delete" - } + reverse( + "package compound detail", + kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)}, + ), + {"hidden": "delete"}, ) self.assertEqual(self.user1_default_package.compounds.count(), 0) diff --git a/tests/views/test_model_views.py b/tests/views/test_model_views.py index 36ca4cf6..558277f5 100644 --- a/tests/views/test_model_views.py +++ b/tests/views/test_model_views.py @@ -2,8 +2,8 @@ from django.test import TestCase, override_settings from django.urls import reverse from django.conf import settings as s -from epdb.logic import UserManager, PackageManager -from epdb.models import Pathway, Edge, Package, User +from epdb.logic import UserManager +from epdb.models import Package, User @override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models") @@ -13,10 +13,16 @@ class PathwayViewTest(TestCase): @classmethod def setUpClass(cls): super(PathwayViewTest, cls).setUpClass() - cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", - set_setting=True, add_to_group=True, is_active=True) + cls.user1 = UserManager.create_user( + "user1", + "user1@envipath.com", + "SuperSafe", + set_setting=True, + add_to_group=True, + is_active=True, + ) cls.user1_default_package = cls.user1.default_package - cls.model_package = Package.objects.get(name='Fixtures') + cls.model_package = Package.objects.get(name="Fixtures") def setUp(self): self.client.force_login(self.user1) @@ -24,90 +30,96 @@ class PathwayViewTest(TestCase): def test_predict(self): self.client.force_login(User.objects.get(username="admin")) response = self.client.get( - reverse("package model detail", kwargs={ - 'package_uuid': str(self.model_package.uuid), - 'model_uuid': str(self.model_package.models.first().uuid) - }), { - 'classify': 'ILikeCats!', - 'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO', - } + reverse( + "package model detail", + kwargs={ + "package_uuid": str(self.model_package.uuid), + "model_uuid": str(self.model_package.models.first().uuid), + }, + ), + { + "classify": "ILikeCats!", + "smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO", + }, ) expected = [ { - 'products': [ - [ - 'O=C(O)C1=CC(CO)=CC=C1', - 'CCNCC' - ] - ], - 'probability': 0.25, - 'btrule': { - 'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/0e6e9290-b658-4450-b291-3ec19fa19206', - 'name': 'bt0430-4011' - } - }, { - 'products': [ - [ - 'CCNC(=O)C1=CC(CO)=CC=C1', - 'CC=O' - ] - ], 'probability': 0.0, - 'btrule': { - 'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/27a3a353-0b66-4228-bd16-e407949e90df', - 'name': 'bt0243-4301' - } - }, { - 'products': [ - [ - 'CCN(CC)C(=O)C1=CC(C=O)=CC=C1' - ] - ], 'probability': 0.75, - 'btrule': { - 'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/2f2e0c39-e109-4836-959f-2bda2524f022', - 'name': 'bt0001-3568' - } - } + "products": [["O=C(O)C1=CC(CO)=CC=C1", "CCNCC"]], + "probability": 0.25, + "btrule": { + "url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/0e6e9290-b658-4450-b291-3ec19fa19206", + "name": "bt0430-4011", + }, + }, + { + "products": [["CCNC(=O)C1=CC(CO)=CC=C1", "CC=O"]], + "probability": 0.0, + "btrule": { + "url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/27a3a353-0b66-4228-bd16-e407949e90df", + "name": "bt0243-4301", + }, + }, + { + "products": [["CCN(CC)C(=O)C1=CC(C=O)=CC=C1"]], + "probability": 0.75, + "btrule": { + "url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/2f2e0c39-e109-4836-959f-2bda2524f022", + "name": "bt0001-3568", + }, + }, ] actual = response.json() self.assertEqual(actual, expected) response = self.client.get( - reverse("package model detail", kwargs={ - 'package_uuid': str(self.model_package.uuid), - 'model_uuid': str(self.model_package.models.first().uuid) - }), { - 'classify': 'ILikeCats!', - 'smiles': '', - } + reverse( + "package model detail", + kwargs={ + "package_uuid": str(self.model_package.uuid), + "model_uuid": str(self.model_package.models.first().uuid), + }, + ), + { + "classify": "ILikeCats!", + "smiles": "", + }, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json()['error'], 'Received empty SMILES') + self.assertEqual(response.json()["error"], "Received empty SMILES") response = self.client.get( - reverse("package model detail", kwargs={ - 'package_uuid': str(self.model_package.uuid), - 'model_uuid': str(self.model_package.models.first().uuid) - }), { - 'classify': 'ILikeCats!', - 'smiles': ' ', # Input should be stripped - } + reverse( + "package model detail", + kwargs={ + "package_uuid": str(self.model_package.uuid), + "model_uuid": str(self.model_package.models.first().uuid), + }, + ), + { + "classify": "ILikeCats!", + "smiles": " ", # Input should be stripped + }, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json()['error'], 'Received empty SMILES') + self.assertEqual(response.json()["error"], "Received empty SMILES") response = self.client.get( - reverse("package model detail", kwargs={ - 'package_uuid': str(self.model_package.uuid), - 'model_uuid': str(self.model_package.models.first().uuid) - }), { - 'classify': 'ILikeCats!', - 'smiles': 'RandomInput', - } + reverse( + "package model detail", + kwargs={ + "package_uuid": str(self.model_package.uuid), + "model_uuid": str(self.model_package.models.first().uuid), + }, + ), + { + "classify": "ILikeCats!", + "smiles": "RandomInput", + }, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json()['error'], '"RandomInput" is not a valid SMILES') + self.assertEqual(response.json()["error"], '"RandomInput" is not a valid SMILES') diff --git a/tests/views/test_package_views.py b/tests/views/test_package_views.py index 77bbf724..f185f721 100644 --- a/tests/views/test_package_views.py +++ b/tests/views/test_package_views.py @@ -13,19 +13,34 @@ class PackageViewTest(TestCase): @classmethod def setUpClass(cls): super(PackageViewTest, cls).setUpClass() - cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", - set_setting=False, add_to_group=True, is_active=True) - cls.user2 = UserManager.create_user("user2", "user2@envipath.com", "SuperSafe", - set_setting=False, add_to_group=True, is_active=True) + cls.user1 = UserManager.create_user( + "user1", + "user1@envipath.com", + "SuperSafe", + set_setting=False, + add_to_group=True, + is_active=True, + ) + cls.user2 = UserManager.create_user( + "user2", + "user2@envipath.com", + "SuperSafe", + set_setting=False, + add_to_group=True, + is_active=True, + ) def setUp(self): self.client.force_login(self.user1) def test_create_package(self): - response = self.client.post(reverse("packages"), { - "package-name": "Test Package", - "package-description": "Just a Description", - }) + response = self.client.post( + reverse("packages"), + { + "package-name": "Test Package", + "package-description": "Just a Description", + }, + ) self.assertEqual(response.status_code, 302) package_url = response.url @@ -41,13 +56,12 @@ class PackageViewTest(TestCase): file = SimpleUploadedFile( "Fixture_Package.json", open(s.FIXTURE_DIRS[0] / "Fixture_Package.json", "rb").read(), - content_type="application/json" + content_type="application/json", ) - response = self.client.post(reverse("packages"), { - "file": file, - "hidden": "import-package-json" - }) + response = self.client.post( + reverse("packages"), {"file": file, "hidden": "import-package-json"} + ) self.assertEqual(response.status_code, 302) package_url = response.url @@ -67,13 +81,12 @@ class PackageViewTest(TestCase): file = SimpleUploadedFile( "EAWAG-BBD.json", open(s.FIXTURE_DIRS[0] / "packages" / "2025-07-18" / "EAWAG-BBD.json", "rb").read(), - content_type="application/json" + content_type="application/json", ) - response = self.client.post(reverse("packages"), { - "file": file, - "hidden": "import-legacy-package-json" - }) + response = self.client.post( + reverse("packages"), {"file": file, "hidden": "import-legacy-package-json"} + ) self.assertEqual(response.status_code, 302) package_url = response.url @@ -90,17 +103,23 @@ class PackageViewTest(TestCase): self.assertEqual(upp.permission, Permission.ALL[0]) def test_edit_package(self): - response = self.client.post(reverse("packages"), { - "package-name": "Test Package", - "package-description": "Just a Description", - }) + response = self.client.post( + reverse("packages"), + { + "package-name": "Test Package", + "package-description": "Just a Description", + }, + ) self.assertEqual(response.status_code, 302) package_url = response.url - self.client.post(package_url, { - "package-name": "New Name", - "package-description": "New Description", - }) + self.client.post( + package_url, + { + "package-name": "New Name", + "package-description": "New Description", + }, + ) p = Package.objects.get(url=package_url) @@ -108,10 +127,13 @@ class PackageViewTest(TestCase): self.assertEqual(p.description, "New Description") def test_edit_package_permissions(self): - response = self.client.post(reverse("packages"), { - "package-name": "Test Package", - "package-description": "Just a Description", - }) + response = self.client.post( + reverse("packages"), + { + "package-name": "Test Package", + "package-description": "Just a Description", + }, + ) self.assertEqual(response.status_code, 302) package_url = response.url p = Package.objects.get(url=package_url) @@ -119,57 +141,63 @@ class PackageViewTest(TestCase): with self.assertRaises(UserPackagePermission.DoesNotExist): UserPackagePermission.objects.get(package=p, user=self.user2) - self.client.post(package_url, { - "grantee": self.user2.url, - "read": "on", - "write": "on", - }) + self.client.post( + package_url, + { + "grantee": self.user2.url, + "read": "on", + "write": "on", + }, + ) upp = UserPackagePermission.objects.get(package=p, user=self.user2) self.assertEqual(upp.permission, Permission.WRITE[0]) def test_publish_package(self): - response = self.client.post(reverse("packages"), { - "package-name": "Test Package", - "package-description": "Just a Description", - }) + response = self.client.post( + reverse("packages"), + { + "package-name": "Test Package", + "package-description": "Just a Description", + }, + ) self.assertEqual(response.status_code, 302) package_url = response.url p = Package.objects.get(url=package_url) - self.client.post(package_url, { - "hidden": "publish-package" - }) + self.client.post(package_url, {"hidden": "publish-package"}) self.assertEqual(Group.objects.filter(public=True).count(), 1) g = Group.objects.get(public=True) gpp = GroupPackagePermission.objects.get(package=p, group=g) self.assertEqual(gpp.permission, Permission.READ[0]) - def test_set_package_license(self): - response = self.client.post(reverse("packages"), { - "package-name": "Test Package", - "package-description": "Just a Description", - }) + response = self.client.post( + reverse("packages"), + { + "package-name": "Test Package", + "package-description": "Just a Description", + }, + ) package_url = response.url p = Package.objects.get(url=package_url) - self.client.post(package_url, { - "license": "no-license" - }) + self.client.post(package_url, {"license": "no-license"}) self.assertIsNone(p.license) # TODO test others - def test_delete_package(self): - response = self.client.post(reverse("packages"), { - "package-name": "Test Package", - "package-description": "Just a Description", - }) + response = self.client.post( + reverse("packages"), + { + "package-name": "Test Package", + "package-description": "Just a Description", + }, + ) package_url = response.url p = Package.objects.get(url=package_url) @@ -182,11 +210,11 @@ class PackageViewTest(TestCase): def test_delete_default_package(self): self.client.force_login(self.user1) # Try to delete the default package - response = self.client.post(self.user1.default_package.url, { - "hidden": "delete" - }) + response = self.client.post(self.user1.default_package.url, {"hidden": "delete"}) self.assertEqual(response.status_code, 400) - self.assertTrue(f'You cannot delete the default package. ' - f'If you want to delete this package you have to ' - f'set another default package first' in response.content.decode()) + self.assertTrue( + "You cannot delete the default package. " + "If you want to delete this package you have to " + "set another default package first" in response.content.decode() + ) diff --git a/tests/views/test_pathway_views.py b/tests/views/test_pathway_views.py index 129aac60..1094bae3 100644 --- a/tests/views/test_pathway_views.py +++ b/tests/views/test_pathway_views.py @@ -5,6 +5,7 @@ from django.conf import settings as s from epdb.logic import UserManager, PackageManager from epdb.models import Pathway, Edge + @override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models") class PathwayViewTest(TestCase): fixtures = ["test_fixtures_incl_model.jsonl.gz"] @@ -12,41 +13,52 @@ class PathwayViewTest(TestCase): @classmethod def setUpClass(cls): super(PathwayViewTest, cls).setUpClass() - cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", - set_setting=True, add_to_group=True, is_active=True) + cls.user1 = UserManager.create_user( + "user1", + "user1@envipath.com", + "SuperSafe", + set_setting=True, + add_to_group=True, + is_active=True, + ) cls.user1_default_package = cls.user1.default_package - cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') + cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack") def setUp(self): self.client.force_login(self.user1) def test_predict_pathway(self): - response = self.client.post(reverse("pathways"), { - 'name': 'Test Pathway', - 'description': 'Just a Description', - 'predict': 'predict', - 'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO', - }) + response = self.client.post( + reverse("pathways"), + { + "name": "Test Pathway", + "description": "Just a Description", + "predict": "predict", + "smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO", + }, + ) self.assertEqual(response.status_code, 302) pathway_url = response.url pw = Pathway.objects.get(url=pathway_url) self.assertEqual(self.user1_default_package, pw.package) - self.assertEqual(pw.name, 'Test Pathway') - self.assertEqual(pw.description, 'Just a Description') + self.assertEqual(pw.name, "Test Pathway") + self.assertEqual(pw.description, "Just a Description") self.assertEqual(len(pw.root_nodes), 1) - self.assertEqual(pw.root_nodes.first().default_node_label.smiles, 'CCN(CC)C(=O)C1=CC(CO)=CC=C1') + self.assertEqual( + pw.root_nodes.first().default_node_label.smiles, "CCN(CC)C(=O)C1=CC(CO)=CC=C1" + ) first_level_nodes = { # Edge 1 - 'CCN(CC)C(=O)C1=CC(C=O)=CC=C1', + "CCN(CC)C(=O)C1=CC(C=O)=CC=C1", # Edge 2 - 'CCNC(=O)C1=CC(CO)=CC=C1', - 'CC=O', + "CCNC(=O)C1=CC(CO)=CC=C1", + "CC=O", # Edge 3 - 'CCNCC', - 'O=C(O)C1=CC(CO)=CC=C1', + "CCNCC", + "O=C(O)C1=CC(CO)=CC=C1", } predicted_nodes = set() @@ -60,32 +72,36 @@ class PathwayViewTest(TestCase): def test_predict_package_pathway(self): response = self.client.post( - reverse("package pathway list", kwargs={'package_uuid': str(self.package.uuid)}), { - 'name': 'Test Pathway', - 'description': 'Just a Description', - 'predict': 'predict', - 'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO', - }) + reverse("package pathway list", kwargs={"package_uuid": str(self.package.uuid)}), + { + "name": "Test Pathway", + "description": "Just a Description", + "predict": "predict", + "smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO", + }, + ) self.assertEqual(response.status_code, 302) pathway_url = response.url pw = Pathway.objects.get(url=pathway_url) self.assertEqual(self.package, pw.package) - self.assertEqual(pw.name, 'Test Pathway') - self.assertEqual(pw.description, 'Just a Description') + self.assertEqual(pw.name, "Test Pathway") + self.assertEqual(pw.description, "Just a Description") self.assertEqual(len(pw.root_nodes), 1) - self.assertEqual(pw.root_nodes.first().default_node_label.smiles, 'CCN(CC)C(=O)C1=CC(CO)=CC=C1') + self.assertEqual( + pw.root_nodes.first().default_node_label.smiles, "CCN(CC)C(=O)C1=CC(CO)=CC=C1" + ) first_level_nodes = { # Edge 1 - 'CCN(CC)C(=O)C1=CC(C=O)=CC=C1', + "CCN(CC)C(=O)C1=CC(C=O)=CC=C1", # Edge 2 - 'CCNC(=O)C1=CC(CO)=CC=C1', - 'CC=O', + "CCNC(=O)C1=CC(CO)=CC=C1", + "CC=O", # Edge 3 - 'CCNCC', - 'O=C(O)C1=CC(CO)=CC=C1', + "CCNCC", + "O=C(O)C1=CC(CO)=CC=C1", } predicted_nodes = set() diff --git a/tests/views/test_reaction_views.py b/tests/views/test_reaction_views.py index 8d34b93d..1dafa297 100644 --- a/tests/views/test_reaction_views.py +++ b/tests/views/test_reaction_views.py @@ -3,7 +3,7 @@ from django.urls import reverse from envipy_additional_information import Temperature, Interval from epdb.logic import UserManager, PackageManager -from epdb.models import Reaction, Scenario, ExternalIdentifier, ExternalDatabase +from epdb.models import Reaction, Scenario, ExternalDatabase class ReactionViewTest(TestCase): @@ -12,21 +12,28 @@ class ReactionViewTest(TestCase): @classmethod def setUpClass(cls): super(ReactionViewTest, cls).setUpClass() - cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", - set_setting=False, add_to_group=True, is_active=True) + cls.user1 = UserManager.create_user( + "user1", + "user1@envipath.com", + "SuperSafe", + set_setting=False, + add_to_group=True, + is_active=True, + ) cls.user1_default_package = cls.user1.default_package - cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') + cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack") def setUp(self): self.client.force_login(self.user1) def test_create_reaction(self): response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -42,11 +49,12 @@ class ReactionViewTest(TestCase): # Adding the same rule again should return the existing one, hence not increasing the number of rules response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.url, reaction_url) @@ -55,11 +63,12 @@ class ReactionViewTest(TestCase): # Adding the same rule in a different package should create a new rule response = self.client.post( - reverse("package reaction list", kwargs={'package_uuid': self.package.uuid}), { + reverse("package reaction list", kwargs={"package_uuid": self.package.uuid}), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -67,11 +76,12 @@ class ReactionViewTest(TestCase): # adding another reaction should increase count response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0002", "reaction-description": "Description for Eawag BBD reaction r0002", "reaction-smirks": "C(CO)Cl>>C(C=O)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -80,11 +90,12 @@ class ReactionViewTest(TestCase): # Edit def test_edit_rule(self): response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -93,13 +104,17 @@ class ReactionViewTest(TestCase): r = Reaction.objects.get(url=reaction_url) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(self.user1_default_package.uuid), - 'reaction_uuid': str(r.uuid) - }), { + reverse( + "package reaction detail", + kwargs={ + "package_uuid": str(self.user1_default_package.uuid), + "reaction_uuid": str(r.uuid), + }, + ), + { "reaction-name": "Test Reaction Adjusted", "reaction-description": "New Description", - } + }, ) self.assertEqual(response.status_code, 302) @@ -119,7 +134,7 @@ class ReactionViewTest(TestCase): "Test Desc", "2025-10", "soil", - [Temperature(interval=Interval(start=20, end=30))] + [Temperature(interval=Interval(start=20, end=30))], ) s2 = Scenario.create( @@ -128,15 +143,16 @@ class ReactionViewTest(TestCase): "Test Desc2", "2025-10", "soil", - [Temperature(interval=Interval(start=10, end=20))] + [Temperature(interval=Interval(start=10, end=20))], ) response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -144,47 +160,47 @@ class ReactionViewTest(TestCase): r = Reaction.objects.get(url=reaction_url) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'reaction_uuid': str(r.uuid) - }), { - "selected-scenarios": [s1.url, s2.url] - } + reverse( + "package reaction detail", + kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)}, + ), + {"selected-scenarios": [s1.url, s2.url]}, ) self.assertEqual(len(r.scenarios.all()), 2) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'reaction_uuid': str(r.uuid) - }), { - "selected-scenarios": [s1.url] - } + reverse( + "package reaction detail", + kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)}, + ), + {"selected-scenarios": [s1.url]}, ) self.assertEqual(len(r.scenarios.all()), 1) self.assertEqual(r.scenarios.first().url, s1.url) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'reaction_uuid': str(r.uuid) - }), { + reverse( + "package reaction detail", + kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)}, + ), + { # We have to set an empty string to avoid that the parameter is removed "selected-scenarios": "" - } + }, ) self.assertEqual(len(r.scenarios.all()), 0) def test_copy(self): response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -192,12 +208,13 @@ class ReactionViewTest(TestCase): r = Reaction.objects.get(url=reaction_url) response = self.client.post( - reverse("package detail", kwargs={ - 'package_uuid': str(self.package.uuid), - }), { - "hidden": "copy", - "object_to_copy": r.url - } + reverse( + "package detail", + kwargs={ + "package_uuid": str(self.package.uuid), + }, + ), + {"hidden": "copy", "object_to_copy": r.url}, ) self.assertEqual(response.status_code, 200) @@ -211,44 +228,48 @@ class ReactionViewTest(TestCase): # Copy to the same package should fail response = self.client.post( - reverse("package detail", kwargs={ - 'package_uuid': str(r.package.uuid), - }), { - "hidden": "copy", - "object_to_copy": r.url - } + reverse( + "package detail", + kwargs={ + "package_uuid": str(r.package.uuid), + }, + ), + {"hidden": "copy", "object_to_copy": r.url}, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json()['error'], f"Can't copy object {reaction_url} to the same package!") + self.assertEqual( + response.json()["error"], f"Can't copy object {reaction_url} to the same package!" + ) def test_references(self): ext_db, _ = ExternalDatabase.objects.get_or_create( - name='KEGG Reaction', + name="KEGG Reaction", defaults={ - 'full_name': 'KEGG Reaction Database', - 'description': 'Database of biochemical reactions', - 'base_url': 'https://www.genome.jp', - 'url_pattern': 'https://www.genome.jp/entry/{id}' - } + "full_name": "KEGG Reaction Database", + "description": "Database of biochemical reactions", + "base_url": "https://www.genome.jp", + "url_pattern": "https://www.genome.jp/entry/{id}", + }, ) ext_db2, _ = ExternalDatabase.objects.get_or_create( - name='RHEA', + name="RHEA", defaults={ - 'full_name': 'RHEA Reaction Database', - 'description': 'Comprehensive resource of biochemical reactions', - 'base_url': 'https://www.rhea-db.org', - 'url_pattern': 'https://www.rhea-db.org/rhea/{id}' + "full_name": "RHEA Reaction Database", + "description": "Comprehensive resource of biochemical reactions", + "base_url": "https://www.rhea-db.org", + "url_pattern": "https://www.rhea-db.org/rhea/{id}", }, ) response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -256,45 +277,49 @@ class ReactionViewTest(TestCase): r = Reaction.objects.get(url=reaction_url) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'reaction_uuid': str(r.uuid), - }), { - 'selected-database': ext_db.pk, - 'identifier': 'C12345' - } + reverse( + "package reaction detail", + kwargs={ + "package_uuid": str(r.package.uuid), + "reaction_uuid": str(r.uuid), + }, + ), + {"selected-database": ext_db.pk, "identifier": "C12345"}, ) self.assertEqual(r.external_identifiers.count(), 1) self.assertEqual(r.external_identifiers.first().database, ext_db) - self.assertEqual(r.external_identifiers.first().identifier_value, 'C12345') + self.assertEqual(r.external_identifiers.first().identifier_value, "C12345") # TODO Fixture contains old url template there the real test fails, use old value instead # self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/C12345') - self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/reaction+C12345') + self.assertEqual( + r.external_identifiers.first().url, "https://www.genome.jp/entry/reaction+C12345" + ) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'reaction_uuid': str(r.uuid), - }), { - 'selected-database': ext_db2.pk, - 'identifier': '60116' - } + reverse( + "package reaction detail", + kwargs={ + "package_uuid": str(r.package.uuid), + "reaction_uuid": str(r.uuid), + }, + ), + {"selected-database": ext_db2.pk, "identifier": "60116"}, ) self.assertEqual(r.external_identifiers.count(), 2) self.assertEqual(r.external_identifiers.last().database, ext_db2) - self.assertEqual(r.external_identifiers.last().identifier_value, '60116') - self.assertEqual(r.external_identifiers.last().url, 'https://www.rhea-db.org/rhea/60116') - + self.assertEqual(r.external_identifiers.last().identifier_value, "60116") + self.assertEqual(r.external_identifiers.last().url, "https://www.rhea-db.org/rhea/60116") def test_delete(self): response = self.client.post( - reverse("reactions"), { + reverse("reactions"), + { "reaction-name": "Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl", - } + }, ) self.assertEqual(response.status_code, 302) @@ -302,12 +327,11 @@ class ReactionViewTest(TestCase): r = Reaction.objects.get(url=reaction_url) response = self.client.post( - reverse("package reaction detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'reaction_uuid': str(r.uuid) - }), { - "hidden": "delete" - } + reverse( + "package reaction detail", + kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)}, + ), + {"hidden": "delete"}, ) self.assertEqual(self.user1_default_package.reactions.count(), 0) diff --git a/tests/views/test_rule_views.py b/tests/views/test_rule_views.py index b42e87b1..92b0727b 100644 --- a/tests/views/test_rule_views.py +++ b/tests/views/test_rule_views.py @@ -12,22 +12,29 @@ class RuleViewTest(TestCase): @classmethod def setUpClass(cls): super(RuleViewTest, cls).setUpClass() - cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", - set_setting=False, add_to_group=True, is_active=True) + cls.user1 = UserManager.create_user( + "user1", + "user1@envipath.com", + "SuperSafe", + set_setting=False, + add_to_group=True, + is_active=True, + ) cls.user1_default_package = cls.user1.default_package - cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') + cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack") def setUp(self): self.client.force_login(self.user1) def test_create_rule(self): response = self.client.post( - reverse("rules"), { + reverse("rules"), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.status_code, 302) @@ -38,18 +45,21 @@ class RuleViewTest(TestCase): self.assertEqual(r.package, self.user1_default_package) self.assertEqual(r.name, "Test Rule") self.assertEqual(r.description, "Just a Description") - self.assertEqual(r.smirks, - "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]") + self.assertEqual( + r.smirks, + "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", + ) self.assertEqual(self.user1_default_package.rules.count(), 1) # Adding the same rule again should return the existing one, hence not increasing the number of rules response = self.client.post( - reverse("rules"), { + reverse("rules"), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.url, rule_url) @@ -58,12 +68,13 @@ class RuleViewTest(TestCase): # Adding the same rule in a different package should create a new rule response = self.client.post( - reverse("package rule list", kwargs={'package_uuid': self.package.uuid}), { + reverse("package rule list", kwargs={"package_uuid": self.package.uuid}), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.status_code, 302) @@ -72,12 +83,13 @@ class RuleViewTest(TestCase): # Edit def test_edit_rule(self): response = self.client.post( - reverse("rules"), { + reverse("rules"), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.status_code, 302) @@ -86,13 +98,17 @@ class RuleViewTest(TestCase): r = Rule.objects.get(url=rule_url) response = self.client.post( - reverse("package rule detail", kwargs={ - 'package_uuid': str(self.user1_default_package.uuid), - 'rule_uuid': str(r.uuid) - }), { + reverse( + "package rule detail", + kwargs={ + "package_uuid": str(self.user1_default_package.uuid), + "rule_uuid": str(r.uuid), + }, + ), + { "rule-name": "Test Rule Adjusted", "rule-description": "New Description", - } + }, ) self.assertEqual(response.status_code, 302) @@ -108,7 +124,7 @@ class RuleViewTest(TestCase): "Test Desc", "2025-10", "soil", - [Temperature(interval=Interval(start=20, end=30))] + [Temperature(interval=Interval(start=20, end=30))], ) s2 = Scenario.create( @@ -117,16 +133,17 @@ class RuleViewTest(TestCase): "Test Desc2", "2025-10", "soil", - [Temperature(interval=Interval(start=10, end=20))] + [Temperature(interval=Interval(start=10, end=20))], ) response = self.client.post( - reverse("rules"), { + reverse("rules"), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.status_code, 302) @@ -134,48 +151,48 @@ class RuleViewTest(TestCase): r = Rule.objects.get(url=rule_url) response = self.client.post( - reverse("package rule detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'rule_uuid': str(r.uuid) - }), { - "selected-scenarios": [s1.url, s2.url] - } + reverse( + "package rule detail", + kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)}, + ), + {"selected-scenarios": [s1.url, s2.url]}, ) self.assertEqual(len(r.scenarios.all()), 2) response = self.client.post( - reverse("package rule detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'rule_uuid': str(r.uuid) - }), { - "selected-scenarios": [s1.url] - } + reverse( + "package rule detail", + kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)}, + ), + {"selected-scenarios": [s1.url]}, ) self.assertEqual(len(r.scenarios.all()), 1) self.assertEqual(r.scenarios.first().url, s1.url) response = self.client.post( - reverse("package rule detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'rule_uuid': str(r.uuid) - }), { + reverse( + "package rule detail", + kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)}, + ), + { # We have to set an empty string to avoid that the parameter is removed "selected-scenarios": "" - } + }, ) self.assertEqual(len(r.scenarios.all()), 0) def test_copy(self): response = self.client.post( - reverse("rules"), { + reverse("rules"), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.status_code, 302) @@ -183,12 +200,13 @@ class RuleViewTest(TestCase): r = Rule.objects.get(url=rule_url) response = self.client.post( - reverse("package detail", kwargs={ - 'package_uuid': str(self.package.uuid), - }), { - "hidden": "copy", - "object_to_copy": r.url - } + reverse( + "package detail", + kwargs={ + "package_uuid": str(self.package.uuid), + }, + ), + {"hidden": "copy", "object_to_copy": r.url}, ) self.assertEqual(response.status_code, 200) @@ -202,26 +220,29 @@ class RuleViewTest(TestCase): # Copy to the same package should fail response = self.client.post( - reverse("package detail", kwargs={ - 'package_uuid': str(r.package.uuid), - }), { - "hidden": "copy", - "object_to_copy": r.url - } + reverse( + "package detail", + kwargs={ + "package_uuid": str(r.package.uuid), + }, + ), + {"hidden": "copy", "object_to_copy": r.url}, ) self.assertEqual(response.status_code, 400) - self.assertEqual(response.json()['error'], f"Can't copy object {rule_url} to the same package!") - + self.assertEqual( + response.json()["error"], f"Can't copy object {rule_url} to the same package!" + ) def test_delete(self): response = self.client.post( - reverse("rules"), { + reverse("rules"), + { "rule-name": "Test Rule", "rule-description": "Just a Description", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-type": "SimpleAmbitRule", - } + }, ) self.assertEqual(response.status_code, 302) @@ -229,12 +250,11 @@ class RuleViewTest(TestCase): r = Rule.objects.get(url=rule_url) response = self.client.post( - reverse("package rule detail", kwargs={ - 'package_uuid': str(r.package.uuid), - 'rule_uuid': str(r.uuid) - }), { - "hidden": "delete" - } + reverse( + "package rule detail", + kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)}, + ), + {"hidden": "delete"}, ) self.assertEqual(self.user1_default_package.rules.count(), 0) diff --git a/tests/views/test_user_views.py b/tests/views/test_user_views.py index 5e1aa814..1760aaa8 100644 --- a/tests/views/test_user_views.py +++ b/tests/views/test_user_views.py @@ -11,70 +11,81 @@ class UserViewTest(TestCase): @classmethod def setUpClass(cls): super(UserViewTest, cls).setUpClass() - cls.user = User.objects.get(username='anonymous') - cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') - cls.BBD_SUBSET = Package.objects.get(name='Fixtures') + cls.user = User.objects.get(username="anonymous") + cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") + cls.BBD_SUBSET = Package.objects.get(name="Fixtures") def test_login_with_valid_credentials(self): - response = self.client.post(reverse("login"), { - "username": "user0", - "password": 'SuperSafe', - }) + response = self.client.post( + reverse("login"), + { + "username": "user0", + "password": "SuperSafe", + }, + ) self.assertRedirects(response, reverse("index")) self.assertTrue(response.wsgi_request.user.is_authenticated) def test_login_with_invalid_credentials(self): - response = self.client.post(reverse("login"), { - "username": "user0", - "password": "wrongpassword", - }) + response = self.client.post( + reverse("login"), + { + "username": "user0", + "password": "wrongpassword", + }, + ) self.assertEqual(response.status_code, 200) self.assertFalse(response.wsgi_request.user.is_authenticated) def test_register(self): - response = self.client.post(reverse("register"), { - "username": "user1", - "email": "user1@envipath.com", - "password": "SuperSafe", - "rpassword": "SuperSafe", - }) + response = self.client.post( + reverse("register"), + { + "username": "user1", + "email": "user1@envipath.com", + "password": "SuperSafe", + "rpassword": "SuperSafe", + }, + ) self.assertEqual(response.status_code, 200) # TODO currently fails as the fixture does not provide a global setting... self.assertContains(response, "Registration failed!") def test_register_password_mismatch(self): - response = self.client.post(reverse("register"), { - "username": "user1", - "email": "user1@envipath.com", - "password": "SuperSafe", - "rpassword": "SuperSaf3", - }) + response = self.client.post( + reverse("register"), + { + "username": "user1", + "email": "user1@envipath.com", + "password": "SuperSafe", + "rpassword": "SuperSaf3", + }, + ) self.assertEqual(response.status_code, 200) self.assertContains(response, "Registration failed, provided passwords differ") def test_logout(self): - response = self.client.post(reverse("login"), { - "username": "user0", - "password": 'SuperSafe', - "login": "true" - }) + response = self.client.post( + reverse("login"), {"username": "user0", "password": "SuperSafe", "login": "true"} + ) self.assertTrue(response.wsgi_request.user.is_authenticated) - response = self.client.post(reverse('logout'), { - "logout": "true", - }) + response = self.client.post( + reverse("logout"), + { + "logout": "true", + }, + ) self.assertFalse(response.wsgi_request.user.is_authenticated) def test_next_param_properly_handled(self): - response = self.client.get(reverse('packages')) + response = self.client.get(reverse("packages")) self.assertRedirects(response, f"{reverse('login')}/?next=/package") - response = self.client.post(reverse('login'), { - "username": "user0", - "password": 'SuperSafe', - "login": "true", - "next": "/package" - }) + response = self.client.post( + reverse("login"), + {"username": "user0", "password": "SuperSafe", "login": "true", "next": "/package"}, + ) - self.assertRedirects(response, reverse('packages')) + self.assertRedirects(response, reverse("packages")) diff --git a/utilities/biodeg.py b/utilities/biodeg.py deleted file mode 100644 index d47d13d4..00000000 --- a/utilities/biodeg.py +++ /dev/null @@ -1,13 +0,0 @@ -import abc -from enviPy.epdb import Pathway - -class PredictionSchema(abc.ABC): - pass - -class DFS(PredictionSchema): - - def __init__(self, pw: Pathway, settings=None): - self.setting = settings or pw.prediction_settings - - def predict(self): - pass diff --git a/utilities/chem.py b/utilities/chem.py index b91e7edb..6de46147 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -2,12 +2,11 @@ import logging import re from abc import ABC from collections import defaultdict -from typing import List, Optional, Dict +from typing import List, Optional, Dict, TYPE_CHECKING from indigo import Indigo, IndigoException, IndigoObject from indigo.renderer import IndigoRenderer -from rdkit import Chem -from rdkit import RDLogger +from rdkit import Chem, rdBase from rdkit.Chem import MACCSkeys, Descriptors from rdkit.Chem import rdChemReactions from rdkit.Chem.Draw import rdMolDraw2D @@ -15,9 +14,11 @@ from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit.Chem.rdmolops import GetMolFrags from rdkit.Contrib.IFG import ifg -logger = logging.getLogger(__name__) -RDLogger.DisableLog('rdApp.*') +if TYPE_CHECKING: + from epdb.models import Rule +logger = logging.getLogger(__name__) +rdBase.DisableLog("rdApp.*") # from rdkit import rdBase # rdBase.LogToPythonLogger() @@ -28,7 +29,6 @@ RDLogger.DisableLog('rdApp.*') class ProductSet(object): - def __init__(self, product_set: List[str]): self.product_set = product_set @@ -42,15 +42,18 @@ class ProductSet(object): return iter(self.product_set) def __eq__(self, other): - return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(other.product_set) + return isinstance(other, ProductSet) and sorted(self.product_set) == sorted( + other.product_set + ) def __hash__(self): - return hash('-'.join(sorted(self.product_set))) + return hash("-".join(sorted(self.product_set))) class PredictionResult(object): - - def __init__(self, product_sets: List['ProductSet'], probability: float, rule: Optional['Rule'] = None): + def __init__( + self, product_sets: List["ProductSet"], probability: float, rule: Optional["Rule"] = None + ): self.product_sets = product_sets self.probability = probability self.rule = rule @@ -66,7 +69,6 @@ class PredictionResult(object): class FormatConverter(object): - @staticmethod def mass(smiles): return Descriptors.MolWt(FormatConverter.from_smiles(smiles)) @@ -127,7 +129,7 @@ class FormatConverter(object): if kekulize: try: mol = Chem.Kekulize(mol) - except: + except Exception: mol = Chem.Mol(mol.ToBinary()) if not mol.GetNumConformers(): @@ -139,8 +141,8 @@ class FormatConverter(object): opts.clearBackground = False drawer.DrawMolecule(mol) drawer.FinishDrawing() - svg = drawer.GetDrawingText().replace('svg:', '') - svg = re.sub("<\?xml.*\?>", '', svg) + svg = drawer.GetDrawingText().replace("svg:", "") + svg = re.sub("<\?xml.*\?>", "", svg) return svg @@ -151,7 +153,7 @@ class FormatConverter(object): if kekulize: try: Chem.Kekulize(mol) - except: + except Exception: mc = Chem.Mol(mol.ToBinary()) if not mc.GetNumConformers(): @@ -178,7 +180,7 @@ class FormatConverter(object): smiles = tmp_smiles if change is False: - print(f"nothing changed") + print("nothing changed") return smiles @@ -198,7 +200,9 @@ class FormatConverter(object): parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) # try to neutralize molecule - uncharger = rdMolStandardize.Uncharger() # annoying, but necessary as no convenience method exists + uncharger = ( + rdMolStandardize.Uncharger() + ) # annoying, but necessary as no convenience method exists uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) # note that no attempt is made at reionization at this step @@ -239,17 +243,24 @@ class FormatConverter(object): try: rdChemReactions.ReactionFromSmarts(smirks) return True - except: + except Exception: return False @staticmethod - def apply(smiles: str, smirks: str, preprocess_smiles: bool = True, bracketize: bool = True, - standardize: bool = True, kekulize: bool = True, remove_stereo: bool = True) -> List['ProductSet']: - logger.debug(f'Applying {smirks} on {smiles}') + def apply( + smiles: str, + smirks: str, + preprocess_smiles: bool = True, + bracketize: bool = True, + standardize: bool = True, + kekulize: bool = True, + remove_stereo: bool = True, + ) -> List["ProductSet"]: + logger.debug(f"Applying {smirks} on {smiles}") # If explicitly wanted or rule generates multiple products add brackets around products to capture all if bracketize: # or "." in smirks: - smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")" + smirks = smirks.split(">>")[0] + ">>(" + smirks.split(">>")[1] + ")" # List of ProductSet objects pss = set() @@ -274,7 +285,9 @@ class FormatConverter(object): Chem.SanitizeMol(product) product = GetMolFrags(product, asMols=True) for p in product: - p = FormatConverter.standardize(Chem.MolToSmiles(p), remove_stereo=remove_stereo) + p = FormatConverter.standardize( + Chem.MolToSmiles(p), remove_stereo=remove_stereo + ) prods.append(p) # if kekulize: @@ -300,9 +313,8 @@ class FormatConverter(object): # # bond.SetIsAromatic(False) # Chem.Kekulize(product) - except ValueError as e: - logger.error(f'Sanitizing and converting failed:\n{e}') + logger.error(f"Sanitizing and converting failed:\n{e}") continue if len(prods): @@ -310,7 +322,7 @@ class FormatConverter(object): pss.add(ps) except Exception as e: - logger.error(f'Applying {smirks} on {smiles} failed:\n{e}') + logger.error(f"Applying {smirks} on {smiles} failed:\n{e}") return pss @@ -340,22 +352,19 @@ class FormatConverter(object): smi_p = Chem.MolToSmiles(mol, kekuleSmiles=True) smi_p = Chem.CanonSmiles(smi_p) - if '~' in smi_p: - smi_p1 = smi_p.replace('~', '') + if "~" in smi_p: + smi_p1 = smi_p.replace("~", "") parsed_smiles.append(smi_p1) else: parsed_smiles.append(smi_p) - except Exception as e: + except Exception: errors += 1 pass return parsed_smiles, errors - - class Standardizer(ABC): - def __init__(self, name): self.name = name @@ -364,7 +373,6 @@ class Standardizer(ABC): class RuleStandardizer(Standardizer): - def __init__(self, name, smirks): super().__init__(name) self.smirks = smirks @@ -373,8 +381,8 @@ class RuleStandardizer(Standardizer): standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks))) if len(standardized_smiles) > 1: - logger.warning(f'{self.smirks} generated more than 1 compound {standardized_smiles}') - print(f'{self.smirks} generated more than 1 compound {standardized_smiles}') + logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}") + print(f"{self.smirks} generated more than 1 compound {standardized_smiles}") standardized_smiles = standardized_smiles[:1] if standardized_smiles: @@ -384,7 +392,6 @@ class RuleStandardizer(Standardizer): class RegExStandardizer(Standardizer): - def __init__(self, name, replacements: dict): super().__init__(name) self.replacements = replacements @@ -404,28 +411,39 @@ class RegExStandardizer(Standardizer): return super().standardize(smi) -FLATTEN = [ - RegExStandardizer("Remove Stereo", {"@": ""}) -] +FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})] -UN_CIS_TRANS = [ - RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""}) -] +UN_CIS_TRANS = [RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})] BASIC = [ RuleStandardizer("ammoniumstandardization", "[H][N+:1]([H])([H])[#6:2]>>[H][#7:1]([H])-[#6:2]"), RuleStandardizer("cyanate", "[H][#8:1][C:2]#[N:3]>>[#8-:1][C:2]#[N:3]"), RuleStandardizer("deprotonatecarboxyls", "[H][#8:1]-[#6:2]=[O:3]>>[#8-:1]-[#6:2]=[O:3]"), RuleStandardizer("forNOOH", "[H][#8:1]-[#7+:2](-[*:3])=[O:4]>>[#8-:1]-[#7+:2](-[*:3])=[O:4]"), - RuleStandardizer("Hydroxylprotonation", "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]"), - RuleStandardizer("phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"), - RuleStandardizer("PicricAcid", - "[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]"), - RuleStandardizer("Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"), - RuleStandardizer("Sulfate2", - "[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]"), - RuleStandardizer("Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"), - RuleStandardizer("Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"), + RuleStandardizer( + "Hydroxylprotonation", + "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]", + ), + RuleStandardizer( + "phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]" + ), + RuleStandardizer( + "PicricAcid", + "[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]", + ), + RuleStandardizer( + "Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]" + ), + RuleStandardizer( + "Sulfate2", + "[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]", + ), + RuleStandardizer( + "Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]" + ), + RuleStandardizer( + "Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]" + ), ] ENHANCED = BASIC + [ @@ -433,28 +451,30 @@ ENHANCED = BASIC + [ ] EXOTIC = ENHANCED + [ - RuleStandardizer("ThioPhosphate1", "[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]") + RuleStandardizer( + "ThioPhosphate1", + "[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]", + ) ] COA_CUTTER = [ - RuleStandardizer("CutCoEnzymeAOff", - "CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]") + RuleStandardizer( + "CutCoEnzymeAOff", + "CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]", + ) ] -ENOL_KETO = [ - RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]") -] +ENOL_KETO = [RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")] MATCH_STANDARDIZER = EXOTIC + FLATTEN + UN_CIS_TRANS + COA_CUTTER + ENOL_KETO class IndigoUtils(object): - @staticmethod def layout(mol_data): i = Indigo() try: - if mol_data.startswith('$RXN') or '>>' in mol_data: + if mol_data.startswith("$RXN") or ">>" in mol_data: rxn = i.loadQueryReaction(mol_data) rxn.layout() return rxn.rxnfile() @@ -462,14 +482,14 @@ class IndigoUtils(object): mol = i.loadQueryMolecule(mol_data) mol.layout() return mol.molfile() - except IndigoException as e: + except IndigoException: try: logger.info("layout() failed, trying loadReactionSMARTS as fallback!") rxn = IndigoUtils.load_reaction_SMARTS(mol_data) rxn.layout() return rxn.molfile() except IndigoException as e2: - logger.error(f'layout() failed due to {e2}!') + logger.error(f"layout() failed due to {e2}!") @staticmethod def load_reaction_SMARTS(mol): @@ -479,7 +499,7 @@ class IndigoUtils(object): def aromatize(mol_data, is_query): i = Indigo() try: - if mol_data.startswith('$RXN'): + if mol_data.startswith("$RXN"): if is_query: rxn = i.loadQueryReaction(mol_data) else: @@ -495,20 +515,20 @@ class IndigoUtils(object): mol.aromatize() return mol.molfile() - except IndigoException as e: + except IndigoException: try: logger.info("Aromatizing failed, trying loadReactionSMARTS as fallback!") rxn = IndigoUtils.load_reaction_SMARTS(mol_data) rxn.aromatize() return rxn.molfile() except IndigoException as e2: - logger.error(f'Aromatizing failed due to {e2}!') + logger.error(f"Aromatizing failed due to {e2}!") @staticmethod def dearomatize(mol_data, is_query): i = Indigo() try: - if mol_data.startswith('$RXN'): + if mol_data.startswith("$RXN"): if is_query: rxn = i.loadQueryReaction(mol_data) else: @@ -524,14 +544,14 @@ class IndigoUtils(object): mol.dearomatize() return mol.molfile() - except IndigoException as e: + except IndigoException: try: logger.info("De-Aromatizing failed, trying loadReactionSMARTS as fallback!") rxn = IndigoUtils.load_reaction_SMARTS(mol_data) rxn.dearomatize() return rxn.molfile() except IndigoException as e2: - logger.error(f'De-Aromatizing failed due to {e2}!') + logger.error(f"De-Aromatizing failed due to {e2}!") @staticmethod def sanitize_functional_group(functional_group: str): @@ -543,7 +563,7 @@ class IndigoUtils(object): # special environment handling (amines, hydroxy, esters, ethers) # the higher substituted should not contain H env. - if functional_group == '[C]=O': + if functional_group == "[C]=O": functional_group = "[H][C](=O)[CX4,c]" # aldamines @@ -577,15 +597,20 @@ class IndigoUtils(object): functional_group = "[nH1,nX2](a)a" # pyrrole (with H) or pyridine (no other connections); currently overlaps with neighboring aromatic atoms # substituted aromatic nitrogen - functional_group = functional_group.replace("N*(R)R", - "n(a)a") # substituent will be before N*; currently overlaps with neighboring aromatic atoms + functional_group = functional_group.replace( + "N*(R)R", "n(a)a" + ) # substituent will be before N*; currently overlaps with neighboring aromatic atoms # pyridinium if functional_group == "RN*(R)(R)(R)R": - functional_group = "[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms + functional_group = ( + "[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms + ) # N-oxide if functional_group == "[H]ON*(R)(R)(R)R": - functional_group = "[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms + functional_group = ( + "[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms + ) # other aromatic hetero atoms functional_group = functional_group.replace("C*", "c") @@ -598,7 +623,9 @@ class IndigoUtils(object): # other replacement, to accomodate for the standardization rules in enviPath # This is not the perfect way to do it; there should be a way to replace substructure SMARTS in SMARTS? # nitro groups are broken, due to charge handling. this SMARTS matches both forms (formal charges and hypervalent); Ertl-CDK still treats both forms separately... - functional_group = functional_group.replace("[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]") + functional_group = functional_group.replace( + "[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]" + ) functional_group = functional_group.replace("O=N(=O)R", "[CX4,c][NX3](~[OX1])~[OX1]") # carboxylic acid: this SMARTS matches both neutral and anionic form; includes COOH in larger functional_groups functional_group = functional_group.replace("[H]OC(=O)", "[OD1]C(=O)") @@ -616,7 +643,9 @@ class IndigoUtils(object): return functional_group @staticmethod - def _colorize(indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool): + def _colorize( + indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool + ): indigo.setOption("render-atom-color-property", "color") indigo.setOption("aromaticity-model", "generic") @@ -646,7 +675,6 @@ class IndigoUtils(object): for match in matcher.iterateMatches(query): if match is not None: - for atom in query.iterateAtoms(): mappedAtom = match.mapAtom(atom) if mappedAtom is None or mappedAtom.index() in environment: @@ -655,7 +683,7 @@ class IndigoUtils(object): counts[mappedAtom.index()] = max(v, counts[mappedAtom.index()]) except IndigoException as e: - logger.debug(f'Colorizing failed due to {e}') + logger.debug(f"Colorizing failed due to {e}") for k, v in counts.items(): if is_reaction: @@ -669,8 +697,9 @@ class IndigoUtils(object): molecule.addDataSGroup([k], [], "color", color) @staticmethod - def mol_to_svg(mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None): - + def mol_to_svg( + mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None + ): if functional_groups is None: functional_groups = {} @@ -682,7 +711,7 @@ class IndigoUtils(object): i.setOption("render-image-size", width, height) i.setOption("render-bond-line-width", 2.0) - if '~' in mol_data: + if "~" in mol_data: mol = i.loadSmarts(mol_data) else: mol = i.loadMolecule(mol_data) @@ -690,11 +719,17 @@ class IndigoUtils(object): if len(functional_groups.keys()) > 0: IndigoUtils._colorize(i, mol, functional_groups, False) - return renderer.renderToBuffer(mol).decode('UTF-8') + return renderer.renderToBuffer(mol).decode("UTF-8") @staticmethod - def smirks_to_svg(smirks: str, is_query_smirks, width: int = 0, height: int = 0, - educt_functional_groups: Dict[str, int] = None, product_functional_groups: Dict[str, int] = None): + def smirks_to_svg( + smirks: str, + is_query_smirks, + width: int = 0, + height: int = 0, + educt_functional_groups: Dict[str, int] = None, + product_functional_groups: Dict[str, int] = None, + ): if educt_functional_groups is None: educt_functional_groups = {} @@ -721,18 +756,18 @@ class IndigoUtils(object): for prod in obj.iterateProducts(): IndigoUtils._colorize(i, prod, product_functional_groups, True) - return renderer.renderToBuffer(obj).decode('UTF-8') + return renderer.renderToBuffer(obj).decode("UTF-8") -if __name__ == '__main__': +if __name__ == "__main__": data = { "struct": "\n Ketcher 2172510 12D 1 1.00000 0.00000 0\n\n 6 6 0 0 0 999 V2000\n 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 1 2 2 0 0 0 0\n 2 3 1 0 0 0 0\n 3 4 2 0 0 0 0\n 4 5 1 0 0 0 0\n 5 6 2 0 0 0 0\n 6 1 1 0 0 0 0\nM END\n", "options": { "smart-layout": True, "ignore-stereochemistry-errors": True, "mass-skip-error-on-pseudoatoms": False, - "gross-formula-add-rsites": True - } + "gross-formula-add-rsites": True, + }, } - print(IndigoUtils.aromatize(data['struct'], False)) + print(IndigoUtils.aromatize(data["struct"], False)) diff --git a/utilities/clients.py b/utilities/clients.py deleted file mode 100644 index 2b4ff0c4..00000000 --- a/utilities/clients.py +++ /dev/null @@ -1,83 +0,0 @@ -import json - -import requests - - -class AMBITResult: - - def __init__(self, *args, **kwargs): - self.smiles = kwargs['smiles'] - self.tps = [] - for bt in kwargs['products']: - if len(bt['products']): - self.tps.append(bt) - - self.probs = None - - def __str__(self): - x = self.smiles + "\n" - total_bts = len(self.tps) - - for i, tp in enumerate(self.tps): - prob = "" - if self.probs: - prob = f" (p={self.probs[tp['id']]})" - - if i == total_bts - 1: - x += f"\t└── {tp['name']}{prob}\n" - else: - x += f"\t├── {tp['name']}{prob}\n" - - total_products = len(tp['products']) - for j, p in enumerate(tp['products']): - if j == total_products - 1: - if i == total_bts - 1: - x += f"\t\t└── {p}" - else: - x += f"\t│\t└── {p}\n" - else: - if i == total_bts - 1: - x += f"\t\t├── {p}\n" - else: - x += f"\t│\t├── {p}\n" - return x - - def set_probs(self, probs): - self.probs = probs - - -class AMBIT: - - def __init__(self, host, rules=None): - self.host = host - self.rules = rules - self.ambit_params = { - 'singlePos': True, - 'split': False, - } - - def batch_apply(self, smiles: list): - payload = { - 'smiles': smiles, - 'rules': self.rules, - } - payload.update(**self.ambit_params) - - res = self._execute(payload) - - tps = list() - for r in res['result']: - ar = AMBITResult(**r) - if len(ar.tps): - tps.append(ar) - else: - tps.append(None) - return tps - - def apply(self, smiles: str): - return self.batch_apply([smiles])[0] - - def _execute(self, payload): - res = requests.post(self.host + '/ambit', data=json.dumps(payload)) - res.raise_for_status() - return res.json() diff --git a/utilities/decorators.py b/utilities/decorators.py index ae37ceef..eabbde16 100644 --- a/utilities/decorators.py +++ b/utilities/decorators.py @@ -8,9 +8,9 @@ from epdb.models import Package # Map HTTP methods to required permissions DEFAULT_METHOD_PERMISSIONS = { - 'GET': 'read', - 'POST': 'write', - 'DELETE': 'write', + "GET": "read", + "POST": "write", + "DELETE": "write", } @@ -22,6 +22,7 @@ def package_permission_required(method_permissions=None): @wraps(view_func) def _wrapped_view(request, package_uuid, *args, **kwargs): from epdb.views import _anonymous_or_real + user = _anonymous_or_real(request) permission_required = method_permissions[request.method] @@ -30,11 +31,12 @@ def package_permission_required(method_permissions=None): if not PackageManager.has_package_permission(user, package_uuid, permission_required): from epdb.views import error + return error( request, "Operation failed!", f"Couldn't perform the desired operation as {user.username} does not have the required permissions!", - code=403 + code=403, ) return view_func(request, package_uuid, *args, **kwargs) diff --git a/utilities/misc.py b/utilities/misc.py index 185441f2..3e4eeb59 100644 --- a/utilities/misc.py +++ b/utilities/misc.py @@ -17,11 +17,27 @@ from envipy_additional_information import NAME_MAPPING from pydantic import BaseModel, HttpUrl from epdb.models import ( - Package, Compound, CompoundStructure, SimpleRule, SimpleAmbitRule, - SimpleRDKitRule, ParallelRule, SequentialRule, Reaction, Pathway, Node, Edge, Scenario, EPModel, + Package, + Compound, + CompoundStructure, + SimpleRule, + SimpleAmbitRule, + SimpleRDKitRule, + ParallelRule, + SequentialRule, + Reaction, + Pathway, + Node, + Edge, + Scenario, + EPModel, MLRelativeReasoning, - RuleBasedRelativeReasoning, EnviFormer, PluginModel, ExternalIdentifier, - ExternalDatabase, License + RuleBasedRelativeReasoning, + EnviFormer, + PluginModel, + ExternalIdentifier, + ExternalDatabase, + License, ) logger = logging.getLogger(__name__) @@ -31,7 +47,7 @@ class HTMLGenerator: registry = {x.__name__: x for x in NAME_MAPPING.values()} @staticmethod - def generate_html(additional_information: 'EnviPyModel', prefix='') -> str: + def generate_html(additional_information: "EnviPyModel", prefix="") -> str: from typing import get_origin, get_args, Union if isinstance(additional_information, type): @@ -39,9 +55,9 @@ class HTMLGenerator: else: clz_name = additional_information.__class__.__name__ - widget = f'

{clz_name}

' + widget = f"

{clz_name}

" - if hasattr(additional_information, 'uuid'): + if hasattr(additional_information, "uuid"): uuid = additional_information.uuid widget += f'' @@ -61,62 +77,61 @@ class HTMLGenerator: field_type = base_type is_interval_float = ( - field_type == Interval[float] or - str(field_type) == str(Interval[float]) or - 'Interval[float]' in str(field_type) + field_type == Interval[float] + or str(field_type) == str(Interval[float]) + or "Interval[float]" in str(field_type) ) if is_interval_float: widget += f"""
- - + +
- - + +
""" elif issubclass(field_type, Enum): - options: str = '' + options: str = "" for e in field_type: options += f'' widget += f"""
- +
""" else: - if field_type == str or field_type == HttpUrl: - input_type = 'text' - elif field_type == float or field_type == int: - input_type = 'number' - elif field_type == bool: - input_type = 'checkbox' + if field_type is str or field_type is HttpUrl: + input_type = "text" + elif field_type is float or field_type is int: + input_type = "number" + elif field_type is bool: + input_type = "checkbox" else: raise ValueError(f"Could not parse field type {field_type} for {name}") - value_to_use = value if value and field_type != bool else '' + value_to_use = value if value and field_type is not bool else "" widget += f"""
- - + +
""" return widget + "
" @staticmethod - def build_models(params) -> Dict[str, List['EnviPyModel']]: - + def build_models(params) -> Dict[str, List["EnviPyModel"]]: def has_non_none(d): """ Recursively checks if any value in a (possibly nested) dict is not None. @@ -143,7 +158,7 @@ class HTMLGenerator: # Step 1: group fields by ClassName and Number for key, value in params.items(): - if value == '': + if value == "": value = None parts = key.split("__") @@ -174,28 +189,35 @@ class HTMLGenerator: print(f"Skipping empty {class_name} {number} {fields}") continue - uuid = fields.pop('uuid', None) + uuid = fields.pop("uuid", None) instance = model_cls(**fields) if uuid: - instance.__dict__['uuid'] = uuid + instance.__dict__["uuid"] = uuid instances[class_name].append(instance) return instances class PackageExporter: - - def __init__(self, package: Package, include_models: bool = False, include_external_identifiers: bool = True): + def __init__( + self, + package: Package, + include_models: bool = False, + include_external_identifiers: bool = True, + ): self._raw_package = package self.include_modes = include_models self.include_external_identifiers = include_external_identifiers def do_export(self): - return PackageExporter._export_package_as_json(self._raw_package, self.include_modes, self.include_external_identifiers) + return PackageExporter._export_package_as_json( + self._raw_package, self.include_modes, self.include_external_identifiers + ) @staticmethod - def _export_package_as_json(package: Package, include_models: bool = False, - include_external_identifiers: bool = True) -> Dict[str, Any]: + def _export_package_as_json( + package: Package, include_models: bool = False, include_external_identifiers: bool = True + ) -> Dict[str, Any]: """ Dumps a Package and all its related objects as JSON. @@ -208,52 +230,54 @@ class PackageExporter: Dict containing the complete package data as JSON-serializable structure """ - def serialize_base_object(obj, include_aliases: bool = True, include_scenarios: bool = True) -> Dict[str, Any]: + def serialize_base_object( + obj, include_aliases: bool = True, include_scenarios: bool = True + ) -> Dict[str, Any]: """Serialize common EnviPathModel fields""" base_dict = { - 'uuid': str(obj.uuid), - 'name': obj.name, - 'description': obj.description, - 'url': obj.url, - 'kv': obj.kv, + "uuid": str(obj.uuid), + "name": obj.name, + "description": obj.description, + "url": obj.url, + "kv": obj.kv, } # Add aliases if the object has them - if include_aliases and hasattr(obj, 'aliases'): - base_dict['aliases'] = obj.aliases + if include_aliases and hasattr(obj, "aliases"): + base_dict["aliases"] = obj.aliases # Add scenarios if the object has them - if include_scenarios and hasattr(obj, 'scenarios'): - base_dict['scenarios'] = [ - {'uuid': str(s.uuid), 'url': s.url} for s in obj.scenarios.all() + if include_scenarios and hasattr(obj, "scenarios"): + base_dict["scenarios"] = [ + {"uuid": str(s.uuid), "url": s.url} for s in obj.scenarios.all() ] return base_dict def serialize_external_identifiers(obj) -> List[Dict[str, Any]]: """Serialize external identifiers for an object""" - if not include_external_identifiers or not hasattr(obj, 'external_identifiers'): + if not include_external_identifiers or not hasattr(obj, "external_identifiers"): return [] identifiers = [] for ext_id in obj.external_identifiers.all(): identifier_dict = { - 'uuid': str(ext_id.uuid), - 'database': { - 'uuid': str(ext_id.database.uuid), - 'name': ext_id.database.name, - 'base_url': ext_id.database.base_url + "uuid": str(ext_id.uuid), + "database": { + "uuid": str(ext_id.database.uuid), + "name": ext_id.database.name, + "base_url": ext_id.database.base_url, }, - 'identifier_value': ext_id.identifier_value, - 'url': ext_id.url, - 'is_primary': ext_id.is_primary + "identifier_value": ext_id.identifier_value, + "url": ext_id.url, + "is_primary": ext_id.is_primary, } identifiers.append(identifier_dict) return identifiers # Start with the package itself result = serialize_base_object(package, include_aliases=True, include_scenarios=True) - result['reviewed'] = package.reviewed + result["reviewed"] = package.reviewed # # Add license information # if package.license: @@ -267,247 +291,302 @@ class PackageExporter: # result['license'] = None # Initialize collections - result.update({ - 'compounds': [], - 'structures': [], - 'rules': { - 'simple_rules': [], - 'parallel_rules': [], - 'sequential_rules': [] - }, - 'reactions': [], - 'pathways': [], - 'nodes': [], - 'edges': [], - 'scenarios': [], - 'models': [] - }) + result.update( + { + "compounds": [], + "structures": [], + "rules": {"simple_rules": [], "parallel_rules": [], "sequential_rules": []}, + "reactions": [], + "pathways": [], + "nodes": [], + "edges": [], + "scenarios": [], + "models": [], + } + ) print(f"Exporting package: {package.name}") # Export compounds print("Exporting compounds...") - for compound in package.compounds.prefetch_related('default_structure').order_by('url'): - compound_dict = serialize_base_object(compound, include_aliases=True, include_scenarios=True) + for compound in package.compounds.prefetch_related("default_structure").order_by("url"): + compound_dict = serialize_base_object( + compound, include_aliases=True, include_scenarios=True + ) if compound.default_structure: - compound_dict['default_structure'] = { - 'uuid': str(compound.default_structure.uuid), - 'url': compound.default_structure.url + compound_dict["default_structure"] = { + "uuid": str(compound.default_structure.uuid), + "url": compound.default_structure.url, } else: - compound_dict['default_structure'] = None + compound_dict["default_structure"] = None - compound_dict['external_identifiers'] = serialize_external_identifiers(compound) - result['compounds'].append(compound_dict) + compound_dict["external_identifiers"] = serialize_external_identifiers(compound) + result["compounds"].append(compound_dict) # Export compound structures print("Exporting compound structures...") - compound_structures = CompoundStructure.objects.filter( - compound__package=package - ).select_related('compound').order_by('url') + compound_structures = ( + CompoundStructure.objects.filter(compound__package=package) + .select_related("compound") + .order_by("url") + ) for structure in compound_structures: - structure_dict = serialize_base_object(structure, include_aliases=True, include_scenarios=True) - structure_dict.update({ - 'compound': { - 'uuid': str(structure.compound.uuid), - 'url': structure.compound.url - }, - 'smiles': structure.smiles, - 'canonical_smiles': structure.canonical_smiles, - 'inchikey': structure.inchikey, - 'normalized_structure': structure.normalized_structure, - 'external_identifiers': serialize_external_identifiers(structure) - }) - result['structures'].append(structure_dict) + structure_dict = serialize_base_object( + structure, include_aliases=True, include_scenarios=True + ) + structure_dict.update( + { + "compound": { + "uuid": str(structure.compound.uuid), + "url": structure.compound.url, + }, + "smiles": structure.smiles, + "canonical_smiles": structure.canonical_smiles, + "inchikey": structure.inchikey, + "normalized_structure": structure.normalized_structure, + "external_identifiers": serialize_external_identifiers(structure), + } + ) + result["structures"].append(structure_dict) # Export rules print("Exporting rules...") # Simple rules (including SimpleAmbitRule and SimpleRDKitRule) - for rule in SimpleRule.objects.filter(package=package).order_by('url'): + for rule in SimpleRule.objects.filter(package=package).order_by("url"): rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True) # Add specific fields for SimpleAmbitRule if isinstance(rule, SimpleAmbitRule): - rule_dict.update({ - 'rule_type': 'SimpleAmbitRule', - 'smirks': rule.smirks, - 'reactant_filter_smarts': rule.reactant_filter_smarts or '', - 'product_filter_smarts': rule.product_filter_smarts or '' - }) + rule_dict.update( + { + "rule_type": "SimpleAmbitRule", + "smirks": rule.smirks, + "reactant_filter_smarts": rule.reactant_filter_smarts or "", + "product_filter_smarts": rule.product_filter_smarts or "", + } + ) elif isinstance(rule, SimpleRDKitRule): - rule_dict.update({ - 'rule_type': 'SimpleRDKitRule', - 'reaction_smarts': rule.reaction_smarts - }) + rule_dict.update( + {"rule_type": "SimpleRDKitRule", "reaction_smarts": rule.reaction_smarts} + ) else: - rule_dict['rule_type'] = 'SimpleRule' + rule_dict["rule_type"] = "SimpleRule" - result['rules']['simple_rules'].append(rule_dict) + result["rules"]["simple_rules"].append(rule_dict) # Parallel rules - for rule in ParallelRule.objects.filter(package=package).prefetch_related('simple_rules').order_by('url'): + for rule in ( + ParallelRule.objects.filter(package=package) + .prefetch_related("simple_rules") + .order_by("url") + ): rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True) - rule_dict['rule_type'] = 'ParallelRule' - rule_dict['simple_rules'] = [ - {'uuid': str(sr.uuid), 'url': sr.url} for sr in rule.simple_rules.all() + rule_dict["rule_type"] = "ParallelRule" + rule_dict["simple_rules"] = [ + {"uuid": str(sr.uuid), "url": sr.url} for sr in rule.simple_rules.all() ] - result['rules']['parallel_rules'].append(rule_dict) + result["rules"]["parallel_rules"].append(rule_dict) # Sequential rules - for rule in SequentialRule.objects.filter(package=package).prefetch_related('simple_rules').order_by('url'): + for rule in ( + SequentialRule.objects.filter(package=package) + .prefetch_related("simple_rules") + .order_by("url") + ): rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True) - rule_dict['rule_type'] = 'SequentialRule' - rule_dict['simple_rules'] = [ - {'uuid': str(sr.uuid), 'url': sr.url, - 'order_index': sr.sequentialruleordering_set.get(sequential_rule=rule).order_index} + rule_dict["rule_type"] = "SequentialRule" + rule_dict["simple_rules"] = [ + { + "uuid": str(sr.uuid), + "url": sr.url, + "order_index": sr.sequentialruleordering_set.get( + sequential_rule=rule + ).order_index, + } for sr in rule.simple_rules.all() ] - result['rules']['sequential_rules'].append(rule_dict) + result["rules"]["sequential_rules"].append(rule_dict) # Export reactions print("Exporting reactions...") - for reaction in package.reactions.prefetch_related('educts', 'products', 'rules').order_by('url'): - reaction_dict = serialize_base_object(reaction, include_aliases=True, include_scenarios=True) - reaction_dict.update({ - 'educts': [{'uuid': str(e.uuid), 'url': e.url} for e in reaction.educts.all()], - 'products': [{'uuid': str(p.uuid), 'url': p.url} for p in reaction.products.all()], - 'rules': [{'uuid': str(r.uuid), 'url': r.url} for r in reaction.rules.all()], - 'multi_step': reaction.multi_step, - 'medline_references': reaction.medline_references, - 'external_identifiers': serialize_external_identifiers(reaction) - }) - result['reactions'].append(reaction_dict) + for reaction in package.reactions.prefetch_related("educts", "products", "rules").order_by( + "url" + ): + reaction_dict = serialize_base_object( + reaction, include_aliases=True, include_scenarios=True + ) + reaction_dict.update( + { + "educts": [{"uuid": str(e.uuid), "url": e.url} for e in reaction.educts.all()], + "products": [ + {"uuid": str(p.uuid), "url": p.url} for p in reaction.products.all() + ], + "rules": [{"uuid": str(r.uuid), "url": r.url} for r in reaction.rules.all()], + "multi_step": reaction.multi_step, + "medline_references": reaction.medline_references, + "external_identifiers": serialize_external_identifiers(reaction), + } + ) + result["reactions"].append(reaction_dict) # Export pathways print("Exporting pathways...") - for pathway in package.pathways.order_by('url'): - pathway_dict = serialize_base_object(pathway, include_aliases=True, include_scenarios=True) + for pathway in package.pathways.order_by("url"): + pathway_dict = serialize_base_object( + pathway, include_aliases=True, include_scenarios=True + ) # Add setting reference if exists - if hasattr(pathway, 'setting') and pathway.setting: - pathway_dict['setting'] = { - 'uuid': str(pathway.setting.uuid), - 'url': pathway.setting.url + if hasattr(pathway, "setting") and pathway.setting: + pathway_dict["setting"] = { + "uuid": str(pathway.setting.uuid), + "url": pathway.setting.url, } else: - pathway_dict['setting'] = None + pathway_dict["setting"] = None - result['pathways'].append(pathway_dict) + result["pathways"].append(pathway_dict) # Export nodes print("Exporting nodes...") - pathway_nodes = Node.objects.filter( - pathway__package=package - ).select_related('pathway', 'default_node_label').prefetch_related('node_labels', 'out_edges').order_by('url') + pathway_nodes = ( + Node.objects.filter(pathway__package=package) + .select_related("pathway", "default_node_label") + .prefetch_related("node_labels", "out_edges") + .order_by("url") + ) for node in pathway_nodes: node_dict = serialize_base_object(node, include_aliases=True, include_scenarios=True) - node_dict.update({ - 'pathway': {'uuid': str(node.pathway.uuid), 'url': node.pathway.url}, - 'default_node_label': { - 'uuid': str(node.default_node_label.uuid), - 'url': node.default_node_label.url - }, - 'node_labels': [ - {'uuid': str(label.uuid), 'url': label.url} for label in node.node_labels.all() - ], - 'out_edges': [ - {'uuid': str(edge.uuid), 'url': edge.url} for edge in node.out_edges.all() - ], - 'depth': node.depth - }) - result['nodes'].append(node_dict) + node_dict.update( + { + "pathway": {"uuid": str(node.pathway.uuid), "url": node.pathway.url}, + "default_node_label": { + "uuid": str(node.default_node_label.uuid), + "url": node.default_node_label.url, + }, + "node_labels": [ + {"uuid": str(label.uuid), "url": label.url} + for label in node.node_labels.all() + ], + "out_edges": [ + {"uuid": str(edge.uuid), "url": edge.url} for edge in node.out_edges.all() + ], + "depth": node.depth, + } + ) + result["nodes"].append(node_dict) # Export edges print("Exporting edges...") - pathway_edges = Edge.objects.filter( - pathway__package=package - ).select_related('pathway', 'edge_label').prefetch_related('start_nodes', 'end_nodes').order_by('url') + pathway_edges = ( + Edge.objects.filter(pathway__package=package) + .select_related("pathway", "edge_label") + .prefetch_related("start_nodes", "end_nodes") + .order_by("url") + ) for edge in pathway_edges: edge_dict = serialize_base_object(edge, include_aliases=True, include_scenarios=True) - edge_dict.update({ - 'pathway': {'uuid': str(edge.pathway.uuid), 'url': edge.pathway.url}, - 'edge_label': {'uuid': str(edge.edge_label.uuid), 'url': edge.edge_label.url}, - 'start_nodes': [ - {'uuid': str(node.uuid), 'url': node.url} for node in edge.start_nodes.all() - ], - 'end_nodes': [ - {'uuid': str(node.uuid), 'url': node.url} for node in edge.end_nodes.all() - ] - }) - result['edges'].append(edge_dict) + edge_dict.update( + { + "pathway": {"uuid": str(edge.pathway.uuid), "url": edge.pathway.url}, + "edge_label": {"uuid": str(edge.edge_label.uuid), "url": edge.edge_label.url}, + "start_nodes": [ + {"uuid": str(node.uuid), "url": node.url} for node in edge.start_nodes.all() + ], + "end_nodes": [ + {"uuid": str(node.uuid), "url": node.url} for node in edge.end_nodes.all() + ], + } + ) + result["edges"].append(edge_dict) # Export scenarios print("Exporting scenarios...") - for scenario in package.scenarios.order_by('url'): - scenario_dict = serialize_base_object(scenario, include_aliases=False, include_scenarios=False) - scenario_dict.update({ - 'scenario_date': scenario.scenario_date, - 'scenario_type': scenario.scenario_type, - 'parent': { - 'uuid': str(scenario.parent.uuid), - 'url': scenario.parent.url - } if scenario.parent else None, - 'additional_information': scenario.additional_information - }) - result['scenarios'].append(scenario_dict) + for scenario in package.scenarios.order_by("url"): + scenario_dict = serialize_base_object( + scenario, include_aliases=False, include_scenarios=False + ) + scenario_dict.update( + { + "scenario_date": scenario.scenario_date, + "scenario_type": scenario.scenario_type, + "parent": {"uuid": str(scenario.parent.uuid), "url": scenario.parent.url} + if scenario.parent + else None, + "additional_information": scenario.additional_information, + } + ) + result["scenarios"].append(scenario_dict) # Export models if include_models: print("Exporting models...") - package_models = package.models.select_related('app_domain').prefetch_related( - 'rule_packages', 'data_packages', 'eval_packages' - ).order_by('url') + package_models = ( + package.models.select_related("app_domain") + .prefetch_related("rule_packages", "data_packages", "eval_packages") + .order_by("url") + ) for model in package_models: - model_dict = serialize_base_object(model, include_aliases=True, include_scenarios=False) + model_dict = serialize_base_object( + model, include_aliases=True, include_scenarios=False + ) # Common fields for PackageBasedModel - if hasattr(model, 'rule_packages'): - model_dict.update({ - 'rule_packages': [ - {'uuid': str(p.uuid), 'url': p.url} for p in model.rule_packages.all() - ], - 'data_packages': [ - {'uuid': str(p.uuid), 'url': p.url} for p in model.data_packages.all() - ], - 'eval_packages': [ - {'uuid': str(p.uuid), 'url': p.url} for p in model.eval_packages.all() - ], - 'threshold': model.threshold, - 'eval_results': model.eval_results, - 'model_status': model.model_status - }) + if hasattr(model, "rule_packages"): + model_dict.update( + { + "rule_packages": [ + {"uuid": str(p.uuid), "url": p.url} + for p in model.rule_packages.all() + ], + "data_packages": [ + {"uuid": str(p.uuid), "url": p.url} + for p in model.data_packages.all() + ], + "eval_packages": [ + {"uuid": str(p.uuid), "url": p.url} + for p in model.eval_packages.all() + ], + "threshold": model.threshold, + "eval_results": model.eval_results, + "model_status": model.model_status, + } + ) if model.app_domain: - model_dict['app_domain'] = { - 'uuid': str(model.app_domain.uuid), - 'url': model.app_domain.url + model_dict["app_domain"] = { + "uuid": str(model.app_domain.uuid), + "url": model.app_domain.url, } else: - model_dict['app_domain'] = None + model_dict["app_domain"] = None # Specific fields for different model types if isinstance(model, RuleBasedRelativeReasoning): - model_dict.update({ - 'model_type': 'RuleBasedRelativeReasoning', - 'min_count': model.min_count, - 'max_count': model.max_count - }) + model_dict.update( + { + "model_type": "RuleBasedRelativeReasoning", + "min_count": model.min_count, + "max_count": model.max_count, + } + ) elif isinstance(model, MLRelativeReasoning): - model_dict['model_type'] = 'MLRelativeReasoning' + model_dict["model_type"] = "MLRelativeReasoning" elif isinstance(model, EnviFormer): - model_dict['model_type'] = 'EnviFormer' + model_dict["model_type"] = "EnviFormer" elif isinstance(model, PluginModel): - model_dict['model_type'] = 'PluginModel' + model_dict["model_type"] = "PluginModel" else: - model_dict['model_type'] = 'EPModel' + model_dict["model_type"] = "EPModel" - result['models'].append(model_dict) + result["models"].append(model_dict) print(f"Export completed for package: {package.name}") print(f"- Compounds: {len(result['compounds'])}") @@ -524,14 +603,20 @@ class PackageExporter: return result + class PackageImporter: """ Imports package data from JSON export. Handles object creation, relationship mapping, and dependency resolution. """ - def __init__(self, package: Dict[str, Any], preserve_uuids: bool = False, add_import_timestamp=True, - trust_reviewed=False): + def __init__( + self, + package: Dict[str, Any], + preserve_uuids: bool = False, + add_import_timestamp=True, + trust_reviewed=False, + ): """ Initialize the importer. @@ -570,7 +655,7 @@ class PackageImporter: def sign(data: Dict[str, Any], key: str) -> Dict[str, Any]: json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) signature = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest() - data['_signature'] = base64.b64encode(signature).decode() + data["_signature"] = base64.b64encode(signature).decode() return data @staticmethod @@ -582,7 +667,6 @@ class PackageImporter: expected = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest() return hmac.compare_digest(signature, expected) - @transaction.atomic def _import_package_from_json(self, package_data: Dict[str, Any]) -> Package: """ @@ -600,58 +684,58 @@ class PackageImporter: package = self._create_package(package_data) # Import in dependency order - self._import_compounds(package, package_data.get('compounds', [])) - self._import_structures(package, package_data.get('structures', [])) - self._import_rules(package, package_data.get('rules', {})) - self._import_reactions(package, package_data.get('reactions', [])) - self._import_pathways(package, package_data.get('pathways', [])) - self._import_nodes(package, package_data.get('nodes', [])) - self._import_edges(package, package_data.get('edges', [])) - self._import_scenarios(package, package_data.get('scenarios', [])) + self._import_compounds(package, package_data.get("compounds", [])) + self._import_structures(package, package_data.get("structures", [])) + self._import_rules(package, package_data.get("rules", {})) + self._import_reactions(package, package_data.get("reactions", [])) + self._import_pathways(package, package_data.get("pathways", [])) + self._import_nodes(package, package_data.get("nodes", [])) + self._import_edges(package, package_data.get("edges", [])) + self._import_scenarios(package, package_data.get("scenarios", [])) - if package_data.get('models'): - self._import_models(package, package_data['models']) + if package_data.get("models"): + self._import_models(package, package_data["models"]) # Set default structures for compounds (after all structures are created) - self._set_default_structures(package_data.get('compounds', [])) + self._set_default_structures(package_data.get("compounds", [])) print(f"Package import completed: {package.name}") return package def _create_package(self, package_data: Dict[str, Any]) -> Package: """Create the main package object.""" - package_uuid = self._get_or_generate_uuid(package_data['uuid']) + package_uuid = self._get_or_generate_uuid(package_data["uuid"]) # Handle license license_obj = None - if package_data.get('license'): - license_data = package_data['license'] + if package_data.get("license"): + license_data = package_data["license"] license_obj, _ = License.objects.get_or_create( - name=license_data['name'], + name=license_data["name"], defaults={ - 'link': license_data.get('link', ''), - 'image_link': license_data.get('image_link', '') - } + "link": license_data.get("link", ""), + "image_link": license_data.get("image_link", ""), + }, ) - new_name = package_data.get('name') + new_name = package_data.get("name") if self.add_import_timestamp: new_name = f"{new_name} - Imported at {datetime.now()}" new_reviewed = False if self.trust_reviewed: - new_reviewed = package_data.get('reviewed', False) + new_reviewed = package_data.get("reviewed", False) package = Package.objects.create( uuid=package_uuid, name=new_name, - description=package_data['description'], - kv=package_data.get('kv', {}), + description=package_data["description"], + kv=package_data.get("kv", {}), reviewed=new_reviewed, - license=license_obj + license=license_obj, ) - self._cache_object('Package', package_data['uuid'], package) + self._cache_object("Package", package_data["uuid"], package) print(f"Created package: {package.name}") return package @@ -660,35 +744,37 @@ class PackageImporter: print(f"Importing {len(compounds_data)} compounds...") for compound_data in compounds_data: - compound_uuid = self._get_or_generate_uuid(compound_data['uuid']) + compound_uuid = self._get_or_generate_uuid(compound_data["uuid"]) compound = Compound.objects.create( uuid=compound_uuid, package=package, - name=compound_data['name'], - description=compound_data['description'], - kv=compound_data.get('kv', {}), + name=compound_data["name"], + description=compound_data["description"], + kv=compound_data.get("kv", {}), # default_structure will be set later ) # Set aliases if present - if compound_data.get('aliases'): - compound.aliases = compound_data['aliases'] + if compound_data.get("aliases"): + compound.aliases = compound_data["aliases"] compound.save() - self._cache_object('Compound', compound_data['uuid'], compound) + self._cache_object("Compound", compound_data["uuid"], compound) # Handle external identifiers - self._create_external_identifiers(compound, compound_data.get('external_identifiers', [])) + self._create_external_identifiers( + compound, compound_data.get("external_identifiers", []) + ) def _import_structures(self, package: Package, structures_data: List[Dict[str, Any]]): """Import compound structures.""" print(f"Importing {len(structures_data)} compound structures...") for structure_data in structures_data: - structure_uuid = self._get_or_generate_uuid(structure_data['uuid']) - compound_uuid = structure_data['compound']['uuid'] - compound = self._get_cached_object('Compound', compound_uuid) + structure_uuid = self._get_or_generate_uuid(structure_data["uuid"]) + compound_uuid = structure_data["compound"]["uuid"] + compound = self._get_cached_object("Compound", compound_uuid) if not compound: print(f"Warning: Compound with UUID {compound_uuid} not found for structure") @@ -697,38 +783,40 @@ class PackageImporter: structure = CompoundStructure.objects.create( uuid=structure_uuid, compound=compound, - name=structure_data['name'], - description=structure_data['description'], - kv=structure_data.get('kv', {}), - smiles=structure_data['smiles'], - canonical_smiles=structure_data['canonical_smiles'], - inchikey=structure_data['inchikey'], - normalized_structure=structure_data.get('normalized_structure', False) + name=structure_data["name"], + description=structure_data["description"], + kv=structure_data.get("kv", {}), + smiles=structure_data["smiles"], + canonical_smiles=structure_data["canonical_smiles"], + inchikey=structure_data["inchikey"], + normalized_structure=structure_data.get("normalized_structure", False), ) # Set aliases if present - if structure_data.get('aliases'): - structure.aliases = structure_data['aliases'] + if structure_data.get("aliases"): + structure.aliases = structure_data["aliases"] structure.save() - self._cache_object('CompoundStructure', structure_data['uuid'], structure) + self._cache_object("CompoundStructure", structure_data["uuid"], structure) # Handle external identifiers - self._create_external_identifiers(structure, structure_data.get('external_identifiers', [])) + self._create_external_identifiers( + structure, structure_data.get("external_identifiers", []) + ) def _import_rules(self, package: Package, rules_data: Dict[str, Any]): """Import all types of rules.""" print("Importing rules...") # Import simple rules first - simple_rules_data = rules_data.get('simple_rules', []) + simple_rules_data = rules_data.get("simple_rules", []) print(f"Importing {len(simple_rules_data)} simple rules...") for rule_data in simple_rules_data: self._create_simple_rule(package, rule_data) # Import parallel rules - parallel_rules_data = rules_data.get('parallel_rules', []) + parallel_rules_data = rules_data.get("parallel_rules", []) print(f"Importing {len(parallel_rules_data)} parallel rules...") for rule_data in parallel_rules_data: @@ -736,64 +824,63 @@ class PackageImporter: def _create_simple_rule(self, package: Package, rule_data: Dict[str, Any]): """Create a simple rule (SimpleAmbitRule or SimpleRDKitRule).""" - rule_uuid = self._get_or_generate_uuid(rule_data['uuid']) - rule_type = rule_data.get('rule_type', 'SimpleRule') + rule_uuid = self._get_or_generate_uuid(rule_data["uuid"]) + rule_type = rule_data.get("rule_type", "SimpleRule") common_fields = { - 'uuid': rule_uuid, - 'package': package, - 'name': rule_data['name'], - 'description': rule_data['description'], - 'kv': rule_data.get('kv', {}) + "uuid": rule_uuid, + "package": package, + "name": rule_data["name"], + "description": rule_data["description"], + "kv": rule_data.get("kv", {}), } - if rule_type == 'SimpleAmbitRule': + if rule_type == "SimpleAmbitRule": rule = SimpleAmbitRule.objects.create( **common_fields, - smirks=rule_data.get('smirks', ''), - reactant_filter_smarts=rule_data.get('reactant_filter_smarts', ''), - product_filter_smarts=rule_data.get('product_filter_smarts', '') + smirks=rule_data.get("smirks", ""), + reactant_filter_smarts=rule_data.get("reactant_filter_smarts", ""), + product_filter_smarts=rule_data.get("product_filter_smarts", ""), ) - elif rule_type == 'SimpleRDKitRule': + elif rule_type == "SimpleRDKitRule": rule = SimpleRDKitRule.objects.create( - **common_fields, - reaction_smarts=rule_data.get('reaction_smarts', '') + **common_fields, reaction_smarts=rule_data.get("reaction_smarts", "") ) else: rule = SimpleRule.objects.create(**common_fields) # Set aliases if present - if rule_data.get('aliases'): - rule.aliases = rule_data['aliases'] + if rule_data.get("aliases"): + rule.aliases = rule_data["aliases"] rule.save() - self._cache_object('SimpleRule', rule_data['uuid'], rule) + self._cache_object("SimpleRule", rule_data["uuid"], rule) return rule def _create_parallel_rule(self, package: Package, rule_data: Dict[str, Any]): """Create a parallel rule.""" - rule_uuid = self._get_or_generate_uuid(rule_data['uuid']) + rule_uuid = self._get_or_generate_uuid(rule_data["uuid"]) rule = ParallelRule.objects.create( uuid=rule_uuid, package=package, - name=rule_data['name'], - description=rule_data['description'], - kv=rule_data.get('kv', {}) + name=rule_data["name"], + description=rule_data["description"], + kv=rule_data.get("kv", {}), ) # Set aliases if present - if rule_data.get('aliases'): - rule.aliases = rule_data['aliases'] + if rule_data.get("aliases"): + rule.aliases = rule_data["aliases"] rule.save() # Add simple rules - for simple_rule_ref in rule_data.get('simple_rules', []): - simple_rule = self._get_cached_object('SimpleRule', simple_rule_ref['uuid']) + for simple_rule_ref in rule_data.get("simple_rules", []): + simple_rule = self._get_cached_object("SimpleRule", simple_rule_ref["uuid"]) if simple_rule: rule.simple_rules.add(simple_rule) - self._cache_object('ParallelRule', rule_data['uuid'], rule) + self._cache_object("ParallelRule", rule_data["uuid"], rule) return rule def _import_reactions(self, package: Package, reactions_data: List[Dict[str, Any]]): @@ -801,78 +888,81 @@ class PackageImporter: print(f"Importing {len(reactions_data)} reactions...") for reaction_data in reactions_data: - reaction_uuid = self._get_or_generate_uuid(reaction_data['uuid']) + reaction_uuid = self._get_or_generate_uuid(reaction_data["uuid"]) reaction = Reaction.objects.create( uuid=reaction_uuid, package=package, - name=reaction_data['name'], - description=reaction_data['description'], - kv=reaction_data.get('kv', {}), - multi_step=reaction_data.get('multi_step', False), - medline_references=reaction_data.get('medline_references', []) + name=reaction_data["name"], + description=reaction_data["description"], + kv=reaction_data.get("kv", {}), + multi_step=reaction_data.get("multi_step", False), + medline_references=reaction_data.get("medline_references", []), ) # Set aliases if present - if reaction_data.get('aliases'): - reaction.aliases = reaction_data['aliases'] + if reaction_data.get("aliases"): + reaction.aliases = reaction_data["aliases"] reaction.save() # Add educts and products - for educt_ref in reaction_data.get('educts', []): - compound = self._get_cached_object('CompoundStructure', educt_ref['uuid']) + for educt_ref in reaction_data.get("educts", []): + compound = self._get_cached_object("CompoundStructure", educt_ref["uuid"]) if compound: reaction.educts.add(compound) - for product_ref in reaction_data.get('products', []): - compound = self._get_cached_object('CompoundStructure', product_ref['uuid']) + for product_ref in reaction_data.get("products", []): + compound = self._get_cached_object("CompoundStructure", product_ref["uuid"]) if compound: reaction.products.add(compound) # Add rules - for rule_ref in reaction_data.get('rules', []): + for rule_ref in reaction_data.get("rules", []): # Try to find rule in different caches - rule = (self._get_cached_object('SimpleRule', rule_ref['uuid']) or - self._get_cached_object('ParallelRule', rule_ref['uuid'])) + rule = self._get_cached_object( + "SimpleRule", rule_ref["uuid"] + ) or self._get_cached_object("ParallelRule", rule_ref["uuid"]) if rule: reaction.rules.add(rule) - self._cache_object('Reaction', reaction_data['uuid'], reaction) + self._cache_object("Reaction", reaction_data["uuid"], reaction) # Handle external identifiers - self._create_external_identifiers(reaction, reaction_data.get('external_identifiers', [])) + self._create_external_identifiers( + reaction, reaction_data.get("external_identifiers", []) + ) def _import_pathways(self, package: Package, pathways_data: List[Dict[str, Any]]): """Import pathways.""" print(f"Importing {len(pathways_data)} pathways...") for pathway_data in pathways_data: - pathway_uuid = self._get_or_generate_uuid(pathway_data['uuid']) + pathway_uuid = self._get_or_generate_uuid(pathway_data["uuid"]) pathway = Pathway.objects.create( uuid=pathway_uuid, package=package, - name=pathway_data['name'], - description=pathway_data['description'], - kv=pathway_data.get('kv', {}) + name=pathway_data["name"], + description=pathway_data["description"], + kv=pathway_data.get("kv", {}), # setting will be handled separately if needed ) # Set aliases if present - if pathway_data.get('aliases'): - pathway.aliases = pathway_data['aliases'] + if pathway_data.get("aliases"): + pathway.aliases = pathway_data["aliases"] pathway.save() - self._cache_object('Pathway', pathway_data['uuid'], pathway) + self._cache_object("Pathway", pathway_data["uuid"], pathway) def _import_nodes(self, package: Package, nodes_data: List[Dict[str, Any]]): """Import pathway nodes.""" print(f"Importing {len(nodes_data)} nodes...") for node_data in nodes_data: - node_uuid = self._get_or_generate_uuid(node_data['uuid']) - pathway_uuid = node_data['pathway']['uuid'] - pathway = self._get_cached_object('Pathway', pathway_uuid) + node_uuid = self._get_or_generate_uuid(node_data["uuid"]) + pathway_uuid = node_data["pathway"]["uuid"] + pathway = self._get_cached_object("Pathway", pathway_uuid) if not pathway: print(f"Warning: Pathway with UUID {pathway_uuid} not found for node") @@ -883,19 +973,21 @@ class PackageImporter: node = Node.objects.create( uuid=node_uuid, pathway=pathway, - name=node_data['name'], - description=node_data['description'], - kv=node_data.get('kv', {}), - depth=node_data.get('depth', 0), - default_node_label=self._get_cached_object('CompoundStructure', node_data['default_node_label']['uuid']) + name=node_data["name"], + description=node_data["description"], + kv=node_data.get("kv", {}), + depth=node_data.get("depth", 0), + default_node_label=self._get_cached_object( + "CompoundStructure", node_data["default_node_label"]["uuid"] + ), ) # Set aliases if present - if node_data.get('aliases'): - node.aliases = node_data['aliases'] + if node_data.get("aliases"): + node.aliases = node_data["aliases"] node.save() - self._cache_object('Node', node_data['uuid'], node) + self._cache_object("Node", node_data["uuid"], node) # Store node_data for later processing of relationships node._import_data = node_data @@ -904,9 +996,9 @@ class PackageImporter: print(f"Importing {len(edges_data)} edges...") for edge_data in edges_data: - edge_uuid = self._get_or_generate_uuid(edge_data['uuid']) - pathway_uuid = edge_data['pathway']['uuid'] - pathway = self._get_cached_object('Pathway', pathway_uuid) + edge_uuid = self._get_or_generate_uuid(edge_data["uuid"]) + pathway_uuid = edge_data["pathway"]["uuid"] + pathway = self._get_cached_object("Pathway", pathway_uuid) if not pathway: print(f"Warning: Pathway with UUID {pathway_uuid} not found for edge") @@ -916,29 +1008,29 @@ class PackageImporter: edge = Edge.objects.create( uuid=edge_uuid, pathway=pathway, - name=edge_data['name'], - description=edge_data['description'], - kv=edge_data.get('kv', {}), - edge_label=self._get_cached_object('Reaction', edge_data['edge_label']['uuid']) + name=edge_data["name"], + description=edge_data["description"], + kv=edge_data.get("kv", {}), + edge_label=self._get_cached_object("Reaction", edge_data["edge_label"]["uuid"]), ) # Set aliases if present - if edge_data.get('aliases'): - edge.aliases = edge_data['aliases'] + if edge_data.get("aliases"): + edge.aliases = edge_data["aliases"] edge.save() # Add start and end nodes - for start_node_ref in edge_data.get('start_nodes', []): - node = self._get_cached_object('Node', start_node_ref['uuid']) + for start_node_ref in edge_data.get("start_nodes", []): + node = self._get_cached_object("Node", start_node_ref["uuid"]) if node: edge.start_nodes.add(node) - for end_node_ref in edge_data.get('end_nodes', []): - node = self._get_cached_object('Node', end_node_ref['uuid']) + for end_node_ref in edge_data.get("end_nodes", []): + node = self._get_cached_object("Node", end_node_ref["uuid"]) if node: edge.end_nodes.add(node) - self._cache_object('Edge', edge_data['uuid'], edge) + self._cache_object("Edge", edge_data["uuid"], edge) def _import_scenarios(self, package: Package, scenarios_data: List[Dict[str, Any]]): """Import scenarios.""" @@ -946,32 +1038,32 @@ class PackageImporter: # First pass: create scenarios without parent relationships for scenario_data in scenarios_data: - scenario_uuid = self._get_or_generate_uuid(scenario_data['uuid']) + scenario_uuid = self._get_or_generate_uuid(scenario_data["uuid"]) scenario_date = None - if scenario_data.get('scenario_date'): - scenario_date = scenario_data['scenario_date'] + if scenario_data.get("scenario_date"): + scenario_date = scenario_data["scenario_date"] scenario = Scenario.objects.create( uuid=scenario_uuid, package=package, - name=scenario_data['name'], - description=scenario_data['description'], - kv=scenario_data.get('kv', {}), + name=scenario_data["name"], + description=scenario_data["description"], + kv=scenario_data.get("kv", {}), scenario_date=scenario_date, - scenario_type=scenario_data.get('scenario_type'), - additional_information=scenario_data.get('additional_information', {}) + scenario_type=scenario_data.get("scenario_type"), + additional_information=scenario_data.get("additional_information", {}), ) - self._cache_object('Scenario', scenario_data['uuid'], scenario) + self._cache_object("Scenario", scenario_data["uuid"], scenario) # Store scenario_data for later processing of parent relationships scenario._import_data = scenario_data # Second pass: set parent relationships for scenario_data in scenarios_data: - if scenario_data.get('parent'): - scenario = self._get_cached_object('Scenario', scenario_data['uuid']) - parent = self._get_cached_object('Scenario', scenario_data['parent']['uuid']) + if scenario_data.get("parent"): + scenario = self._get_cached_object("Scenario", scenario_data["uuid"]) + parent = self._get_cached_object("Scenario", scenario_data["parent"]["uuid"]) if scenario and parent: scenario.parent = parent scenario.save() @@ -981,74 +1073,77 @@ class PackageImporter: print(f"Importing {len(models_data)} models...") for model_data in models_data: - model_uuid = self._get_or_generate_uuid(model_data['uuid']) - model_type = model_data.get('model_type', 'EPModel') + model_uuid = self._get_or_generate_uuid(model_data["uuid"]) + model_type = model_data.get("model_type", "EPModel") common_fields = { - 'uuid': model_uuid, - 'package': package, - 'name': model_data['name'], - 'description': model_data['description'], - 'kv': model_data.get('kv', {}) + "uuid": model_uuid, + "package": package, + "name": model_data["name"], + "description": model_data["description"], + "kv": model_data.get("kv", {}), } # Add PackageBasedModel fields if present - if 'threshold' in model_data: - common_fields.update({ - 'threshold': model_data.get('threshold'), - 'eval_results': model_data.get('eval_results', {}), - 'model_status': model_data.get('model_status', 'INITIAL') - }) + if "threshold" in model_data: + common_fields.update( + { + "threshold": model_data.get("threshold"), + "eval_results": model_data.get("eval_results", {}), + "model_status": model_data.get("model_status", "INITIAL"), + } + ) # Create the appropriate model type - if model_type == 'RuleBasedRelativeReasoning': + if model_type == "RuleBasedRelativeReasoning": model = RuleBasedRelativeReasoning.objects.create( **common_fields, - min_count=model_data.get('min_count', 1), - max_count=model_data.get('max_count', 10) + min_count=model_data.get("min_count", 1), + max_count=model_data.get("max_count", 10), ) - elif model_type == 'MLRelativeReasoning': + elif model_type == "MLRelativeReasoning": model = MLRelativeReasoning.objects.create(**common_fields) - elif model_type == 'EnviFormer': + elif model_type == "EnviFormer": model = EnviFormer.objects.create(**common_fields) - elif model_type == 'PluginModel': + elif model_type == "PluginModel": model = PluginModel.objects.create(**common_fields) else: model = EPModel.objects.create(**common_fields) # Set aliases if present - if model_data.get('aliases'): - model.aliases = model_data['aliases'] + if model_data.get("aliases"): + model.aliases = model_data["aliases"] model.save() # Add package relationships for PackageBasedModel - if hasattr(model, 'rule_packages'): - for pkg_ref in model_data.get('rule_packages', []): - pkg = self._get_cached_object('Package', pkg_ref['uuid']) + if hasattr(model, "rule_packages"): + for pkg_ref in model_data.get("rule_packages", []): + pkg = self._get_cached_object("Package", pkg_ref["uuid"]) if pkg: model.rule_packages.add(pkg) - for pkg_ref in model_data.get('data_packages', []): - pkg = self._get_cached_object('Package', pkg_ref['uuid']) + for pkg_ref in model_data.get("data_packages", []): + pkg = self._get_cached_object("Package", pkg_ref["uuid"]) if pkg: model.data_packages.add(pkg) - for pkg_ref in model_data.get('eval_packages', []): - pkg = self._get_cached_object('Package', pkg_ref['uuid']) + for pkg_ref in model_data.get("eval_packages", []): + pkg = self._get_cached_object("Package", pkg_ref["uuid"]) if pkg: model.eval_packages.add(pkg) - self._cache_object('EPModel', model_data['uuid'], model) + self._cache_object("EPModel", model_data["uuid"], model) def _set_default_structures(self, compounds_data: List[Dict[str, Any]]): """Set default structures for compounds after all structures are created.""" print("Setting default structures for compounds...") for compound_data in compounds_data: - if compound_data.get('default_structure'): - compound = self._get_cached_object('Compound', compound_data['uuid']) - structure = self._get_cached_object('CompoundStructure', - compound_data['default_structure']['uuid']) + if compound_data.get("default_structure"): + compound = self._get_cached_object("Compound", compound_data["uuid"]) + structure = self._get_cached_object( + "CompoundStructure", compound_data["default_structure"]["uuid"] + ) if compound and structure: compound.default_structure = structure compound.save() @@ -1057,22 +1152,22 @@ class PackageImporter: """Create external identifiers for an object.""" for identifier_data in identifiers_data: # Get or create the external database - db_data = identifier_data['database'] + db_data = identifier_data["database"] database, _ = ExternalDatabase.objects.get_or_create( - name=db_data['name'], + name=db_data["name"], defaults={ - 'base_url': db_data.get('base_url', ''), - 'full_name': db_data.get('name', ''), - 'description': '', - 'is_active': True - } + "base_url": db_data.get("base_url", ""), + "full_name": db_data.get("name", ""), + "description": "", + "is_active": True, + }, ) # Create the external identifier ExternalIdentifier.objects.create( content_object=obj, database=database, - identifier_value=identifier_data['identifier_value'], - url=identifier_data.get('url', ''), - is_primary=identifier_data.get('is_primary', False) + identifier_value=identifier_data["identifier_value"], + url=identifier_data.get("url", ""), + is_primary=identifier_data.get("is_primary", False), ) diff --git a/utilities/ml.py b/utilities/ml.py index 8dbaa1ef..a93fafd9 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -1,37 +1,35 @@ from __future__ import annotations import copy - -import numpy as np -from numpy.random import default_rng -from sklearn.dummy import DummyClassifier -from sklearn.tree import DecisionTreeClassifier import logging -from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime -from typing import List, Dict, Set, Tuple +from pathlib import Path +from typing import List, Dict, Set, Tuple, TYPE_CHECKING import networkx as nx - +import numpy as np +from numpy.random import default_rng from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.decomposition import PCA +from sklearn.dummy import DummyClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score from sklearn.multioutput import ClassifierChain from sklearn.preprocessing import StandardScaler +from utilities.chem import FormatConverter, PredictionResult + logger = logging.getLogger(__name__) - -from dataclasses import dataclass, field - -from utilities.chem import FormatConverter, PredictionResult +if TYPE_CHECKING: + from epdb.models import Rule, CompoundStructure, Reaction class Dataset: - - def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None): + def __init__( + self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None + ): self.columns: List[str] = columns self.num_labels: int = num_labels @@ -41,9 +39,9 @@ class Dataset: self.data = data self.num_features: int = len(columns) - self.num_labels - self._struct_features: Tuple[int, int] = self._block_indices('feature_') - self._triggered: Tuple[int, int] = self._block_indices('trig_') - self._observed: Tuple[int, int] = self._block_indices('obs_') + self._struct_features: Tuple[int, int] = self._block_indices("feature_") + self._triggered: Tuple[int, int] = self._block_indices("trig_") + self._observed: Tuple[int, int] = self._block_indices("obs_") def _block_indices(self, prefix) -> Tuple[int, int]: indices: List[int] = [] @@ -62,7 +60,7 @@ class Dataset: self.data.append(row) def times_triggered(self, rule_uuid) -> int: - idx = self.columns.index(f'trig_{rule_uuid}') + idx = self.columns.index(f"trig_{rule_uuid}") times_triggered = 0 for row in self.data: @@ -89,12 +87,12 @@ class Dataset: def __iter__(self): return (self.at(i) for i, _ in enumerate(self.data)) - - def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]: + def classification_dataset( + self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] + ) -> Tuple[Dataset, List[List[PredictionResult]]]: classify_data = [] classify_products = [] for struct in structures: - if isinstance(struct, str): struct_id = None struct_smiles = struct @@ -119,10 +117,14 @@ class Dataset: classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) classify_products.append(prods) - return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products + return Dataset( + columns=self.columns, num_labels=self.num_labels, data=classify_data + ), classify_products @staticmethod - def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset: + def generate_dataset( + reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True + ) -> Dataset: _structures = set() for r in reactions: @@ -155,12 +157,11 @@ class Dataset: for prod_set in product_sets: for smi in prod_set: - try: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception: # :shrug: - logger.debug(f'Standardizing SMILES failed for {smi}') + logger.debug(f"Standardizing SMILES failed for {smi}") pass triggered[key].add(smi) @@ -188,7 +189,7 @@ class Dataset: smi = FormatConverter.standardize(smi, remove_stereo=True) except Exception as e: # :shrug: - logger.debug(f'Standardizing SMILES failed for {smi}') + logger.debug(f"Standardizing SMILES failed for {smi}") pass standardized_products.append(smi) @@ -224,19 +225,22 @@ class Dataset: obs.append(0) if ds is None: - header = ['structure_id'] + \ - [f'feature_{i}' for i, _ in enumerate(feat)] \ - + [f'trig_{r.uuid}' for r in applicable_rules] \ - + [f'obs_{r.uuid}' for r in applicable_rules] + header = ( + ["structure_id"] + + [f"feature_{i}" for i, _ in enumerate(feat)] + + [f"trig_{r.uuid}" for r in applicable_rules] + + [f"obs_{r.uuid}" for r in applicable_rules] + ) ds = Dataset(header, len(applicable_rules)) ds.add_row([str(comp.uuid)] + feat + trig + obs) return ds - def X(self, exclude_id_col=True, na_replacement=0): - res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))) + res = self.__getitem__( + (slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)) + ) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res @@ -247,14 +251,12 @@ class Dataset: res = [[x if x is not None else na_replacement for x in row] for row in res] return res - def y(self, na_replacement=0): res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) if na_replacement is not None: res = [[x if x is not None else na_replacement for x in row] for row in res] return res - def __getitem__(self, key): if not isinstance(key, tuple): raise TypeError("Dataset must be indexed with dataset[rows, columns]") @@ -271,42 +273,50 @@ class Dataset: if isinstance(col_key, int): res = [row[col_key] for row in rows] else: - res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice) - else [row[i] for i in col_key] for row in rows] + res = [ + [row[i] for i in range(*col_key.indices(len(row)))] + if isinstance(col_key, slice) + else [row[i] for i in col_key] + for row in rows + ] return res - def save(self, path: 'Path'): + def save(self, path: "Path"): import pickle + with open(path, "wb") as fh: pickle.dump(self, fh) @staticmethod - def load(path: 'Path') -> 'Dataset': + def load(path: "Path") -> "Dataset": import pickle + return pickle.load(open(path, "rb")) - def to_arff(self, path: 'Path'): + def to_arff(self, path: "Path"): arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff += "\n" - for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]: - if c == 'structure_id': + for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]: + if c == "structure_id": arff += f"@attribute {c} string\n" else: arff += f"@attribute {c} {{0,1}}\n" - arff += f"\n@data\n" + arff += "\n@data\n" for d in self.data: - ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]]) - xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]]) - arff += f'{ys},{xs}\n' + ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]]) + xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]]) + arff += f"{ys},{xs}\n" with open(path, "w") as fh: fh.write(arff) fh.flush() def __repr__(self): - return f"" + return ( + f"" + ) class SparseLabelECC(BaseEstimator, ClassifierMixin): @@ -315,8 +325,11 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin): Removes labels that are constant across all samples in training. """ - def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42), - num_chains: int = 10): + def __init__( + self, + base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42), + num_chains: int = 10, + ): self.base_clf = base_clf self.num_chains = num_chains @@ -384,16 +397,16 @@ class BinaryRelevance: if self.classifiers is None: self.classifiers = [] - for l in range(len(Y[0])): - X_l = X[~np.isnan(Y[:, l])] - Y_l = (Y[~np.isnan(Y[:, l]), l]) + for label in range(len(Y[0])): + X_l = X[~np.isnan(Y[:, label])] + Y_l = Y[~np.isnan(Y[:, label]), label] if len(X_l) == 0: # all labels are nan -> predict 0 - clf = DummyClassifier(strategy='constant', constant=0) + clf = DummyClassifier(strategy="constant", constant=0) clf.fit([X[0]], [0]) self.classifiers.append(clf) continue elif len(np.unique(Y_l)) == 1: # only one class -> predict that class - clf = DummyClassifier(strategy='most_frequent') + clf = DummyClassifier(strategy="most_frequent") else: clf = copy.deepcopy(self.clf) clf.fit(X_l, Y_l) @@ -439,17 +452,19 @@ class MissingValuesClassifierChain: X_p = X[~np.isnan(Y[:, p])] Y_p = Y[~np.isnan(Y[:, p]), p] if len(X_p) == 0: # all labels are nan -> predict 0 - clf = DummyClassifier(strategy='constant', constant=0) + clf = DummyClassifier(strategy="constant", constant=0) self.classifiers.append(clf.fit([X[0]], [0])) elif len(np.unique(Y_p)) == 1: # only one class -> predict that class - clf = DummyClassifier(strategy='most_frequent') + clf = DummyClassifier(strategy="most_frequent") self.classifiers.append(clf.fit(X_p, Y_p)) else: clf = copy.deepcopy(self.base_clf) self.classifiers.append(clf.fit(X_p, Y_p)) newcol = Y[:, p] pred = clf.predict(X) - newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions + newcol[np.isnan(newcol)] = pred[ + np.isnan(newcol) + ] # fill in missing values with clf predictions X = np.column_stack((X, newcol)) def predict(self, X): @@ -541,13 +556,10 @@ class RelativeReasoning: # We've seen more than self.min_count wins, more wins than loosing, no looses and no ties if ( - countwin >= self.min_count and - countwin > countloose and - ( - countloose <= self.max_count or - self.max_count < 0 - ) and - countboth == 0 + countwin >= self.min_count + and countwin > countloose + and (countloose <= self.max_count or self.max_count < 0) + and countboth == 0 ): self.winmap[i].append(j) @@ -557,13 +569,13 @@ class RelativeReasoning: # Loop through all instances for inst_idx, inst in enumerate(X): # Loop through all "triggered" features - for i, t in enumerate(inst[self.start_index: self.end_index + 1]): + for i, t in enumerate(inst[self.start_index : self.end_index + 1]): # Set label res[inst_idx][i] = t # If we predict a 1, check if the rule gets dominated by another if t: # Second loop to check other triggered rules - for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]): + for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]): if i != i2: # Check if rule idx is in "dominated by" list if i2 in self.winmap.get(i, []): @@ -579,7 +591,6 @@ class RelativeReasoning: class ApplicabilityDomainPCA(PCA): - def __init__(self, num_neighbours: int = 5): super().__init__(n_components=num_neighbours) self.scaler = StandardScaler() @@ -587,7 +598,7 @@ class ApplicabilityDomainPCA(PCA): self.min_vals = None self.max_vals = None - def build(self, train_dataset: 'Dataset'): + def build(self, train_dataset: "Dataset"): # transform X_scaled = self.scaler.fit_transform(train_dataset.X()) # fit pca @@ -601,7 +612,7 @@ class ApplicabilityDomainPCA(PCA): instances_pca = self.transform(instances_scaled) return instances_pca - def is_applicable(self, classify_instances: 'Dataset'): + def is_applicable(self, classify_instances: "Dataset"): instances_pca = self.__transform(classify_instances.X()) is_applicable = [] @@ -632,6 +643,7 @@ def graph_from_pathway(data): """Convert Pathway or SPathway to networkx""" from epdb.models import Pathway from epdb.logic import SPathway + graph = nx.DiGraph() co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation @@ -645,7 +657,9 @@ def graph_from_pathway(data): def get_sources_targets(): if isinstance(data, Pathway): - return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()] + return [n.node for n in edge.start_nodes.constrained_target.all()], [ + n.node for n in edge.end_nodes.constrained_target.all() + ] elif isinstance(data, SPathway): return edge.educts, edge.products else: @@ -662,7 +676,7 @@ def graph_from_pathway(data): def get_probability(): try: if isinstance(data, Pathway): - return edge.kv.get('probability') + return edge.kv.get("probability") elif isinstance(data, SPathway): return edge.probability else: @@ -680,17 +694,29 @@ def graph_from_pathway(data): for source in sources: source_smiles, source_depth = get_smiles_depth(source) if source_smiles not in graph: - graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles, - root=source_smiles in root_smiles) + graph.add_node( + source_smiles, + depth=source_depth, + smiles=source_smiles, + root=source_smiles in root_smiles, + ) else: - graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"]) + graph.nodes[source_smiles]["depth"] = min( + source_depth, graph.nodes[source_smiles]["depth"] + ) for target in targets: target_smiles, target_depth = get_smiles_depth(target) if target_smiles not in graph and target_smiles not in co2: - graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles, - root=target_smiles in root_smiles) + graph.add_node( + target_smiles, + depth=target_depth, + smiles=target_smiles, + root=target_smiles in root_smiles, + ) elif target_smiles not in co2: - graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"]) + graph.nodes[target_smiles]["depth"] = min( + target_depth, graph.nodes[target_smiles]["depth"] + ) if target_smiles not in co2 and target_smiles != source_smiles: graph.add_edge(source_smiles, target_smiles, probability=probability) return graph @@ -710,7 +736,9 @@ def set_pathway_eval_weight(pathway): node_eval_weights = {} for node in pathway.nodes: # Scale score according to depth level - node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0 + node_eval_weights[node] = ( + 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0 + ) return node_eval_weights @@ -731,8 +759,11 @@ def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates): shortest_path_list.append(shortest_path_nodes) if shortest_path_list: shortest_path_nodes = min(shortest_path_list, key=len) - num_ints = sum(1 for shortest_path_node in shortest_path_nodes if - shortest_path_node in intermediates) + num_ints = sum( + 1 + for shortest_path_node in shortest_path_nodes + if shortest_path_node in intermediates + ) pred_pathway.nodes[node]["depth"] -= num_ints return pred_pathway @@ -879,6 +910,11 @@ def pathway_edit_eval(data_pathway, pred_pathway): data_pathway = initialise_pathway(data_pathway) pred_pathway = initialise_pathway(pred_pathway) roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0]) - return nx.graph_edit_distance(data_pathway, pred_pathway, - node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost, - node_ins_cost=node_ins_del_cost, roots=roots) + return nx.graph_edit_distance( + data_pathway, + pred_pathway, + node_subst_cost=node_subst_cost, + node_del_cost=node_ins_del_cost, + node_ins_cost=node_ins_del_cost, + roots=roots, + ) diff --git a/utilities/plugin.py b/utilities/plugin.py index 97d295e1..d2248d3c 100644 --- a/utilities/plugin.py +++ b/utilities/plugin.py @@ -23,11 +23,11 @@ def install_wheel(wheel_path): def extract_package_name_from_wheel(wheel_filename): # Example: my_plugin-0.1.0-py3-none-any.whl -> my_plugin - return wheel_filename.split('-')[0] + return wheel_filename.split("-")[0] def ensure_plugins_installed(): - wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, '*.whl')) + wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, "*.whl")) for wheel_path in wheel_files: wheel_filename = os.path.basename(wheel_path) @@ -45,7 +45,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]: plugins = {} - for entry_point in importlib.metadata.entry_points(group='enviPy_plugins'): + for entry_point in importlib.metadata.entry_points(group="enviPy_plugins"): try: plugin_class = entry_point.load() if _cls: @@ -54,9 +54,9 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]: plugins[instance.name()] = instance else: if ( - issubclass(plugin_class, Classifier) - or issubclass(plugin_class, Descriptor) - or issubclass(plugin_class, Property) + issubclass(plugin_class, Classifier) + or issubclass(plugin_class, Descriptor) + or issubclass(plugin_class, Property) ): instance = plugin_class() plugins[instance.name()] = instance