[Chore] Linted Files (#150)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#150
This commit is contained in:
2025-10-09 07:25:13 +13:00
parent 22f0bbe10b
commit afeb56622c
50 changed files with 5616 additions and 4408 deletions

View File

@ -2,4 +2,4 @@
# Django starts so that shared_task will use this app. # Django starts so that shared_task will use this app.
from .celery import app as celery_app from .celery import app as celery_app
__all__ = ('celery_app',) __all__ = ("celery_app",)

View File

@ -4,8 +4,6 @@ from ninja import NinjaAPI
api = NinjaAPI() api = NinjaAPI()
from ninja import NinjaAPI
api_v1 = NinjaAPI(title="API V1 Docs", urls_namespace="api-v1") api_v1 = NinjaAPI(title="API V1 Docs", urls_namespace="api-v1")
api_legacy = NinjaAPI(title="Legacy API Docs", urls_namespace="api-legacy") api_legacy = NinjaAPI(title="Legacy API Docs", urls_namespace="api-legacy")

View File

@ -11,6 +11,6 @@ import os
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings")
application = get_asgi_application() application = get_asgi_application()

View File

@ -4,15 +4,15 @@ from celery import Celery
from celery.signals import setup_logging from celery.signals import setup_logging
# Set the default Django settings module for the 'celery' program. # Set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings")
app = Celery('envipath') app = Celery("envipath")
# Using a string here means the worker doesn't have to serialize # Using a string here means the worker doesn't have to serialize
# the configuration object to child processes. # the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys # - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix. # should have a `CELERY_` prefix.
app.config_from_object('django.conf:settings', namespace='CELERY') app.config_from_object("django.conf:settings", namespace="CELERY")
@setup_logging.connect @setup_logging.connect

View File

@ -14,6 +14,7 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path 1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
""" """
from django.conf import settings as s from django.conf import settings as s
from django.contrib import admin from django.contrib import admin
from django.urls import include, path from django.urls import include, path
@ -21,7 +22,6 @@ from django.urls import include, path
from .api import api_v1, api_legacy from .api import api_v1, api_legacy
urlpatterns = [ urlpatterns = [
path("", include("epdb.urls")), path("", include("epdb.urls")),
path("", include("migration.urls")), path("", include("migration.urls")),
path("admin/", admin.site.urls), path("admin/", admin.site.urls),

View File

@ -11,6 +11,6 @@ import os
from django.core.wsgi import get_wsgi_application from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings")
application = get_wsgi_application() application = get_wsgi_application()

View File

@ -18,7 +18,7 @@ from .models import (
Scenario, Scenario,
Setting, Setting,
ExternalDatabase, ExternalDatabase,
ExternalIdentifier ExternalIdentifier,
) )
@ -39,7 +39,7 @@ class GroupPackagePermissionAdmin(admin.ModelAdmin):
class EPAdmin(admin.ModelAdmin): class EPAdmin(admin.ModelAdmin):
search_fields = ['name', 'description'] search_fields = ["name", "description"]
class PackageAdmin(EPAdmin): class PackageAdmin(EPAdmin):

View File

@ -21,7 +21,7 @@ class BearerTokenAuth(HttpBearer):
def _anonymous_or_real(request): def _anonymous_or_real(request):
if request.user.is_authenticated and not request.user.is_anonymous: if request.user.is_authenticated and not request.user.is_anonymous:
return request.user return request.user
return get_user_model().objects.get(username='anonymous') return get_user_model().objects.get(username="anonymous")
router = Router(auth=BearerTokenAuth()) router = Router(auth=BearerTokenAuth())
@ -85,7 +85,9 @@ def get_package(request, package_uuid):
try: try:
return PackageManager.get_package_by_id(request.auth, package_id=package_uuid) return PackageManager.get_package_by_id(request.auth, package_id=package_uuid)
except ValueError: except ValueError:
return 403, {'message': f'Getting Package with id {package_uuid} failed due to insufficient rights!'} return 403, {
"message": f"Getting Package with id {package_uuid} failed due to insufficient rights!"
}
@router.get("/compound", response={200: List[CompoundSchema], 403: Error}) @router.get("/compound", response={200: List[CompoundSchema], 403: Error})
@ -97,7 +99,9 @@ def get_compounds(request):
return qs return qs
@router.get("/package/{uuid:package_uuid}/compound", response={200: List[CompoundSchema], 403: Error}) @router.get(
"/package/{uuid:package_uuid}/compound", response={200: List[CompoundSchema], 403: Error}
)
@paginate @paginate
def get_package_compounds(request, package_uuid): def get_package_compounds(request, package_uuid):
try: try:
@ -105,4 +109,5 @@ def get_package_compounds(request, package_uuid):
return Compound.objects.filter(package=p) return Compound.objects.filter(package=p)
except ValueError: except ValueError:
return 403, { return 403, {
'message': f'Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!'} "message": f"Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!"
}

View File

@ -2,8 +2,8 @@ from django.apps import AppConfig
class EPDBConfig(AppConfig): class EPDBConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'epdb' name = "epdb"
def ready(self): def ready(self):
import epdb.signals # noqa: F401 import epdb.signals # noqa: F401

View File

@ -1,5 +0,0 @@
from django import forms
class EmailLoginForm(forms.Form):
email = forms.EmailField()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -5,32 +5,49 @@ from django.core.management.base import BaseCommand
from django.db import transaction from django.db import transaction
from epdb.logic import UserManager, GroupManager, PackageManager, SettingManager from epdb.logic import UserManager, GroupManager, PackageManager, SettingManager
from epdb.models import UserSettingPermission, MLRelativeReasoning, EnviFormer, Permission, User, ExternalDatabase from epdb.models import (
UserSettingPermission,
MLRelativeReasoning,
EnviFormer,
Permission,
User,
ExternalDatabase,
)
class Command(BaseCommand): class Command(BaseCommand):
def create_users(self): def create_users(self):
# Anonymous User # Anonymous User
if not User.objects.filter(email='anon@envipath.com').exists(): if not User.objects.filter(email="anon@envipath.com").exists():
anon = UserManager.create_user("anonymous", "anon@envipath.com", "SuperSafe", anon = UserManager.create_user(
is_active=True, add_to_group=False, set_setting=False) "anonymous",
"anon@envipath.com",
"SuperSafe",
is_active=True,
add_to_group=False,
set_setting=False,
)
else: else:
anon = User.objects.get(email='anon@envipath.com') anon = User.objects.get(email="anon@envipath.com")
# Admin User # Admin User
if not User.objects.filter(email='admin@envipath.com').exists(): if not User.objects.filter(email="admin@envipath.com").exists():
admin = UserManager.create_user("admin", "admin@envipath.com", "SuperSafe", admin = UserManager.create_user(
is_active=True, add_to_group=False, set_setting=False) "admin",
"admin@envipath.com",
"SuperSafe",
is_active=True,
add_to_group=False,
set_setting=False,
)
admin.is_staff = True admin.is_staff = True
admin.is_superuser = True admin.is_superuser = True
admin.save() admin.save()
else: else:
admin = User.objects.get(email='admin@envipath.com') admin = User.objects.get(email="admin@envipath.com")
# System Group # System Group
g = GroupManager.create_group(admin, 'enviPath Users', 'All enviPath Users') g = GroupManager.create_group(admin, "enviPath Users", "All enviPath Users")
g.public = True g.public = True
g.save() g.save()
@ -43,14 +60,20 @@ class Command(BaseCommand):
admin.default_group = g admin.default_group = g
admin.save() admin.save()
if not User.objects.filter(email='user0@envipath.com').exists(): if not User.objects.filter(email="user0@envipath.com").exists():
user0 = UserManager.create_user("user0", "user0@envipath.com", "SuperSafe", user0 = UserManager.create_user(
is_active=True, add_to_group=False, set_setting=False) "user0",
"user0@envipath.com",
"SuperSafe",
is_active=True,
add_to_group=False,
set_setting=False,
)
user0.is_staff = True user0.is_staff = True
user0.is_superuser = True user0.is_superuser = True
user0.save() user0.save()
else: else:
user0 = User.objects.get(email='user0@envipath.com') user0 = User.objects.get(email="user0@envipath.com")
g.user_member.add(user0) g.user_member.add(user0)
g.save() g.save()
@ -61,18 +84,20 @@ class Command(BaseCommand):
return anon, admin, g, user0 return anon, admin, g, user0
def import_package(self, data, owner): def import_package(self, data, owner):
return PackageManager.import_legacy_package(data, owner, keep_ids=True, add_import_timestamp=False, trust_reviewed=True) return PackageManager.import_legacy_package(
data, owner, keep_ids=True, add_import_timestamp=False, trust_reviewed=True
)
def create_default_setting(self, owner, packages): def create_default_setting(self, owner, packages):
s = SettingManager.create_setting( s = SettingManager.create_setting(
owner, owner,
name='Global Default Setting', name="Global Default Setting",
description='Global Default Setting containing BBD Rules and Max 30 Nodes and Max Depth of 8', description="Global Default Setting containing BBD Rules and Max 30 Nodes and Max Depth of 8",
max_nodes=30, max_nodes=30,
max_depth=5, max_depth=5,
rule_packages=packages, rule_packages=packages,
model=None, model=None,
model_threshold=None model_threshold=None,
) )
return s return s
@ -84,54 +109,51 @@ class Command(BaseCommand):
""" """
databases = [ databases = [
{ {
'name': 'PubChem Compound', "name": "PubChem Compound",
'full_name': 'PubChem Compound Database', "full_name": "PubChem Compound Database",
'description': 'Chemical database of small organic molecules', "description": "Chemical database of small organic molecules",
'base_url': 'https://pubchem.ncbi.nlm.nih.gov', "base_url": "https://pubchem.ncbi.nlm.nih.gov",
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}' "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/compound/{id}",
}, },
{ {
'name': 'PubChem Substance', "name": "PubChem Substance",
'full_name': 'PubChem Substance Database', "full_name": "PubChem Substance Database",
'description': 'Database of chemical substances', "description": "Database of chemical substances",
'base_url': 'https://pubchem.ncbi.nlm.nih.gov', "base_url": "https://pubchem.ncbi.nlm.nih.gov",
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/substance/{id}' "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/substance/{id}",
}, },
{ {
'name': 'ChEBI', "name": "ChEBI",
'full_name': 'Chemical Entities of Biological Interest', "full_name": "Chemical Entities of Biological Interest",
'description': 'Dictionary of molecular entities', "description": "Dictionary of molecular entities",
'base_url': 'https://www.ebi.ac.uk/chebi', "base_url": "https://www.ebi.ac.uk/chebi",
'url_pattern': 'https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{id}' "url_pattern": "https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{id}",
}, },
{ {
'name': 'RHEA', "name": "RHEA",
'full_name': 'RHEA Reaction Database', "full_name": "RHEA Reaction Database",
'description': 'Comprehensive resource of biochemical reactions', "description": "Comprehensive resource of biochemical reactions",
'base_url': 'https://www.rhea-db.org', "base_url": "https://www.rhea-db.org",
'url_pattern': 'https://www.rhea-db.org/rhea/{id}' "url_pattern": "https://www.rhea-db.org/rhea/{id}",
}, },
{ {
'name': 'KEGG Reaction', "name": "KEGG Reaction",
'full_name': 'KEGG Reaction Database', "full_name": "KEGG Reaction Database",
'description': 'Database of biochemical reactions', "description": "Database of biochemical reactions",
'base_url': 'https://www.genome.jp', "base_url": "https://www.genome.jp",
'url_pattern': 'https://www.genome.jp/entry/{id}' "url_pattern": "https://www.genome.jp/entry/{id}",
}, },
{ {
'name': 'UniProt', "name": "UniProt",
'full_name': 'MetaCyc Metabolic Pathway Database', "full_name": "MetaCyc Metabolic Pathway Database",
'description': 'UniProt is a freely accessible database of protein sequence and functional information', "description": "UniProt is a freely accessible database of protein sequence and functional information",
'base_url': 'https://www.uniprot.org', "base_url": "https://www.uniprot.org",
'url_pattern': 'https://www.uniprot.org/uniprotkb?query="{id}"' "url_pattern": 'https://www.uniprot.org/uniprotkb?query="{id}"',
} },
] ]
for db_info in databases: for db_info in databases:
ExternalDatabase.objects.get_or_create( ExternalDatabase.objects.get_or_create(name=db_info["name"], defaults=db_info)
name=db_info['name'],
defaults=db_info
)
@transaction.atomic @transaction.atomic
def handle(self, *args, **options): def handle(self, *args, **options):
@ -142,20 +164,24 @@ class Command(BaseCommand):
# Import Packages # Import Packages
packages = [ packages = [
'EAWAG-BBD.json', "EAWAG-BBD.json",
'EAWAG-SOIL.json', "EAWAG-SOIL.json",
'EAWAG-SLUDGE.json', "EAWAG-SLUDGE.json",
'EAWAG-SEDIMENT.json', "EAWAG-SEDIMENT.json",
] ]
mapping = {} mapping = {}
for p in packages: for p in packages:
print(f"Importing {p}...") print(f"Importing {p}...")
package_data = json.loads(open(s.BASE_DIR / 'fixtures' / 'packages' / '2025-07-18' / p, encoding='utf-8').read()) package_data = json.loads(
open(
s.BASE_DIR / "fixtures" / "packages" / "2025-07-18" / p, encoding="utf-8"
).read()
)
imported_package = self.import_package(package_data, admin) imported_package = self.import_package(package_data, admin)
mapping[p.replace('.json', '')] = imported_package mapping[p.replace(".json", "")] = imported_package
setting = self.create_default_setting(admin, [mapping['EAWAG-BBD']]) setting = self.create_default_setting(admin, [mapping["EAWAG-BBD"]])
setting.public = True setting.public = True
setting.save() setting.save()
setting.make_global_default() setting.make_global_default()
@ -171,26 +197,28 @@ class Command(BaseCommand):
usp.save() usp.save()
# Create Model Package # Create Model Package
pack = PackageManager.create_package(admin, "Public Prediction Models", pack = PackageManager.create_package(
"Package to make Prediction Models publicly available") admin,
"Public Prediction Models",
"Package to make Prediction Models publicly available",
)
pack.reviewed = True pack.reviewed = True
pack.save() pack.save()
# Create RR # Create RR
ml_model = MLRelativeReasoning.create( ml_model = MLRelativeReasoning.create(
package=pack, package=pack,
rule_packages=[mapping['EAWAG-BBD']], rule_packages=[mapping["EAWAG-BBD"]],
data_packages=[mapping['EAWAG-BBD']], data_packages=[mapping["EAWAG-BBD"]],
eval_packages=[], eval_packages=[],
threshold=0.5, threshold=0.5,
name='ECC - BBD - T0.5', name="ECC - BBD - T0.5",
description='ML Relative Reasoning', description="ML Relative Reasoning",
) )
ml_model.build_dataset() ml_model.build_dataset()
ml_model.build_model() ml_model.build_model()
# ml_model.evaluate_model()
# If available, create EnviFormerModel # If available, create EnviFormerModel
if s.ENVIFORMER_PRESENT: if s.ENVIFORMER_PRESENT:
enviFormer_model = EnviFormer.create(pack, 'EnviFormer - T0.5', 'EnviFormer Model with Threshold 0.5', 0.5) EnviFormer.create(pack, "EnviFormer - T0.5", "EnviFormer Model with Threshold 0.5", 0.5)

View File

@ -12,11 +12,28 @@ class Command(BaseCommand):
the below command would be used: the below command would be used:
`python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge `python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge
""" """
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("model_names", nargs="+", type=str, help="The names of models to train. Options are: enviformer, mlrr") parser.add_argument(
parser.add_argument("-d", "--data-packages", nargs="+", type=str, help="Packages for training") "model_names",
parser.add_argument("-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[]) nargs="+",
parser.add_argument("-r", "--rule-packages", nargs="*", type=str, help="Rule Packages mandatory for MLRR", default=[]) type=str,
help="The names of models to train. Options are: enviformer, mlrr",
)
parser.add_argument(
"-d", "--data-packages", nargs="+", type=str, help="Packages for training"
)
parser.add_argument(
"-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[]
)
parser.add_argument(
"-r",
"--rule-packages",
nargs="*",
type=str,
help="Rule Packages mandatory for MLRR",
default=[],
)
@transaction.atomic @transaction.atomic
def handle(self, *args, **options): def handle(self, *args, **options):
@ -28,7 +45,9 @@ class Command(BaseCommand):
sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0] sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0]
sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0] sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0]
except IndexError: except IndexError:
raise IndexError("Can't find correct packages. They should be created with the bootstrap command") raise IndexError(
"Can't find correct packages. They should be created with the bootstrap command"
)
def decode_packages(package_list): def decode_packages(package_list):
"""Decode package strings into their respective packages""" """Decode package strings into their respective packages"""
@ -52,15 +71,27 @@ class Command(BaseCommand):
data_packages = decode_packages(options["data_packages"]) data_packages = decode_packages(options["data_packages"])
eval_packages = decode_packages(options["eval_packages"]) eval_packages = decode_packages(options["eval_packages"])
rule_packages = decode_packages(options["rule_packages"]) rule_packages = decode_packages(options["rule_packages"])
for model_name in options['model_names']: for model_name in options["model_names"]:
model_name = model_name.lower() model_name = model_name.lower()
if model_name == "enviformer" and s.ENVIFORMER_PRESENT: if model_name == "enviformer" and s.ENVIFORMER_PRESENT:
model = EnviFormer.create(pack, data_packages=data_packages, eval_packages=eval_packages, threshold=0.5, model = EnviFormer.create(
name="EnviFormer - T0.5", description="EnviFormer transformer") pack,
data_packages=data_packages,
eval_packages=eval_packages,
threshold=0.5,
name="EnviFormer - T0.5",
description="EnviFormer transformer",
)
elif model_name == "mlrr": elif model_name == "mlrr":
model = MLRelativeReasoning.create(package=pack, rule_packages=rule_packages, model = MLRelativeReasoning.create(
data_packages=data_packages, eval_packages=eval_packages, threshold=0.5, package=pack,
name='ECC - BBD - T0.5', description='ML Relative Reasoning') rule_packages=rule_packages,
data_packages=data_packages,
eval_packages=eval_packages,
threshold=0.5,
name="ECC - BBD - T0.5",
description="ML Relative Reasoning",
)
else: else:
raise ValueError(f"Cannot create model of type {model_name}, unknown model type") raise ValueError(f"Cannot create model of type {model_name}, unknown model type")
# Build the dataset for the model, train it, evaluate it and save it # Build the dataset for the model, train it, evaluate it and save it

View File

@ -1,57 +1,58 @@
from csv import DictReader from csv import DictReader
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.models import * from epdb.models import Compound, CompoundStructure, Reaction, ExternalDatabase, ExternalIdentifier
class Command(BaseCommand): class Command(BaseCommand):
STR_TO_MODEL = { STR_TO_MODEL = {
'Compound': Compound, "Compound": Compound,
'CompoundStructure': CompoundStructure, "CompoundStructure": CompoundStructure,
'Reaction': Reaction, "Reaction": Reaction,
} }
STR_TO_DATABASE = { STR_TO_DATABASE = {
'ChEBI': ExternalDatabase.objects.get(name='ChEBI'), "ChEBI": ExternalDatabase.objects.get(name="ChEBI"),
'RHEA': ExternalDatabase.objects.get(name='RHEA'), "RHEA": ExternalDatabase.objects.get(name="RHEA"),
'KEGG Reaction': ExternalDatabase.objects.get(name='KEGG Reaction'), "KEGG Reaction": ExternalDatabase.objects.get(name="KEGG Reaction"),
'PubChem Compound': ExternalDatabase.objects.get(name='PubChem Compound'), "PubChem Compound": ExternalDatabase.objects.get(name="PubChem Compound"),
'PubChem Substance': ExternalDatabase.objects.get(name='PubChem Substance'), "PubChem Substance": ExternalDatabase.objects.get(name="PubChem Substance"),
} }
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--data', "--data",
type=str, type=str,
help='Path of the ID Mapping file.', help="Path of the ID Mapping file.",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
'--replace-host', "--replace-host",
type=str, type=str,
help='Replace https://envipath.org/ with this host, e.g. http://localhost:8000/', help="Replace https://envipath.org/ with this host, e.g. http://localhost:8000/",
) )
@transaction.atomic @transaction.atomic
def handle(self, *args, **options): def handle(self, *args, **options):
with open(options['data']) as fh: with open(options["data"]) as fh:
reader = DictReader(fh) reader = DictReader(fh)
for row in reader: for row in reader:
clz = self.STR_TO_MODEL[row['model']] clz = self.STR_TO_MODEL[row["model"]]
url = row['url'] url = row["url"]
if options['replace_host']: if options["replace_host"]:
url = url.replace('https://envipath.org/', options['replace_host']) url = url.replace("https://envipath.org/", options["replace_host"])
instance = clz.objects.get(url=url) instance = clz.objects.get(url=url)
db = self.STR_TO_DATABASE[row['identifier_type']] db = self.STR_TO_DATABASE[row["identifier_type"]]
ExternalIdentifier.objects.create( ExternalIdentifier.objects.create(
content_object=instance, content_object=instance,
database=db, database=db,
identifier_value=row['identifier_value'], identifier_value=row["identifier_value"],
url=db.url_pattern.format(id=row['identifier_value']), url=db.url_pattern.format(id=row["identifier_value"]),
is_primary=False is_primary=False,
) )

View File

@ -1,27 +1,29 @@
import json
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import * from epdb.models import User
class Command(BaseCommand): class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--data', "--data",
type=str, type=str,
help='Path of the Package to import.', help="Path of the Package to import.",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
'--owner', "--owner",
type=str, type=str,
help='Username of the desired Owner.', help="Username of the desired Owner.",
required=True, required=True,
) )
@transaction.atomic @transaction.atomic
def handle(self, *args, **options): def handle(self, *args, **options):
owner = User.objects.get(username=options['owner']) owner = User.objects.get(username=options["owner"])
package_data = json.load(open(options['data'])) package_data = json.load(open(options["data"]))
PackageManager.import_legacy_package(package_data, owner) PackageManager.import_legacy_package(package_data, owner)

View File

@ -6,46 +6,45 @@ from django.db.models.functions import Replace
class Command(BaseCommand): class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--old', "--old",
type=str, type=str,
help='Old Host, most likely https://envipath.org/', help="Old Host, most likely https://envipath.org/",
required=True, required=True,
) )
parser.add_argument( parser.add_argument(
'--new', "--new",
type=str, type=str,
help='New Host, most likely http://localhost:8000/', help="New Host, most likely http://localhost:8000/",
required=True, required=True,
) )
def handle(self, *args, **options): def handle(self, *args, **options):
MODELS = [ MODELS = [
'User', "User",
'Group', "Group",
'Package', "Package",
'Compound', "Compound",
'CompoundStructure', "CompoundStructure",
'Pathway', "Pathway",
'Edge', "Edge",
'Node', "Node",
'Reaction', "Reaction",
'SimpleAmbitRule', "SimpleAmbitRule",
'SimpleRDKitRule', "SimpleRDKitRule",
'ParallelRule', "ParallelRule",
'SequentialRule', "SequentialRule",
'Scenario', "Scenario",
'Setting', "Setting",
'MLRelativeReasoning', "MLRelativeReasoning",
'RuleBasedRelativeReasoning', "RuleBasedRelativeReasoning",
'EnviFormer', "EnviFormer",
'ApplicabilityDomain', "ApplicabilityDomain",
] ]
for model in MODELS: for model in MODELS:
obj_cls = apps.get_model("epdb", model) obj_cls = apps.get_model("epdb", model)
print(f"Localizing urls for {model}") print(f"Localizing urls for {model}")
obj_cls.objects.update( obj_cls.objects.update(
url=Replace(F('url'), Value(options['old']), Value(options['new'])) url=Replace(F("url"), Value(options["old"]), Value(options["new"]))
) )

View File

@ -3,22 +3,25 @@ from django.shortcuts import redirect
from django.urls import reverse from django.urls import reverse
from urllib.parse import quote from urllib.parse import quote
class LoginRequiredMiddleware: class LoginRequiredMiddleware:
def __init__(self, get_response): def __init__(self, get_response):
self.get_response = get_response self.get_response = get_response
self.exempt_urls = [ self.exempt_urls = [
reverse('login'), reverse("login"),
reverse('logout'), reverse("logout"),
reverse('admin:login'), reverse("admin:login"),
reverse('admin:index'), reverse("admin:index"),
] + getattr(settings, 'LOGIN_EXEMPT_URLS', []) ] + getattr(settings, "LOGIN_EXEMPT_URLS", [])
def __call__(self, request): def __call__(self, request):
if not request.user.is_authenticated: if not request.user.is_authenticated:
path = request.path_info path = request.path_info
if not any(path.startswith(url) for url in self.exempt_urls): if not any(path.startswith(url) for url in self.exempt_urls):
if request.method == 'GET': if request.method == "GET":
if request.get_full_path() and request.get_full_path() != '/': if request.get_full_path() and request.get_full_path() != "/":
return redirect(f"{settings.LOGIN_URL}?next={quote(request.get_full_path())}") return redirect(
f"{settings.LOGIN_URL}?next={quote(request.get_full_path())}"
)
return redirect(settings.LOGIN_URL) return redirect(settings.LOGIN_URL)
return self.get_response(request) return self.get_response(request)

File diff suppressed because it is too large Load Diff

View File

@ -2,55 +2,57 @@ import logging
from typing import Optional from typing import Optional
from celery import shared_task from celery import shared_task
from epdb.models import Pathway, Node, Edge, EPModel, Setting from epdb.models import Pathway, Node, EPModel, Setting
from epdb.logic import SPathway from epdb.logic import SPathway
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@shared_task(queue='background') @shared_task(queue="background")
def mul(a, b): def mul(a, b):
return a * b return a * b
@shared_task(queue='predict') @shared_task(queue="predict")
def predict_simple(model_pk: int, smiles: str): def predict_simple(model_pk: int, smiles: str):
mod = EPModel.objects.get(id=model_pk) mod = EPModel.objects.get(id=model_pk)
res = mod.predict(smiles) res = mod.predict(smiles)
return res return res
@shared_task(queue='background') @shared_task(queue="background")
def send_registration_mail(user_pk: int): def send_registration_mail(user_pk: int):
pass pass
@shared_task(queue='model') @shared_task(queue="model")
def build_model(model_pk: int): def build_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk) mod = EPModel.objects.get(id=model_pk)
mod.build_dataset() mod.build_dataset()
mod.build_model() mod.build_model()
@shared_task(queue='model') @shared_task(queue="model")
def evaluate_model(model_pk: int): def evaluate_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk) mod = EPModel.objects.get(id=model_pk)
mod.evaluate_model() mod.evaluate_model()
@shared_task(queue='model') @shared_task(queue="model")
def retrain(model_pk: int): def retrain(model_pk: int):
mod = EPModel.objects.get(id=model_pk) mod = EPModel.objects.get(id=model_pk)
mod.retrain() mod.retrain()
@shared_task(queue='predict') @shared_task(queue="predict")
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway: def predict(
pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk) pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk) setting = Setting.objects.get(id=pred_setting_pk)
pw.kv.update(**{'status': 'running'}) pw.kv.update(**{"status": "running"})
pw.save() pw.save()
try: try:
@ -74,12 +76,10 @@ def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_
else: else:
raise ValueError("Neither limit nor node_pk given!") raise ValueError("Neither limit nor node_pk given!")
except Exception as e: except Exception as e:
pw.kv.update({'status': 'failed'}) pw.kv.update({"status": "failed"})
pw.save() pw.save()
raise e raise e
pw.kv.update(**{'status': 'completed'}) pw.kv.update(**{"status": "completed"})
pw.save() pw.save()

View File

@ -2,6 +2,7 @@ from django import template
register = template.Library() register = template.Library()
@register.filter @register.filter
def classname(obj): def classname(obj):
return obj.__class__.__name__ return obj.__class__.__name__

View File

@ -1,3 +0,0 @@
from django.test import TestCase
# Create your tests here.

View File

@ -3,97 +3,177 @@ from django.contrib.auth import views as auth_views
from . import views as v from . import views as v
UUID = '[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}' UUID = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
urlpatterns = [ urlpatterns = [
# Home # Home
re_path(r'^$', v.index, name='index'), re_path(r"^$", v.index, name="index"),
# Login # Login
re_path(r'^login', v.login, name='login'), re_path(r"^login", v.login, name="login"),
re_path(r'^logout', v.logout, name='logout'), re_path(r"^logout", v.logout, name="logout"),
re_path(r'^register', v.register, name='register'), re_path(r"^register", v.register, name="register"),
# Built-In views
# Built In views path(
path('password_reset/', auth_views.PasswordResetView.as_view( "password_reset/",
template_name='static/password_reset_form.html' auth_views.PasswordResetView.as_view(template_name="static/password_reset_form.html"),
), name='password_reset'), name="password_reset",
),
path('password_reset/done/', auth_views.PasswordResetDoneView.as_view( path(
template_name='static/password_reset_done.html' "password_reset/done/",
), name='password_reset_done'), auth_views.PasswordResetDoneView.as_view(template_name="static/password_reset_done.html"),
name="password_reset_done",
path('reset/<uidb64>/<token>/', auth_views.PasswordResetConfirmView.as_view( ),
template_name='static/password_reset_confirm.html' path(
), name='password_reset_confirm'), "reset/<uidb64>/<token>/",
auth_views.PasswordResetConfirmView.as_view(
path('reset/done/', auth_views.PasswordResetCompleteView.as_view( template_name="static/password_reset_confirm.html"
template_name='static/password_reset_complete.html' ),
), name='password_reset_complete'), name="password_reset_confirm",
),
path(
"reset/done/",
auth_views.PasswordResetCompleteView.as_view(
template_name="static/password_reset_complete.html"
),
name="password_reset_complete",
),
# Top level urls # Top level urls
re_path(r'^package$', v.packages, name='packages'), re_path(r"^package$", v.packages, name="packages"),
re_path(r'^compound$', v.compounds, name='compounds'), re_path(r"^compound$", v.compounds, name="compounds"),
re_path(r'^rule$', v.rules, name='rules'), re_path(r"^rule$", v.rules, name="rules"),
re_path(r'^reaction$', v.reactions, name='reactions'), re_path(r"^reaction$", v.reactions, name="reactions"),
re_path(r'^pathway$', v.pathways, name='pathways'), re_path(r"^pathway$", v.pathways, name="pathways"),
re_path(r'^scenario$', v.scenarios, name='scenarios'), re_path(r"^scenario$", v.scenarios, name="scenarios"),
re_path(r'^model$', v.models, name='model'), re_path(r"^model$", v.models, name="model"),
re_path(r'^user$', v.users, name='users'), re_path(r"^user$", v.users, name="users"),
re_path(r'^group$', v.groups, name='groups'), re_path(r"^group$", v.groups, name="groups"),
re_path(r'^search$', v.search, name='search'), re_path(r"^search$", v.search, name="search"),
# User Detail # User Detail
re_path(rf'^user/(?P<user_uuid>{UUID})', v.user, name='user'), re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
# Group Detail # Group Detail
re_path(rf'^group/(?P<group_uuid>{UUID})$', v.group, name='group detail'), re_path(rf"^group/(?P<group_uuid>{UUID})$", v.group, name="group detail"),
# "in package" urls # "in package" urls
re_path(rf'^package/(?P<package_uuid>{UUID})$', v.package, name='package detail'), re_path(rf"^package/(?P<package_uuid>{UUID})$", v.package, name="package detail"),
# Compound # Compound
re_path(rf'^package/(?P<package_uuid>{UUID})/compound$', v.package_compounds, name='package compound list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})$', v.package_compound, name='package compound detail'), rf"^package/(?P<package_uuid>{UUID})/compound$",
v.package_compounds,
name="package compound list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})$",
v.package_compound,
name="package compound detail",
),
# Compound Structure # Compound Structure
re_path(rf'^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure$', v.package_compound_structures, name='package compound structure list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure/(?P<structure_uuid>{UUID})$', v.package_compound_structure, name='package compound structure detail'), rf"^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure$",
v.package_compound_structures,
name="package compound structure list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure/(?P<structure_uuid>{UUID})$",
v.package_compound_structure,
name="package compound structure detail",
),
# Rule # Rule
re_path(rf'^package/(?P<package_uuid>{UUID})/rule$', v.package_rules, name='package rule list'), re_path(rf"^package/(?P<package_uuid>{UUID})/rule$", v.package_rules, name="package rule list"),
re_path(rf'^package/(?P<package_uuid>{UUID})/rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/simple-ambit-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'), rf"^package/(?P<package_uuid>{UUID})/rule/(?P<rule_uuid>{UUID})$",
re_path(rf'^package/(?P<package_uuid>{UUID})/simple-rdkit-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'), v.package_rule,
re_path(rf'^package/(?P<package_uuid>{UUID})/parallel-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'), name="package rule detail",
re_path(rf'^package/(?P<package_uuid>{UUID})/sequential-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'), ),
re_path(
rf"^package/(?P<package_uuid>{UUID})/simple-ambit-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/simple-rdkit-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/parallel-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/sequential-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
# Reaction # Reaction
re_path(rf'^package/(?P<package_uuid>{UUID})/reaction$', v.package_reactions, name='package reaction list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/reaction/(?P<reaction_uuid>{UUID})$', v.package_reaction, name='package reaction detail'), rf"^package/(?P<package_uuid>{UUID})/reaction$",
v.package_reactions,
name="package reaction list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/reaction/(?P<reaction_uuid>{UUID})$",
v.package_reaction,
name="package reaction detail",
),
# # Pathway # # Pathway
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway$', v.package_pathways, name='package pathway list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})$', v.package_pathway, name='package pathway detail'), rf"^package/(?P<package_uuid>{UUID})/pathway$",
v.package_pathways,
name="package pathway list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})$",
v.package_pathway,
name="package pathway detail",
),
# Pathway Nodes # Pathway Nodes
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node$', v.package_pathway_nodes, name='package pathway node list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node/(?P<node_uuid>{UUID})$', v.package_pathway_node, name='package pathway node detail'), rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node$",
v.package_pathway_nodes,
name="package pathway node list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node/(?P<node_uuid>{UUID})$",
v.package_pathway_node,
name="package pathway node detail",
),
# Pathway Edges # Pathway Edges
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge$', v.package_pathway_edges, name='package pathway edge list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge/(?P<edge_uuid>{UUID})$', v.package_pathway_edge, name='package pathway edge detail'), rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge$",
v.package_pathway_edges,
name="package pathway edge list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge/(?P<edge_uuid>{UUID})$",
v.package_pathway_edge,
name="package pathway edge detail",
),
# Scenario # Scenario
re_path(rf'^package/(?P<package_uuid>{UUID})/scenario$', v.package_scenarios, name='package scenario list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/scenario/(?P<scenario_uuid>{UUID})$', v.package_scenario, name='package scenario detail'), rf"^package/(?P<package_uuid>{UUID})/scenario$",
v.package_scenarios,
name="package scenario list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/scenario/(?P<scenario_uuid>{UUID})$",
v.package_scenario,
name="package scenario detail",
),
# Model # Model
re_path(rf'^package/(?P<package_uuid>{UUID})/model$', v.package_models, name='package model list'), re_path(
re_path(rf'^package/(?P<package_uuid>{UUID})/model/(?P<model_uuid>{UUID})$', v.package_model,name='package model detail'), rf"^package/(?P<package_uuid>{UUID})/model$", v.package_models, name="package model list"
),
re_path(r'^setting$', v.settings, name='settings'), re_path(
re_path(rf'^setting/(?P<setting_uuid>{UUID})', v.setting, name='setting'), rf"^package/(?P<package_uuid>{UUID})/model/(?P<model_uuid>{UUID})$",
v.package_model,
re_path(r'^indigo/info$', v.indigo, name='indigo_info'), name="package model detail",
re_path(r'^indigo/aromatize$', v.aromatize, name='indigo_aromatize'), ),
re_path(r'^indigo/dearomatize$', v.dearomatize, name='indigo_dearomatize'), re_path(r"^setting$", v.settings, name="settings"),
re_path(r'^indigo/layout$', v.layout, name='indigo_layout'), re_path(rf"^setting/(?P<setting_uuid>{UUID})", v.setting, name="setting"),
re_path(r"^indigo/info$", v.indigo, name="indigo_info"),
re_path(r'^depict$', v.depict, name='depict'), re_path(r"^indigo/aromatize$", v.aromatize, name="indigo_aromatize"),
re_path(r"^indigo/dearomatize$", v.dearomatize, name="indigo_dearomatize"),
re_path(r"^indigo/layout$", v.layout, name="indigo_layout"),
re_path(r"^depict$", v.depict, name="depict"),
# OAuth Stuff # OAuth Stuff
path("o/userinfo/", v.userinfo, name="oauth_userinfo"), path("o/userinfo/", v.userinfo, name="oauth_userinfo"),
] ]

File diff suppressed because it is too large Load Diff

View File

@ -13,91 +13,80 @@ class CompoundTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(CompoundTest, cls).setUpClass() super(CompoundTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self): def test_smoke(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name='Afoxolaner', name="Afoxolaner",
description='No Desc' description="No Desc",
) )
self.assertEqual(c.default_structure.smiles, self.assertEqual(
'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F') c.default_structure.smiles,
self.assertEqual(c.name, 'Afoxolaner') "C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
self.assertEqual(c.description, 'No Desc') )
self.assertEqual(c.name, "Afoxolaner")
self.assertEqual(c.description, "No Desc")
def test_missing_smiles(self): def test_missing_smiles(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Compound.create( _ = Compound.create(self.package, smiles=None, name="Afoxolaner", description="No Desc")
self.package,
smiles=None,
name='Afoxolaner',
description='No Desc'
)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Compound.create( _ = Compound.create(self.package, smiles="", name="Afoxolaner", description="No Desc")
self.package,
smiles='',
name='Afoxolaner',
description='No Desc'
)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Compound.create( _ = Compound.create(self.package, smiles=" ", name="Afoxolaner", description="No Desc")
self.package,
smiles=' ',
name='Afoxolaner',
description='No Desc'
)
def test_smiles_are_trimmed(self): def test_smiles_are_trimmed(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles=' C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F ', smiles=" C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F ",
name='Afoxolaner', name="Afoxolaner",
description='No Desc' description="No Desc",
) )
self.assertEqual(c.default_structure.smiles, self.assertEqual(
'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F') c.default_structure.smiles,
"C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
)
def test_name_and_description_optional(self): def test_name_and_description_optional(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
) )
self.assertEqual(c.name, 'Compound 1') self.assertEqual(c.name, "Compound 1")
self.assertEqual(c.description, 'no description') self.assertEqual(c.description, "no description")
def test_empty_name_and_description_are_ignored(self): def test_empty_name_and_description_are_ignored(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name='', name="",
description='', description="",
) )
self.assertEqual(c.name, 'Compound 1') self.assertEqual(c.name, "Compound 1")
self.assertEqual(c.description, 'no description') self.assertEqual(c.description, "no description")
def test_deduplication(self): def test_deduplication(self):
c1 = Compound.create( c1 = Compound.create(
self.package, self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name='Afoxolaner', name="Afoxolaner",
description='No Desc' description="No Desc",
) )
c2 = Compound.create( c2 = Compound.create(
self.package, self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name='Afoxolaner', name="Afoxolaner",
description='No Desc' description="No Desc",
) )
# Check if create detects that this Compound already exist # Check if create detects that this Compound already exist
@ -109,36 +98,36 @@ class CompoundTest(TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Compound.create( _ = Compound.create(
self.package, self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name='Afoxolaner', name="Afoxolaner",
description='No Desc' description="No Desc",
) )
def test_create_with_standardized_smiles(self): def test_create_with_standardized_smiles(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name='Standardized SMILES', name="Standardized SMILES",
description='No Desc' description="No Desc",
) )
self.assertEqual(len(c.structures.all()), 1) self.assertEqual(len(c.structures.all()), 1)
cs = c.structures.all()[0] cs = c.structures.all()[0]
self.assertEqual(cs.normalized_structure, True) self.assertEqual(cs.normalized_structure, True)
self.assertEqual(cs.smiles, 'O=C(O)C1=CC=C([N+](=O)[O-])C=C1') self.assertEqual(cs.smiles, "O=C(O)C1=CC=C([N+](=O)[O-])C=C1")
def test_create_with_non_standardized_smiles(self): def test_create_with_non_standardized_smiles(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='[O-][N+](=O)c1ccc(C(=O)[O-])cc1', smiles="[O-][N+](=O)c1ccc(C(=O)[O-])cc1",
name='Non Standardized SMILES', name="Non Standardized SMILES",
description='No Desc' description="No Desc",
) )
self.assertEqual(len(c.structures.all()), 2) self.assertEqual(len(c.structures.all()), 2)
for cs in c.structures.all(): for cs in c.structures.all():
if cs.normalized_structure: if cs.normalized_structure:
self.assertEqual(cs.smiles, 'O=C(O)C1=CC=C([N+](=O)[O-])C=C1') self.assertEqual(cs.smiles, "O=C(O)C1=CC=C([N+](=O)[O-])C=C1")
break break
else: else:
# Loop finished without break, lets fail... # Loop finished without break, lets fail...
@ -147,51 +136,54 @@ class CompoundTest(TestCase):
def test_add_structure_smoke(self): def test_add_structure_smoke(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name='Standardized SMILES', name="Standardized SMILES",
description='No Desc' description="No Desc",
) )
c.add_structure('[O-][N+](=O)c1ccc(C(=O)[O-])cc1', 'Non Standardized SMILES') c.add_structure("[O-][N+](=O)c1ccc(C(=O)[O-])cc1", "Non Standardized SMILES")
self.assertEqual(len(c.structures.all()), 2) self.assertEqual(len(c.structures.all()), 2)
def test_add_structure_with_different_normalized_smiles(self): def test_add_structure_with_different_normalized_smiles(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name='Standardized SMILES', name="Standardized SMILES",
description='No Desc' description="No Desc",
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
c.add_structure( c.add_structure(
'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', "C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
'Different Standardized SMILES') "Different Standardized SMILES",
)
def test_delete(self): def test_delete(self):
c = Compound.create( c = Compound.create(
self.package, self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name='Standardization Test', name="Standardization Test",
description='No Desc' description="No Desc",
) )
c.delete() c.delete()
self.assertEqual(Compound.objects.filter(package=self.package).count(), 0) self.assertEqual(Compound.objects.filter(package=self.package).count(), 0)
self.assertEqual(CompoundStructure.objects.filter(compound__package=self.package).count(), 0) self.assertEqual(
CompoundStructure.objects.filter(compound__package=self.package).count(), 0
)
def test_set_as_default_structure(self): def test_set_as_default_structure(self):
c1 = Compound.create( c1 = Compound.create(
self.package, self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1', smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name='Standardized SMILES', name="Standardized SMILES",
description='No Desc' description="No Desc",
) )
default_structure = c1.default_structure default_structure = c1.default_structure
c2 = c1.add_structure('[O-][N+](=O)c1ccc(C(=O)[O-])cc1', 'Non Standardized SMILES') c2 = c1.add_structure("[O-][N+](=O)c1ccc(C(=O)[O-])cc1", "Non Standardized SMILES")
c1.set_default_structure(c2) c1.set_default_structure(c2)
self.assertNotEqual(default_structure, c2) self.assertNotEqual(default_structure, c2)

View File

@ -1,6 +1,5 @@
from django.test import TestCase from django.test import TestCase
from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import Compound, User, Reaction from epdb.models import Compound, User, Reaction
@ -12,50 +11,47 @@ class CopyTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(CopyTest, cls).setUpClass() super(CopyTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Source Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Source Package", "No Desc")
cls.AFOXOLANER = Compound.create( cls.AFOXOLANER = Compound.create(
cls.package, cls.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F', smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name='Afoxolaner', name="Afoxolaner",
description='Test compound for copying' description="Test compound for copying",
) )
cls.FOUR_NITROBENZOIC_ACID = Compound.create( cls.FOUR_NITROBENZOIC_ACID = Compound.create(
cls.package, cls.package,
smiles='[O-][N+](=O)c1ccc(C(=O)[O-])cc1', # Normalized: O=C(O)C1=CC=C([N+](=O)[O-])C=C1', smiles="[O-][N+](=O)c1ccc(C(=O)[O-])cc1", # Normalized: O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Test Compound', name="Test Compound",
description='Compound with multiple structures' description="Compound with multiple structures",
) )
cls.ETHANOL = Compound.create( cls.ETHANOL = Compound.create(
cls.package, cls.package, smiles="CCO", name="Ethanol", description="Simple alcohol"
smiles='CCO',
name='Ethanol',
description='Simple alcohol'
) )
cls.target_package = PackageManager.create_package(cls.user, 'Target Package', 'No Desc') cls.target_package = PackageManager.create_package(cls.user, "Target Package", "No Desc")
cls.reaction_educt = Compound.create( cls.reaction_educt = Compound.create(
cls.package, cls.package,
smiles='C(CCl)Cl', smiles="C(CCl)Cl",
name='1,2-Dichloroethane', name="1,2-Dichloroethane",
description='Eawag BBD compound c0001' description="Eawag BBD compound c0001",
).default_structure ).default_structure
cls.reaction_product = Compound.create( cls.reaction_product = Compound.create(
cls.package, cls.package,
smiles='C(CO)Cl', smiles="C(CO)Cl",
name='2-Chloroethanol', name="2-Chloroethanol",
description='Eawag BBD compound c0005' description="Eawag BBD compound c0005",
).default_structure ).default_structure
cls.REACTION = Reaction.create( cls.REACTION = Reaction.create(
package=cls.package, package=cls.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=[cls.reaction_educt], educts=[cls.reaction_educt],
products=[cls.reaction_product], products=[cls.reaction_product],
multi_step=False multi_step=False,
) )
def test_compound_copy_basic(self): def test_compound_copy_basic(self):
@ -68,7 +64,9 @@ class CopyTest(TestCase):
self.assertEqual(self.AFOXOLANER.description, copied_compound.description) self.assertEqual(self.AFOXOLANER.description, copied_compound.description)
self.assertEqual(copied_compound.package, self.target_package) self.assertEqual(copied_compound.package, self.target_package)
self.assertEqual(self.AFOXOLANER.package, self.package) self.assertEqual(self.AFOXOLANER.package, self.package)
self.assertEqual(self.AFOXOLANER.default_structure.smiles, copied_compound.default_structure.smiles) self.assertEqual(
self.AFOXOLANER.default_structure.smiles, copied_compound.default_structure.smiles
)
def test_compound_copy_with_multiple_structures(self): def test_compound_copy_with_multiple_structures(self):
"""Test copying a compound with multiple structures""" """Test copying a compound with multiple structures"""
@ -86,7 +84,7 @@ class CopyTest(TestCase):
self.assertIsNotNone(copied_compound.default_structure) self.assertIsNotNone(copied_compound.default_structure)
self.assertEqual( self.assertEqual(
copied_compound.default_structure.smiles, copied_compound.default_structure.smiles,
self.FOUR_NITROBENZOIC_ACID.default_structure.smiles self.FOUR_NITROBENZOIC_ACID.default_structure.smiles,
) )
def test_compound_copy_preserves_aliases(self): def test_compound_copy_preserves_aliases(self):
@ -95,15 +93,15 @@ class CopyTest(TestCase):
original_compound = self.ETHANOL original_compound = self.ETHANOL
# Add aliases if the method exists # Add aliases if the method exists
if hasattr(original_compound, 'add_alias'): if hasattr(original_compound, "add_alias"):
original_compound.add_alias('Ethyl alcohol') original_compound.add_alias("Ethyl alcohol")
original_compound.add_alias('Grain alcohol') original_compound.add_alias("Grain alcohol")
mapping = dict() mapping = dict()
copied_compound = original_compound.copy(self.target_package, mapping) copied_compound = original_compound.copy(self.target_package, mapping)
# Verify aliases were copied if they exist # Verify aliases were copied if they exist
if hasattr(original_compound, 'aliases') and hasattr(copied_compound, 'aliases'): if hasattr(original_compound, "aliases") and hasattr(copied_compound, "aliases"):
original_aliases = original_compound.aliases original_aliases = original_compound.aliases
copied_aliases = copied_compound.aliases copied_aliases = copied_compound.aliases
self.assertEqual(original_aliases, copied_aliases) self.assertEqual(original_aliases, copied_aliases)
@ -113,10 +111,10 @@ class CopyTest(TestCase):
original_compound = self.ETHANOL original_compound = self.ETHANOL
# Add external identifiers if the methods exist # Add external identifiers if the methods exist
if hasattr(original_compound, 'add_cas_number'): if hasattr(original_compound, "add_cas_number"):
original_compound.add_cas_number('64-17-5') original_compound.add_cas_number("64-17-5")
if hasattr(original_compound, 'add_pubchem_compound_id'): if hasattr(original_compound, "add_pubchem_compound_id"):
original_compound.add_pubchem_compound_id('702') original_compound.add_pubchem_compound_id("702")
mapping = dict() mapping = dict()
copied_compound = original_compound.copy(self.target_package, mapping) copied_compound = original_compound.copy(self.target_package, mapping)
@ -146,7 +144,9 @@ class CopyTest(TestCase):
self.assertEqual(original_structure.smiles, copied_structure.smiles) self.assertEqual(original_structure.smiles, copied_structure.smiles)
self.assertEqual(original_structure.canonical_smiles, copied_structure.canonical_smiles) self.assertEqual(original_structure.canonical_smiles, copied_structure.canonical_smiles)
self.assertEqual(original_structure.inchikey, copied_structure.inchikey) self.assertEqual(original_structure.inchikey, copied_structure.inchikey)
self.assertEqual(original_structure.normalized_structure, copied_structure.normalized_structure) self.assertEqual(
original_structure.normalized_structure, copied_structure.normalized_structure
)
# Verify they are different objects # Verify they are different objects
self.assertNotEqual(original_structure.uuid, copied_structure.uuid) self.assertNotEqual(original_structure.uuid, copied_structure.uuid)
@ -177,7 +177,9 @@ class CopyTest(TestCase):
self.assertEqual(orig_educt.compound.package, self.package) self.assertEqual(orig_educt.compound.package, self.package)
self.assertEqual(orig_educt.smiles, copy_educt.smiles) self.assertEqual(orig_educt.smiles, copy_educt.smiles)
for orig_product, copy_product in zip(self.REACTION.products.all(), copied_reaction.products.all()): for orig_product, copy_product in zip(
self.REACTION.products.all(), copied_reaction.products.all()
):
self.assertNotEqual(orig_product.uuid, copy_product.uuid) self.assertNotEqual(orig_product.uuid, copy_product.uuid)
self.assertEqual(orig_product.name, copy_product.name) self.assertEqual(orig_product.name, copy_product.name)
self.assertEqual(orig_product.description, copy_product.description) self.assertEqual(orig_product.description, copy_product.description)

View File

@ -11,21 +11,21 @@ class DatasetTest(TestCase):
def setUp(self): def setUp(self):
self.cs1 = Compound.create( self.cs1 = Compound.create(
self.package, self.package,
name='2,6-Dibromohydroquinone', name="2,6-Dibromohydroquinone",
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b', description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b",
smiles='C1=C(C(=C(C=C1O)Br)O)Br', smiles="C1=C(C(=C(C=C1O)Br)O)Br",
).default_structure ).default_structure
self.cs2 = Compound.create( self.cs2 = Compound.create(
self.package, self.package,
smiles='O=C(O)CC(=O)/C=C(/Br)C(=O)O', smiles="O=C(O)CC(=O)/C=C(/Br)C(=O)O",
).default_structure ).default_structure
self.rule1 = Rule.create( self.rule1 = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\[#6:3]=[#6:2](\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]', smirks="[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\\[#6:3]=[#6:2](\\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]",
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6' description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6",
) )
self.reaction1 = Reaction.create( self.reaction1 = Reaction.create(
@ -33,14 +33,14 @@ class DatasetTest(TestCase):
educts=[self.cs1], educts=[self.cs1],
products=[self.cs2], products=[self.cs2],
rules=[self.rule1], rules=[self.rule1],
multi_step=False multi_step=False,
) )
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(DatasetTest, cls).setUpClass() super(DatasetTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self): def test_smoke(self):
reactions = [r for r in Reaction.objects.filter(package=self.package)] reactions = [r for r in Reaction.objects.filter(package=self.package)]

View File

@ -1,18 +1,19 @@
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from django.test import TestCase from django.test import TestCase, tag
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import User, EnviFormer, Package from epdb.models import User, EnviFormer, Package
@tag("slow")
class EnviFormerTest(TestCase): class EnviFormerTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"] fixtures = ["test_fixtures.jsonl.gz"]
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(EnviFormerTest, cls).setUpClass() super(EnviFormerTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name='Fixtures') cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_model_flow(self): def test_model_flow(self):
"""Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference""" """Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference"""
@ -21,11 +22,14 @@ class EnviFormerTest(TestCase):
threshold = float(0.5) threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold) mod = EnviFormer.create(
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset() mod.build_dataset()
mod.build_model() mod.build_model()
mod.multigen_eval = True mod.multigen_eval = True
mod.save() mod.save()
mod.evaluate_model() mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -4,8 +4,7 @@ from utilities.chem import FormatConverter
class FormatConverterTestCase(TestCase): class FormatConverterTestCase(TestCase):
def test_standardization(self): def test_standardization(self):
smiles = 'C[n+]1c([N-](C))cccc1' smiles = "C[n+]1c([N-](C))cccc1"
standardized_smiles = FormatConverter.standardize(smiles) standardized_smiles = FormatConverter.standardize(smiles)
self.assertEqual(standardized_smiles, 'CN=C1C=CC=CN1C') self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")

View File

@ -4,7 +4,7 @@ import numpy as np
from django.test import TestCase from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package from epdb.models import User, MLRelativeReasoning, Package
class ModelTest(TestCase): class ModelTest(TestCase):
@ -13,9 +13,9 @@ class ModelTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(ModelTest, cls).setUpClass() super(ModelTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name='Fixtures') cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self): def test_smoke(self):
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
@ -32,8 +32,8 @@ class ModelTest(TestCase):
data_package_objs, data_package_objs,
eval_packages_objs, eval_packages_objs,
threshold=threshold, threshold=threshold,
name='ECC - BBD - 0.5', name="ECC - BBD - 0.5",
description='Created MLRelativeReasoning in Testcase', description="Created MLRelativeReasoning in Testcase",
) )
# mod = RuleBasedRelativeReasoning.create( # mod = RuleBasedRelativeReasoning.create(
@ -54,7 +54,7 @@ class ModelTest(TestCase):
mod.save() mod.save()
mod.evaluate_model() mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C') results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
products = dict() products = dict()
for r in results: for r in results:
@ -62,8 +62,11 @@ class ModelTest(TestCase):
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability) products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
expected = { expected = {
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)), ("CC=O", "CCNC(=O)C1=CC(C)=CC=C1"): (
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)), "bt0243-4301",
np.float64(0.33333333333333337),
),
("CC1=CC=CC(C(=O)O)=C1", "CCNCC"): ("bt0430-4011", np.float64(0.25)),
} }
self.assertEqual(products, expected) self.assertEqual(products, expected)

View File

@ -1,4 +1,3 @@
import json
from django.test import TestCase from django.test import TestCase
from networkx.utils.misc import graphs_equal from networkx.utils.misc import graphs_equal
from epdb.logic import PackageManager, SPathway from epdb.logic import PackageManager, SPathway
@ -12,9 +11,11 @@ class MultiGenTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(MultiGenTest, cls).setUpClass() super(MultiGenTest, cls).setUpClass()
cls.user: 'User' = User.objects.get(username='anonymous') cls.user: "User" = User.objects.get(username="anonymous")
cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package: "Package" = PackageManager.create_package(
cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures') cls.user, "Anon Test Package", "No Desc"
)
cls.BBD_SUBSET: "Package" = Package.objects.get(name="Fixtures")
def test_equal_pathways(self): def test_equal_pathways(self):
"""Test that two identical pathways return a precision and recall of 1.0""" """Test that two identical pathways return a precision and recall of 1.0"""
@ -23,14 +24,23 @@ class MultiGenTest(TestCase):
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue continue
score, precision, recall = multigen_eval(pathway, pathway) score, precision, recall = multigen_eval(pathway, pathway)
self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. " self.assertEqual(
f"Failed on pathway: {pathway.name}") precision,
self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. " 1.0,
f"Failed on pathway: {pathway.name}") f"Precision should be one for identical pathways. "
f"Failed on pathway: {pathway.name}",
)
self.assertEqual(
recall,
1.0,
f"Recall should be one for identical pathways. Failed on pathway: {pathway.name}",
)
def test_intermediates(self): def test_intermediates(self):
"""Test that an intermediate can be correctly identified and the metrics are correctly adjusted""" """Test that an intermediate can be correctly identified and the metrics are correctly adjusted"""
score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True) score, precision, recall, intermediates = multigen_eval(
*self.intermediate_case(), return_intermediates=True
)
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate") self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
self.assertEqual(precision, 1, "Precision should be 1") self.assertEqual(precision, 1, "Precision should be 1")
self.assertEqual(recall, 1, "Recall should be 1") self.assertEqual(recall, 1, "Recall should be 1")
@ -49,7 +59,9 @@ class MultiGenTest(TestCase):
def test_all(self): def test_all(self):
"""Test an intermediate, false-positive and false-negative together""" """Test an intermediate, false-positive and false-negative together"""
score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True) score, precision, recall, intermediates = multigen_eval(
*self.all_case(), return_intermediates=True
)
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate") self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6") self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6")
self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75") self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75")
@ -57,19 +69,22 @@ class MultiGenTest(TestCase):
def test_shallow_pathway(self): def test_shallow_pathway(self):
pathways = self.BBD_SUBSET.pathways.all() pathways = self.BBD_SUBSET.pathways.all()
for pathway in pathways: for pathway in pathways:
pathway_name = pathway.name
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue continue
shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway)) shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway))
pathway = graph_from_pathway(pathway) pathway = graph_from_pathway(pathway)
if not graphs_equal(shallow_pathway, pathway): if not graphs_equal(shallow_pathway, pathway):
print('\n\nS', shallow_pathway.adj) print("\n\nS", shallow_pathway.adj)
print('\n\nPW', pathway.adj) print("\n\nPW", pathway.adj)
# print(shallow_pathway.nodes, pathway.nodes) # print(shallow_pathway.nodes, pathway.nodes)
# print(shallow_pathway.graph, pathway.graph) # print(shallow_pathway.graph, pathway.graph)
self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not " self.assertTrue(
f"equal to pathway for pathway {pathway.name}") graphs_equal(shallow_pathway, pathway),
f"Networkx graph from shallow pathway not "
f"equal to pathway for pathway {pathway.name}",
)
def test_graph_edit_eval(self): def test_graph_edit_eval(self):
"""Performs all the previous tests but with graph_edit_eval """Performs all the previous tests but with graph_edit_eval
@ -79,10 +94,16 @@ class MultiGenTest(TestCase):
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue continue
score = pathway_edit_eval(pathway, pathway) score = pathway_edit_eval(pathway, pathway)
self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. " self.assertEqual(
f"Failed on pathway: {pathway.name}") score,
0.0,
"Pathway edit distance should be zero for identical pathways. "
f"Failed on pathway: {pathway.name}",
)
inter_score = pathway_edit_eval(*self.intermediate_case()) inter_score = pathway_edit_eval(*self.intermediate_case())
self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case") self.assertAlmostEqual(
inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case"
)
fp_score = pathway_edit_eval(*self.fp_case()) fp_score = pathway_edit_eval(*self.fp_case())
self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case") self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case")
fn_score = pathway_edit_eval(*self.fn_case()) fn_score = pathway_edit_eval(*self.fn_case())
@ -93,22 +114,30 @@ class MultiGenTest(TestCase):
def intermediate_case(self): def intermediate_case(self):
"""Create an example with an intermediate in the predicted pathway""" """Create an example with an intermediate in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO") true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)]) true_pathway.add_edge(
[true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)]
)
pred_pathway = Pathway.create(self.package, "CCO") pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], pred_pathway.add_edge(
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)]) [pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)],
)
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
return true_pathway, pred_pathway return true_pathway, pred_pathway
def fp_case(self): def fp_case(self):
"""Create an example with an extra compound in the predicted pathway""" """Create an example with an extra compound in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO") true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], true_pathway.add_edge(
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)]) [true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)],
)
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO") pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], pred_pathway.add_edge(
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)]) [pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)],
)
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)]) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)]) pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)])
return true_pathway, pred_pathway return true_pathway, pred_pathway
@ -116,22 +145,30 @@ class MultiGenTest(TestCase):
def fn_case(self): def fn_case(self):
"""Create an example with a missing compound in the predicted pathway""" """Create an example with a missing compound in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO") true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], true_pathway.add_edge(
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)]) [true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)],
)
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO") pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)]) pred_pathway.add_edge(
[pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)]
)
return true_pathway, pred_pathway return true_pathway, pred_pathway
def all_case(self): def all_case(self):
"""Create an example with an intermediate, extra compound and missing compound""" """Create an example with an intermediate, extra compound and missing compound"""
true_pathway = Pathway.create(self.package, "CCO") true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], true_pathway.add_edge(
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)]) [true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)],
)
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)]) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)])
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)]) true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO") pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)]) pred_pathway.add_edge(
[pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)]
)
pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)]) pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)])
pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)]) pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)])
return true_pathway, pred_pathway return true_pathway, pred_pathway

View File

@ -10,127 +10,127 @@ class ReactionTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(ReactionTest, cls).setUpClass() super(ReactionTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self): def test_smoke(self):
educt = Compound.create( educt = Compound.create(
self.package, self.package,
smiles='C(CCl)Cl', smiles="C(CCl)Cl",
name='1,2-Dichloroethane', name="1,2-Dichloroethane",
description='Eawag BBD compound c0001' description="Eawag BBD compound c0001",
).default_structure ).default_structure
product = Compound.create( product = Compound.create(
self.package, self.package,
smiles='C(CO)Cl', smiles="C(CO)Cl",
name='2-Chloroethanol', name="2-Chloroethanol",
description='Eawag BBD compound c0005' description="Eawag BBD compound c0005",
).default_structure ).default_structure
r = Reaction.create( r = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=[educt], educts=[educt],
products=[product], products=[product],
multi_step=False multi_step=False,
) )
self.assertEqual(r.smirks(), 'C(CCl)Cl>>C(CO)Cl') self.assertEqual(r.smirks(), "C(CCl)Cl>>C(CO)Cl")
self.assertEqual(r.name, 'Eawag BBD reaction r0001') self.assertEqual(r.name, "Eawag BBD reaction r0001")
self.assertEqual(r.description, 'no description') self.assertEqual(r.description, "no description")
def test_string_educts_and_products(self): def test_string_educts_and_products(self):
r = Reaction.create( r = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
multi_step=False multi_step=False,
) )
self.assertEqual(r.smirks(), 'C(CCl)Cl>>C(CO)Cl') self.assertEqual(r.smirks(), "C(CCl)Cl>>C(CO)Cl")
def test_missing_smiles(self): def test_missing_smiles(self):
educt = Compound.create( educt = Compound.create(
self.package, self.package,
smiles='C(CCl)Cl', smiles="C(CCl)Cl",
name='1,2-Dichloroethane', name="1,2-Dichloroethane",
description='Eawag BBD compound c0001' description="Eawag BBD compound c0001",
).default_structure ).default_structure
product = Compound.create( product = Compound.create(
self.package, self.package,
smiles='C(CO)Cl', smiles="C(CO)Cl",
name='2-Chloroethanol', name="2-Chloroethanol",
description='Eawag BBD compound c0005' description="Eawag BBD compound c0005",
).default_structure ).default_structure
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Reaction.create( _ = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=[educt], educts=[educt],
products=[], products=[],
multi_step=False multi_step=False,
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Reaction.create( _ = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=[], educts=[],
products=[product], products=[product],
multi_step=False multi_step=False,
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Reaction.create( _ = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=[], educts=[],
products=[], products=[],
multi_step=False multi_step=False,
) )
def test_empty_name_and_description_are_ignored(self): def test_empty_name_and_description_are_ignored(self):
r = Reaction.create( r = Reaction.create(
package=self.package, package=self.package,
name='', name="",
description='', description="",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
multi_step=False, multi_step=False,
) )
self.assertEqual(r.name, 'no name') self.assertEqual(r.name, "no name")
self.assertEqual(r.description, 'no description') self.assertEqual(r.description, "no description")
def test_deduplication(self): def test_deduplication(self):
rule = Rule.create( rule = Rule.create(
package=self.package, package=self.package,
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
name='bt0022-2833', name="bt0022-2833",
description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative', description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
) )
r1 = Reaction.create( r1 = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
rules=[rule], rules=[rule],
multi_step=False multi_step=False,
) )
r2 = Reaction.create( r2 = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
rules=[rule], rules=[rule],
multi_step=False multi_step=False,
) )
# Check if create detects that this Compound already exist # Check if create detects that this Compound already exist
@ -141,18 +141,18 @@ class ReactionTest(TestCase):
def test_deduplication_without_rules(self): def test_deduplication_without_rules(self):
r1 = Reaction.create( r1 = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
multi_step=False multi_step=False,
) )
r2 = Reaction.create( r2 = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
multi_step=False multi_step=False,
) )
# Check if create detects that this Compound already exist # Check if create detects that this Compound already exist
@ -164,19 +164,19 @@ class ReactionTest(TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Reaction.create( _ = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['ASDF'], educts=["ASDF"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
multi_step=False multi_step=False,
) )
def test_delete(self): def test_delete(self):
r = Reaction.create( r = Reaction.create(
package=self.package, package=self.package,
name='Eawag BBD reaction r0001', name="Eawag BBD reaction r0001",
educts=['C(CCl)Cl'], educts=["C(CCl)Cl"],
products=['C(CO)Cl'], products=["C(CO)Cl"],
multi_step=False multi_step=False,
) )
r.delete() r.delete()

View File

@ -10,73 +10,79 @@ class RuleTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(RuleTest, cls).setUpClass() super(RuleTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self): def test_smoke(self):
r = Rule.create( r = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
name='bt0022-2833', name="bt0022-2833",
description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative', description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
) )
self.assertEqual(r.smirks, self.assertEqual(
'[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]') r.smirks,
self.assertEqual(r.name, 'bt0022-2833') "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
self.assertEqual(r.description, )
'Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative') self.assertEqual(r.name, "bt0022-2833")
self.assertEqual(
r.description,
"Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
)
def test_smirks_are_trimmed(self): def test_smirks_are_trimmed(self):
r = Rule.create( r = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
name='bt0022-2833', name="bt0022-2833",
description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative', description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
smirks=' [H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4] ', smirks=" [H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4] ",
) )
self.assertEqual(r.smirks, self.assertEqual(
'[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]') r.smirks,
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
def test_name_and_description_optional(self): def test_name_and_description_optional(self):
r = Rule.create( r = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
) )
self.assertRegex(r.name, 'Rule \\d+') self.assertRegex(r.name, "Rule \\d+")
self.assertEqual(r.description, 'no description') self.assertEqual(r.description, "no description")
def test_empty_name_and_description_are_ignored(self): def test_empty_name_and_description_are_ignored(self):
r = Rule.create( r = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name='', name="",
description='', description="",
) )
self.assertRegex(r.name, 'Rule \\d+') self.assertRegex(r.name, "Rule \\d+")
self.assertEqual(r.description, 'no description') self.assertEqual(r.description, "no description")
def test_deduplication(self): def test_deduplication(self):
r1 = Rule.create( r1 = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name='', name="",
description='', description="",
) )
r2 = Rule.create( r2 = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name='', name="",
description='', description="",
) )
self.assertEqual(r1.pk, r2.pk) self.assertEqual(r1.pk, r2.pk)
@ -84,21 +90,21 @@ class RuleTest(TestCase):
def test_valid_smirks(self): def test_valid_smirks(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
r = Rule.create( Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='This is not a valid SMIRKS', smirks="This is not a valid SMIRKS",
name='', name="",
description='', description="",
) )
def test_delete(self): def test_delete(self):
r = Rule.create( r = Rule.create(
rule_type='SimpleAmbitRule', rule_type="SimpleAmbitRule",
package=self.package, package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]', smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name='', name="",
description='', description="",
) )
r.delete() r.delete()

File diff suppressed because it is too large Load Diff

View File

@ -12,34 +12,32 @@ class SimpleAmbitRuleTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(SimpleAmbitRuleTest, cls).setUpClass() super(SimpleAmbitRuleTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Simple Ambit Rule Test Package', cls.package = PackageManager.create_package(
'Test Package for SimpleAmbitRule') cls.user, "Simple Ambit Rule Test Package", "Test Package for SimpleAmbitRule"
)
def test_create_basic_rule(self): def test_create_basic_rule(self):
"""Test creating a basic SimpleAmbitRule with minimal parameters.""" """Test creating a basic SimpleAmbitRule with minimal parameters."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
package=self.package,
smirks=smirks
)
self.assertIsInstance(rule, SimpleAmbitRule) self.assertIsInstance(rule, SimpleAmbitRule)
self.assertEqual(rule.smirks, smirks) self.assertEqual(rule.smirks, smirks)
self.assertEqual(rule.package, self.package) self.assertEqual(rule.package, self.package)
self.assertRegex(rule.name, r'Rule \d+') self.assertRegex(rule.name, r"Rule \d+")
self.assertEqual(rule.description, 'no description') self.assertEqual(rule.description, "no description")
self.assertIsNone(rule.reactant_filter_smarts) self.assertIsNone(rule.reactant_filter_smarts)
self.assertIsNone(rule.product_filter_smarts) self.assertIsNone(rule.product_filter_smarts)
def test_create_with_all_parameters(self): def test_create_with_all_parameters(self):
"""Test creating SimpleAmbitRule with all parameters.""" """Test creating SimpleAmbitRule with all parameters."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
name = 'Test Rule' name = "Test Rule"
description = 'A test biotransformation rule' description = "A test biotransformation rule"
reactant_filter = '[CH2X4]' reactant_filter = "[CH2X4]"
product_filter = '[OH]' product_filter = "[OH]"
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(
package=self.package, package=self.package,
@ -47,7 +45,7 @@ class SimpleAmbitRuleTest(TestCase):
description=description, description=description,
smirks=smirks, smirks=smirks,
reactant_filter_smarts=reactant_filter, reactant_filter_smarts=reactant_filter,
product_filter_smarts=product_filter product_filter_smarts=product_filter,
) )
self.assertEqual(rule.name, name) self.assertEqual(rule.name, name)
@ -60,127 +58,114 @@ class SimpleAmbitRuleTest(TestCase):
"""Test that SMIRKS is required for rule creation.""" """Test that SMIRKS is required for rule creation."""
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(package=self.package, smirks=None) SimpleAmbitRule.create(package=self.package, smirks=None)
self.assertIn('SMIRKS is required', str(cm.exception)) self.assertIn("SMIRKS is required", str(cm.exception))
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(package=self.package, smirks='') SimpleAmbitRule.create(package=self.package, smirks="")
self.assertIn('SMIRKS is required', str(cm.exception)) self.assertIn("SMIRKS is required", str(cm.exception))
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(package=self.package, smirks=' ') SimpleAmbitRule.create(package=self.package, smirks=" ")
self.assertIn('SMIRKS is required', str(cm.exception)) self.assertIn("SMIRKS is required", str(cm.exception))
@patch('epdb.models.FormatConverter.is_valid_smirks') @patch("epdb.models.FormatConverter.is_valid_smirks")
def test_invalid_smirks_validation(self, mock_is_valid): def test_invalid_smirks_validation(self, mock_is_valid):
"""Test validation of SMIRKS format.""" """Test validation of SMIRKS format."""
mock_is_valid.return_value = False mock_is_valid.return_value = False
invalid_smirks = 'invalid_smirks_string' invalid_smirks = "invalid_smirks_string"
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create( SimpleAmbitRule.create(package=self.package, smirks=invalid_smirks)
package=self.package,
smirks=invalid_smirks
)
self.assertIn(f'SMIRKS "{invalid_smirks}" is invalid', str(cm.exception)) self.assertIn(f'SMIRKS "{invalid_smirks}" is invalid', str(cm.exception))
mock_is_valid.assert_called_once_with(invalid_smirks) mock_is_valid.assert_called_once_with(invalid_smirks)
def test_smirks_trimming(self): def test_smirks_trimming(self):
"""Test that SMIRKS strings are trimmed during creation.""" """Test that SMIRKS strings are trimmed during creation."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
smirks_with_whitespace = f' {smirks} ' smirks_with_whitespace = f" {smirks} "
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks=smirks_with_whitespace)
package=self.package,
smirks=smirks_with_whitespace
)
self.assertEqual(rule.smirks, smirks) self.assertEqual(rule.smirks, smirks)
def test_empty_name_and_description_handling(self): def test_empty_name_and_description_handling(self):
"""Test that empty name and description are handled appropriately.""" """Test that empty name and description are handled appropriately."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(
package=self.package, package=self.package, smirks=smirks, name="", description=" "
smirks=smirks,
name='',
description=' '
) )
self.assertRegex(rule.name, r'Rule \d+') self.assertRegex(rule.name, r"Rule \d+")
self.assertEqual(rule.description, 'no description') self.assertEqual(rule.description, "no description")
def test_deduplication_basic(self): def test_deduplication_basic(self):
"""Test that identical rules are deduplicated.""" """Test that identical rules are deduplicated."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule1 = SimpleAmbitRule.create( rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks, name="Rule 1")
package=self.package,
smirks=smirks,
name='Rule 1'
)
rule2 = SimpleAmbitRule.create( rule2 = SimpleAmbitRule.create(
package=self.package, package=self.package,
smirks=smirks, smirks=smirks,
name='Rule 2' # Different name, but same SMIRKS name="Rule 2", # Different name, but same SMIRKS
) )
self.assertEqual(rule1.pk, rule2.pk) self.assertEqual(rule1.pk, rule2.pk)
self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 1) self.assertEqual(
SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 1
)
def test_deduplication_with_filters(self): def test_deduplication_with_filters(self):
"""Test deduplication with filter SMARTS.""" """Test deduplication with filter SMARTS."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
reactant_filter = '[CH2X4]' reactant_filter = "[CH2X4]"
product_filter = '[OH]' product_filter = "[OH]"
rule1 = SimpleAmbitRule.create( rule1 = SimpleAmbitRule.create(
package=self.package, package=self.package,
smirks=smirks, smirks=smirks,
reactant_filter_smarts=reactant_filter, reactant_filter_smarts=reactant_filter,
product_filter_smarts=product_filter product_filter_smarts=product_filter,
) )
rule2 = SimpleAmbitRule.create( rule2 = SimpleAmbitRule.create(
package=self.package, package=self.package,
smirks=smirks, smirks=smirks,
reactant_filter_smarts=reactant_filter, reactant_filter_smarts=reactant_filter,
product_filter_smarts=product_filter product_filter_smarts=product_filter,
) )
self.assertEqual(rule1.pk, rule2.pk) self.assertEqual(rule1.pk, rule2.pk)
def test_no_deduplication_different_filters(self): def test_no_deduplication_different_filters(self):
"""Test that rules with different filters are not deduplicated.""" """Test that rules with different filters are not deduplicated."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule1 = SimpleAmbitRule.create( rule1 = SimpleAmbitRule.create(
package=self.package, package=self.package, smirks=smirks, reactant_filter_smarts="[CH2X4]"
smirks=smirks,
reactant_filter_smarts='[CH2X4]'
) )
rule2 = SimpleAmbitRule.create( rule2 = SimpleAmbitRule.create(
package=self.package, package=self.package, smirks=smirks, reactant_filter_smarts="[CH3X4]"
smirks=smirks,
reactant_filter_smarts='[CH3X4]'
) )
self.assertNotEqual(rule1.pk, rule2.pk) self.assertNotEqual(rule1.pk, rule2.pk)
self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 2) self.assertEqual(
SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 2
)
def test_filter_smarts_trimming(self): def test_filter_smarts_trimming(self):
"""Test that filter SMARTS are trimmed and handled correctly.""" """Test that filter SMARTS are trimmed and handled correctly."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
# Test with whitespace-only filters (should be treated as None) # Test with whitespace-only filters (should be treated as None)
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(
package=self.package, package=self.package,
smirks=smirks, smirks=smirks,
reactant_filter_smarts=' ', reactant_filter_smarts=" ",
product_filter_smarts=' ' product_filter_smarts=" ",
) )
self.assertIsNone(rule.reactant_filter_smarts) self.assertIsNone(rule.reactant_filter_smarts)
@ -188,94 +173,85 @@ class SimpleAmbitRuleTest(TestCase):
def test_url_property(self): def test_url_property(self):
"""Test the URL property generation.""" """Test the URL property generation."""
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
expected_url = f'{self.package.url}/simple-ambit-rule/{rule.uuid}' expected_url = f"{self.package.url}/simple-ambit-rule/{rule.uuid}"
self.assertEqual(rule.url, expected_url) self.assertEqual(rule.url, expected_url)
@patch('epdb.models.FormatConverter.apply') @patch("epdb.models.FormatConverter.apply")
def test_apply_method(self, mock_apply): def test_apply_method(self, mock_apply):
"""Test the apply method delegates to FormatConverter.""" """Test the apply method delegates to FormatConverter."""
mock_apply.return_value = ['product1', 'product2'] mock_apply.return_value = ["product1", "product2"]
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
test_smiles = 'CCO' test_smiles = "CCO"
result = rule.apply(test_smiles) result = rule.apply(test_smiles)
mock_apply.assert_called_once_with(test_smiles, rule.smirks) mock_apply.assert_called_once_with(test_smiles, rule.smirks)
self.assertEqual(result, ['product1', 'product2']) self.assertEqual(result, ["product1", "product2"])
def test_reactants_smarts_property(self): def test_reactants_smarts_property(self):
"""Test reactants_smarts property extracts correct part of SMIRKS.""" """Test reactants_smarts property extracts correct part of SMIRKS."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
expected_reactants = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]' expected_reactants = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]"
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
package=self.package,
smirks=smirks
)
self.assertEqual(rule.reactants_smarts, expected_reactants) self.assertEqual(rule.reactants_smarts, expected_reactants)
def test_products_smarts_property(self): def test_products_smarts_property(self):
"""Test products_smarts property extracts correct part of SMIRKS.""" """Test products_smarts property extracts correct part of SMIRKS."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
expected_products = '[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]' expected_products = "[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
package=self.package,
smirks=smirks
)
self.assertEqual(rule.products_smarts, expected_products) self.assertEqual(rule.products_smarts, expected_products)
@patch('epdb.models.Package.objects') @patch("epdb.models.Package.objects")
def test_related_reactions_property(self, mock_package_objects): def test_related_reactions_property(self, mock_package_objects):
"""Test related_reactions property returns correct queryset.""" """Test related_reactions property returns correct queryset."""
mock_qs = MagicMock() mock_qs = MagicMock()
mock_package_objects.filter.return_value = mock_qs mock_package_objects.filter.return_value = mock_qs
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
# Instead of directly assigning, patch the property or use with patch.object # Instead of directly assigning, patch the property or use with patch.object
with patch.object(type(rule), 'reaction_rule', new_callable=PropertyMock) as mock_reaction_rule: with patch.object(
mock_reaction_rule.return_value.filter.return_value.order_by.return_value = ['reaction1', 'reaction2'] type(rule), "reaction_rule", new_callable=PropertyMock
) as mock_reaction_rule:
mock_reaction_rule.return_value.filter.return_value.order_by.return_value = [
"reaction1",
"reaction2",
]
result = rule.related_reactions result = rule.related_reactions
mock_package_objects.filter.assert_called_once_with(reviewed=True) mock_package_objects.filter.assert_called_once_with(reviewed=True)
mock_reaction_rule.return_value.filter.assert_called_once_with(package__in=mock_qs) mock_reaction_rule.return_value.filter.assert_called_once_with(package__in=mock_qs)
mock_reaction_rule.return_value.filter.return_value.order_by.assert_called_once_with('name') mock_reaction_rule.return_value.filter.return_value.order_by.assert_called_once_with(
self.assertEqual(result, ['reaction1', 'reaction2']) "name"
)
self.assertEqual(result, ["reaction1", "reaction2"])
@patch('epdb.models.Pathway.objects') @patch("epdb.models.Pathway.objects")
@patch('epdb.models.Edge.objects') @patch("epdb.models.Edge.objects")
def test_related_pathways_property(self, mock_edge_objects, mock_pathway_objects): def test_related_pathways_property(self, mock_edge_objects, mock_pathway_objects):
"""Test related_pathways property returns correct queryset.""" """Test related_pathways property returns correct queryset."""
mock_related_reactions = ['reaction1', 'reaction2'] mock_related_reactions = ["reaction1", "reaction2"]
with patch.object(SimpleAmbitRule, "related_reactions", new_callable=PropertyMock) as mock_prop: with patch.object(
SimpleAmbitRule, "related_reactions", new_callable=PropertyMock
) as mock_prop:
mock_prop.return_value = mock_related_reactions mock_prop.return_value = mock_related_reactions
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
# Mock Edge objects query # Mock Edge objects query
mock_edge_values = MagicMock() mock_edge_values = MagicMock()
mock_edge_values.values.return_value = ['pathway_id1', 'pathway_id2'] mock_edge_values.values.return_value = ["pathway_id1", "pathway_id2"]
mock_edge_objects.filter.return_value = mock_edge_values mock_edge_objects.filter.return_value = mock_edge_values
# Mock Pathway objects query # Mock Pathway objects query
@ -285,52 +261,49 @@ class SimpleAmbitRuleTest(TestCase):
result = rule.related_pathways result = rule.related_pathways
mock_edge_objects.filter.assert_called_once_with(edge_label__in=mock_related_reactions) mock_edge_objects.filter.assert_called_once_with(edge_label__in=mock_related_reactions)
mock_edge_values.values.assert_called_once_with('pathway_id') mock_edge_values.values.assert_called_once_with("pathway_id")
mock_pathway_objects.filter.assert_called_once() mock_pathway_objects.filter.assert_called_once()
self.assertEqual(result, mock_pathway_qs) self.assertEqual(result, mock_pathway_qs)
@patch('epdb.models.IndigoUtils.smirks_to_svg') @patch("epdb.models.IndigoUtils.smirks_to_svg")
def test_as_svg_property(self, mock_smirks_to_svg): def test_as_svg_property(self, mock_smirks_to_svg):
"""Test as_svg property calls IndigoUtils correctly.""" """Test as_svg property calls IndigoUtils correctly."""
mock_smirks_to_svg.return_value = '<svg>test_svg</svg>' mock_smirks_to_svg.return_value = "<svg>test_svg</svg>"
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
result = rule.as_svg result = rule.as_svg
mock_smirks_to_svg.assert_called_once_with(rule.smirks, True, width=800, height=400) mock_smirks_to_svg.assert_called_once_with(rule.smirks, True, width=800, height=400)
self.assertEqual(result, '<svg>test_svg</svg>') self.assertEqual(result, "<svg>test_svg</svg>")
def test_atomic_transaction(self): def test_atomic_transaction(self):
"""Test that rule creation is atomic.""" """Test that rule creation is atomic."""
smirks = '[H:1][C:2]>>[H:1][O:2]' smirks = "[H:1][C:2]>>[H:1][O:2]"
# This should work normally # This should work normally
rule = SimpleAmbitRule.create(package=self.package, smirks=smirks) rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
self.assertIsInstance(rule, SimpleAmbitRule) self.assertIsInstance(rule, SimpleAmbitRule)
# Test transaction rollback on error # Test transaction rollback on error
with patch('epdb.models.SimpleAmbitRule.save', side_effect=Exception('Database error')): with patch("epdb.models.SimpleAmbitRule.save", side_effect=Exception("Database error")):
with self.assertRaises(Exception): with self.assertRaises(Exception):
SimpleAmbitRule.create(package=self.package, smirks='[H:3][C:4]>>[H:3][O:4]') SimpleAmbitRule.create(package=self.package, smirks="[H:3][C:4]>>[H:3][O:4]")
# Verify no partial data was saved # Verify no partial data was saved
self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package).count(), 1) self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package).count(), 1)
def test_multiple_duplicate_warning(self): def test_multiple_duplicate_warning(self):
"""Test logging when multiple duplicates are found.""" """Test logging when multiple duplicates are found."""
smirks = '[H:1][C:2]>>[H:1][O:2]' smirks = "[H:1][C:2]>>[H:1][O:2]"
# Create first rule # Create first rule
rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks) rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks)
# Manually create a duplicate to simulate the error condition # Manually create a duplicate to simulate the error condition
rule2 = SimpleAmbitRule(package=self.package, smirks=smirks, name='Manual Rule') rule2 = SimpleAmbitRule(package=self.package, smirks=smirks, name="Manual Rule")
rule2.save() rule2.save()
with patch('epdb.models.logger') as mock_logger: with patch("epdb.models.logger") as mock_logger:
# This should find the existing rule and log an error about multiple matches # This should find the existing rule and log an error about multiple matches
result = SimpleAmbitRule.create(package=self.package, smirks=smirks) result = SimpleAmbitRule.create(package=self.package, smirks=smirks)
@ -339,24 +312,28 @@ class SimpleAmbitRuleTest(TestCase):
# Should log an error about multiple matches # Should log an error about multiple matches
mock_logger.error.assert_called() mock_logger.error.assert_called()
self.assertIn('More than one rule matched', mock_logger.error.call_args[0][0]) self.assertIn("More than one rule matched", mock_logger.error.call_args[0][0])
def test_model_fields(self): def test_model_fields(self):
"""Test model field properties.""" """Test model field properties."""
rule = SimpleAmbitRule.create( rule = SimpleAmbitRule.create(
package=self.package, package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]', smirks="[H:1][C:2]>>[H:1][O:2]",
reactant_filter_smarts='[CH3]', reactant_filter_smarts="[CH3]",
product_filter_smarts='[OH]' product_filter_smarts="[OH]",
) )
# Test field properties # Test field properties
self.assertFalse(rule._meta.get_field('smirks').blank) self.assertFalse(rule._meta.get_field("smirks").blank)
self.assertFalse(rule._meta.get_field('smirks').null) self.assertFalse(rule._meta.get_field("smirks").null)
self.assertTrue(rule._meta.get_field('reactant_filter_smarts').null) self.assertTrue(rule._meta.get_field("reactant_filter_smarts").null)
self.assertTrue(rule._meta.get_field('product_filter_smarts').null) self.assertTrue(rule._meta.get_field("product_filter_smarts").null)
# Test verbose names # Test verbose names
self.assertEqual(rule._meta.get_field('smirks').verbose_name, 'SMIRKS') self.assertEqual(rule._meta.get_field("smirks").verbose_name, "SMIRKS")
self.assertEqual(rule._meta.get_field('reactant_filter_smarts').verbose_name, 'Reactant Filter SMARTS') self.assertEqual(
self.assertEqual(rule._meta.get_field('product_filter_smarts').verbose_name, 'Product Filter SMARTS') rule._meta.get_field("reactant_filter_smarts").verbose_name, "Reactant Filter SMARTS"
)
self.assertEqual(
rule._meta.get_field("product_filter_smarts").verbose_name, "Product Filter SMARTS"
)

View File

@ -1,32 +1,29 @@
from django.test import TestCase from django.test import TestCase
from epdb.logic import SNode, SEdge, SPathway from epdb.logic import SNode, SEdge
class SObjectTest(TestCase): class SObjectTest(TestCase):
def setUp(self): def setUp(self):
pass pass
def test_snode_eq(self): def test_snode_eq(self):
snode1 = SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0) snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
snode2 = SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0) snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
assert snode1 == snode2 assert snode1 == snode2
def test_snode_hash(self): def test_snode_hash(self):
pass pass
def test_sedge_eq(self): def test_sedge_eq(self):
sedge1 = SEdge([SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)], sedge1 = SEdge(
[SNode('CN1C(=O)NC2=C(C1=O)N(C)C=N2', 1), SNode('C=O', 1)], [SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
rule=None) [SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
sedge2 = SEdge([SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)], rule=None,
[SNode('CN1C(=O)NC2=C(C1=O)N(C)C=N2', 1), SNode('C=O', 1)], )
rule=None) sedge2 = SEdge(
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
rule=None,
)
assert sedge1 == sedge2 assert sedge1 == sedge2
def test_sedge_hash(self):
pass
def test_spathway(self):
pw = SPathway()

View File

@ -3,7 +3,7 @@ from django.urls import reverse
from envipy_additional_information import Temperature, Interval from envipy_additional_information import Temperature, Interval
from epdb.logic import UserManager, PackageManager from epdb.logic import UserManager, PackageManager
from epdb.models import Compound, Scenario, ExternalIdentifier, ExternalDatabase from epdb.models import Compound, Scenario, ExternalDatabase
class CompoundViewTest(TestCase): class CompoundViewTest(TestCase):
@ -12,21 +12,28 @@ class CompoundViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(CompoundViewTest, cls).setUpClass() super(CompoundViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", cls.user1 = UserManager.create_user(
set_setting=False, add_to_group=True, is_active=True) "user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self): def setUp(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
def test_create_compound(self): def test_create_compound(self):
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -38,17 +45,18 @@ class CompoundViewTest(TestCase):
self.assertEqual(c.name, "1,2-Dichloroethane") self.assertEqual(c.name, "1,2-Dichloroethane")
self.assertEqual(c.description, "Eawag BBD compound c0001") self.assertEqual(c.description, "Eawag BBD compound c0001")
self.assertEqual(c.default_structure.smiles, "C(CCl)Cl") self.assertEqual(c.default_structure.smiles, "C(CCl)Cl")
self.assertEqual(c.default_structure.canonical_smiles, 'ClCCCl') self.assertEqual(c.default_structure.canonical_smiles, "ClCCCl")
self.assertEqual(c.structures.all().count(), 2) self.assertEqual(c.structures.all().count(), 2)
self.assertEqual(self.user1_default_package.compounds.count(), 1) self.assertEqual(self.user1_default_package.compounds.count(), 1)
# Adding the same rule again should return the existing one, hence not increasing the number of rules # Adding the same rule again should return the existing one, hence not increasing the number of rules
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.url, compound_url) self.assertEqual(response.url, compound_url)
@ -57,11 +65,12 @@ class CompoundViewTest(TestCase):
# Adding the same rule in a different package should create a new rule # Adding the same rule in a different package should create a new rule
response = self.client.post( response = self.client.post(
reverse("package compound list", kwargs={'package_uuid': self.package.uuid}), { reverse("package compound list", kwargs={"package_uuid": self.package.uuid}),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -69,11 +78,12 @@ class CompoundViewTest(TestCase):
# adding another reaction should increase count # adding another reaction should increase count
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "2-Chloroethanol", "compound-name": "2-Chloroethanol",
"compound-description": "Eawag BBD compound c0005", "compound-description": "Eawag BBD compound c0005",
"compound-smiles": "C(CO)Cl", "compound-smiles": "C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -82,11 +92,12 @@ class CompoundViewTest(TestCase):
# Edit # Edit
def test_edit_rule(self): def test_edit_rule(self):
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -95,13 +106,17 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url) c = Compound.objects.get(url=compound_url)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(self.user1_default_package.uuid), "package compound detail",
'compound_uuid': str(c.uuid) kwargs={
}), { "package_uuid": str(self.user1_default_package.uuid),
"compound_uuid": str(c.uuid),
},
),
{
"compound-name": "Test Compound Adjusted", "compound-name": "Test Compound Adjusted",
"compound-description": "New Description", "compound-description": "New Description",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -121,7 +136,7 @@ class CompoundViewTest(TestCase):
"Test Desc", "Test Desc",
"2025-10", "2025-10",
"soil", "soil",
[Temperature(interval=Interval(start=20, end=30))] [Temperature(interval=Interval(start=20, end=30))],
) )
s2 = Scenario.create( s2 = Scenario.create(
@ -130,15 +145,16 @@ class CompoundViewTest(TestCase):
"Test Desc2", "Test Desc2",
"2025-10", "2025-10",
"soil", "soil",
[Temperature(interval=Interval(start=10, end=20))] [Temperature(interval=Interval(start=10, end=20))],
) )
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -147,36 +163,35 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url) c = Compound.objects.get(url=compound_url)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package compound detail",
'compound_uuid': str(c.uuid) kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
}), { ),
"selected-scenarios": [s1.url, s2.url] {"selected-scenarios": [s1.url, s2.url]},
}
) )
self.assertEqual(len(c.scenarios.all()), 2) self.assertEqual(len(c.scenarios.all()), 2)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package compound detail",
'compound_uuid': str(c.uuid) kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
}), { ),
"selected-scenarios": [s1.url] {"selected-scenarios": [s1.url]},
}
) )
self.assertEqual(len(c.scenarios.all()), 1) self.assertEqual(len(c.scenarios.all()), 1)
self.assertEqual(c.scenarios.first().url, s1.url) self.assertEqual(c.scenarios.first().url, s1.url)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package compound detail",
'compound_uuid': str(c.uuid) kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
}), { ),
{
# We have to set an empty string to avoid that the parameter is removed # We have to set an empty string to avoid that the parameter is removed
"selected-scenarios": "" "selected-scenarios": ""
} },
) )
self.assertEqual(len(c.scenarios.all()), 0) self.assertEqual(len(c.scenarios.all()), 0)
@ -184,11 +199,12 @@ class CompoundViewTest(TestCase):
# #
def test_copy(self): def test_copy(self):
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -196,12 +212,13 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url) c = Compound.objects.get(url=compound_url)
response = self.client.post( response = self.client.post(
reverse("package detail", kwargs={ reverse(
'package_uuid': str(self.package.uuid), "package detail",
}), { kwargs={
"hidden": "copy", "package_uuid": str(self.package.uuid),
"object_to_copy": c.url },
} ),
{"hidden": "copy", "object_to_copy": c.url},
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -215,44 +232,48 @@ class CompoundViewTest(TestCase):
# Copy to the same package should fail # Copy to the same package should fail
response = self.client.post( response = self.client.post(
reverse("package detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package detail",
}), { kwargs={
"hidden": "copy", "package_uuid": str(c.package.uuid),
"object_to_copy": c.url },
} ),
{"hidden": "copy", "object_to_copy": c.url},
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], f"Can't copy object {compound_url} to the same package!") self.assertEqual(
response.json()["error"], f"Can't copy object {compound_url} to the same package!"
)
def test_references(self): def test_references(self):
ext_db, _ = ExternalDatabase.objects.get_or_create( ext_db, _ = ExternalDatabase.objects.get_or_create(
name='PubChem Compound', name="PubChem Compound",
defaults={ defaults={
'full_name': 'PubChem Compound Database', "full_name": "PubChem Compound Database",
'description': 'Chemical database of small organic molecules', "description": "Chemical database of small organic molecules",
'base_url': 'https://pubchem.ncbi.nlm.nih.gov', "base_url": "https://pubchem.ncbi.nlm.nih.gov",
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}' "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/compound/{id}",
} },
) )
ext_db2, _ = ExternalDatabase.objects.get_or_create( ext_db2, _ = ExternalDatabase.objects.get_or_create(
name='PubChem Substance', name="PubChem Substance",
defaults={ defaults={
'full_name': 'PubChem Substance Database', "full_name": "PubChem Substance Database",
'description': 'Database of chemical substances', "description": "Database of chemical substances",
'base_url': 'https://pubchem.ncbi.nlm.nih.gov', "base_url": "https://pubchem.ncbi.nlm.nih.gov",
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/substance/{id}' "url_pattern": "https://pubchem.ncbi.nlm.nih.gov/substance/{id}",
} },
) )
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -260,42 +281,49 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url) c = Compound.objects.get(url=compound_url)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package compound detail",
'compound_uuid': str(c.uuid), kwargs={
}), { "package_uuid": str(c.package.uuid),
'selected-database': ext_db.pk, "compound_uuid": str(c.uuid),
'identifier': '25154249' },
} ),
{"selected-database": ext_db.pk, "identifier": "25154249"},
) )
self.assertEqual(c.external_identifiers.count(), 1) self.assertEqual(c.external_identifiers.count(), 1)
self.assertEqual(c.external_identifiers.first().database, ext_db) self.assertEqual(c.external_identifiers.first().database, ext_db)
self.assertEqual(c.external_identifiers.first().identifier_value, '25154249') self.assertEqual(c.external_identifiers.first().identifier_value, "25154249")
self.assertEqual(c.external_identifiers.first().url, 'https://pubchem.ncbi.nlm.nih.gov/compound/25154249') self.assertEqual(
c.external_identifiers.first().url, "https://pubchem.ncbi.nlm.nih.gov/compound/25154249"
)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package compound detail",
'compound_uuid': str(c.uuid), kwargs={
}), { "package_uuid": str(c.package.uuid),
'selected-database': ext_db2.pk, "compound_uuid": str(c.uuid),
'identifier': '25154249' },
} ),
{"selected-database": ext_db2.pk, "identifier": "25154249"},
) )
self.assertEqual(c.external_identifiers.count(), 2) self.assertEqual(c.external_identifiers.count(), 2)
self.assertEqual(c.external_identifiers.last().database, ext_db2) self.assertEqual(c.external_identifiers.last().database, ext_db2)
self.assertEqual(c.external_identifiers.last().identifier_value, '25154249') self.assertEqual(c.external_identifiers.last().identifier_value, "25154249")
self.assertEqual(c.external_identifiers.last().url, 'https://pubchem.ncbi.nlm.nih.gov/substance/25154249') self.assertEqual(
c.external_identifiers.last().url, "https://pubchem.ncbi.nlm.nih.gov/substance/25154249"
)
def test_delete(self): def test_delete(self):
response = self.client.post( response = self.client.post(
reverse("compounds"), { reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane", "compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001", "compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl", "compound-smiles": "C(CCl)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -304,12 +332,11 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url) c = Compound.objects.get(url=compound_url)
response = self.client.post( response = self.client.post(
reverse("package compound detail", kwargs={ reverse(
'package_uuid': str(c.package.uuid), "package compound detail",
'compound_uuid': str(c.uuid) kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
}), { ),
"hidden": "delete" {"hidden": "delete"},
}
) )
self.assertEqual(self.user1_default_package.compounds.count(), 0) self.assertEqual(self.user1_default_package.compounds.count(), 0)

View File

@ -2,8 +2,8 @@ from django.test import TestCase, override_settings
from django.urls import reverse from django.urls import reverse
from django.conf import settings as s from django.conf import settings as s
from epdb.logic import UserManager, PackageManager from epdb.logic import UserManager
from epdb.models import Pathway, Edge, Package, User from epdb.models import Package, User
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models") @override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models")
@ -13,10 +13,16 @@ class PathwayViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(PathwayViewTest, cls).setUpClass() super(PathwayViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", cls.user1 = UserManager.create_user(
set_setting=True, add_to_group=True, is_active=True) "user1",
"user1@envipath.com",
"SuperSafe",
set_setting=True,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package cls.user1_default_package = cls.user1.default_package
cls.model_package = Package.objects.get(name='Fixtures') cls.model_package = Package.objects.get(name="Fixtures")
def setUp(self): def setUp(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
@ -24,90 +30,96 @@ class PathwayViewTest(TestCase):
def test_predict(self): def test_predict(self):
self.client.force_login(User.objects.get(username="admin")) self.client.force_login(User.objects.get(username="admin"))
response = self.client.get( response = self.client.get(
reverse("package model detail", kwargs={ reverse(
'package_uuid': str(self.model_package.uuid), "package model detail",
'model_uuid': str(self.model_package.models.first().uuid) kwargs={
}), { "package_uuid": str(self.model_package.uuid),
'classify': 'ILikeCats!', "model_uuid": str(self.model_package.models.first().uuid),
'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO', },
} ),
{
"classify": "ILikeCats!",
"smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO",
},
) )
expected = [ expected = [
{ {
'products': [ "products": [["O=C(O)C1=CC(CO)=CC=C1", "CCNCC"]],
[ "probability": 0.25,
'O=C(O)C1=CC(CO)=CC=C1', "btrule": {
'CCNCC' "url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/0e6e9290-b658-4450-b291-3ec19fa19206",
] "name": "bt0430-4011",
], },
'probability': 0.25, },
'btrule': { {
'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/0e6e9290-b658-4450-b291-3ec19fa19206', "products": [["CCNC(=O)C1=CC(CO)=CC=C1", "CC=O"]],
'name': 'bt0430-4011' "probability": 0.0,
} "btrule": {
}, { "url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/27a3a353-0b66-4228-bd16-e407949e90df",
'products': [ "name": "bt0243-4301",
[ },
'CCNC(=O)C1=CC(CO)=CC=C1', },
'CC=O' {
] "products": [["CCN(CC)C(=O)C1=CC(C=O)=CC=C1"]],
], 'probability': 0.0, "probability": 0.75,
'btrule': { "btrule": {
'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/27a3a353-0b66-4228-bd16-e407949e90df', "url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/2f2e0c39-e109-4836-959f-2bda2524f022",
'name': 'bt0243-4301' "name": "bt0001-3568",
} },
}, { },
'products': [
[
'CCN(CC)C(=O)C1=CC(C=O)=CC=C1'
]
], 'probability': 0.75,
'btrule': {
'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/2f2e0c39-e109-4836-959f-2bda2524f022',
'name': 'bt0001-3568'
}
}
] ]
actual = response.json() actual = response.json()
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
response = self.client.get( response = self.client.get(
reverse("package model detail", kwargs={ reverse(
'package_uuid': str(self.model_package.uuid), "package model detail",
'model_uuid': str(self.model_package.models.first().uuid) kwargs={
}), { "package_uuid": str(self.model_package.uuid),
'classify': 'ILikeCats!', "model_uuid": str(self.model_package.models.first().uuid),
'smiles': '', },
} ),
{
"classify": "ILikeCats!",
"smiles": "",
},
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], 'Received empty SMILES') self.assertEqual(response.json()["error"], "Received empty SMILES")
response = self.client.get( response = self.client.get(
reverse("package model detail", kwargs={ reverse(
'package_uuid': str(self.model_package.uuid), "package model detail",
'model_uuid': str(self.model_package.models.first().uuid) kwargs={
}), { "package_uuid": str(self.model_package.uuid),
'classify': 'ILikeCats!', "model_uuid": str(self.model_package.models.first().uuid),
'smiles': ' ', # Input should be stripped },
} ),
{
"classify": "ILikeCats!",
"smiles": " ", # Input should be stripped
},
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], 'Received empty SMILES') self.assertEqual(response.json()["error"], "Received empty SMILES")
response = self.client.get( response = self.client.get(
reverse("package model detail", kwargs={ reverse(
'package_uuid': str(self.model_package.uuid), "package model detail",
'model_uuid': str(self.model_package.models.first().uuid) kwargs={
}), { "package_uuid": str(self.model_package.uuid),
'classify': 'ILikeCats!', "model_uuid": str(self.model_package.models.first().uuid),
'smiles': 'RandomInput', },
} ),
{
"classify": "ILikeCats!",
"smiles": "RandomInput",
},
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], '"RandomInput" is not a valid SMILES') self.assertEqual(response.json()["error"], '"RandomInput" is not a valid SMILES')

View File

@ -13,19 +13,34 @@ class PackageViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(PackageViewTest, cls).setUpClass() super(PackageViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", cls.user1 = UserManager.create_user(
set_setting=False, add_to_group=True, is_active=True) "user1",
cls.user2 = UserManager.create_user("user2", "user2@envipath.com", "SuperSafe", "user1@envipath.com",
set_setting=False, add_to_group=True, is_active=True) "SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user2 = UserManager.create_user(
"user2",
"user2@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
def setUp(self): def setUp(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
def test_create_package(self): def test_create_package(self):
response = self.client.post(reverse("packages"), { response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package", "package-name": "Test Package",
"package-description": "Just a Description", "package-description": "Just a Description",
}) },
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
package_url = response.url package_url = response.url
@ -41,13 +56,12 @@ class PackageViewTest(TestCase):
file = SimpleUploadedFile( file = SimpleUploadedFile(
"Fixture_Package.json", "Fixture_Package.json",
open(s.FIXTURE_DIRS[0] / "Fixture_Package.json", "rb").read(), open(s.FIXTURE_DIRS[0] / "Fixture_Package.json", "rb").read(),
content_type="application/json" content_type="application/json",
) )
response = self.client.post(reverse("packages"), { response = self.client.post(
"file": file, reverse("packages"), {"file": file, "hidden": "import-package-json"}
"hidden": "import-package-json" )
})
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
package_url = response.url package_url = response.url
@ -67,13 +81,12 @@ class PackageViewTest(TestCase):
file = SimpleUploadedFile( file = SimpleUploadedFile(
"EAWAG-BBD.json", "EAWAG-BBD.json",
open(s.FIXTURE_DIRS[0] / "packages" / "2025-07-18" / "EAWAG-BBD.json", "rb").read(), open(s.FIXTURE_DIRS[0] / "packages" / "2025-07-18" / "EAWAG-BBD.json", "rb").read(),
content_type="application/json" content_type="application/json",
) )
response = self.client.post(reverse("packages"), { response = self.client.post(
"file": file, reverse("packages"), {"file": file, "hidden": "import-legacy-package-json"}
"hidden": "import-legacy-package-json" )
})
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
package_url = response.url package_url = response.url
@ -90,17 +103,23 @@ class PackageViewTest(TestCase):
self.assertEqual(upp.permission, Permission.ALL[0]) self.assertEqual(upp.permission, Permission.ALL[0])
def test_edit_package(self): def test_edit_package(self):
response = self.client.post(reverse("packages"), { response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package", "package-name": "Test Package",
"package-description": "Just a Description", "package-description": "Just a Description",
}) },
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
package_url = response.url package_url = response.url
self.client.post(package_url, { self.client.post(
package_url,
{
"package-name": "New Name", "package-name": "New Name",
"package-description": "New Description", "package-description": "New Description",
}) },
)
p = Package.objects.get(url=package_url) p = Package.objects.get(url=package_url)
@ -108,10 +127,13 @@ class PackageViewTest(TestCase):
self.assertEqual(p.description, "New Description") self.assertEqual(p.description, "New Description")
def test_edit_package_permissions(self): def test_edit_package_permissions(self):
response = self.client.post(reverse("packages"), { response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package", "package-name": "Test Package",
"package-description": "Just a Description", "package-description": "Just a Description",
}) },
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
package_url = response.url package_url = response.url
p = Package.objects.get(url=package_url) p = Package.objects.get(url=package_url)
@ -119,57 +141,63 @@ class PackageViewTest(TestCase):
with self.assertRaises(UserPackagePermission.DoesNotExist): with self.assertRaises(UserPackagePermission.DoesNotExist):
UserPackagePermission.objects.get(package=p, user=self.user2) UserPackagePermission.objects.get(package=p, user=self.user2)
self.client.post(package_url, { self.client.post(
package_url,
{
"grantee": self.user2.url, "grantee": self.user2.url,
"read": "on", "read": "on",
"write": "on", "write": "on",
}) },
)
upp = UserPackagePermission.objects.get(package=p, user=self.user2) upp = UserPackagePermission.objects.get(package=p, user=self.user2)
self.assertEqual(upp.permission, Permission.WRITE[0]) self.assertEqual(upp.permission, Permission.WRITE[0])
def test_publish_package(self): def test_publish_package(self):
response = self.client.post(reverse("packages"), { response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package", "package-name": "Test Package",
"package-description": "Just a Description", "package-description": "Just a Description",
}) },
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
package_url = response.url package_url = response.url
p = Package.objects.get(url=package_url) p = Package.objects.get(url=package_url)
self.client.post(package_url, { self.client.post(package_url, {"hidden": "publish-package"})
"hidden": "publish-package"
})
self.assertEqual(Group.objects.filter(public=True).count(), 1) self.assertEqual(Group.objects.filter(public=True).count(), 1)
g = Group.objects.get(public=True) g = Group.objects.get(public=True)
gpp = GroupPackagePermission.objects.get(package=p, group=g) gpp = GroupPackagePermission.objects.get(package=p, group=g)
self.assertEqual(gpp.permission, Permission.READ[0]) self.assertEqual(gpp.permission, Permission.READ[0])
def test_set_package_license(self): def test_set_package_license(self):
response = self.client.post(reverse("packages"), { response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package", "package-name": "Test Package",
"package-description": "Just a Description", "package-description": "Just a Description",
}) },
)
package_url = response.url package_url = response.url
p = Package.objects.get(url=package_url) p = Package.objects.get(url=package_url)
self.client.post(package_url, { self.client.post(package_url, {"license": "no-license"})
"license": "no-license"
})
self.assertIsNone(p.license) self.assertIsNone(p.license)
# TODO test others # TODO test others
def test_delete_package(self): def test_delete_package(self):
response = self.client.post(reverse("packages"), { response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package", "package-name": "Test Package",
"package-description": "Just a Description", "package-description": "Just a Description",
}) },
)
package_url = response.url package_url = response.url
p = Package.objects.get(url=package_url) p = Package.objects.get(url=package_url)
@ -182,11 +210,11 @@ class PackageViewTest(TestCase):
def test_delete_default_package(self): def test_delete_default_package(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
# Try to delete the default package # Try to delete the default package
response = self.client.post(self.user1.default_package.url, { response = self.client.post(self.user1.default_package.url, {"hidden": "delete"})
"hidden": "delete"
})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertTrue(f'You cannot delete the default package. ' self.assertTrue(
f'If you want to delete this package you have to ' "You cannot delete the default package. "
f'set another default package first' in response.content.decode()) "If you want to delete this package you have to "
"set another default package first" in response.content.decode()
)

View File

@ -5,6 +5,7 @@ from django.conf import settings as s
from epdb.logic import UserManager, PackageManager from epdb.logic import UserManager, PackageManager
from epdb.models import Pathway, Edge from epdb.models import Pathway, Edge
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models") @override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models")
class PathwayViewTest(TestCase): class PathwayViewTest(TestCase):
fixtures = ["test_fixtures_incl_model.jsonl.gz"] fixtures = ["test_fixtures_incl_model.jsonl.gz"]
@ -12,41 +13,52 @@ class PathwayViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(PathwayViewTest, cls).setUpClass() super(PathwayViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", cls.user1 = UserManager.create_user(
set_setting=True, add_to_group=True, is_active=True) "user1",
"user1@envipath.com",
"SuperSafe",
set_setting=True,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self): def setUp(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
def test_predict_pathway(self): def test_predict_pathway(self):
response = self.client.post(reverse("pathways"), { response = self.client.post(
'name': 'Test Pathway', reverse("pathways"),
'description': 'Just a Description', {
'predict': 'predict', "name": "Test Pathway",
'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO', "description": "Just a Description",
}) "predict": "predict",
"smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO",
},
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
pathway_url = response.url pathway_url = response.url
pw = Pathway.objects.get(url=pathway_url) pw = Pathway.objects.get(url=pathway_url)
self.assertEqual(self.user1_default_package, pw.package) self.assertEqual(self.user1_default_package, pw.package)
self.assertEqual(pw.name, 'Test Pathway') self.assertEqual(pw.name, "Test Pathway")
self.assertEqual(pw.description, 'Just a Description') self.assertEqual(pw.description, "Just a Description")
self.assertEqual(len(pw.root_nodes), 1) self.assertEqual(len(pw.root_nodes), 1)
self.assertEqual(pw.root_nodes.first().default_node_label.smiles, 'CCN(CC)C(=O)C1=CC(CO)=CC=C1') self.assertEqual(
pw.root_nodes.first().default_node_label.smiles, "CCN(CC)C(=O)C1=CC(CO)=CC=C1"
)
first_level_nodes = { first_level_nodes = {
# Edge 1 # Edge 1
'CCN(CC)C(=O)C1=CC(C=O)=CC=C1', "CCN(CC)C(=O)C1=CC(C=O)=CC=C1",
# Edge 2 # Edge 2
'CCNC(=O)C1=CC(CO)=CC=C1', "CCNC(=O)C1=CC(CO)=CC=C1",
'CC=O', "CC=O",
# Edge 3 # Edge 3
'CCNCC', "CCNCC",
'O=C(O)C1=CC(CO)=CC=C1', "O=C(O)C1=CC(CO)=CC=C1",
} }
predicted_nodes = set() predicted_nodes = set()
@ -60,32 +72,36 @@ class PathwayViewTest(TestCase):
def test_predict_package_pathway(self): def test_predict_package_pathway(self):
response = self.client.post( response = self.client.post(
reverse("package pathway list", kwargs={'package_uuid': str(self.package.uuid)}), { reverse("package pathway list", kwargs={"package_uuid": str(self.package.uuid)}),
'name': 'Test Pathway', {
'description': 'Just a Description', "name": "Test Pathway",
'predict': 'predict', "description": "Just a Description",
'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO', "predict": "predict",
}) "smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO",
},
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
pathway_url = response.url pathway_url = response.url
pw = Pathway.objects.get(url=pathway_url) pw = Pathway.objects.get(url=pathway_url)
self.assertEqual(self.package, pw.package) self.assertEqual(self.package, pw.package)
self.assertEqual(pw.name, 'Test Pathway') self.assertEqual(pw.name, "Test Pathway")
self.assertEqual(pw.description, 'Just a Description') self.assertEqual(pw.description, "Just a Description")
self.assertEqual(len(pw.root_nodes), 1) self.assertEqual(len(pw.root_nodes), 1)
self.assertEqual(pw.root_nodes.first().default_node_label.smiles, 'CCN(CC)C(=O)C1=CC(CO)=CC=C1') self.assertEqual(
pw.root_nodes.first().default_node_label.smiles, "CCN(CC)C(=O)C1=CC(CO)=CC=C1"
)
first_level_nodes = { first_level_nodes = {
# Edge 1 # Edge 1
'CCN(CC)C(=O)C1=CC(C=O)=CC=C1', "CCN(CC)C(=O)C1=CC(C=O)=CC=C1",
# Edge 2 # Edge 2
'CCNC(=O)C1=CC(CO)=CC=C1', "CCNC(=O)C1=CC(CO)=CC=C1",
'CC=O', "CC=O",
# Edge 3 # Edge 3
'CCNCC', "CCNCC",
'O=C(O)C1=CC(CO)=CC=C1', "O=C(O)C1=CC(CO)=CC=C1",
} }
predicted_nodes = set() predicted_nodes = set()

View File

@ -3,7 +3,7 @@ from django.urls import reverse
from envipy_additional_information import Temperature, Interval from envipy_additional_information import Temperature, Interval
from epdb.logic import UserManager, PackageManager from epdb.logic import UserManager, PackageManager
from epdb.models import Reaction, Scenario, ExternalIdentifier, ExternalDatabase from epdb.models import Reaction, Scenario, ExternalDatabase
class ReactionViewTest(TestCase): class ReactionViewTest(TestCase):
@ -12,21 +12,28 @@ class ReactionViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(ReactionViewTest, cls).setUpClass() super(ReactionViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", cls.user1 = UserManager.create_user(
set_setting=False, add_to_group=True, is_active=True) "user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self): def setUp(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
def test_create_reaction(self): def test_create_reaction(self):
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -42,11 +49,12 @@ class ReactionViewTest(TestCase):
# Adding the same rule again should return the existing one, hence not increasing the number of rules # Adding the same rule again should return the existing one, hence not increasing the number of rules
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.url, reaction_url) self.assertEqual(response.url, reaction_url)
@ -55,11 +63,12 @@ class ReactionViewTest(TestCase):
# Adding the same rule in a different package should create a new rule # Adding the same rule in a different package should create a new rule
response = self.client.post( response = self.client.post(
reverse("package reaction list", kwargs={'package_uuid': self.package.uuid}), { reverse("package reaction list", kwargs={"package_uuid": self.package.uuid}),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -67,11 +76,12 @@ class ReactionViewTest(TestCase):
# adding another reaction should increase count # adding another reaction should increase count
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0002", "reaction-name": "Eawag BBD reaction r0002",
"reaction-description": "Description for Eawag BBD reaction r0002", "reaction-description": "Description for Eawag BBD reaction r0002",
"reaction-smirks": "C(CO)Cl>>C(C=O)Cl", "reaction-smirks": "C(CO)Cl>>C(C=O)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -80,11 +90,12 @@ class ReactionViewTest(TestCase):
# Edit # Edit
def test_edit_rule(self): def test_edit_rule(self):
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -93,13 +104,17 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url) r = Reaction.objects.get(url=reaction_url)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(self.user1_default_package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid) kwargs={
}), { "package_uuid": str(self.user1_default_package.uuid),
"reaction_uuid": str(r.uuid),
},
),
{
"reaction-name": "Test Reaction Adjusted", "reaction-name": "Test Reaction Adjusted",
"reaction-description": "New Description", "reaction-description": "New Description",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -119,7 +134,7 @@ class ReactionViewTest(TestCase):
"Test Desc", "Test Desc",
"2025-10", "2025-10",
"soil", "soil",
[Temperature(interval=Interval(start=20, end=30))] [Temperature(interval=Interval(start=20, end=30))],
) )
s2 = Scenario.create( s2 = Scenario.create(
@ -128,15 +143,16 @@ class ReactionViewTest(TestCase):
"Test Desc2", "Test Desc2",
"2025-10", "2025-10",
"soil", "soil",
[Temperature(interval=Interval(start=10, end=20))] [Temperature(interval=Interval(start=10, end=20))],
) )
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -144,47 +160,47 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url) r = Reaction.objects.get(url=reaction_url)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
}), { ),
"selected-scenarios": [s1.url, s2.url] {"selected-scenarios": [s1.url, s2.url]},
}
) )
self.assertEqual(len(r.scenarios.all()), 2) self.assertEqual(len(r.scenarios.all()), 2)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
}), { ),
"selected-scenarios": [s1.url] {"selected-scenarios": [s1.url]},
}
) )
self.assertEqual(len(r.scenarios.all()), 1) self.assertEqual(len(r.scenarios.all()), 1)
self.assertEqual(r.scenarios.first().url, s1.url) self.assertEqual(r.scenarios.first().url, s1.url)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
}), { ),
{
# We have to set an empty string to avoid that the parameter is removed # We have to set an empty string to avoid that the parameter is removed
"selected-scenarios": "" "selected-scenarios": ""
} },
) )
self.assertEqual(len(r.scenarios.all()), 0) self.assertEqual(len(r.scenarios.all()), 0)
def test_copy(self): def test_copy(self):
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -192,12 +208,13 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url) r = Reaction.objects.get(url=reaction_url)
response = self.client.post( response = self.client.post(
reverse("package detail", kwargs={ reverse(
'package_uuid': str(self.package.uuid), "package detail",
}), { kwargs={
"hidden": "copy", "package_uuid": str(self.package.uuid),
"object_to_copy": r.url },
} ),
{"hidden": "copy", "object_to_copy": r.url},
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -211,44 +228,48 @@ class ReactionViewTest(TestCase):
# Copy to the same package should fail # Copy to the same package should fail
response = self.client.post( response = self.client.post(
reverse("package detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package detail",
}), { kwargs={
"hidden": "copy", "package_uuid": str(r.package.uuid),
"object_to_copy": r.url },
} ),
{"hidden": "copy", "object_to_copy": r.url},
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], f"Can't copy object {reaction_url} to the same package!") self.assertEqual(
response.json()["error"], f"Can't copy object {reaction_url} to the same package!"
)
def test_references(self): def test_references(self):
ext_db, _ = ExternalDatabase.objects.get_or_create( ext_db, _ = ExternalDatabase.objects.get_or_create(
name='KEGG Reaction', name="KEGG Reaction",
defaults={ defaults={
'full_name': 'KEGG Reaction Database', "full_name": "KEGG Reaction Database",
'description': 'Database of biochemical reactions', "description": "Database of biochemical reactions",
'base_url': 'https://www.genome.jp', "base_url": "https://www.genome.jp",
'url_pattern': 'https://www.genome.jp/entry/{id}' "url_pattern": "https://www.genome.jp/entry/{id}",
} },
) )
ext_db2, _ = ExternalDatabase.objects.get_or_create( ext_db2, _ = ExternalDatabase.objects.get_or_create(
name='RHEA', name="RHEA",
defaults={ defaults={
'full_name': 'RHEA Reaction Database', "full_name": "RHEA Reaction Database",
'description': 'Comprehensive resource of biochemical reactions', "description": "Comprehensive resource of biochemical reactions",
'base_url': 'https://www.rhea-db.org', "base_url": "https://www.rhea-db.org",
'url_pattern': 'https://www.rhea-db.org/rhea/{id}' "url_pattern": "https://www.rhea-db.org/rhea/{id}",
}, },
) )
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -256,45 +277,49 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url) r = Reaction.objects.get(url=reaction_url)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid), kwargs={
}), { "package_uuid": str(r.package.uuid),
'selected-database': ext_db.pk, "reaction_uuid": str(r.uuid),
'identifier': 'C12345' },
} ),
{"selected-database": ext_db.pk, "identifier": "C12345"},
) )
self.assertEqual(r.external_identifiers.count(), 1) self.assertEqual(r.external_identifiers.count(), 1)
self.assertEqual(r.external_identifiers.first().database, ext_db) self.assertEqual(r.external_identifiers.first().database, ext_db)
self.assertEqual(r.external_identifiers.first().identifier_value, 'C12345') self.assertEqual(r.external_identifiers.first().identifier_value, "C12345")
# TODO Fixture contains old url template there the real test fails, use old value instead # TODO Fixture contains old url template there the real test fails, use old value instead
# self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/C12345') # self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/C12345')
self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/reaction+C12345') self.assertEqual(
r.external_identifiers.first().url, "https://www.genome.jp/entry/reaction+C12345"
)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid), kwargs={
}), { "package_uuid": str(r.package.uuid),
'selected-database': ext_db2.pk, "reaction_uuid": str(r.uuid),
'identifier': '60116' },
} ),
{"selected-database": ext_db2.pk, "identifier": "60116"},
) )
self.assertEqual(r.external_identifiers.count(), 2) self.assertEqual(r.external_identifiers.count(), 2)
self.assertEqual(r.external_identifiers.last().database, ext_db2) self.assertEqual(r.external_identifiers.last().database, ext_db2)
self.assertEqual(r.external_identifiers.last().identifier_value, '60116') self.assertEqual(r.external_identifiers.last().identifier_value, "60116")
self.assertEqual(r.external_identifiers.last().url, 'https://www.rhea-db.org/rhea/60116') self.assertEqual(r.external_identifiers.last().url, "https://www.rhea-db.org/rhea/60116")
def test_delete(self): def test_delete(self):
response = self.client.post( response = self.client.post(
reverse("reactions"), { reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001", "reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001", "reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl", "reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -302,12 +327,11 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url) r = Reaction.objects.get(url=reaction_url)
response = self.client.post( response = self.client.post(
reverse("package reaction detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package reaction detail",
'reaction_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
}), { ),
"hidden": "delete" {"hidden": "delete"},
}
) )
self.assertEqual(self.user1_default_package.reactions.count(), 0) self.assertEqual(self.user1_default_package.reactions.count(), 0)

View File

@ -12,22 +12,29 @@ class RuleViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(RuleViewTest, cls).setUpClass() super(RuleViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe", cls.user1 = UserManager.create_user(
set_setting=False, add_to_group=True, is_active=True) "user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack') cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self): def setUp(self):
self.client.force_login(self.user1) self.client.force_login(self.user1)
def test_create_rule(self): def test_create_rule(self):
response = self.client.post( response = self.client.post(
reverse("rules"), { reverse("rules"),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -38,18 +45,21 @@ class RuleViewTest(TestCase):
self.assertEqual(r.package, self.user1_default_package) self.assertEqual(r.package, self.user1_default_package)
self.assertEqual(r.name, "Test Rule") self.assertEqual(r.name, "Test Rule")
self.assertEqual(r.description, "Just a Description") self.assertEqual(r.description, "Just a Description")
self.assertEqual(r.smirks, self.assertEqual(
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]") r.smirks,
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
self.assertEqual(self.user1_default_package.rules.count(), 1) self.assertEqual(self.user1_default_package.rules.count(), 1)
# Adding the same rule again should return the existing one, hence not increasing the number of rules # Adding the same rule again should return the existing one, hence not increasing the number of rules
response = self.client.post( response = self.client.post(
reverse("rules"), { reverse("rules"),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.url, rule_url) self.assertEqual(response.url, rule_url)
@ -58,12 +68,13 @@ class RuleViewTest(TestCase):
# Adding the same rule in a different package should create a new rule # Adding the same rule in a different package should create a new rule
response = self.client.post( response = self.client.post(
reverse("package rule list", kwargs={'package_uuid': self.package.uuid}), { reverse("package rule list", kwargs={"package_uuid": self.package.uuid}),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -72,12 +83,13 @@ class RuleViewTest(TestCase):
# Edit # Edit
def test_edit_rule(self): def test_edit_rule(self):
response = self.client.post( response = self.client.post(
reverse("rules"), { reverse("rules"),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -86,13 +98,17 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url) r = Rule.objects.get(url=rule_url)
response = self.client.post( response = self.client.post(
reverse("package rule detail", kwargs={ reverse(
'package_uuid': str(self.user1_default_package.uuid), "package rule detail",
'rule_uuid': str(r.uuid) kwargs={
}), { "package_uuid": str(self.user1_default_package.uuid),
"rule_uuid": str(r.uuid),
},
),
{
"rule-name": "Test Rule Adjusted", "rule-name": "Test Rule Adjusted",
"rule-description": "New Description", "rule-description": "New Description",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -108,7 +124,7 @@ class RuleViewTest(TestCase):
"Test Desc", "Test Desc",
"2025-10", "2025-10",
"soil", "soil",
[Temperature(interval=Interval(start=20, end=30))] [Temperature(interval=Interval(start=20, end=30))],
) )
s2 = Scenario.create( s2 = Scenario.create(
@ -117,16 +133,17 @@ class RuleViewTest(TestCase):
"Test Desc2", "Test Desc2",
"2025-10", "2025-10",
"soil", "soil",
[Temperature(interval=Interval(start=10, end=20))] [Temperature(interval=Interval(start=10, end=20))],
) )
response = self.client.post( response = self.client.post(
reverse("rules"), { reverse("rules"),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -134,48 +151,48 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url) r = Rule.objects.get(url=rule_url)
response = self.client.post( response = self.client.post(
reverse("package rule detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package rule detail",
'rule_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
}), { ),
"selected-scenarios": [s1.url, s2.url] {"selected-scenarios": [s1.url, s2.url]},
}
) )
self.assertEqual(len(r.scenarios.all()), 2) self.assertEqual(len(r.scenarios.all()), 2)
response = self.client.post( response = self.client.post(
reverse("package rule detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package rule detail",
'rule_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
}), { ),
"selected-scenarios": [s1.url] {"selected-scenarios": [s1.url]},
}
) )
self.assertEqual(len(r.scenarios.all()), 1) self.assertEqual(len(r.scenarios.all()), 1)
self.assertEqual(r.scenarios.first().url, s1.url) self.assertEqual(r.scenarios.first().url, s1.url)
response = self.client.post( response = self.client.post(
reverse("package rule detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package rule detail",
'rule_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
}), { ),
{
# We have to set an empty string to avoid that the parameter is removed # We have to set an empty string to avoid that the parameter is removed
"selected-scenarios": "" "selected-scenarios": ""
} },
) )
self.assertEqual(len(r.scenarios.all()), 0) self.assertEqual(len(r.scenarios.all()), 0)
def test_copy(self): def test_copy(self):
response = self.client.post( response = self.client.post(
reverse("rules"), { reverse("rules"),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -183,12 +200,13 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url) r = Rule.objects.get(url=rule_url)
response = self.client.post( response = self.client.post(
reverse("package detail", kwargs={ reverse(
'package_uuid': str(self.package.uuid), "package detail",
}), { kwargs={
"hidden": "copy", "package_uuid": str(self.package.uuid),
"object_to_copy": r.url },
} ),
{"hidden": "copy", "object_to_copy": r.url},
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -202,26 +220,29 @@ class RuleViewTest(TestCase):
# Copy to the same package should fail # Copy to the same package should fail
response = self.client.post( response = self.client.post(
reverse("package detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package detail",
}), { kwargs={
"hidden": "copy", "package_uuid": str(r.package.uuid),
"object_to_copy": r.url },
} ),
{"hidden": "copy", "object_to_copy": r.url},
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], f"Can't copy object {rule_url} to the same package!") self.assertEqual(
response.json()["error"], f"Can't copy object {rule_url} to the same package!"
)
def test_delete(self): def test_delete(self):
response = self.client.post( response = self.client.post(
reverse("rules"), { reverse("rules"),
{
"rule-name": "Test Rule", "rule-name": "Test Rule",
"rule-description": "Just a Description", "rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]", "rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule", "rule-type": "SimpleAmbitRule",
} },
) )
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
@ -229,12 +250,11 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url) r = Rule.objects.get(url=rule_url)
response = self.client.post( response = self.client.post(
reverse("package rule detail", kwargs={ reverse(
'package_uuid': str(r.package.uuid), "package rule detail",
'rule_uuid': str(r.uuid) kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
}), { ),
"hidden": "delete" {"hidden": "delete"},
}
) )
self.assertEqual(self.user1_default_package.rules.count(), 0) self.assertEqual(self.user1_default_package.rules.count(), 0)

View File

@ -11,70 +11,81 @@ class UserViewTest(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(UserViewTest, cls).setUpClass() super(UserViewTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous') cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc') cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name='Fixtures') cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_login_with_valid_credentials(self): def test_login_with_valid_credentials(self):
response = self.client.post(reverse("login"), { response = self.client.post(
reverse("login"),
{
"username": "user0", "username": "user0",
"password": 'SuperSafe', "password": "SuperSafe",
}) },
)
self.assertRedirects(response, reverse("index")) self.assertRedirects(response, reverse("index"))
self.assertTrue(response.wsgi_request.user.is_authenticated) self.assertTrue(response.wsgi_request.user.is_authenticated)
def test_login_with_invalid_credentials(self): def test_login_with_invalid_credentials(self):
response = self.client.post(reverse("login"), { response = self.client.post(
reverse("login"),
{
"username": "user0", "username": "user0",
"password": "wrongpassword", "password": "wrongpassword",
}) },
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertFalse(response.wsgi_request.user.is_authenticated) self.assertFalse(response.wsgi_request.user.is_authenticated)
def test_register(self): def test_register(self):
response = self.client.post(reverse("register"), { response = self.client.post(
reverse("register"),
{
"username": "user1", "username": "user1",
"email": "user1@envipath.com", "email": "user1@envipath.com",
"password": "SuperSafe", "password": "SuperSafe",
"rpassword": "SuperSafe", "rpassword": "SuperSafe",
}) },
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
# TODO currently fails as the fixture does not provide a global setting... # TODO currently fails as the fixture does not provide a global setting...
self.assertContains(response, "Registration failed!") self.assertContains(response, "Registration failed!")
def test_register_password_mismatch(self): def test_register_password_mismatch(self):
response = self.client.post(reverse("register"), { response = self.client.post(
reverse("register"),
{
"username": "user1", "username": "user1",
"email": "user1@envipath.com", "email": "user1@envipath.com",
"password": "SuperSafe", "password": "SuperSafe",
"rpassword": "SuperSaf3", "rpassword": "SuperSaf3",
}) },
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, "Registration failed, provided passwords differ") self.assertContains(response, "Registration failed, provided passwords differ")
def test_logout(self): def test_logout(self):
response = self.client.post(reverse("login"), { response = self.client.post(
"username": "user0", reverse("login"), {"username": "user0", "password": "SuperSafe", "login": "true"}
"password": 'SuperSafe', )
"login": "true"
})
self.assertTrue(response.wsgi_request.user.is_authenticated) self.assertTrue(response.wsgi_request.user.is_authenticated)
response = self.client.post(reverse('logout'), { response = self.client.post(
reverse("logout"),
{
"logout": "true", "logout": "true",
}) },
)
self.assertFalse(response.wsgi_request.user.is_authenticated) self.assertFalse(response.wsgi_request.user.is_authenticated)
def test_next_param_properly_handled(self): def test_next_param_properly_handled(self):
response = self.client.get(reverse('packages')) response = self.client.get(reverse("packages"))
self.assertRedirects(response, f"{reverse('login')}/?next=/package") self.assertRedirects(response, f"{reverse('login')}/?next=/package")
response = self.client.post(reverse('login'), { response = self.client.post(
"username": "user0", reverse("login"),
"password": 'SuperSafe', {"username": "user0", "password": "SuperSafe", "login": "true", "next": "/package"},
"login": "true", )
"next": "/package"
})
self.assertRedirects(response, reverse('packages')) self.assertRedirects(response, reverse("packages"))

View File

@ -1,13 +0,0 @@
import abc
from enviPy.epdb import Pathway
class PredictionSchema(abc.ABC):
pass
class DFS(PredictionSchema):
def __init__(self, pw: Pathway, settings=None):
self.setting = settings or pw.prediction_settings
def predict(self):
pass

View File

@ -2,12 +2,11 @@ import logging
import re import re
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Dict from typing import List, Optional, Dict, TYPE_CHECKING
from indigo import Indigo, IndigoException, IndigoObject from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer from indigo.renderer import IndigoRenderer
from rdkit import Chem from rdkit import Chem, rdBase
from rdkit import RDLogger
from rdkit.Chem import MACCSkeys, Descriptors from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.Chem import rdChemReactions from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.Draw import rdMolDraw2D
@ -15,9 +14,11 @@ from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.rdmolops import GetMolFrags from rdkit.Chem.rdmolops import GetMolFrags
from rdkit.Contrib.IFG import ifg from rdkit.Contrib.IFG import ifg
logger = logging.getLogger(__name__) if TYPE_CHECKING:
RDLogger.DisableLog('rdApp.*') from epdb.models import Rule
logger = logging.getLogger(__name__)
rdBase.DisableLog("rdApp.*")
# from rdkit import rdBase # from rdkit import rdBase
# rdBase.LogToPythonLogger() # rdBase.LogToPythonLogger()
@ -28,7 +29,6 @@ RDLogger.DisableLog('rdApp.*')
class ProductSet(object): class ProductSet(object):
def __init__(self, product_set: List[str]): def __init__(self, product_set: List[str]):
self.product_set = product_set self.product_set = product_set
@ -42,15 +42,18 @@ class ProductSet(object):
return iter(self.product_set) return iter(self.product_set)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(other.product_set) return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(
other.product_set
)
def __hash__(self): def __hash__(self):
return hash('-'.join(sorted(self.product_set))) return hash("-".join(sorted(self.product_set)))
class PredictionResult(object): class PredictionResult(object):
def __init__(
def __init__(self, product_sets: List['ProductSet'], probability: float, rule: Optional['Rule'] = None): self, product_sets: List["ProductSet"], probability: float, rule: Optional["Rule"] = None
):
self.product_sets = product_sets self.product_sets = product_sets
self.probability = probability self.probability = probability
self.rule = rule self.rule = rule
@ -66,7 +69,6 @@ class PredictionResult(object):
class FormatConverter(object): class FormatConverter(object):
@staticmethod @staticmethod
def mass(smiles): def mass(smiles):
return Descriptors.MolWt(FormatConverter.from_smiles(smiles)) return Descriptors.MolWt(FormatConverter.from_smiles(smiles))
@ -127,7 +129,7 @@ class FormatConverter(object):
if kekulize: if kekulize:
try: try:
mol = Chem.Kekulize(mol) mol = Chem.Kekulize(mol)
except: except Exception:
mol = Chem.Mol(mol.ToBinary()) mol = Chem.Mol(mol.ToBinary())
if not mol.GetNumConformers(): if not mol.GetNumConformers():
@ -139,8 +141,8 @@ class FormatConverter(object):
opts.clearBackground = False opts.clearBackground = False
drawer.DrawMolecule(mol) drawer.DrawMolecule(mol)
drawer.FinishDrawing() drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace('svg:', '') svg = drawer.GetDrawingText().replace("svg:", "")
svg = re.sub("<\?xml.*\?>", '', svg) svg = re.sub("<\?xml.*\?>", "", svg)
return svg return svg
@ -151,7 +153,7 @@ class FormatConverter(object):
if kekulize: if kekulize:
try: try:
Chem.Kekulize(mol) Chem.Kekulize(mol)
except: except Exception:
mc = Chem.Mol(mol.ToBinary()) mc = Chem.Mol(mol.ToBinary())
if not mc.GetNumConformers(): if not mc.GetNumConformers():
@ -178,7 +180,7 @@ class FormatConverter(object):
smiles = tmp_smiles smiles = tmp_smiles
if change is False: if change is False:
print(f"nothing changed") print("nothing changed")
return smiles return smiles
@ -198,7 +200,9 @@ class FormatConverter(object):
parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol)
# try to neutralize molecule # try to neutralize molecule
uncharger = rdMolStandardize.Uncharger() # annoying, but necessary as no convenience method exists uncharger = (
rdMolStandardize.Uncharger()
) # annoying, but necessary as no convenience method exists
uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol)
# note that no attempt is made at reionization at this step # note that no attempt is made at reionization at this step
@ -239,17 +243,24 @@ class FormatConverter(object):
try: try:
rdChemReactions.ReactionFromSmarts(smirks) rdChemReactions.ReactionFromSmarts(smirks)
return True return True
except: except Exception:
return False return False
@staticmethod @staticmethod
def apply(smiles: str, smirks: str, preprocess_smiles: bool = True, bracketize: bool = True, def apply(
standardize: bool = True, kekulize: bool = True, remove_stereo: bool = True) -> List['ProductSet']: smiles: str,
logger.debug(f'Applying {smirks} on {smiles}') smirks: str,
preprocess_smiles: bool = True,
bracketize: bool = True,
standardize: bool = True,
kekulize: bool = True,
remove_stereo: bool = True,
) -> List["ProductSet"]:
logger.debug(f"Applying {smirks} on {smiles}")
# If explicitly wanted or rule generates multiple products add brackets around products to capture all # If explicitly wanted or rule generates multiple products add brackets around products to capture all
if bracketize: # or "." in smirks: if bracketize: # or "." in smirks:
smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")" smirks = smirks.split(">>")[0] + ">>(" + smirks.split(">>")[1] + ")"
# List of ProductSet objects # List of ProductSet objects
pss = set() pss = set()
@ -274,7 +285,9 @@ class FormatConverter(object):
Chem.SanitizeMol(product) Chem.SanitizeMol(product)
product = GetMolFrags(product, asMols=True) product = GetMolFrags(product, asMols=True)
for p in product: for p in product:
p = FormatConverter.standardize(Chem.MolToSmiles(p), remove_stereo=remove_stereo) p = FormatConverter.standardize(
Chem.MolToSmiles(p), remove_stereo=remove_stereo
)
prods.append(p) prods.append(p)
# if kekulize: # if kekulize:
@ -300,9 +313,8 @@ class FormatConverter(object):
# # bond.SetIsAromatic(False) # # bond.SetIsAromatic(False)
# Chem.Kekulize(product) # Chem.Kekulize(product)
except ValueError as e: except ValueError as e:
logger.error(f'Sanitizing and converting failed:\n{e}') logger.error(f"Sanitizing and converting failed:\n{e}")
continue continue
if len(prods): if len(prods):
@ -310,7 +322,7 @@ class FormatConverter(object):
pss.add(ps) pss.add(ps)
except Exception as e: except Exception as e:
logger.error(f'Applying {smirks} on {smiles} failed:\n{e}') logger.error(f"Applying {smirks} on {smiles} failed:\n{e}")
return pss return pss
@ -340,22 +352,19 @@ class FormatConverter(object):
smi_p = Chem.MolToSmiles(mol, kekuleSmiles=True) smi_p = Chem.MolToSmiles(mol, kekuleSmiles=True)
smi_p = Chem.CanonSmiles(smi_p) smi_p = Chem.CanonSmiles(smi_p)
if '~' in smi_p: if "~" in smi_p:
smi_p1 = smi_p.replace('~', '') smi_p1 = smi_p.replace("~", "")
parsed_smiles.append(smi_p1) parsed_smiles.append(smi_p1)
else: else:
parsed_smiles.append(smi_p) parsed_smiles.append(smi_p)
except Exception as e: except Exception:
errors += 1 errors += 1
pass pass
return parsed_smiles, errors return parsed_smiles, errors
class Standardizer(ABC): class Standardizer(ABC):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@ -364,7 +373,6 @@ class Standardizer(ABC):
class RuleStandardizer(Standardizer): class RuleStandardizer(Standardizer):
def __init__(self, name, smirks): def __init__(self, name, smirks):
super().__init__(name) super().__init__(name)
self.smirks = smirks self.smirks = smirks
@ -373,8 +381,8 @@ class RuleStandardizer(Standardizer):
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks))) standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks)))
if len(standardized_smiles) > 1: if len(standardized_smiles) > 1:
logger.warning(f'{self.smirks} generated more than 1 compound {standardized_smiles}') logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
print(f'{self.smirks} generated more than 1 compound {standardized_smiles}') print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
standardized_smiles = standardized_smiles[:1] standardized_smiles = standardized_smiles[:1]
if standardized_smiles: if standardized_smiles:
@ -384,7 +392,6 @@ class RuleStandardizer(Standardizer):
class RegExStandardizer(Standardizer): class RegExStandardizer(Standardizer):
def __init__(self, name, replacements: dict): def __init__(self, name, replacements: dict):
super().__init__(name) super().__init__(name)
self.replacements = replacements self.replacements = replacements
@ -404,28 +411,39 @@ class RegExStandardizer(Standardizer):
return super().standardize(smi) return super().standardize(smi)
FLATTEN = [ FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]
RegExStandardizer("Remove Stereo", {"@": ""})
]
UN_CIS_TRANS = [ UN_CIS_TRANS = [RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})]
RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})
]
BASIC = [ BASIC = [
RuleStandardizer("ammoniumstandardization", "[H][N+:1]([H])([H])[#6:2]>>[H][#7:1]([H])-[#6:2]"), RuleStandardizer("ammoniumstandardization", "[H][N+:1]([H])([H])[#6:2]>>[H][#7:1]([H])-[#6:2]"),
RuleStandardizer("cyanate", "[H][#8:1][C:2]#[N:3]>>[#8-:1][C:2]#[N:3]"), RuleStandardizer("cyanate", "[H][#8:1][C:2]#[N:3]>>[#8-:1][C:2]#[N:3]"),
RuleStandardizer("deprotonatecarboxyls", "[H][#8:1]-[#6:2]=[O:3]>>[#8-:1]-[#6:2]=[O:3]"), RuleStandardizer("deprotonatecarboxyls", "[H][#8:1]-[#6:2]=[O:3]>>[#8-:1]-[#6:2]=[O:3]"),
RuleStandardizer("forNOOH", "[H][#8:1]-[#7+:2](-[*:3])=[O:4]>>[#8-:1]-[#7+:2](-[*:3])=[O:4]"), RuleStandardizer("forNOOH", "[H][#8:1]-[#7+:2](-[*:3])=[O:4]>>[#8-:1]-[#7+:2](-[*:3])=[O:4]"),
RuleStandardizer("Hydroxylprotonation", "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]"), RuleStandardizer(
RuleStandardizer("phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"), "Hydroxylprotonation",
RuleStandardizer("PicricAcid", "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]",
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]"), ),
RuleStandardizer("Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"), RuleStandardizer(
RuleStandardizer("Sulfate2", "phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]"), ),
RuleStandardizer("Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"), RuleStandardizer(
RuleStandardizer("Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"), "PicricAcid",
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]",
),
RuleStandardizer(
"Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"
),
RuleStandardizer(
"Sulfate2",
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]",
),
RuleStandardizer(
"Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"
),
RuleStandardizer(
"Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"
),
] ]
ENHANCED = BASIC + [ ENHANCED = BASIC + [
@ -433,28 +451,30 @@ ENHANCED = BASIC + [
] ]
EXOTIC = ENHANCED + [ EXOTIC = ENHANCED + [
RuleStandardizer("ThioPhosphate1", "[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]") RuleStandardizer(
"ThioPhosphate1",
"[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]",
)
] ]
COA_CUTTER = [ COA_CUTTER = [
RuleStandardizer("CutCoEnzymeAOff", RuleStandardizer(
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]") "CutCoEnzymeAOff",
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]",
)
] ]
ENOL_KETO = [ ENOL_KETO = [RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")]
RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")
]
MATCH_STANDARDIZER = EXOTIC + FLATTEN + UN_CIS_TRANS + COA_CUTTER + ENOL_KETO MATCH_STANDARDIZER = EXOTIC + FLATTEN + UN_CIS_TRANS + COA_CUTTER + ENOL_KETO
class IndigoUtils(object): class IndigoUtils(object):
@staticmethod @staticmethod
def layout(mol_data): def layout(mol_data):
i = Indigo() i = Indigo()
try: try:
if mol_data.startswith('$RXN') or '>>' in mol_data: if mol_data.startswith("$RXN") or ">>" in mol_data:
rxn = i.loadQueryReaction(mol_data) rxn = i.loadQueryReaction(mol_data)
rxn.layout() rxn.layout()
return rxn.rxnfile() return rxn.rxnfile()
@ -462,14 +482,14 @@ class IndigoUtils(object):
mol = i.loadQueryMolecule(mol_data) mol = i.loadQueryMolecule(mol_data)
mol.layout() mol.layout()
return mol.molfile() return mol.molfile()
except IndigoException as e: except IndigoException:
try: try:
logger.info("layout() failed, trying loadReactionSMARTS as fallback!") logger.info("layout() failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data) rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.layout() rxn.layout()
return rxn.molfile() return rxn.molfile()
except IndigoException as e2: except IndigoException as e2:
logger.error(f'layout() failed due to {e2}!') logger.error(f"layout() failed due to {e2}!")
@staticmethod @staticmethod
def load_reaction_SMARTS(mol): def load_reaction_SMARTS(mol):
@ -479,7 +499,7 @@ class IndigoUtils(object):
def aromatize(mol_data, is_query): def aromatize(mol_data, is_query):
i = Indigo() i = Indigo()
try: try:
if mol_data.startswith('$RXN'): if mol_data.startswith("$RXN"):
if is_query: if is_query:
rxn = i.loadQueryReaction(mol_data) rxn = i.loadQueryReaction(mol_data)
else: else:
@ -495,20 +515,20 @@ class IndigoUtils(object):
mol.aromatize() mol.aromatize()
return mol.molfile() return mol.molfile()
except IndigoException as e: except IndigoException:
try: try:
logger.info("Aromatizing failed, trying loadReactionSMARTS as fallback!") logger.info("Aromatizing failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data) rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.aromatize() rxn.aromatize()
return rxn.molfile() return rxn.molfile()
except IndigoException as e2: except IndigoException as e2:
logger.error(f'Aromatizing failed due to {e2}!') logger.error(f"Aromatizing failed due to {e2}!")
@staticmethod @staticmethod
def dearomatize(mol_data, is_query): def dearomatize(mol_data, is_query):
i = Indigo() i = Indigo()
try: try:
if mol_data.startswith('$RXN'): if mol_data.startswith("$RXN"):
if is_query: if is_query:
rxn = i.loadQueryReaction(mol_data) rxn = i.loadQueryReaction(mol_data)
else: else:
@ -524,14 +544,14 @@ class IndigoUtils(object):
mol.dearomatize() mol.dearomatize()
return mol.molfile() return mol.molfile()
except IndigoException as e: except IndigoException:
try: try:
logger.info("De-Aromatizing failed, trying loadReactionSMARTS as fallback!") logger.info("De-Aromatizing failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data) rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.dearomatize() rxn.dearomatize()
return rxn.molfile() return rxn.molfile()
except IndigoException as e2: except IndigoException as e2:
logger.error(f'De-Aromatizing failed due to {e2}!') logger.error(f"De-Aromatizing failed due to {e2}!")
@staticmethod @staticmethod
def sanitize_functional_group(functional_group: str): def sanitize_functional_group(functional_group: str):
@ -543,7 +563,7 @@ class IndigoUtils(object):
# special environment handling (amines, hydroxy, esters, ethers) # special environment handling (amines, hydroxy, esters, ethers)
# the higher substituted should not contain H env. # the higher substituted should not contain H env.
if functional_group == '[C]=O': if functional_group == "[C]=O":
functional_group = "[H][C](=O)[CX4,c]" functional_group = "[H][C](=O)[CX4,c]"
# aldamines # aldamines
@ -577,15 +597,20 @@ class IndigoUtils(object):
functional_group = "[nH1,nX2](a)a" # pyrrole (with H) or pyridine (no other connections); currently overlaps with neighboring aromatic atoms functional_group = "[nH1,nX2](a)a" # pyrrole (with H) or pyridine (no other connections); currently overlaps with neighboring aromatic atoms
# substituted aromatic nitrogen # substituted aromatic nitrogen
functional_group = functional_group.replace("N*(R)R", functional_group = functional_group.replace(
"n(a)a") # substituent will be before N*; currently overlaps with neighboring aromatic atoms "N*(R)R", "n(a)a"
) # substituent will be before N*; currently overlaps with neighboring aromatic atoms
# pyridinium # pyridinium
if functional_group == "RN*(R)(R)(R)R": if functional_group == "RN*(R)(R)(R)R":
functional_group = "[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms functional_group = (
"[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
)
# N-oxide # N-oxide
if functional_group == "[H]ON*(R)(R)(R)R": if functional_group == "[H]ON*(R)(R)(R)R":
functional_group = "[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms functional_group = (
"[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
)
# other aromatic hetero atoms # other aromatic hetero atoms
functional_group = functional_group.replace("C*", "c") functional_group = functional_group.replace("C*", "c")
@ -598,7 +623,9 @@ class IndigoUtils(object):
# other replacement, to accomodate for the standardization rules in enviPath # other replacement, to accomodate for the standardization rules in enviPath
# This is not the perfect way to do it; there should be a way to replace substructure SMARTS in SMARTS? # This is not the perfect way to do it; there should be a way to replace substructure SMARTS in SMARTS?
# nitro groups are broken, due to charge handling. this SMARTS matches both forms (formal charges and hypervalent); Ertl-CDK still treats both forms separately... # nitro groups are broken, due to charge handling. this SMARTS matches both forms (formal charges and hypervalent); Ertl-CDK still treats both forms separately...
functional_group = functional_group.replace("[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]") functional_group = functional_group.replace(
"[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]"
)
functional_group = functional_group.replace("O=N(=O)R", "[CX4,c][NX3](~[OX1])~[OX1]") functional_group = functional_group.replace("O=N(=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
# carboxylic acid: this SMARTS matches both neutral and anionic form; includes COOH in larger functional_groups # carboxylic acid: this SMARTS matches both neutral and anionic form; includes COOH in larger functional_groups
functional_group = functional_group.replace("[H]OC(=O)", "[OD1]C(=O)") functional_group = functional_group.replace("[H]OC(=O)", "[OD1]C(=O)")
@ -616,7 +643,9 @@ class IndigoUtils(object):
return functional_group return functional_group
@staticmethod @staticmethod
def _colorize(indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool): def _colorize(
indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool
):
indigo.setOption("render-atom-color-property", "color") indigo.setOption("render-atom-color-property", "color")
indigo.setOption("aromaticity-model", "generic") indigo.setOption("aromaticity-model", "generic")
@ -646,7 +675,6 @@ class IndigoUtils(object):
for match in matcher.iterateMatches(query): for match in matcher.iterateMatches(query):
if match is not None: if match is not None:
for atom in query.iterateAtoms(): for atom in query.iterateAtoms():
mappedAtom = match.mapAtom(atom) mappedAtom = match.mapAtom(atom)
if mappedAtom is None or mappedAtom.index() in environment: if mappedAtom is None or mappedAtom.index() in environment:
@ -655,7 +683,7 @@ class IndigoUtils(object):
counts[mappedAtom.index()] = max(v, counts[mappedAtom.index()]) counts[mappedAtom.index()] = max(v, counts[mappedAtom.index()])
except IndigoException as e: except IndigoException as e:
logger.debug(f'Colorizing failed due to {e}') logger.debug(f"Colorizing failed due to {e}")
for k, v in counts.items(): for k, v in counts.items():
if is_reaction: if is_reaction:
@ -669,8 +697,9 @@ class IndigoUtils(object):
molecule.addDataSGroup([k], [], "color", color) molecule.addDataSGroup([k], [], "color", color)
@staticmethod @staticmethod
def mol_to_svg(mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None): def mol_to_svg(
mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None
):
if functional_groups is None: if functional_groups is None:
functional_groups = {} functional_groups = {}
@ -682,7 +711,7 @@ class IndigoUtils(object):
i.setOption("render-image-size", width, height) i.setOption("render-image-size", width, height)
i.setOption("render-bond-line-width", 2.0) i.setOption("render-bond-line-width", 2.0)
if '~' in mol_data: if "~" in mol_data:
mol = i.loadSmarts(mol_data) mol = i.loadSmarts(mol_data)
else: else:
mol = i.loadMolecule(mol_data) mol = i.loadMolecule(mol_data)
@ -690,11 +719,17 @@ class IndigoUtils(object):
if len(functional_groups.keys()) > 0: if len(functional_groups.keys()) > 0:
IndigoUtils._colorize(i, mol, functional_groups, False) IndigoUtils._colorize(i, mol, functional_groups, False)
return renderer.renderToBuffer(mol).decode('UTF-8') return renderer.renderToBuffer(mol).decode("UTF-8")
@staticmethod @staticmethod
def smirks_to_svg(smirks: str, is_query_smirks, width: int = 0, height: int = 0, def smirks_to_svg(
educt_functional_groups: Dict[str, int] = None, product_functional_groups: Dict[str, int] = None): smirks: str,
is_query_smirks,
width: int = 0,
height: int = 0,
educt_functional_groups: Dict[str, int] = None,
product_functional_groups: Dict[str, int] = None,
):
if educt_functional_groups is None: if educt_functional_groups is None:
educt_functional_groups = {} educt_functional_groups = {}
@ -721,18 +756,18 @@ class IndigoUtils(object):
for prod in obj.iterateProducts(): for prod in obj.iterateProducts():
IndigoUtils._colorize(i, prod, product_functional_groups, True) IndigoUtils._colorize(i, prod, product_functional_groups, True)
return renderer.renderToBuffer(obj).decode('UTF-8') return renderer.renderToBuffer(obj).decode("UTF-8")
if __name__ == '__main__': if __name__ == "__main__":
data = { data = {
"struct": "\n Ketcher 2172510 12D 1 1.00000 0.00000 0\n\n 6 6 0 0 0 999 V2000\n 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 1 2 2 0 0 0 0\n 2 3 1 0 0 0 0\n 3 4 2 0 0 0 0\n 4 5 1 0 0 0 0\n 5 6 2 0 0 0 0\n 6 1 1 0 0 0 0\nM END\n", "struct": "\n Ketcher 2172510 12D 1 1.00000 0.00000 0\n\n 6 6 0 0 0 999 V2000\n 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 1 2 2 0 0 0 0\n 2 3 1 0 0 0 0\n 3 4 2 0 0 0 0\n 4 5 1 0 0 0 0\n 5 6 2 0 0 0 0\n 6 1 1 0 0 0 0\nM END\n",
"options": { "options": {
"smart-layout": True, "smart-layout": True,
"ignore-stereochemistry-errors": True, "ignore-stereochemistry-errors": True,
"mass-skip-error-on-pseudoatoms": False, "mass-skip-error-on-pseudoatoms": False,
"gross-formula-add-rsites": True "gross-formula-add-rsites": True,
} },
} }
print(IndigoUtils.aromatize(data['struct'], False)) print(IndigoUtils.aromatize(data["struct"], False))

View File

@ -1,83 +0,0 @@
import json
import requests
class AMBITResult:
def __init__(self, *args, **kwargs):
self.smiles = kwargs['smiles']
self.tps = []
for bt in kwargs['products']:
if len(bt['products']):
self.tps.append(bt)
self.probs = None
def __str__(self):
x = self.smiles + "\n"
total_bts = len(self.tps)
for i, tp in enumerate(self.tps):
prob = ""
if self.probs:
prob = f" (p={self.probs[tp['id']]})"
if i == total_bts - 1:
x += f"\t└── {tp['name']}{prob}\n"
else:
x += f"\t├── {tp['name']}{prob}\n"
total_products = len(tp['products'])
for j, p in enumerate(tp['products']):
if j == total_products - 1:
if i == total_bts - 1:
x += f"\t\t└── {p}"
else:
x += f"\t\t└── {p}\n"
else:
if i == total_bts - 1:
x += f"\t\t├── {p}\n"
else:
x += f"\t\t├── {p}\n"
return x
def set_probs(self, probs):
self.probs = probs
class AMBIT:
def __init__(self, host, rules=None):
self.host = host
self.rules = rules
self.ambit_params = {
'singlePos': True,
'split': False,
}
def batch_apply(self, smiles: list):
payload = {
'smiles': smiles,
'rules': self.rules,
}
payload.update(**self.ambit_params)
res = self._execute(payload)
tps = list()
for r in res['result']:
ar = AMBITResult(**r)
if len(ar.tps):
tps.append(ar)
else:
tps.append(None)
return tps
def apply(self, smiles: str):
return self.batch_apply([smiles])[0]
def _execute(self, payload):
res = requests.post(self.host + '/ambit', data=json.dumps(payload))
res.raise_for_status()
return res.json()

View File

@ -8,9 +8,9 @@ from epdb.models import Package
# Map HTTP methods to required permissions # Map HTTP methods to required permissions
DEFAULT_METHOD_PERMISSIONS = { DEFAULT_METHOD_PERMISSIONS = {
'GET': 'read', "GET": "read",
'POST': 'write', "POST": "write",
'DELETE': 'write', "DELETE": "write",
} }
@ -22,6 +22,7 @@ def package_permission_required(method_permissions=None):
@wraps(view_func) @wraps(view_func)
def _wrapped_view(request, package_uuid, *args, **kwargs): def _wrapped_view(request, package_uuid, *args, **kwargs):
from epdb.views import _anonymous_or_real from epdb.views import _anonymous_or_real
user = _anonymous_or_real(request) user = _anonymous_or_real(request)
permission_required = method_permissions[request.method] permission_required = method_permissions[request.method]
@ -30,11 +31,12 @@ def package_permission_required(method_permissions=None):
if not PackageManager.has_package_permission(user, package_uuid, permission_required): if not PackageManager.has_package_permission(user, package_uuid, permission_required):
from epdb.views import error from epdb.views import error
return error( return error(
request, request,
"Operation failed!", "Operation failed!",
f"Couldn't perform the desired operation as {user.username} does not have the required permissions!", f"Couldn't perform the desired operation as {user.username} does not have the required permissions!",
code=403 code=403,
) )
return view_func(request, package_uuid, *args, **kwargs) return view_func(request, package_uuid, *args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,35 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import numpy as np
from numpy.random import default_rng
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
import logging import logging
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import List, Dict, Set, Tuple from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
import networkx as nx import networkx as nx
import numpy as np
from numpy.random import default_rng
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from utilities.chem import FormatConverter, PredictionResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from dataclasses import dataclass, field from epdb.models import Rule, CompoundStructure, Reaction
from utilities.chem import FormatConverter, PredictionResult
class Dataset: class Dataset:
def __init__(
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None): self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns self.columns: List[str] = columns
self.num_labels: int = num_labels self.num_labels: int = num_labels
@ -41,9 +39,9 @@ class Dataset:
self.data = data self.data = data
self.num_features: int = len(columns) - self.num_labels self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices('feature_') self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices('trig_') self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices('obs_') self._observed: Tuple[int, int] = self._block_indices("obs_")
def _block_indices(self, prefix) -> Tuple[int, int]: def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = [] indices: List[int] = []
@ -62,7 +60,7 @@ class Dataset:
self.data.append(row) self.data.append(row)
def times_triggered(self, rule_uuid) -> int: def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f'trig_{rule_uuid}') idx = self.columns.index(f"trig_{rule_uuid}")
times_triggered = 0 times_triggered = 0
for row in self.data: for row in self.data:
@ -89,12 +87,12 @@ class Dataset:
def __iter__(self): def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data)) return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]: self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = [] classify_data = []
classify_products = [] classify_products = []
for struct in structures: for struct in structures:
if isinstance(struct, str): if isinstance(struct, str):
struct_id = None struct_id = None
struct_smiles = struct struct_smiles = struct
@ -119,10 +117,14 @@ class Dataset:
classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods) classify_products.append(prods)
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products return Dataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod @staticmethod
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset: def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set() _structures = set()
for r in reactions: for r in reactions:
@ -155,12 +157,11 @@ class Dataset:
for prod_set in product_sets: for prod_set in product_sets:
for smi in prod_set: for smi in prod_set:
try: try:
smi = FormatConverter.standardize(smi, remove_stereo=True) smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception: except Exception:
# :shrug: # :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}') logger.debug(f"Standardizing SMILES failed for {smi}")
pass pass
triggered[key].add(smi) triggered[key].add(smi)
@ -188,7 +189,7 @@ class Dataset:
smi = FormatConverter.standardize(smi, remove_stereo=True) smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e: except Exception as e:
# :shrug: # :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}') logger.debug(f"Standardizing SMILES failed for {smi}")
pass pass
standardized_products.append(smi) standardized_products.append(smi)
@ -224,19 +225,22 @@ class Dataset:
obs.append(0) obs.append(0)
if ds is None: if ds is None:
header = ['structure_id'] + \ header = (
[f'feature_{i}' for i, _ in enumerate(feat)] \ ["structure_id"]
+ [f'trig_{r.uuid}' for r in applicable_rules] \ + [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f'obs_{r.uuid}' for r in applicable_rules] + [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules)) ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs) ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds return ds
def X(self, exclude_id_col=True, na_replacement=0): def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))) res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None: if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res] res = [[x if x is not None else na_replacement for x in row] for row in res]
return res return res
@ -247,14 +251,12 @@ class Dataset:
res = [[x if x is not None else na_replacement for x in row] for row in res] res = [[x if x is not None else na_replacement for x in row] for row in res]
return res return res
def y(self, na_replacement=0): def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None))) res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None: if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res] res = [[x if x is not None else na_replacement for x in row] for row in res]
return res return res
def __getitem__(self, key): def __getitem__(self, key):
if not isinstance(key, tuple): if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]") raise TypeError("Dataset must be indexed with dataset[rows, columns]")
@ -271,42 +273,50 @@ class Dataset:
if isinstance(col_key, int): if isinstance(col_key, int):
res = [row[col_key] for row in rows] res = [row[col_key] for row in rows]
else: else:
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice) res = [
else [row[i] for i in col_key] for row in rows] [row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res return res
def save(self, path: 'Path'): def save(self, path: "Path"):
import pickle import pickle
with open(path, "wb") as fh: with open(path, "wb") as fh:
pickle.dump(self, fh) pickle.dump(self, fh)
@staticmethod @staticmethod
def load(path: 'Path') -> 'Dataset': def load(path: "Path") -> "Dataset":
import pickle import pickle
return pickle.load(open(path, "rb")) return pickle.load(open(path, "rb"))
def to_arff(self, path: 'Path'): def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n" arff += "\n"
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]: for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
if c == 'structure_id': if c == "structure_id":
arff += f"@attribute {c} string\n" arff += f"@attribute {c} string\n"
else: else:
arff += f"@attribute {c} {{0,1}}\n" arff += f"@attribute {c} {{0,1}}\n"
arff += f"\n@data\n" arff += "\n@data\n"
for d in self.data: for d in self.data:
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]]) ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]]) xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f'{ys},{xs}\n' arff += f"{ys},{xs}\n"
with open(path, "w") as fh: with open(path, "w") as fh:
fh.write(arff) fh.write(arff)
fh.flush() fh.flush()
def __repr__(self): def __repr__(self):
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>" return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class SparseLabelECC(BaseEstimator, ClassifierMixin): class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -315,8 +325,11 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
Removes labels that are constant across all samples in training. Removes labels that are constant across all samples in training.
""" """
def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42), def __init__(
num_chains: int = 10): self,
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
num_chains: int = 10,
):
self.base_clf = base_clf self.base_clf = base_clf
self.num_chains = num_chains self.num_chains = num_chains
@ -384,16 +397,16 @@ class BinaryRelevance:
if self.classifiers is None: if self.classifiers is None:
self.classifiers = [] self.classifiers = []
for l in range(len(Y[0])): for label in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, l])] X_l = X[~np.isnan(Y[:, label])]
Y_l = (Y[~np.isnan(Y[:, l]), l]) Y_l = Y[~np.isnan(Y[:, label]), label]
if len(X_l) == 0: # all labels are nan -> predict 0 if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0) clf = DummyClassifier(strategy="constant", constant=0)
clf.fit([X[0]], [0]) clf.fit([X[0]], [0])
self.classifiers.append(clf) self.classifiers.append(clf)
continue continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent') clf = DummyClassifier(strategy="most_frequent")
else: else:
clf = copy.deepcopy(self.clf) clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l) clf.fit(X_l, Y_l)
@ -439,17 +452,19 @@ class MissingValuesClassifierChain:
X_p = X[~np.isnan(Y[:, p])] X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p] Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0 if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0) clf = DummyClassifier(strategy="constant", constant=0)
self.classifiers.append(clf.fit([X[0]], [0])) self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent') clf = DummyClassifier(strategy="most_frequent")
self.classifiers.append(clf.fit(X_p, Y_p)) self.classifiers.append(clf.fit(X_p, Y_p))
else: else:
clf = copy.deepcopy(self.base_clf) clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p)) self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p] newcol = Y[:, p]
pred = clf.predict(X) pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions newcol[np.isnan(newcol)] = pred[
np.isnan(newcol)
] # fill in missing values with clf predictions
X = np.column_stack((X, newcol)) X = np.column_stack((X, newcol))
def predict(self, X): def predict(self, X):
@ -541,13 +556,10 @@ class RelativeReasoning:
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties # We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if ( if (
countwin >= self.min_count and countwin >= self.min_count
countwin > countloose and and countwin > countloose
( and (countloose <= self.max_count or self.max_count < 0)
countloose <= self.max_count or and countboth == 0
self.max_count < 0
) and
countboth == 0
): ):
self.winmap[i].append(j) self.winmap[i].append(j)
@ -557,13 +569,13 @@ class RelativeReasoning:
# Loop through all instances # Loop through all instances
for inst_idx, inst in enumerate(X): for inst_idx, inst in enumerate(X):
# Loop through all "triggered" features # Loop through all "triggered" features
for i, t in enumerate(inst[self.start_index: self.end_index + 1]): for i, t in enumerate(inst[self.start_index : self.end_index + 1]):
# Set label # Set label
res[inst_idx][i] = t res[inst_idx][i] = t
# If we predict a 1, check if the rule gets dominated by another # If we predict a 1, check if the rule gets dominated by another
if t: if t:
# Second loop to check other triggered rules # Second loop to check other triggered rules
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]): for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]):
if i != i2: if i != i2:
# Check if rule idx is in "dominated by" list # Check if rule idx is in "dominated by" list
if i2 in self.winmap.get(i, []): if i2 in self.winmap.get(i, []):
@ -579,7 +591,6 @@ class RelativeReasoning:
class ApplicabilityDomainPCA(PCA): class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5): def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours) super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler() self.scaler = StandardScaler()
@ -587,7 +598,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None self.min_vals = None
self.max_vals = None self.max_vals = None
def build(self, train_dataset: 'Dataset'): def build(self, train_dataset: "Dataset"):
# transform # transform
X_scaled = self.scaler.fit_transform(train_dataset.X()) X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca # fit pca
@ -601,7 +612,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled) instances_pca = self.transform(instances_scaled)
return instances_pca return instances_pca
def is_applicable(self, classify_instances: 'Dataset'): def is_applicable(self, classify_instances: "Dataset"):
instances_pca = self.__transform(classify_instances.X()) instances_pca = self.__transform(classify_instances.X())
is_applicable = [] is_applicable = []
@ -632,6 +643,7 @@ def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx""" """Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway from epdb.models import Pathway
from epdb.logic import SPathway from epdb.logic import SPathway
graph = nx.DiGraph() graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
@ -645,7 +657,9 @@ def graph_from_pathway(data):
def get_sources_targets(): def get_sources_targets():
if isinstance(data, Pathway): if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()] return [n.node for n in edge.start_nodes.constrained_target.all()], [
n.node for n in edge.end_nodes.constrained_target.all()
]
elif isinstance(data, SPathway): elif isinstance(data, SPathway):
return edge.educts, edge.products return edge.educts, edge.products
else: else:
@ -662,7 +676,7 @@ def graph_from_pathway(data):
def get_probability(): def get_probability():
try: try:
if isinstance(data, Pathway): if isinstance(data, Pathway):
return edge.kv.get('probability') return edge.kv.get("probability")
elif isinstance(data, SPathway): elif isinstance(data, SPathway):
return edge.probability return edge.probability
else: else:
@ -680,17 +694,29 @@ def graph_from_pathway(data):
for source in sources: for source in sources:
source_smiles, source_depth = get_smiles_depth(source) source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph: if source_smiles not in graph:
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles, graph.add_node(
root=source_smiles in root_smiles) source_smiles,
depth=source_depth,
smiles=source_smiles,
root=source_smiles in root_smiles,
)
else: else:
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"]) graph.nodes[source_smiles]["depth"] = min(
source_depth, graph.nodes[source_smiles]["depth"]
)
for target in targets: for target in targets:
target_smiles, target_depth = get_smiles_depth(target) target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2: if target_smiles not in graph and target_smiles not in co2:
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles, graph.add_node(
root=target_smiles in root_smiles) target_smiles,
depth=target_depth,
smiles=target_smiles,
root=target_smiles in root_smiles,
)
elif target_smiles not in co2: elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"]) graph.nodes[target_smiles]["depth"] = min(
target_depth, graph.nodes[target_smiles]["depth"]
)
if target_smiles not in co2 and target_smiles != source_smiles: if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability) graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph return graph
@ -710,7 +736,9 @@ def set_pathway_eval_weight(pathway):
node_eval_weights = {} node_eval_weights = {}
for node in pathway.nodes: for node in pathway.nodes:
# Scale score according to depth level # Scale score according to depth level
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0 node_eval_weights[node] = (
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
)
return node_eval_weights return node_eval_weights
@ -731,8 +759,11 @@ def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
shortest_path_list.append(shortest_path_nodes) shortest_path_list.append(shortest_path_nodes)
if shortest_path_list: if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len) shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if num_ints = sum(
shortest_path_node in intermediates) 1
for shortest_path_node in shortest_path_nodes
if shortest_path_node in intermediates
)
pred_pathway.nodes[node]["depth"] -= num_ints pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway return pred_pathway
@ -879,6 +910,11 @@ def pathway_edit_eval(data_pathway, pred_pathway):
data_pathway = initialise_pathway(data_pathway) data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway) pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0]) roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(data_pathway, pred_pathway, return nx.graph_edit_distance(
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost, data_pathway,
node_ins_cost=node_ins_del_cost, roots=roots) pred_pathway,
node_subst_cost=node_subst_cost,
node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost,
roots=roots,
)

View File

@ -23,11 +23,11 @@ def install_wheel(wheel_path):
def extract_package_name_from_wheel(wheel_filename): def extract_package_name_from_wheel(wheel_filename):
# Example: my_plugin-0.1.0-py3-none-any.whl -> my_plugin # Example: my_plugin-0.1.0-py3-none-any.whl -> my_plugin
return wheel_filename.split('-')[0] return wheel_filename.split("-")[0]
def ensure_plugins_installed(): def ensure_plugins_installed():
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, '*.whl')) wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, "*.whl"))
for wheel_path in wheel_files: for wheel_path in wheel_files:
wheel_filename = os.path.basename(wheel_path) wheel_filename = os.path.basename(wheel_path)
@ -45,7 +45,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
plugins = {} plugins = {}
for entry_point in importlib.metadata.entry_points(group='enviPy_plugins'): for entry_point in importlib.metadata.entry_points(group="enviPy_plugins"):
try: try:
plugin_class = entry_point.load() plugin_class = entry_point.load()
if _cls: if _cls: