[Chore] Linted Files (#150)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#150
This commit is contained in:
2025-10-09 07:25:13 +13:00
parent 22f0bbe10b
commit afeb56622c
50 changed files with 5616 additions and 4408 deletions

View File

@ -2,4 +2,4 @@
# Django starts so that shared_task will use this app.
from .celery import app as celery_app
__all__ = ('celery_app',)
__all__ = ("celery_app",)

View File

@ -4,8 +4,6 @@ from ninja import NinjaAPI
api = NinjaAPI()
from ninja import NinjaAPI
api_v1 = NinjaAPI(title="API V1 Docs", urls_namespace="api-v1")
api_legacy = NinjaAPI(title="Legacy API Docs", urls_namespace="api-legacy")

View File

@ -11,6 +11,6 @@ import os
from django.core.asgi import get_asgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings")
application = get_asgi_application()

View File

@ -4,15 +4,15 @@ from celery import Celery
from celery.signals import setup_logging
# Set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings")
app = Celery('envipath')
app = Celery("envipath")
# Using a string here means the worker doesn't have to serialize
# the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix.
app.config_from_object('django.conf:settings', namespace='CELERY')
app.config_from_object("django.conf:settings", namespace="CELERY")
@setup_logging.connect

View File

@ -14,6 +14,7 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.conf import settings as s
from django.contrib import admin
from django.urls import include, path
@ -21,7 +22,6 @@ from django.urls import include, path
from .api import api_v1, api_legacy
urlpatterns = [
path("", include("epdb.urls")),
path("", include("migration.urls")),
path("admin/", admin.site.urls),

View File

@ -11,6 +11,6 @@ import os
from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'envipath.settings')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "envipath.settings")
application = get_wsgi_application()

View File

@ -18,7 +18,7 @@ from .models import (
Scenario,
Setting,
ExternalDatabase,
ExternalIdentifier
ExternalIdentifier,
)
@ -39,7 +39,7 @@ class GroupPackagePermissionAdmin(admin.ModelAdmin):
class EPAdmin(admin.ModelAdmin):
search_fields = ['name', 'description']
search_fields = ["name", "description"]
class PackageAdmin(EPAdmin):

View File

@ -21,7 +21,7 @@ class BearerTokenAuth(HttpBearer):
def _anonymous_or_real(request):
if request.user.is_authenticated and not request.user.is_anonymous:
return request.user
return get_user_model().objects.get(username='anonymous')
return get_user_model().objects.get(username="anonymous")
router = Router(auth=BearerTokenAuth())
@ -85,7 +85,9 @@ def get_package(request, package_uuid):
try:
return PackageManager.get_package_by_id(request.auth, package_id=package_uuid)
except ValueError:
return 403, {'message': f'Getting Package with id {package_uuid} failed due to insufficient rights!'}
return 403, {
"message": f"Getting Package with id {package_uuid} failed due to insufficient rights!"
}
@router.get("/compound", response={200: List[CompoundSchema], 403: Error})
@ -97,7 +99,9 @@ def get_compounds(request):
return qs
@router.get("/package/{uuid:package_uuid}/compound", response={200: List[CompoundSchema], 403: Error})
@router.get(
"/package/{uuid:package_uuid}/compound", response={200: List[CompoundSchema], 403: Error}
)
@paginate
def get_package_compounds(request, package_uuid):
try:
@ -105,4 +109,5 @@ def get_package_compounds(request, package_uuid):
return Compound.objects.filter(package=p)
except ValueError:
return 403, {
'message': f'Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!'}
"message": f"Getting Compounds for Package with id {package_uuid} failed due to insufficient rights!"
}

View File

@ -2,8 +2,8 @@ from django.apps import AppConfig
class EPDBConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'epdb'
default_auto_field = "django.db.models.BigAutoField"
name = "epdb"
def ready(self):
import epdb.signals # noqa: F401

View File

@ -1,5 +0,0 @@
from django import forms
class EmailLoginForm(forms.Form):
email = forms.EmailField()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -5,32 +5,49 @@ from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.logic import UserManager, GroupManager, PackageManager, SettingManager
from epdb.models import UserSettingPermission, MLRelativeReasoning, EnviFormer, Permission, User, ExternalDatabase
from epdb.models import (
UserSettingPermission,
MLRelativeReasoning,
EnviFormer,
Permission,
User,
ExternalDatabase,
)
class Command(BaseCommand):
def create_users(self):
# Anonymous User
if not User.objects.filter(email='anon@envipath.com').exists():
anon = UserManager.create_user("anonymous", "anon@envipath.com", "SuperSafe",
is_active=True, add_to_group=False, set_setting=False)
if not User.objects.filter(email="anon@envipath.com").exists():
anon = UserManager.create_user(
"anonymous",
"anon@envipath.com",
"SuperSafe",
is_active=True,
add_to_group=False,
set_setting=False,
)
else:
anon = User.objects.get(email='anon@envipath.com')
anon = User.objects.get(email="anon@envipath.com")
# Admin User
if not User.objects.filter(email='admin@envipath.com').exists():
admin = UserManager.create_user("admin", "admin@envipath.com", "SuperSafe",
is_active=True, add_to_group=False, set_setting=False)
if not User.objects.filter(email="admin@envipath.com").exists():
admin = UserManager.create_user(
"admin",
"admin@envipath.com",
"SuperSafe",
is_active=True,
add_to_group=False,
set_setting=False,
)
admin.is_staff = True
admin.is_superuser = True
admin.save()
else:
admin = User.objects.get(email='admin@envipath.com')
admin = User.objects.get(email="admin@envipath.com")
# System Group
g = GroupManager.create_group(admin, 'enviPath Users', 'All enviPath Users')
g = GroupManager.create_group(admin, "enviPath Users", "All enviPath Users")
g.public = True
g.save()
@ -43,14 +60,20 @@ class Command(BaseCommand):
admin.default_group = g
admin.save()
if not User.objects.filter(email='user0@envipath.com').exists():
user0 = UserManager.create_user("user0", "user0@envipath.com", "SuperSafe",
is_active=True, add_to_group=False, set_setting=False)
if not User.objects.filter(email="user0@envipath.com").exists():
user0 = UserManager.create_user(
"user0",
"user0@envipath.com",
"SuperSafe",
is_active=True,
add_to_group=False,
set_setting=False,
)
user0.is_staff = True
user0.is_superuser = True
user0.save()
else:
user0 = User.objects.get(email='user0@envipath.com')
user0 = User.objects.get(email="user0@envipath.com")
g.user_member.add(user0)
g.save()
@ -61,18 +84,20 @@ class Command(BaseCommand):
return anon, admin, g, user0
def import_package(self, data, owner):
return PackageManager.import_legacy_package(data, owner, keep_ids=True, add_import_timestamp=False, trust_reviewed=True)
return PackageManager.import_legacy_package(
data, owner, keep_ids=True, add_import_timestamp=False, trust_reviewed=True
)
def create_default_setting(self, owner, packages):
s = SettingManager.create_setting(
owner,
name='Global Default Setting',
description='Global Default Setting containing BBD Rules and Max 30 Nodes and Max Depth of 8',
name="Global Default Setting",
description="Global Default Setting containing BBD Rules and Max 30 Nodes and Max Depth of 8",
max_nodes=30,
max_depth=5,
rule_packages=packages,
model=None,
model_threshold=None
model_threshold=None,
)
return s
@ -84,54 +109,51 @@ class Command(BaseCommand):
"""
databases = [
{
'name': 'PubChem Compound',
'full_name': 'PubChem Compound Database',
'description': 'Chemical database of small organic molecules',
'base_url': 'https://pubchem.ncbi.nlm.nih.gov',
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}'
"name": "PubChem Compound",
"full_name": "PubChem Compound Database",
"description": "Chemical database of small organic molecules",
"base_url": "https://pubchem.ncbi.nlm.nih.gov",
"url_pattern": "https://pubchem.ncbi.nlm.nih.gov/compound/{id}",
},
{
'name': 'PubChem Substance',
'full_name': 'PubChem Substance Database',
'description': 'Database of chemical substances',
'base_url': 'https://pubchem.ncbi.nlm.nih.gov',
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/substance/{id}'
"name": "PubChem Substance",
"full_name": "PubChem Substance Database",
"description": "Database of chemical substances",
"base_url": "https://pubchem.ncbi.nlm.nih.gov",
"url_pattern": "https://pubchem.ncbi.nlm.nih.gov/substance/{id}",
},
{
'name': 'ChEBI',
'full_name': 'Chemical Entities of Biological Interest',
'description': 'Dictionary of molecular entities',
'base_url': 'https://www.ebi.ac.uk/chebi',
'url_pattern': 'https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{id}'
"name": "ChEBI",
"full_name": "Chemical Entities of Biological Interest",
"description": "Dictionary of molecular entities",
"base_url": "https://www.ebi.ac.uk/chebi",
"url_pattern": "https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{id}",
},
{
'name': 'RHEA',
'full_name': 'RHEA Reaction Database',
'description': 'Comprehensive resource of biochemical reactions',
'base_url': 'https://www.rhea-db.org',
'url_pattern': 'https://www.rhea-db.org/rhea/{id}'
"name": "RHEA",
"full_name": "RHEA Reaction Database",
"description": "Comprehensive resource of biochemical reactions",
"base_url": "https://www.rhea-db.org",
"url_pattern": "https://www.rhea-db.org/rhea/{id}",
},
{
'name': 'KEGG Reaction',
'full_name': 'KEGG Reaction Database',
'description': 'Database of biochemical reactions',
'base_url': 'https://www.genome.jp',
'url_pattern': 'https://www.genome.jp/entry/{id}'
"name": "KEGG Reaction",
"full_name": "KEGG Reaction Database",
"description": "Database of biochemical reactions",
"base_url": "https://www.genome.jp",
"url_pattern": "https://www.genome.jp/entry/{id}",
},
{
'name': 'UniProt',
'full_name': 'MetaCyc Metabolic Pathway Database',
'description': 'UniProt is a freely accessible database of protein sequence and functional information',
'base_url': 'https://www.uniprot.org',
'url_pattern': 'https://www.uniprot.org/uniprotkb?query="{id}"'
}
"name": "UniProt",
"full_name": "MetaCyc Metabolic Pathway Database",
"description": "UniProt is a freely accessible database of protein sequence and functional information",
"base_url": "https://www.uniprot.org",
"url_pattern": 'https://www.uniprot.org/uniprotkb?query="{id}"',
},
]
for db_info in databases:
ExternalDatabase.objects.get_or_create(
name=db_info['name'],
defaults=db_info
)
ExternalDatabase.objects.get_or_create(name=db_info["name"], defaults=db_info)
@transaction.atomic
def handle(self, *args, **options):
@ -142,20 +164,24 @@ class Command(BaseCommand):
# Import Packages
packages = [
'EAWAG-BBD.json',
'EAWAG-SOIL.json',
'EAWAG-SLUDGE.json',
'EAWAG-SEDIMENT.json',
"EAWAG-BBD.json",
"EAWAG-SOIL.json",
"EAWAG-SLUDGE.json",
"EAWAG-SEDIMENT.json",
]
mapping = {}
for p in packages:
print(f"Importing {p}...")
package_data = json.loads(open(s.BASE_DIR / 'fixtures' / 'packages' / '2025-07-18' / p, encoding='utf-8').read())
package_data = json.loads(
open(
s.BASE_DIR / "fixtures" / "packages" / "2025-07-18" / p, encoding="utf-8"
).read()
)
imported_package = self.import_package(package_data, admin)
mapping[p.replace('.json', '')] = imported_package
mapping[p.replace(".json", "")] = imported_package
setting = self.create_default_setting(admin, [mapping['EAWAG-BBD']])
setting = self.create_default_setting(admin, [mapping["EAWAG-BBD"]])
setting.public = True
setting.save()
setting.make_global_default()
@ -171,26 +197,28 @@ class Command(BaseCommand):
usp.save()
# Create Model Package
pack = PackageManager.create_package(admin, "Public Prediction Models",
"Package to make Prediction Models publicly available")
pack = PackageManager.create_package(
admin,
"Public Prediction Models",
"Package to make Prediction Models publicly available",
)
pack.reviewed = True
pack.save()
# Create RR
ml_model = MLRelativeReasoning.create(
package=pack,
rule_packages=[mapping['EAWAG-BBD']],
data_packages=[mapping['EAWAG-BBD']],
rule_packages=[mapping["EAWAG-BBD"]],
data_packages=[mapping["EAWAG-BBD"]],
eval_packages=[],
threshold=0.5,
name='ECC - BBD - T0.5',
description='ML Relative Reasoning',
name="ECC - BBD - T0.5",
description="ML Relative Reasoning",
)
ml_model.build_dataset()
ml_model.build_model()
# ml_model.evaluate_model()
# If available, create EnviFormerModel
if s.ENVIFORMER_PRESENT:
enviFormer_model = EnviFormer.create(pack, 'EnviFormer - T0.5', 'EnviFormer Model with Threshold 0.5', 0.5)
EnviFormer.create(pack, "EnviFormer - T0.5", "EnviFormer Model with Threshold 0.5", 0.5)

View File

@ -12,11 +12,28 @@ class Command(BaseCommand):
the below command would be used:
`python manage.py create_ml_models enviformer mlrr -d bbd soil -e sludge
"""
def add_arguments(self, parser):
parser.add_argument("model_names", nargs="+", type=str, help="The names of models to train. Options are: enviformer, mlrr")
parser.add_argument("-d", "--data-packages", nargs="+", type=str, help="Packages for training")
parser.add_argument("-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[])
parser.add_argument("-r", "--rule-packages", nargs="*", type=str, help="Rule Packages mandatory for MLRR", default=[])
parser.add_argument(
"model_names",
nargs="+",
type=str,
help="The names of models to train. Options are: enviformer, mlrr",
)
parser.add_argument(
"-d", "--data-packages", nargs="+", type=str, help="Packages for training"
)
parser.add_argument(
"-e", "--eval-packages", nargs="*", type=str, help="Packages for evaluation", default=[]
)
parser.add_argument(
"-r",
"--rule-packages",
nargs="*",
type=str,
help="Rule Packages mandatory for MLRR",
default=[],
)
@transaction.atomic
def handle(self, *args, **options):
@ -28,7 +45,9 @@ class Command(BaseCommand):
sludge = Package.objects.filter(name="EAWAG-SLUDGE")[0]
sediment = Package.objects.filter(name="EAWAG-SEDIMENT")[0]
except IndexError:
raise IndexError("Can't find correct packages. They should be created with the bootstrap command")
raise IndexError(
"Can't find correct packages. They should be created with the bootstrap command"
)
def decode_packages(package_list):
"""Decode package strings into their respective packages"""
@ -52,15 +71,27 @@ class Command(BaseCommand):
data_packages = decode_packages(options["data_packages"])
eval_packages = decode_packages(options["eval_packages"])
rule_packages = decode_packages(options["rule_packages"])
for model_name in options['model_names']:
for model_name in options["model_names"]:
model_name = model_name.lower()
if model_name == "enviformer" and s.ENVIFORMER_PRESENT:
model = EnviFormer.create(pack, data_packages=data_packages, eval_packages=eval_packages, threshold=0.5,
name="EnviFormer - T0.5", description="EnviFormer transformer")
model = EnviFormer.create(
pack,
data_packages=data_packages,
eval_packages=eval_packages,
threshold=0.5,
name="EnviFormer - T0.5",
description="EnviFormer transformer",
)
elif model_name == "mlrr":
model = MLRelativeReasoning.create(package=pack, rule_packages=rule_packages,
data_packages=data_packages, eval_packages=eval_packages, threshold=0.5,
name='ECC - BBD - T0.5', description='ML Relative Reasoning')
model = MLRelativeReasoning.create(
package=pack,
rule_packages=rule_packages,
data_packages=data_packages,
eval_packages=eval_packages,
threshold=0.5,
name="ECC - BBD - T0.5",
description="ML Relative Reasoning",
)
else:
raise ValueError(f"Cannot create model of type {model_name}, unknown model type")
# Build the dataset for the model, train it, evaluate it and save it

View File

@ -1,57 +1,58 @@
from csv import DictReader
from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.models import *
from epdb.models import Compound, CompoundStructure, Reaction, ExternalDatabase, ExternalIdentifier
class Command(BaseCommand):
STR_TO_MODEL = {
'Compound': Compound,
'CompoundStructure': CompoundStructure,
'Reaction': Reaction,
"Compound": Compound,
"CompoundStructure": CompoundStructure,
"Reaction": Reaction,
}
STR_TO_DATABASE = {
'ChEBI': ExternalDatabase.objects.get(name='ChEBI'),
'RHEA': ExternalDatabase.objects.get(name='RHEA'),
'KEGG Reaction': ExternalDatabase.objects.get(name='KEGG Reaction'),
'PubChem Compound': ExternalDatabase.objects.get(name='PubChem Compound'),
'PubChem Substance': ExternalDatabase.objects.get(name='PubChem Substance'),
"ChEBI": ExternalDatabase.objects.get(name="ChEBI"),
"RHEA": ExternalDatabase.objects.get(name="RHEA"),
"KEGG Reaction": ExternalDatabase.objects.get(name="KEGG Reaction"),
"PubChem Compound": ExternalDatabase.objects.get(name="PubChem Compound"),
"PubChem Substance": ExternalDatabase.objects.get(name="PubChem Substance"),
}
def add_arguments(self, parser):
parser.add_argument(
'--data',
"--data",
type=str,
help='Path of the ID Mapping file.',
help="Path of the ID Mapping file.",
required=True,
)
parser.add_argument(
'--replace-host',
"--replace-host",
type=str,
help='Replace https://envipath.org/ with this host, e.g. http://localhost:8000/',
help="Replace https://envipath.org/ with this host, e.g. http://localhost:8000/",
)
@transaction.atomic
def handle(self, *args, **options):
with open(options['data']) as fh:
with open(options["data"]) as fh:
reader = DictReader(fh)
for row in reader:
clz = self.STR_TO_MODEL[row['model']]
clz = self.STR_TO_MODEL[row["model"]]
url = row['url']
if options['replace_host']:
url = url.replace('https://envipath.org/', options['replace_host'])
url = row["url"]
if options["replace_host"]:
url = url.replace("https://envipath.org/", options["replace_host"])
instance = clz.objects.get(url=url)
db = self.STR_TO_DATABASE[row['identifier_type']]
db = self.STR_TO_DATABASE[row["identifier_type"]]
ExternalIdentifier.objects.create(
content_object=instance,
database=db,
identifier_value=row['identifier_value'],
url=db.url_pattern.format(id=row['identifier_value']),
is_primary=False
identifier_value=row["identifier_value"],
url=db.url_pattern.format(id=row["identifier_value"]),
is_primary=False,
)

View File

@ -1,27 +1,29 @@
import json
from django.core.management.base import BaseCommand
from django.db import transaction
from epdb.logic import PackageManager
from epdb.models import *
from epdb.models import User
class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
'--data',
"--data",
type=str,
help='Path of the Package to import.',
help="Path of the Package to import.",
required=True,
)
parser.add_argument(
'--owner',
"--owner",
type=str,
help='Username of the desired Owner.',
help="Username of the desired Owner.",
required=True,
)
@transaction.atomic
def handle(self, *args, **options):
owner = User.objects.get(username=options['owner'])
package_data = json.load(open(options['data']))
owner = User.objects.get(username=options["owner"])
package_data = json.load(open(options["data"]))
PackageManager.import_legacy_package(package_data, owner)

View File

@ -6,46 +6,45 @@ from django.db.models.functions import Replace
class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
'--old',
"--old",
type=str,
help='Old Host, most likely https://envipath.org/',
help="Old Host, most likely https://envipath.org/",
required=True,
)
parser.add_argument(
'--new',
"--new",
type=str,
help='New Host, most likely http://localhost:8000/',
help="New Host, most likely http://localhost:8000/",
required=True,
)
def handle(self, *args, **options):
MODELS = [
'User',
'Group',
'Package',
'Compound',
'CompoundStructure',
'Pathway',
'Edge',
'Node',
'Reaction',
'SimpleAmbitRule',
'SimpleRDKitRule',
'ParallelRule',
'SequentialRule',
'Scenario',
'Setting',
'MLRelativeReasoning',
'RuleBasedRelativeReasoning',
'EnviFormer',
'ApplicabilityDomain',
"User",
"Group",
"Package",
"Compound",
"CompoundStructure",
"Pathway",
"Edge",
"Node",
"Reaction",
"SimpleAmbitRule",
"SimpleRDKitRule",
"ParallelRule",
"SequentialRule",
"Scenario",
"Setting",
"MLRelativeReasoning",
"RuleBasedRelativeReasoning",
"EnviFormer",
"ApplicabilityDomain",
]
for model in MODELS:
obj_cls = apps.get_model("epdb", model)
print(f"Localizing urls for {model}")
obj_cls.objects.update(
url=Replace(F('url'), Value(options['old']), Value(options['new']))
url=Replace(F("url"), Value(options["old"]), Value(options["new"]))
)

View File

@ -3,22 +3,25 @@ from django.shortcuts import redirect
from django.urls import reverse
from urllib.parse import quote
class LoginRequiredMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.exempt_urls = [
reverse('login'),
reverse('logout'),
reverse('admin:login'),
reverse('admin:index'),
] + getattr(settings, 'LOGIN_EXEMPT_URLS', [])
reverse("login"),
reverse("logout"),
reverse("admin:login"),
reverse("admin:index"),
] + getattr(settings, "LOGIN_EXEMPT_URLS", [])
def __call__(self, request):
if not request.user.is_authenticated:
path = request.path_info
if not any(path.startswith(url) for url in self.exempt_urls):
if request.method == 'GET':
if request.get_full_path() and request.get_full_path() != '/':
return redirect(f"{settings.LOGIN_URL}?next={quote(request.get_full_path())}")
if request.method == "GET":
if request.get_full_path() and request.get_full_path() != "/":
return redirect(
f"{settings.LOGIN_URL}?next={quote(request.get_full_path())}"
)
return redirect(settings.LOGIN_URL)
return self.get_response(request)

File diff suppressed because it is too large Load Diff

View File

@ -2,55 +2,57 @@ import logging
from typing import Optional
from celery import shared_task
from epdb.models import Pathway, Node, Edge, EPModel, Setting
from epdb.models import Pathway, Node, EPModel, Setting
from epdb.logic import SPathway
logger = logging.getLogger(__name__)
@shared_task(queue='background')
@shared_task(queue="background")
def mul(a, b):
return a * b
@shared_task(queue='predict')
@shared_task(queue="predict")
def predict_simple(model_pk: int, smiles: str):
mod = EPModel.objects.get(id=model_pk)
res = mod.predict(smiles)
return res
@shared_task(queue='background')
@shared_task(queue="background")
def send_registration_mail(user_pk: int):
pass
@shared_task(queue='model')
@shared_task(queue="model")
def build_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.build_dataset()
mod.build_model()
@shared_task(queue='model')
@shared_task(queue="model")
def evaluate_model(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.evaluate_model()
@shared_task(queue='model')
@shared_task(queue="model")
def retrain(model_pk: int):
mod = EPModel.objects.get(id=model_pk)
mod.retrain()
@shared_task(queue='predict')
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
@shared_task(queue="predict")
def predict(
pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None
) -> Pathway:
pw = Pathway.objects.get(id=pw_pk)
setting = Setting.objects.get(id=pred_setting_pk)
pw.kv.update(**{'status': 'running'})
pw.kv.update(**{"status": "running"})
pw.save()
try:
@ -74,12 +76,10 @@ def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_
else:
raise ValueError("Neither limit nor node_pk given!")
except Exception as e:
pw.kv.update({'status': 'failed'})
pw.kv.update({"status": "failed"})
pw.save()
raise e
pw.kv.update(**{'status': 'completed'})
pw.kv.update(**{"status": "completed"})
pw.save()

View File

@ -2,6 +2,7 @@ from django import template
register = template.Library()
@register.filter
def classname(obj):
return obj.__class__.__name__

View File

@ -1,3 +0,0 @@
from django.test import TestCase
# Create your tests here.

View File

@ -3,97 +3,177 @@ from django.contrib.auth import views as auth_views
from . import views as v
UUID = '[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
UUID = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
urlpatterns = [
# Home
re_path(r'^$', v.index, name='index'),
re_path(r"^$", v.index, name="index"),
# Login
re_path(r'^login', v.login, name='login'),
re_path(r'^logout', v.logout, name='logout'),
re_path(r'^register', v.register, name='register'),
# Built In views
path('password_reset/', auth_views.PasswordResetView.as_view(
template_name='static/password_reset_form.html'
), name='password_reset'),
path('password_reset/done/', auth_views.PasswordResetDoneView.as_view(
template_name='static/password_reset_done.html'
), name='password_reset_done'),
path('reset/<uidb64>/<token>/', auth_views.PasswordResetConfirmView.as_view(
template_name='static/password_reset_confirm.html'
), name='password_reset_confirm'),
path('reset/done/', auth_views.PasswordResetCompleteView.as_view(
template_name='static/password_reset_complete.html'
), name='password_reset_complete'),
re_path(r"^login", v.login, name="login"),
re_path(r"^logout", v.logout, name="logout"),
re_path(r"^register", v.register, name="register"),
# Built-In views
path(
"password_reset/",
auth_views.PasswordResetView.as_view(template_name="static/password_reset_form.html"),
name="password_reset",
),
path(
"password_reset/done/",
auth_views.PasswordResetDoneView.as_view(template_name="static/password_reset_done.html"),
name="password_reset_done",
),
path(
"reset/<uidb64>/<token>/",
auth_views.PasswordResetConfirmView.as_view(
template_name="static/password_reset_confirm.html"
),
name="password_reset_confirm",
),
path(
"reset/done/",
auth_views.PasswordResetCompleteView.as_view(
template_name="static/password_reset_complete.html"
),
name="password_reset_complete",
),
# Top level urls
re_path(r'^package$', v.packages, name='packages'),
re_path(r'^compound$', v.compounds, name='compounds'),
re_path(r'^rule$', v.rules, name='rules'),
re_path(r'^reaction$', v.reactions, name='reactions'),
re_path(r'^pathway$', v.pathways, name='pathways'),
re_path(r'^scenario$', v.scenarios, name='scenarios'),
re_path(r'^model$', v.models, name='model'),
re_path(r'^user$', v.users, name='users'),
re_path(r'^group$', v.groups, name='groups'),
re_path(r'^search$', v.search, name='search'),
re_path(r"^package$", v.packages, name="packages"),
re_path(r"^compound$", v.compounds, name="compounds"),
re_path(r"^rule$", v.rules, name="rules"),
re_path(r"^reaction$", v.reactions, name="reactions"),
re_path(r"^pathway$", v.pathways, name="pathways"),
re_path(r"^scenario$", v.scenarios, name="scenarios"),
re_path(r"^model$", v.models, name="model"),
re_path(r"^user$", v.users, name="users"),
re_path(r"^group$", v.groups, name="groups"),
re_path(r"^search$", v.search, name="search"),
# User Detail
re_path(rf'^user/(?P<user_uuid>{UUID})', v.user, name='user'),
re_path(rf"^user/(?P<user_uuid>{UUID})", v.user, name="user"),
# Group Detail
re_path(rf'^group/(?P<group_uuid>{UUID})$', v.group, name='group detail'),
re_path(rf"^group/(?P<group_uuid>{UUID})$", v.group, name="group detail"),
# "in package" urls
re_path(rf'^package/(?P<package_uuid>{UUID})$', v.package, name='package detail'),
re_path(rf"^package/(?P<package_uuid>{UUID})$", v.package, name="package detail"),
# Compound
re_path(rf'^package/(?P<package_uuid>{UUID})/compound$', v.package_compounds, name='package compound list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})$', v.package_compound, name='package compound detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/compound$",
v.package_compounds,
name="package compound list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})$",
v.package_compound,
name="package compound detail",
),
# Compound Structure
re_path(rf'^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure$', v.package_compound_structures, name='package compound structure list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure/(?P<structure_uuid>{UUID})$', v.package_compound_structure, name='package compound structure detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure$",
v.package_compound_structures,
name="package compound structure list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/compound/(?P<compound_uuid>{UUID})/structure/(?P<structure_uuid>{UUID})$",
v.package_compound_structure,
name="package compound structure detail",
),
# Rule
re_path(rf'^package/(?P<package_uuid>{UUID})/rule$', v.package_rules, name='package rule list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'),
re_path(rf'^package/(?P<package_uuid>{UUID})/simple-ambit-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'),
re_path(rf'^package/(?P<package_uuid>{UUID})/simple-rdkit-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'),
re_path(rf'^package/(?P<package_uuid>{UUID})/parallel-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'),
re_path(rf'^package/(?P<package_uuid>{UUID})/sequential-rule/(?P<rule_uuid>{UUID})$', v.package_rule, name='package rule detail'),
re_path(rf"^package/(?P<package_uuid>{UUID})/rule$", v.package_rules, name="package rule list"),
re_path(
rf"^package/(?P<package_uuid>{UUID})/rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/simple-ambit-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/simple-rdkit-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/parallel-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/sequential-rule/(?P<rule_uuid>{UUID})$",
v.package_rule,
name="package rule detail",
),
# Reaction
re_path(rf'^package/(?P<package_uuid>{UUID})/reaction$', v.package_reactions, name='package reaction list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/reaction/(?P<reaction_uuid>{UUID})$', v.package_reaction, name='package reaction detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/reaction$",
v.package_reactions,
name="package reaction list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/reaction/(?P<reaction_uuid>{UUID})$",
v.package_reaction,
name="package reaction detail",
),
# # Pathway
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway$', v.package_pathways, name='package pathway list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})$', v.package_pathway, name='package pathway detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway$",
v.package_pathways,
name="package pathway list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})$",
v.package_pathway,
name="package pathway detail",
),
# Pathway Nodes
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node$', v.package_pathway_nodes, name='package pathway node list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node/(?P<node_uuid>{UUID})$', v.package_pathway_node, name='package pathway node detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node$",
v.package_pathway_nodes,
name="package pathway node list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/node/(?P<node_uuid>{UUID})$",
v.package_pathway_node,
name="package pathway node detail",
),
# Pathway Edges
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge$', v.package_pathway_edges, name='package pathway edge list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge/(?P<edge_uuid>{UUID})$', v.package_pathway_edge, name='package pathway edge detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge$",
v.package_pathway_edges,
name="package pathway edge list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/pathway/(?P<pathway_uuid>{UUID})/edge/(?P<edge_uuid>{UUID})$",
v.package_pathway_edge,
name="package pathway edge detail",
),
# Scenario
re_path(rf'^package/(?P<package_uuid>{UUID})/scenario$', v.package_scenarios, name='package scenario list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/scenario/(?P<scenario_uuid>{UUID})$', v.package_scenario, name='package scenario detail'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/scenario$",
v.package_scenarios,
name="package scenario list",
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/scenario/(?P<scenario_uuid>{UUID})$",
v.package_scenario,
name="package scenario detail",
),
# Model
re_path(rf'^package/(?P<package_uuid>{UUID})/model$', v.package_models, name='package model list'),
re_path(rf'^package/(?P<package_uuid>{UUID})/model/(?P<model_uuid>{UUID})$', v.package_model,name='package model detail'),
re_path(r'^setting$', v.settings, name='settings'),
re_path(rf'^setting/(?P<setting_uuid>{UUID})', v.setting, name='setting'),
re_path(r'^indigo/info$', v.indigo, name='indigo_info'),
re_path(r'^indigo/aromatize$', v.aromatize, name='indigo_aromatize'),
re_path(r'^indigo/dearomatize$', v.dearomatize, name='indigo_dearomatize'),
re_path(r'^indigo/layout$', v.layout, name='indigo_layout'),
re_path(r'^depict$', v.depict, name='depict'),
re_path(
rf"^package/(?P<package_uuid>{UUID})/model$", v.package_models, name="package model list"
),
re_path(
rf"^package/(?P<package_uuid>{UUID})/model/(?P<model_uuid>{UUID})$",
v.package_model,
name="package model detail",
),
re_path(r"^setting$", v.settings, name="settings"),
re_path(rf"^setting/(?P<setting_uuid>{UUID})", v.setting, name="setting"),
re_path(r"^indigo/info$", v.indigo, name="indigo_info"),
re_path(r"^indigo/aromatize$", v.aromatize, name="indigo_aromatize"),
re_path(r"^indigo/dearomatize$", v.dearomatize, name="indigo_dearomatize"),
re_path(r"^indigo/layout$", v.layout, name="indigo_layout"),
re_path(r"^depict$", v.depict, name="depict"),
# OAuth Stuff
path("o/userinfo/", v.userinfo, name="oauth_userinfo"),
]

File diff suppressed because it is too large Load Diff

View File

@ -13,91 +13,80 @@ class CompoundTest(TestCase):
@classmethod
def setUpClass(cls):
super(CompoundTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self):
c = Compound.create(
self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
name='Afoxolaner',
description='No Desc'
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name="Afoxolaner",
description="No Desc",
)
self.assertEqual(c.default_structure.smiles,
'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F')
self.assertEqual(c.name, 'Afoxolaner')
self.assertEqual(c.description, 'No Desc')
self.assertEqual(
c.default_structure.smiles,
"C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
)
self.assertEqual(c.name, "Afoxolaner")
self.assertEqual(c.description, "No Desc")
def test_missing_smiles(self):
with self.assertRaises(ValueError):
_ = Compound.create(
self.package,
smiles=None,
name='Afoxolaner',
description='No Desc'
)
_ = Compound.create(self.package, smiles=None, name="Afoxolaner", description="No Desc")
with self.assertRaises(ValueError):
_ = Compound.create(
self.package,
smiles='',
name='Afoxolaner',
description='No Desc'
)
_ = Compound.create(self.package, smiles="", name="Afoxolaner", description="No Desc")
with self.assertRaises(ValueError):
_ = Compound.create(
self.package,
smiles=' ',
name='Afoxolaner',
description='No Desc'
)
_ = Compound.create(self.package, smiles=" ", name="Afoxolaner", description="No Desc")
def test_smiles_are_trimmed(self):
c = Compound.create(
self.package,
smiles=' C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F ',
name='Afoxolaner',
description='No Desc'
smiles=" C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F ",
name="Afoxolaner",
description="No Desc",
)
self.assertEqual(c.default_structure.smiles,
'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F')
self.assertEqual(
c.default_structure.smiles,
"C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
)
def test_name_and_description_optional(self):
c = Compound.create(
self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
)
self.assertEqual(c.name, 'Compound 1')
self.assertEqual(c.description, 'no description')
self.assertEqual(c.name, "Compound 1")
self.assertEqual(c.description, "no description")
def test_empty_name_and_description_are_ignored(self):
c = Compound.create(
self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
name='',
description='',
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name="",
description="",
)
self.assertEqual(c.name, 'Compound 1')
self.assertEqual(c.description, 'no description')
self.assertEqual(c.name, "Compound 1")
self.assertEqual(c.description, "no description")
def test_deduplication(self):
c1 = Compound.create(
self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
name='Afoxolaner',
description='No Desc'
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name="Afoxolaner",
description="No Desc",
)
c2 = Compound.create(
self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
name='Afoxolaner',
description='No Desc'
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name="Afoxolaner",
description="No Desc",
)
# Check if create detects that this Compound already exist
@ -109,36 +98,36 @@ class CompoundTest(TestCase):
with self.assertRaises(ValueError):
_ = Compound.create(
self.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
name='Afoxolaner',
description='No Desc'
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name="Afoxolaner",
description="No Desc",
)
def test_create_with_standardized_smiles(self):
c = Compound.create(
self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Standardized SMILES',
description='No Desc'
smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name="Standardized SMILES",
description="No Desc",
)
self.assertEqual(len(c.structures.all()), 1)
cs = c.structures.all()[0]
self.assertEqual(cs.normalized_structure, True)
self.assertEqual(cs.smiles, 'O=C(O)C1=CC=C([N+](=O)[O-])C=C1')
self.assertEqual(cs.smiles, "O=C(O)C1=CC=C([N+](=O)[O-])C=C1")
def test_create_with_non_standardized_smiles(self):
c = Compound.create(
self.package,
smiles='[O-][N+](=O)c1ccc(C(=O)[O-])cc1',
name='Non Standardized SMILES',
description='No Desc'
smiles="[O-][N+](=O)c1ccc(C(=O)[O-])cc1",
name="Non Standardized SMILES",
description="No Desc",
)
self.assertEqual(len(c.structures.all()), 2)
for cs in c.structures.all():
if cs.normalized_structure:
self.assertEqual(cs.smiles, 'O=C(O)C1=CC=C([N+](=O)[O-])C=C1')
self.assertEqual(cs.smiles, "O=C(O)C1=CC=C([N+](=O)[O-])C=C1")
break
else:
# Loop finished without break, lets fail...
@ -147,51 +136,54 @@ class CompoundTest(TestCase):
def test_add_structure_smoke(self):
c = Compound.create(
self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Standardized SMILES',
description='No Desc'
smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name="Standardized SMILES",
description="No Desc",
)
c.add_structure('[O-][N+](=O)c1ccc(C(=O)[O-])cc1', 'Non Standardized SMILES')
c.add_structure("[O-][N+](=O)c1ccc(C(=O)[O-])cc1", "Non Standardized SMILES")
self.assertEqual(len(c.structures.all()), 2)
def test_add_structure_with_different_normalized_smiles(self):
c = Compound.create(
self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Standardized SMILES',
description='No Desc'
smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name="Standardized SMILES",
description="No Desc",
)
with self.assertRaises(ValueError):
c.add_structure(
'C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
'Different Standardized SMILES')
"C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
"Different Standardized SMILES",
)
def test_delete(self):
c = Compound.create(
self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Standardization Test',
description='No Desc'
smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name="Standardization Test",
description="No Desc",
)
c.delete()
self.assertEqual(Compound.objects.filter(package=self.package).count(), 0)
self.assertEqual(CompoundStructure.objects.filter(compound__package=self.package).count(), 0)
self.assertEqual(
CompoundStructure.objects.filter(compound__package=self.package).count(), 0
)
def test_set_as_default_structure(self):
c1 = Compound.create(
self.package,
smiles='O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Standardized SMILES',
description='No Desc'
smiles="O=C(O)C1=CC=C([N+](=O)[O-])C=C1",
name="Standardized SMILES",
description="No Desc",
)
default_structure = c1.default_structure
c2 = c1.add_structure('[O-][N+](=O)c1ccc(C(=O)[O-])cc1', 'Non Standardized SMILES')
c2 = c1.add_structure("[O-][N+](=O)c1ccc(C(=O)[O-])cc1", "Non Standardized SMILES")
c1.set_default_structure(c2)
self.assertNotEqual(default_structure, c2)

View File

@ -1,6 +1,5 @@
from django.test import TestCase
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Compound, User, Reaction
@ -12,50 +11,47 @@ class CopyTest(TestCase):
@classmethod
def setUpClass(cls):
super(CopyTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Source Package', 'No Desc')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Source Package", "No Desc")
cls.AFOXOLANER = Compound.create(
cls.package,
smiles='C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F',
name='Afoxolaner',
description='Test compound for copying'
smiles="C1C(=NOC1(C2=CC(=CC(=C2)Cl)C(F)(F)F)C(F)(F)F)C3=CC=C(C4=CC=CC=C43)C(=O)NCC(=O)NCC(F)(F)F",
name="Afoxolaner",
description="Test compound for copying",
)
cls.FOUR_NITROBENZOIC_ACID = Compound.create(
cls.package,
smiles='[O-][N+](=O)c1ccc(C(=O)[O-])cc1', # Normalized: O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name='Test Compound',
description='Compound with multiple structures'
smiles="[O-][N+](=O)c1ccc(C(=O)[O-])cc1", # Normalized: O=C(O)C1=CC=C([N+](=O)[O-])C=C1',
name="Test Compound",
description="Compound with multiple structures",
)
cls.ETHANOL = Compound.create(
cls.package,
smiles='CCO',
name='Ethanol',
description='Simple alcohol'
cls.package, smiles="CCO", name="Ethanol", description="Simple alcohol"
)
cls.target_package = PackageManager.create_package(cls.user, 'Target Package', 'No Desc')
cls.target_package = PackageManager.create_package(cls.user, "Target Package", "No Desc")
cls.reaction_educt = Compound.create(
cls.package,
smiles='C(CCl)Cl',
name='1,2-Dichloroethane',
description='Eawag BBD compound c0001'
smiles="C(CCl)Cl",
name="1,2-Dichloroethane",
description="Eawag BBD compound c0001",
).default_structure
cls.reaction_product = Compound.create(
cls.package,
smiles='C(CO)Cl',
name='2-Chloroethanol',
description='Eawag BBD compound c0005'
smiles="C(CO)Cl",
name="2-Chloroethanol",
description="Eawag BBD compound c0005",
).default_structure
cls.REACTION = Reaction.create(
package=cls.package,
name='Eawag BBD reaction r0001',
name="Eawag BBD reaction r0001",
educts=[cls.reaction_educt],
products=[cls.reaction_product],
multi_step=False
multi_step=False,
)
def test_compound_copy_basic(self):
@ -68,7 +64,9 @@ class CopyTest(TestCase):
self.assertEqual(self.AFOXOLANER.description, copied_compound.description)
self.assertEqual(copied_compound.package, self.target_package)
self.assertEqual(self.AFOXOLANER.package, self.package)
self.assertEqual(self.AFOXOLANER.default_structure.smiles, copied_compound.default_structure.smiles)
self.assertEqual(
self.AFOXOLANER.default_structure.smiles, copied_compound.default_structure.smiles
)
def test_compound_copy_with_multiple_structures(self):
"""Test copying a compound with multiple structures"""
@ -86,7 +84,7 @@ class CopyTest(TestCase):
self.assertIsNotNone(copied_compound.default_structure)
self.assertEqual(
copied_compound.default_structure.smiles,
self.FOUR_NITROBENZOIC_ACID.default_structure.smiles
self.FOUR_NITROBENZOIC_ACID.default_structure.smiles,
)
def test_compound_copy_preserves_aliases(self):
@ -95,15 +93,15 @@ class CopyTest(TestCase):
original_compound = self.ETHANOL
# Add aliases if the method exists
if hasattr(original_compound, 'add_alias'):
original_compound.add_alias('Ethyl alcohol')
original_compound.add_alias('Grain alcohol')
if hasattr(original_compound, "add_alias"):
original_compound.add_alias("Ethyl alcohol")
original_compound.add_alias("Grain alcohol")
mapping = dict()
copied_compound = original_compound.copy(self.target_package, mapping)
# Verify aliases were copied if they exist
if hasattr(original_compound, 'aliases') and hasattr(copied_compound, 'aliases'):
if hasattr(original_compound, "aliases") and hasattr(copied_compound, "aliases"):
original_aliases = original_compound.aliases
copied_aliases = copied_compound.aliases
self.assertEqual(original_aliases, copied_aliases)
@ -113,10 +111,10 @@ class CopyTest(TestCase):
original_compound = self.ETHANOL
# Add external identifiers if the methods exist
if hasattr(original_compound, 'add_cas_number'):
original_compound.add_cas_number('64-17-5')
if hasattr(original_compound, 'add_pubchem_compound_id'):
original_compound.add_pubchem_compound_id('702')
if hasattr(original_compound, "add_cas_number"):
original_compound.add_cas_number("64-17-5")
if hasattr(original_compound, "add_pubchem_compound_id"):
original_compound.add_pubchem_compound_id("702")
mapping = dict()
copied_compound = original_compound.copy(self.target_package, mapping)
@ -146,7 +144,9 @@ class CopyTest(TestCase):
self.assertEqual(original_structure.smiles, copied_structure.smiles)
self.assertEqual(original_structure.canonical_smiles, copied_structure.canonical_smiles)
self.assertEqual(original_structure.inchikey, copied_structure.inchikey)
self.assertEqual(original_structure.normalized_structure, copied_structure.normalized_structure)
self.assertEqual(
original_structure.normalized_structure, copied_structure.normalized_structure
)
# Verify they are different objects
self.assertNotEqual(original_structure.uuid, copied_structure.uuid)
@ -177,7 +177,9 @@ class CopyTest(TestCase):
self.assertEqual(orig_educt.compound.package, self.package)
self.assertEqual(orig_educt.smiles, copy_educt.smiles)
for orig_product, copy_product in zip(self.REACTION.products.all(), copied_reaction.products.all()):
for orig_product, copy_product in zip(
self.REACTION.products.all(), copied_reaction.products.all()
):
self.assertNotEqual(orig_product.uuid, copy_product.uuid)
self.assertEqual(orig_product.name, copy_product.name)
self.assertEqual(orig_product.description, copy_product.description)

View File

@ -11,21 +11,21 @@ class DatasetTest(TestCase):
def setUp(self):
self.cs1 = Compound.create(
self.package,
name='2,6-Dibromohydroquinone',
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b',
smiles='C1=C(C(=C(C=C1O)Br)O)Br',
name="2,6-Dibromohydroquinone",
description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/compound/d6435251-1a54-4327-b4b1-fd6e9a8f4dc9/structure/d8a0225c-dbb5-4e6c-a642-730081c09c5b",
smiles="C1=C(C(=C(C=C1O)Br)O)Br",
).default_structure
self.cs2 = Compound.create(
self.package,
smiles='O=C(O)CC(=O)/C=C(/Br)C(=O)O',
smiles="O=C(O)CC(=O)/C=C(/Br)C(=O)O",
).default_structure
self.rule1 = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
smirks='[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\[#6:3]=[#6:2](\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]',
description='http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6'
smirks="[#8:8]([H])-[c:4]1[c:3]([H])[c:2](-[#1,#17,#35:9])[c:1](-[#8:7]([H]))[c:6](-[#1,#17,#35])[c:5]([H])1>>[#8-]-[#6:6](=O)-[#6:5]-[#6:4](=[O:8])\\[#6:3]=[#6:2](\\[#1,#17,#35:9])-[#6:1](-[#8-])=[O:7]",
description="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1/simple-ambit-rule/f6a56c0f-a4a0-4ee3-b006-d765b4767cf6",
)
self.reaction1 = Reaction.create(
@ -33,14 +33,14 @@ class DatasetTest(TestCase):
educts=[self.cs1],
products=[self.cs2],
rules=[self.rule1],
multi_step=False
multi_step=False,
)
@classmethod
def setUpClass(cls):
super(DatasetTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self):
reactions = [r for r in Reaction.objects.filter(package=self.package)]

View File

@ -1,18 +1,19 @@
from tempfile import TemporaryDirectory
from django.test import TestCase
from django.test import TestCase, tag
from epdb.logic import PackageManager
from epdb.models import User, EnviFormer, Package
@tag("slow")
class EnviFormerTest(TestCase):
fixtures = ["test_fixtures.jsonl.gz"]
@classmethod
def setUpClass(cls):
super(EnviFormerTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_model_flow(self):
"""Test the full flow of EnviFormer, dataset build -> model finetune -> model evaluate -> model inference"""
@ -21,11 +22,14 @@ class EnviFormerTest(TestCase):
threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create(self.package, data_package_objs, eval_packages_objs, threshold=threshold)
mod = EnviFormer.create(
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset()
mod.build_model()
mod.multigen_eval = True
mod.save()
mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -4,8 +4,7 @@ from utilities.chem import FormatConverter
class FormatConverterTestCase(TestCase):
def test_standardization(self):
smiles = 'C[n+]1c([N-](C))cccc1'
smiles = "C[n+]1c([N-](C))cccc1"
standardized_smiles = FormatConverter.standardize(smiles)
self.assertEqual(standardized_smiles, 'CN=C1C=CC=CN1C')
self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")

View File

@ -4,7 +4,7 @@ import numpy as np
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import User, MLRelativeReasoning, RuleBasedRelativeReasoning, Package
from epdb.models import User, MLRelativeReasoning, Package
class ModelTest(TestCase):
@ -13,9 +13,9 @@ class ModelTest(TestCase):
@classmethod
def setUpClass(cls):
super(ModelTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self):
with TemporaryDirectory() as tmpdir:
@ -32,8 +32,8 @@ class ModelTest(TestCase):
data_package_objs,
eval_packages_objs,
threshold=threshold,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
name="ECC - BBD - 0.5",
description="Created MLRelativeReasoning in Testcase",
)
# mod = RuleBasedRelativeReasoning.create(
@ -54,7 +54,7 @@ class ModelTest(TestCase):
mod.save()
mod.evaluate_model()
results = mod.predict('CCN(CC)C(=O)C1=CC(=CC=C1)C')
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
products = dict()
for r in results:
@ -62,8 +62,11 @@ class ModelTest(TestCase):
products[tuple(sorted(ps.product_set))] = (r.rule.name, r.probability)
expected = {
('CC=O', 'CCNC(=O)C1=CC(C)=CC=C1'): ('bt0243-4301', np.float64(0.33333333333333337)),
('CC1=CC=CC(C(=O)O)=C1', 'CCNCC'): ('bt0430-4011', np.float64(0.25)),
("CC=O", "CCNC(=O)C1=CC(C)=CC=C1"): (
"bt0243-4301",
np.float64(0.33333333333333337),
),
("CC1=CC=CC(C(=O)O)=C1", "CCNCC"): ("bt0430-4011", np.float64(0.25)),
}
self.assertEqual(products, expected)

View File

@ -1,4 +1,3 @@
import json
from django.test import TestCase
from networkx.utils.misc import graphs_equal
from epdb.logic import PackageManager, SPathway
@ -12,9 +11,11 @@ class MultiGenTest(TestCase):
@classmethod
def setUpClass(cls):
super(MultiGenTest, cls).setUpClass()
cls.user: 'User' = User.objects.get(username='anonymous')
cls.package: 'Package' = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.BBD_SUBSET: 'Package' = Package.objects.get(name='Fixtures')
cls.user: "User" = User.objects.get(username="anonymous")
cls.package: "Package" = PackageManager.create_package(
cls.user, "Anon Test Package", "No Desc"
)
cls.BBD_SUBSET: "Package" = Package.objects.get(name="Fixtures")
def test_equal_pathways(self):
"""Test that two identical pathways return a precision and recall of 1.0"""
@ -23,14 +24,23 @@ class MultiGenTest(TestCase):
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue
score, precision, recall = multigen_eval(pathway, pathway)
self.assertEqual(precision, 1.0, f"Precision should be one for identical pathways. "
f"Failed on pathway: {pathway.name}")
self.assertEqual(recall, 1.0, f"Recall should be one for identical pathways. "
f"Failed on pathway: {pathway.name}")
self.assertEqual(
precision,
1.0,
f"Precision should be one for identical pathways. "
f"Failed on pathway: {pathway.name}",
)
self.assertEqual(
recall,
1.0,
f"Recall should be one for identical pathways. Failed on pathway: {pathway.name}",
)
def test_intermediates(self):
"""Test that an intermediate can be correctly identified and the metrics are correctly adjusted"""
score, precision, recall, intermediates = multigen_eval(*self.intermediate_case(), return_intermediates=True)
score, precision, recall, intermediates = multigen_eval(
*self.intermediate_case(), return_intermediates=True
)
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
self.assertEqual(precision, 1, "Precision should be 1")
self.assertEqual(recall, 1, "Recall should be 1")
@ -49,7 +59,9 @@ class MultiGenTest(TestCase):
def test_all(self):
"""Test an intermediate, false-positive and false-negative together"""
score, precision, recall, intermediates = multigen_eval(*self.all_case(), return_intermediates=True)
score, precision, recall, intermediates = multigen_eval(
*self.all_case(), return_intermediates=True
)
self.assertEqual(len(intermediates), 1, "There should be 1 found intermediate")
self.assertAlmostEqual(precision, 0.6, 3, "Precision should be 0.6")
self.assertAlmostEqual(recall, 0.75, 3, "Recall should be 0.75")
@ -57,19 +69,22 @@ class MultiGenTest(TestCase):
def test_shallow_pathway(self):
pathways = self.BBD_SUBSET.pathways.all()
for pathway in pathways:
pathway_name = pathway.name
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue
shallow_pathway = graph_from_pathway(SPathway.from_pathway(pathway))
pathway = graph_from_pathway(pathway)
if not graphs_equal(shallow_pathway, pathway):
print('\n\nS', shallow_pathway.adj)
print('\n\nPW', pathway.adj)
print("\n\nS", shallow_pathway.adj)
print("\n\nPW", pathway.adj)
# print(shallow_pathway.nodes, pathway.nodes)
# print(shallow_pathway.graph, pathway.graph)
self.assertTrue(graphs_equal(shallow_pathway, pathway), f"Networkx graph from shallow pathway not "
f"equal to pathway for pathway {pathway.name}")
self.assertTrue(
graphs_equal(shallow_pathway, pathway),
f"Networkx graph from shallow pathway not "
f"equal to pathway for pathway {pathway.name}",
)
def test_graph_edit_eval(self):
"""Performs all the previous tests but with graph_edit_eval
@ -79,10 +94,16 @@ class MultiGenTest(TestCase):
if len(pathway.edge_set.all()) == 0: # Do not test pathways with no edges
continue
score = pathway_edit_eval(pathway, pathway)
self.assertEqual(score, 0.0, "Pathway edit distance should be zero for identical pathways. "
f"Failed on pathway: {pathway.name}")
self.assertEqual(
score,
0.0,
"Pathway edit distance should be zero for identical pathways. "
f"Failed on pathway: {pathway.name}",
)
inter_score = pathway_edit_eval(*self.intermediate_case())
self.assertAlmostEqual(inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case")
self.assertAlmostEqual(
inter_score, 1.75, 3, "Pathway edit distance failed on intermediate case"
)
fp_score = pathway_edit_eval(*self.fp_case())
self.assertAlmostEqual(fp_score, 1.25, 3, "Pathway edit distance failed on fp case")
fn_score = pathway_edit_eval(*self.fn_case())
@ -93,22 +114,30 @@ class MultiGenTest(TestCase):
def intermediate_case(self):
"""Create an example with an intermediate in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)])
true_pathway.add_edge(
[true_pathway.root_nodes.all()[0]], [true_pathway.add_node("CC(=O)O", depth=1)]
)
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
pred_pathway.add_edge(
[pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)],
)
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
return true_pathway, pred_pathway
def fp_case(self):
"""Create an example with an extra compound in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
true_pathway.add_edge(
[true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)],
)
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)])
pred_pathway.add_edge(
[pred_pathway.root_nodes.all()[0]],
[acetaldehyde := pred_pathway.add_node("CC=O", depth=1)],
)
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway.add_edge([acetaldehyde], [pred_pathway.add_node("C", depth=2)])
return true_pathway, pred_pathway
@ -116,22 +145,30 @@ class MultiGenTest(TestCase):
def fn_case(self):
"""Create an example with a missing compound in the predicted pathway"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
true_pathway.add_edge(
[true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)],
)
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)])
pred_pathway.add_edge(
[pred_pathway.root_nodes.all()[0]], [pred_pathway.add_node("CC=O", depth=1)]
)
return true_pathway, pred_pathway
def all_case(self):
"""Create an example with an intermediate, extra compound and missing compound"""
true_pathway = Pathway.create(self.package, "CCO")
true_pathway.add_edge([true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)])
true_pathway.add_edge(
[true_pathway.root_nodes.all()[0]],
[acetaldehyde := true_pathway.add_node("CC=O", depth=1)],
)
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("C", depth=2)])
true_pathway.add_edge([acetaldehyde], [true_pathway.add_node("CC(=O)O", depth=2)])
pred_pathway = Pathway.create(self.package, "CCO")
pred_pathway.add_edge([pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)])
pred_pathway.add_edge(
[pred_pathway.root_nodes.all()[0]], [methane := pred_pathway.add_node("C", depth=1)]
)
pred_pathway.add_edge([methane], [true_pathway.add_node("CC=O", depth=2)])
pred_pathway.add_edge([methane], [true_pathway.add_node("c1ccccc1", depth=2)])
return true_pathway, pred_pathway

View File

@ -10,127 +10,127 @@ class ReactionTest(TestCase):
@classmethod
def setUpClass(cls):
super(ReactionTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self):
educt = Compound.create(
self.package,
smiles='C(CCl)Cl',
name='1,2-Dichloroethane',
description='Eawag BBD compound c0001'
smiles="C(CCl)Cl",
name="1,2-Dichloroethane",
description="Eawag BBD compound c0001",
).default_structure
product = Compound.create(
self.package,
smiles='C(CO)Cl',
name='2-Chloroethanol',
description='Eawag BBD compound c0005'
smiles="C(CO)Cl",
name="2-Chloroethanol",
description="Eawag BBD compound c0005",
).default_structure
r = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
name="Eawag BBD reaction r0001",
educts=[educt],
products=[product],
multi_step=False
multi_step=False,
)
self.assertEqual(r.smirks(), 'C(CCl)Cl>>C(CO)Cl')
self.assertEqual(r.name, 'Eawag BBD reaction r0001')
self.assertEqual(r.description, 'no description')
self.assertEqual(r.smirks(), "C(CCl)Cl>>C(CO)Cl")
self.assertEqual(r.name, "Eawag BBD reaction r0001")
self.assertEqual(r.description, "no description")
def test_string_educts_and_products(self):
r = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
multi_step=False
name="Eawag BBD reaction r0001",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
multi_step=False,
)
self.assertEqual(r.smirks(), 'C(CCl)Cl>>C(CO)Cl')
self.assertEqual(r.smirks(), "C(CCl)Cl>>C(CO)Cl")
def test_missing_smiles(self):
educt = Compound.create(
self.package,
smiles='C(CCl)Cl',
name='1,2-Dichloroethane',
description='Eawag BBD compound c0001'
smiles="C(CCl)Cl",
name="1,2-Dichloroethane",
description="Eawag BBD compound c0001",
).default_structure
product = Compound.create(
self.package,
smiles='C(CO)Cl',
name='2-Chloroethanol',
description='Eawag BBD compound c0005'
smiles="C(CO)Cl",
name="2-Chloroethanol",
description="Eawag BBD compound c0005",
).default_structure
with self.assertRaises(ValueError):
_ = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
name="Eawag BBD reaction r0001",
educts=[educt],
products=[],
multi_step=False
multi_step=False,
)
with self.assertRaises(ValueError):
_ = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
name="Eawag BBD reaction r0001",
educts=[],
products=[product],
multi_step=False
multi_step=False,
)
with self.assertRaises(ValueError):
_ = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
name="Eawag BBD reaction r0001",
educts=[],
products=[],
multi_step=False
multi_step=False,
)
def test_empty_name_and_description_are_ignored(self):
r = Reaction.create(
package=self.package,
name='',
description='',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
name="",
description="",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
multi_step=False,
)
self.assertEqual(r.name, 'no name')
self.assertEqual(r.description, 'no description')
self.assertEqual(r.name, "no name")
self.assertEqual(r.description, "no description")
def test_deduplication(self):
rule = Rule.create(
package=self.package,
rule_type='SimpleAmbitRule',
name='bt0022-2833',
description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative',
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
rule_type="SimpleAmbitRule",
name="bt0022-2833",
description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
r1 = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
name="Eawag BBD reaction r0001",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
rules=[rule],
multi_step=False
multi_step=False,
)
r2 = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
name="Eawag BBD reaction r0001",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
rules=[rule],
multi_step=False
multi_step=False,
)
# Check if create detects that this Compound already exist
@ -141,18 +141,18 @@ class ReactionTest(TestCase):
def test_deduplication_without_rules(self):
r1 = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
multi_step=False
name="Eawag BBD reaction r0001",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
multi_step=False,
)
r2 = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
multi_step=False
name="Eawag BBD reaction r0001",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
multi_step=False,
)
# Check if create detects that this Compound already exist
@ -164,19 +164,19 @@ class ReactionTest(TestCase):
with self.assertRaises(ValueError):
_ = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['ASDF'],
products=['C(CO)Cl'],
multi_step=False
name="Eawag BBD reaction r0001",
educts=["ASDF"],
products=["C(CO)Cl"],
multi_step=False,
)
def test_delete(self):
r = Reaction.create(
package=self.package,
name='Eawag BBD reaction r0001',
educts=['C(CCl)Cl'],
products=['C(CO)Cl'],
multi_step=False
name="Eawag BBD reaction r0001",
educts=["C(CCl)Cl"],
products=["C(CO)Cl"],
multi_step=False,
)
r.delete()

View File

@ -10,73 +10,79 @@ class RuleTest(TestCase):
@classmethod
def setUpClass(cls):
super(RuleTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
def test_smoke(self):
r = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
name='bt0022-2833',
description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative',
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
name="bt0022-2833",
description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
self.assertEqual(r.smirks,
'[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]')
self.assertEqual(r.name, 'bt0022-2833')
self.assertEqual(r.description,
'Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative')
self.assertEqual(
r.smirks,
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
self.assertEqual(r.name, "bt0022-2833")
self.assertEqual(
r.description,
"Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
)
def test_smirks_are_trimmed(self):
r = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
name='bt0022-2833',
description='Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative',
smirks=' [H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4] ',
name="bt0022-2833",
description="Dihalomethyl derivative + Halomethyl derivative > 1-Halo-1-methylalcohol derivative + 1-Methylalcohol derivative",
smirks=" [H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4] ",
)
self.assertEqual(r.smirks,
'[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]')
self.assertEqual(
r.smirks,
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
def test_name_and_description_optional(self):
r = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
self.assertRegex(r.name, 'Rule \\d+')
self.assertEqual(r.description, 'no description')
self.assertRegex(r.name, "Rule \\d+")
self.assertEqual(r.description, "no description")
def test_empty_name_and_description_are_ignored(self):
r = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
name='',
description='',
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name="",
description="",
)
self.assertRegex(r.name, 'Rule \\d+')
self.assertEqual(r.description, 'no description')
self.assertRegex(r.name, "Rule \\d+")
self.assertEqual(r.description, "no description")
def test_deduplication(self):
r1 = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
name='',
description='',
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name="",
description="",
)
r2 = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
name='',
description='',
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name="",
description="",
)
self.assertEqual(r1.pk, r2.pk)
@ -84,21 +90,21 @@ class RuleTest(TestCase):
def test_valid_smirks(self):
with self.assertRaises(ValueError):
r = Rule.create(
rule_type='SimpleAmbitRule',
Rule.create(
rule_type="SimpleAmbitRule",
package=self.package,
smirks='This is not a valid SMIRKS',
name='',
description='',
smirks="This is not a valid SMIRKS",
name="",
description="",
)
def test_delete(self):
r = Rule.create(
rule_type='SimpleAmbitRule',
rule_type="SimpleAmbitRule",
package=self.package,
smirks='[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]',
name='',
description='',
smirks="[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
name="",
description="",
)
r.delete()

File diff suppressed because it is too large Load Diff

View File

@ -12,34 +12,32 @@ class SimpleAmbitRuleTest(TestCase):
@classmethod
def setUpClass(cls):
super(SimpleAmbitRuleTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Simple Ambit Rule Test Package',
'Test Package for SimpleAmbitRule')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(
cls.user, "Simple Ambit Rule Test Package", "Test Package for SimpleAmbitRule"
)
def test_create_basic_rule(self):
"""Test creating a basic SimpleAmbitRule with minimal parameters."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule = SimpleAmbitRule.create(
package=self.package,
smirks=smirks
)
rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
self.assertIsInstance(rule, SimpleAmbitRule)
self.assertEqual(rule.smirks, smirks)
self.assertEqual(rule.package, self.package)
self.assertRegex(rule.name, r'Rule \d+')
self.assertEqual(rule.description, 'no description')
self.assertRegex(rule.name, r"Rule \d+")
self.assertEqual(rule.description, "no description")
self.assertIsNone(rule.reactant_filter_smarts)
self.assertIsNone(rule.product_filter_smarts)
def test_create_with_all_parameters(self):
"""Test creating SimpleAmbitRule with all parameters."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
name = 'Test Rule'
description = 'A test biotransformation rule'
reactant_filter = '[CH2X4]'
product_filter = '[OH]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
name = "Test Rule"
description = "A test biotransformation rule"
reactant_filter = "[CH2X4]"
product_filter = "[OH]"
rule = SimpleAmbitRule.create(
package=self.package,
@ -47,7 +45,7 @@ class SimpleAmbitRuleTest(TestCase):
description=description,
smirks=smirks,
reactant_filter_smarts=reactant_filter,
product_filter_smarts=product_filter
product_filter_smarts=product_filter,
)
self.assertEqual(rule.name, name)
@ -60,127 +58,114 @@ class SimpleAmbitRuleTest(TestCase):
"""Test that SMIRKS is required for rule creation."""
with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(package=self.package, smirks=None)
self.assertIn('SMIRKS is required', str(cm.exception))
self.assertIn("SMIRKS is required", str(cm.exception))
with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(package=self.package, smirks='')
self.assertIn('SMIRKS is required', str(cm.exception))
SimpleAmbitRule.create(package=self.package, smirks="")
self.assertIn("SMIRKS is required", str(cm.exception))
with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(package=self.package, smirks=' ')
self.assertIn('SMIRKS is required', str(cm.exception))
SimpleAmbitRule.create(package=self.package, smirks=" ")
self.assertIn("SMIRKS is required", str(cm.exception))
@patch('epdb.models.FormatConverter.is_valid_smirks')
@patch("epdb.models.FormatConverter.is_valid_smirks")
def test_invalid_smirks_validation(self, mock_is_valid):
"""Test validation of SMIRKS format."""
mock_is_valid.return_value = False
invalid_smirks = 'invalid_smirks_string'
invalid_smirks = "invalid_smirks_string"
with self.assertRaises(ValueError) as cm:
SimpleAmbitRule.create(
package=self.package,
smirks=invalid_smirks
)
SimpleAmbitRule.create(package=self.package, smirks=invalid_smirks)
self.assertIn(f'SMIRKS "{invalid_smirks}" is invalid', str(cm.exception))
mock_is_valid.assert_called_once_with(invalid_smirks)
def test_smirks_trimming(self):
"""Test that SMIRKS strings are trimmed during creation."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks_with_whitespace = f' {smirks} '
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
smirks_with_whitespace = f" {smirks} "
rule = SimpleAmbitRule.create(
package=self.package,
smirks=smirks_with_whitespace
)
rule = SimpleAmbitRule.create(package=self.package, smirks=smirks_with_whitespace)
self.assertEqual(rule.smirks, smirks)
def test_empty_name_and_description_handling(self):
"""Test that empty name and description are handled appropriately."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
name='',
description=' '
package=self.package, smirks=smirks, name="", description=" "
)
self.assertRegex(rule.name, r'Rule \d+')
self.assertEqual(rule.description, 'no description')
self.assertRegex(rule.name, r"Rule \d+")
self.assertEqual(rule.description, "no description")
def test_deduplication_basic(self):
"""Test that identical rules are deduplicated."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule1 = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
name='Rule 1'
)
rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks, name="Rule 1")
rule2 = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
name='Rule 2' # Different name, but same SMIRKS
name="Rule 2", # Different name, but same SMIRKS
)
self.assertEqual(rule1.pk, rule2.pk)
self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 1)
self.assertEqual(
SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 1
)
def test_deduplication_with_filters(self):
"""Test deduplication with filter SMARTS."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
reactant_filter = '[CH2X4]'
product_filter = '[OH]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
reactant_filter = "[CH2X4]"
product_filter = "[OH]"
rule1 = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
reactant_filter_smarts=reactant_filter,
product_filter_smarts=product_filter
product_filter_smarts=product_filter,
)
rule2 = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
reactant_filter_smarts=reactant_filter,
product_filter_smarts=product_filter
product_filter_smarts=product_filter,
)
self.assertEqual(rule1.pk, rule2.pk)
def test_no_deduplication_different_filters(self):
"""Test that rules with different filters are not deduplicated."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule1 = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
reactant_filter_smarts='[CH2X4]'
package=self.package, smirks=smirks, reactant_filter_smarts="[CH2X4]"
)
rule2 = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
reactant_filter_smarts='[CH3X4]'
package=self.package, smirks=smirks, reactant_filter_smarts="[CH3X4]"
)
self.assertNotEqual(rule1.pk, rule2.pk)
self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 2)
self.assertEqual(
SimpleAmbitRule.objects.filter(package=self.package, smirks=smirks).count(), 2
)
def test_filter_smarts_trimming(self):
"""Test that filter SMARTS are trimmed and handled correctly."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
# Test with whitespace-only filters (should be treated as None)
rule = SimpleAmbitRule.create(
package=self.package,
smirks=smirks,
reactant_filter_smarts=' ',
product_filter_smarts=' '
reactant_filter_smarts=" ",
product_filter_smarts=" ",
)
self.assertIsNone(rule.reactant_filter_smarts)
@ -188,94 +173,85 @@ class SimpleAmbitRuleTest(TestCase):
def test_url_property(self):
"""Test the URL property generation."""
rule = SimpleAmbitRule.create(
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
expected_url = f'{self.package.url}/simple-ambit-rule/{rule.uuid}'
expected_url = f"{self.package.url}/simple-ambit-rule/{rule.uuid}"
self.assertEqual(rule.url, expected_url)
@patch('epdb.models.FormatConverter.apply')
@patch("epdb.models.FormatConverter.apply")
def test_apply_method(self, mock_apply):
"""Test the apply method delegates to FormatConverter."""
mock_apply.return_value = ['product1', 'product2']
mock_apply.return_value = ["product1", "product2"]
rule = SimpleAmbitRule.create(
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
test_smiles = 'CCO'
test_smiles = "CCO"
result = rule.apply(test_smiles)
mock_apply.assert_called_once_with(test_smiles, rule.smirks)
self.assertEqual(result, ['product1', 'product2'])
self.assertEqual(result, ["product1", "product2"])
def test_reactants_smarts_property(self):
"""Test reactants_smarts property extracts correct part of SMIRKS."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
expected_reactants = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
expected_reactants = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]"
rule = SimpleAmbitRule.create(
package=self.package,
smirks=smirks
)
rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
self.assertEqual(rule.reactants_smarts, expected_reactants)
def test_products_smarts_property(self):
"""Test products_smarts property extracts correct part of SMIRKS."""
smirks = '[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
expected_products = '[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]'
smirks = "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
expected_products = "[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]"
rule = SimpleAmbitRule.create(
package=self.package,
smirks=smirks
)
rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
self.assertEqual(rule.products_smarts, expected_products)
@patch('epdb.models.Package.objects')
@patch("epdb.models.Package.objects")
def test_related_reactions_property(self, mock_package_objects):
"""Test related_reactions property returns correct queryset."""
mock_qs = MagicMock()
mock_package_objects.filter.return_value = mock_qs
rule = SimpleAmbitRule.create(
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
# Instead of directly assigning, patch the property or use with patch.object
with patch.object(type(rule), 'reaction_rule', new_callable=PropertyMock) as mock_reaction_rule:
mock_reaction_rule.return_value.filter.return_value.order_by.return_value = ['reaction1', 'reaction2']
with patch.object(
type(rule), "reaction_rule", new_callable=PropertyMock
) as mock_reaction_rule:
mock_reaction_rule.return_value.filter.return_value.order_by.return_value = [
"reaction1",
"reaction2",
]
result = rule.related_reactions
mock_package_objects.filter.assert_called_once_with(reviewed=True)
mock_reaction_rule.return_value.filter.assert_called_once_with(package__in=mock_qs)
mock_reaction_rule.return_value.filter.return_value.order_by.assert_called_once_with('name')
self.assertEqual(result, ['reaction1', 'reaction2'])
mock_reaction_rule.return_value.filter.return_value.order_by.assert_called_once_with(
"name"
)
self.assertEqual(result, ["reaction1", "reaction2"])
@patch('epdb.models.Pathway.objects')
@patch('epdb.models.Edge.objects')
@patch("epdb.models.Pathway.objects")
@patch("epdb.models.Edge.objects")
def test_related_pathways_property(self, mock_edge_objects, mock_pathway_objects):
"""Test related_pathways property returns correct queryset."""
mock_related_reactions = ['reaction1', 'reaction2']
mock_related_reactions = ["reaction1", "reaction2"]
with patch.object(SimpleAmbitRule, "related_reactions", new_callable=PropertyMock) as mock_prop:
with patch.object(
SimpleAmbitRule, "related_reactions", new_callable=PropertyMock
) as mock_prop:
mock_prop.return_value = mock_related_reactions
rule = SimpleAmbitRule.create(
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
# Mock Edge objects query
mock_edge_values = MagicMock()
mock_edge_values.values.return_value = ['pathway_id1', 'pathway_id2']
mock_edge_values.values.return_value = ["pathway_id1", "pathway_id2"]
mock_edge_objects.filter.return_value = mock_edge_values
# Mock Pathway objects query
@ -285,52 +261,49 @@ class SimpleAmbitRuleTest(TestCase):
result = rule.related_pathways
mock_edge_objects.filter.assert_called_once_with(edge_label__in=mock_related_reactions)
mock_edge_values.values.assert_called_once_with('pathway_id')
mock_edge_values.values.assert_called_once_with("pathway_id")
mock_pathway_objects.filter.assert_called_once()
self.assertEqual(result, mock_pathway_qs)
@patch('epdb.models.IndigoUtils.smirks_to_svg')
@patch("epdb.models.IndigoUtils.smirks_to_svg")
def test_as_svg_property(self, mock_smirks_to_svg):
"""Test as_svg property calls IndigoUtils correctly."""
mock_smirks_to_svg.return_value = '<svg>test_svg</svg>'
mock_smirks_to_svg.return_value = "<svg>test_svg</svg>"
rule = SimpleAmbitRule.create(
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]'
)
rule = SimpleAmbitRule.create(package=self.package, smirks="[H:1][C:2]>>[H:1][O:2]")
result = rule.as_svg
mock_smirks_to_svg.assert_called_once_with(rule.smirks, True, width=800, height=400)
self.assertEqual(result, '<svg>test_svg</svg>')
self.assertEqual(result, "<svg>test_svg</svg>")
def test_atomic_transaction(self):
"""Test that rule creation is atomic."""
smirks = '[H:1][C:2]>>[H:1][O:2]'
smirks = "[H:1][C:2]>>[H:1][O:2]"
# This should work normally
rule = SimpleAmbitRule.create(package=self.package, smirks=smirks)
self.assertIsInstance(rule, SimpleAmbitRule)
# Test transaction rollback on error
with patch('epdb.models.SimpleAmbitRule.save', side_effect=Exception('Database error')):
with patch("epdb.models.SimpleAmbitRule.save", side_effect=Exception("Database error")):
with self.assertRaises(Exception):
SimpleAmbitRule.create(package=self.package, smirks='[H:3][C:4]>>[H:3][O:4]')
SimpleAmbitRule.create(package=self.package, smirks="[H:3][C:4]>>[H:3][O:4]")
# Verify no partial data was saved
self.assertEqual(SimpleAmbitRule.objects.filter(package=self.package).count(), 1)
def test_multiple_duplicate_warning(self):
"""Test logging when multiple duplicates are found."""
smirks = '[H:1][C:2]>>[H:1][O:2]'
smirks = "[H:1][C:2]>>[H:1][O:2]"
# Create first rule
rule1 = SimpleAmbitRule.create(package=self.package, smirks=smirks)
# Manually create a duplicate to simulate the error condition
rule2 = SimpleAmbitRule(package=self.package, smirks=smirks, name='Manual Rule')
rule2 = SimpleAmbitRule(package=self.package, smirks=smirks, name="Manual Rule")
rule2.save()
with patch('epdb.models.logger') as mock_logger:
with patch("epdb.models.logger") as mock_logger:
# This should find the existing rule and log an error about multiple matches
result = SimpleAmbitRule.create(package=self.package, smirks=smirks)
@ -339,24 +312,28 @@ class SimpleAmbitRuleTest(TestCase):
# Should log an error about multiple matches
mock_logger.error.assert_called()
self.assertIn('More than one rule matched', mock_logger.error.call_args[0][0])
self.assertIn("More than one rule matched", mock_logger.error.call_args[0][0])
def test_model_fields(self):
"""Test model field properties."""
rule = SimpleAmbitRule.create(
package=self.package,
smirks='[H:1][C:2]>>[H:1][O:2]',
reactant_filter_smarts='[CH3]',
product_filter_smarts='[OH]'
smirks="[H:1][C:2]>>[H:1][O:2]",
reactant_filter_smarts="[CH3]",
product_filter_smarts="[OH]",
)
# Test field properties
self.assertFalse(rule._meta.get_field('smirks').blank)
self.assertFalse(rule._meta.get_field('smirks').null)
self.assertTrue(rule._meta.get_field('reactant_filter_smarts').null)
self.assertTrue(rule._meta.get_field('product_filter_smarts').null)
self.assertFalse(rule._meta.get_field("smirks").blank)
self.assertFalse(rule._meta.get_field("smirks").null)
self.assertTrue(rule._meta.get_field("reactant_filter_smarts").null)
self.assertTrue(rule._meta.get_field("product_filter_smarts").null)
# Test verbose names
self.assertEqual(rule._meta.get_field('smirks').verbose_name, 'SMIRKS')
self.assertEqual(rule._meta.get_field('reactant_filter_smarts').verbose_name, 'Reactant Filter SMARTS')
self.assertEqual(rule._meta.get_field('product_filter_smarts').verbose_name, 'Product Filter SMARTS')
self.assertEqual(rule._meta.get_field("smirks").verbose_name, "SMIRKS")
self.assertEqual(
rule._meta.get_field("reactant_filter_smarts").verbose_name, "Reactant Filter SMARTS"
)
self.assertEqual(
rule._meta.get_field("product_filter_smarts").verbose_name, "Product Filter SMARTS"
)

View File

@ -1,32 +1,29 @@
from django.test import TestCase
from epdb.logic import SNode, SEdge, SPathway
from epdb.logic import SNode, SEdge
class SObjectTest(TestCase):
def setUp(self):
pass
def test_snode_eq(self):
snode1 = SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)
snode2 = SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)
snode1 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
snode2 = SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)
assert snode1 == snode2
def test_snode_hash(self):
pass
def test_sedge_eq(self):
sedge1 = SEdge([SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)],
[SNode('CN1C(=O)NC2=C(C1=O)N(C)C=N2', 1), SNode('C=O', 1)],
rule=None)
sedge2 = SEdge([SNode('CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O', 0)],
[SNode('CN1C(=O)NC2=C(C1=O)N(C)C=N2', 1), SNode('C=O', 1)],
rule=None)
sedge1 = SEdge(
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
rule=None,
)
sedge2 = SEdge(
[SNode("CN1C2C(N(C(N(C)C=2N=C1)=O)C)=O", 0)],
[SNode("CN1C(=O)NC2=C(C1=O)N(C)C=N2", 1), SNode("C=O", 1)],
rule=None,
)
assert sedge1 == sedge2
def test_sedge_hash(self):
pass
def test_spathway(self):
pw = SPathway()

View File

@ -3,7 +3,7 @@ from django.urls import reverse
from envipy_additional_information import Temperature, Interval
from epdb.logic import UserManager, PackageManager
from epdb.models import Compound, Scenario, ExternalIdentifier, ExternalDatabase
from epdb.models import Compound, Scenario, ExternalDatabase
class CompoundViewTest(TestCase):
@ -12,21 +12,28 @@ class CompoundViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(CompoundViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe",
set_setting=False, add_to_group=True, is_active=True)
cls.user1 = UserManager.create_user(
"user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack')
cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self):
self.client.force_login(self.user1)
def test_create_compound(self):
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -38,17 +45,18 @@ class CompoundViewTest(TestCase):
self.assertEqual(c.name, "1,2-Dichloroethane")
self.assertEqual(c.description, "Eawag BBD compound c0001")
self.assertEqual(c.default_structure.smiles, "C(CCl)Cl")
self.assertEqual(c.default_structure.canonical_smiles, 'ClCCCl')
self.assertEqual(c.default_structure.canonical_smiles, "ClCCCl")
self.assertEqual(c.structures.all().count(), 2)
self.assertEqual(self.user1_default_package.compounds.count(), 1)
# Adding the same rule again should return the existing one, hence not increasing the number of rules
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.url, compound_url)
@ -57,11 +65,12 @@ class CompoundViewTest(TestCase):
# Adding the same rule in a different package should create a new rule
response = self.client.post(
reverse("package compound list", kwargs={'package_uuid': self.package.uuid}), {
reverse("package compound list", kwargs={"package_uuid": self.package.uuid}),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -69,11 +78,12 @@ class CompoundViewTest(TestCase):
# adding another reaction should increase count
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "2-Chloroethanol",
"compound-description": "Eawag BBD compound c0005",
"compound-smiles": "C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -82,11 +92,12 @@ class CompoundViewTest(TestCase):
# Edit
def test_edit_rule(self):
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -95,13 +106,17 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(self.user1_default_package.uuid),
'compound_uuid': str(c.uuid)
}), {
reverse(
"package compound detail",
kwargs={
"package_uuid": str(self.user1_default_package.uuid),
"compound_uuid": str(c.uuid),
},
),
{
"compound-name": "Test Compound Adjusted",
"compound-description": "New Description",
}
},
)
self.assertEqual(response.status_code, 302)
@ -121,7 +136,7 @@ class CompoundViewTest(TestCase):
"Test Desc",
"2025-10",
"soil",
[Temperature(interval=Interval(start=20, end=30))]
[Temperature(interval=Interval(start=20, end=30))],
)
s2 = Scenario.create(
@ -130,15 +145,16 @@ class CompoundViewTest(TestCase):
"Test Desc2",
"2025-10",
"soil",
[Temperature(interval=Interval(start=10, end=20))]
[Temperature(interval=Interval(start=10, end=20))],
)
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -147,36 +163,35 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(c.package.uuid),
'compound_uuid': str(c.uuid)
}), {
"selected-scenarios": [s1.url, s2.url]
}
reverse(
"package compound detail",
kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
),
{"selected-scenarios": [s1.url, s2.url]},
)
self.assertEqual(len(c.scenarios.all()), 2)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(c.package.uuid),
'compound_uuid': str(c.uuid)
}), {
"selected-scenarios": [s1.url]
}
reverse(
"package compound detail",
kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
),
{"selected-scenarios": [s1.url]},
)
self.assertEqual(len(c.scenarios.all()), 1)
self.assertEqual(c.scenarios.first().url, s1.url)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(c.package.uuid),
'compound_uuid': str(c.uuid)
}), {
reverse(
"package compound detail",
kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
),
{
# We have to set an empty string to avoid that the parameter is removed
"selected-scenarios": ""
}
},
)
self.assertEqual(len(c.scenarios.all()), 0)
@ -184,11 +199,12 @@ class CompoundViewTest(TestCase):
#
def test_copy(self):
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -196,12 +212,13 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url)
response = self.client.post(
reverse("package detail", kwargs={
'package_uuid': str(self.package.uuid),
}), {
"hidden": "copy",
"object_to_copy": c.url
}
reverse(
"package detail",
kwargs={
"package_uuid": str(self.package.uuid),
},
),
{"hidden": "copy", "object_to_copy": c.url},
)
self.assertEqual(response.status_code, 200)
@ -215,44 +232,48 @@ class CompoundViewTest(TestCase):
# Copy to the same package should fail
response = self.client.post(
reverse("package detail", kwargs={
'package_uuid': str(c.package.uuid),
}), {
"hidden": "copy",
"object_to_copy": c.url
}
reverse(
"package detail",
kwargs={
"package_uuid": str(c.package.uuid),
},
),
{"hidden": "copy", "object_to_copy": c.url},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], f"Can't copy object {compound_url} to the same package!")
self.assertEqual(
response.json()["error"], f"Can't copy object {compound_url} to the same package!"
)
def test_references(self):
ext_db, _ = ExternalDatabase.objects.get_or_create(
name='PubChem Compound',
name="PubChem Compound",
defaults={
'full_name': 'PubChem Compound Database',
'description': 'Chemical database of small organic molecules',
'base_url': 'https://pubchem.ncbi.nlm.nih.gov',
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}'
}
"full_name": "PubChem Compound Database",
"description": "Chemical database of small organic molecules",
"base_url": "https://pubchem.ncbi.nlm.nih.gov",
"url_pattern": "https://pubchem.ncbi.nlm.nih.gov/compound/{id}",
},
)
ext_db2, _ = ExternalDatabase.objects.get_or_create(
name='PubChem Substance',
name="PubChem Substance",
defaults={
'full_name': 'PubChem Substance Database',
'description': 'Database of chemical substances',
'base_url': 'https://pubchem.ncbi.nlm.nih.gov',
'url_pattern': 'https://pubchem.ncbi.nlm.nih.gov/substance/{id}'
}
"full_name": "PubChem Substance Database",
"description": "Database of chemical substances",
"base_url": "https://pubchem.ncbi.nlm.nih.gov",
"url_pattern": "https://pubchem.ncbi.nlm.nih.gov/substance/{id}",
},
)
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -260,42 +281,49 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(c.package.uuid),
'compound_uuid': str(c.uuid),
}), {
'selected-database': ext_db.pk,
'identifier': '25154249'
}
reverse(
"package compound detail",
kwargs={
"package_uuid": str(c.package.uuid),
"compound_uuid": str(c.uuid),
},
),
{"selected-database": ext_db.pk, "identifier": "25154249"},
)
self.assertEqual(c.external_identifiers.count(), 1)
self.assertEqual(c.external_identifiers.first().database, ext_db)
self.assertEqual(c.external_identifiers.first().identifier_value, '25154249')
self.assertEqual(c.external_identifiers.first().url, 'https://pubchem.ncbi.nlm.nih.gov/compound/25154249')
self.assertEqual(c.external_identifiers.first().identifier_value, "25154249")
self.assertEqual(
c.external_identifiers.first().url, "https://pubchem.ncbi.nlm.nih.gov/compound/25154249"
)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(c.package.uuid),
'compound_uuid': str(c.uuid),
}), {
'selected-database': ext_db2.pk,
'identifier': '25154249'
}
reverse(
"package compound detail",
kwargs={
"package_uuid": str(c.package.uuid),
"compound_uuid": str(c.uuid),
},
),
{"selected-database": ext_db2.pk, "identifier": "25154249"},
)
self.assertEqual(c.external_identifiers.count(), 2)
self.assertEqual(c.external_identifiers.last().database, ext_db2)
self.assertEqual(c.external_identifiers.last().identifier_value, '25154249')
self.assertEqual(c.external_identifiers.last().url, 'https://pubchem.ncbi.nlm.nih.gov/substance/25154249')
self.assertEqual(c.external_identifiers.last().identifier_value, "25154249")
self.assertEqual(
c.external_identifiers.last().url, "https://pubchem.ncbi.nlm.nih.gov/substance/25154249"
)
def test_delete(self):
response = self.client.post(
reverse("compounds"), {
reverse("compounds"),
{
"compound-name": "1,2-Dichloroethane",
"compound-description": "Eawag BBD compound c0001",
"compound-smiles": "C(CCl)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -304,12 +332,11 @@ class CompoundViewTest(TestCase):
c = Compound.objects.get(url=compound_url)
response = self.client.post(
reverse("package compound detail", kwargs={
'package_uuid': str(c.package.uuid),
'compound_uuid': str(c.uuid)
}), {
"hidden": "delete"
}
reverse(
"package compound detail",
kwargs={"package_uuid": str(c.package.uuid), "compound_uuid": str(c.uuid)},
),
{"hidden": "delete"},
)
self.assertEqual(self.user1_default_package.compounds.count(), 0)

View File

@ -2,8 +2,8 @@ from django.test import TestCase, override_settings
from django.urls import reverse
from django.conf import settings as s
from epdb.logic import UserManager, PackageManager
from epdb.models import Pathway, Edge, Package, User
from epdb.logic import UserManager
from epdb.models import Package, User
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models")
@ -13,10 +13,16 @@ class PathwayViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(PathwayViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe",
set_setting=True, add_to_group=True, is_active=True)
cls.user1 = UserManager.create_user(
"user1",
"user1@envipath.com",
"SuperSafe",
set_setting=True,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package
cls.model_package = Package.objects.get(name='Fixtures')
cls.model_package = Package.objects.get(name="Fixtures")
def setUp(self):
self.client.force_login(self.user1)
@ -24,90 +30,96 @@ class PathwayViewTest(TestCase):
def test_predict(self):
self.client.force_login(User.objects.get(username="admin"))
response = self.client.get(
reverse("package model detail", kwargs={
'package_uuid': str(self.model_package.uuid),
'model_uuid': str(self.model_package.models.first().uuid)
}), {
'classify': 'ILikeCats!',
'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO',
}
reverse(
"package model detail",
kwargs={
"package_uuid": str(self.model_package.uuid),
"model_uuid": str(self.model_package.models.first().uuid),
},
),
{
"classify": "ILikeCats!",
"smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO",
},
)
expected = [
{
'products': [
[
'O=C(O)C1=CC(CO)=CC=C1',
'CCNCC'
]
],
'probability': 0.25,
'btrule': {
'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/0e6e9290-b658-4450-b291-3ec19fa19206',
'name': 'bt0430-4011'
}
}, {
'products': [
[
'CCNC(=O)C1=CC(CO)=CC=C1',
'CC=O'
]
], 'probability': 0.0,
'btrule': {
'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/27a3a353-0b66-4228-bd16-e407949e90df',
'name': 'bt0243-4301'
}
}, {
'products': [
[
'CCN(CC)C(=O)C1=CC(C=O)=CC=C1'
]
], 'probability': 0.75,
'btrule': {
'url': 'http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/2f2e0c39-e109-4836-959f-2bda2524f022',
'name': 'bt0001-3568'
}
}
"products": [["O=C(O)C1=CC(CO)=CC=C1", "CCNCC"]],
"probability": 0.25,
"btrule": {
"url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/0e6e9290-b658-4450-b291-3ec19fa19206",
"name": "bt0430-4011",
},
},
{
"products": [["CCNC(=O)C1=CC(CO)=CC=C1", "CC=O"]],
"probability": 0.0,
"btrule": {
"url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/27a3a353-0b66-4228-bd16-e407949e90df",
"name": "bt0243-4301",
},
},
{
"products": [["CCN(CC)C(=O)C1=CC(C=O)=CC=C1"]],
"probability": 0.75,
"btrule": {
"url": "http://localhost:8000/package/1869d3f0-60bb-41fd-b6f8-afa75ffb09d3/simple-ambit-rule/2f2e0c39-e109-4836-959f-2bda2524f022",
"name": "bt0001-3568",
},
},
]
actual = response.json()
self.assertEqual(actual, expected)
response = self.client.get(
reverse("package model detail", kwargs={
'package_uuid': str(self.model_package.uuid),
'model_uuid': str(self.model_package.models.first().uuid)
}), {
'classify': 'ILikeCats!',
'smiles': '',
}
reverse(
"package model detail",
kwargs={
"package_uuid": str(self.model_package.uuid),
"model_uuid": str(self.model_package.models.first().uuid),
},
),
{
"classify": "ILikeCats!",
"smiles": "",
},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], 'Received empty SMILES')
self.assertEqual(response.json()["error"], "Received empty SMILES")
response = self.client.get(
reverse("package model detail", kwargs={
'package_uuid': str(self.model_package.uuid),
'model_uuid': str(self.model_package.models.first().uuid)
}), {
'classify': 'ILikeCats!',
'smiles': ' ', # Input should be stripped
}
reverse(
"package model detail",
kwargs={
"package_uuid": str(self.model_package.uuid),
"model_uuid": str(self.model_package.models.first().uuid),
},
),
{
"classify": "ILikeCats!",
"smiles": " ", # Input should be stripped
},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], 'Received empty SMILES')
self.assertEqual(response.json()["error"], "Received empty SMILES")
response = self.client.get(
reverse("package model detail", kwargs={
'package_uuid': str(self.model_package.uuid),
'model_uuid': str(self.model_package.models.first().uuid)
}), {
'classify': 'ILikeCats!',
'smiles': 'RandomInput',
}
reverse(
"package model detail",
kwargs={
"package_uuid": str(self.model_package.uuid),
"model_uuid": str(self.model_package.models.first().uuid),
},
),
{
"classify": "ILikeCats!",
"smiles": "RandomInput",
},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], '"RandomInput" is not a valid SMILES')
self.assertEqual(response.json()["error"], '"RandomInput" is not a valid SMILES')

View File

@ -13,19 +13,34 @@ class PackageViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(PackageViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe",
set_setting=False, add_to_group=True, is_active=True)
cls.user2 = UserManager.create_user("user2", "user2@envipath.com", "SuperSafe",
set_setting=False, add_to_group=True, is_active=True)
cls.user1 = UserManager.create_user(
"user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user2 = UserManager.create_user(
"user2",
"user2@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
def setUp(self):
self.client.force_login(self.user1)
def test_create_package(self):
response = self.client.post(reverse("packages"), {
response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package",
"package-description": "Just a Description",
})
},
)
self.assertEqual(response.status_code, 302)
package_url = response.url
@ -41,13 +56,12 @@ class PackageViewTest(TestCase):
file = SimpleUploadedFile(
"Fixture_Package.json",
open(s.FIXTURE_DIRS[0] / "Fixture_Package.json", "rb").read(),
content_type="application/json"
content_type="application/json",
)
response = self.client.post(reverse("packages"), {
"file": file,
"hidden": "import-package-json"
})
response = self.client.post(
reverse("packages"), {"file": file, "hidden": "import-package-json"}
)
self.assertEqual(response.status_code, 302)
package_url = response.url
@ -67,13 +81,12 @@ class PackageViewTest(TestCase):
file = SimpleUploadedFile(
"EAWAG-BBD.json",
open(s.FIXTURE_DIRS[0] / "packages" / "2025-07-18" / "EAWAG-BBD.json", "rb").read(),
content_type="application/json"
content_type="application/json",
)
response = self.client.post(reverse("packages"), {
"file": file,
"hidden": "import-legacy-package-json"
})
response = self.client.post(
reverse("packages"), {"file": file, "hidden": "import-legacy-package-json"}
)
self.assertEqual(response.status_code, 302)
package_url = response.url
@ -90,17 +103,23 @@ class PackageViewTest(TestCase):
self.assertEqual(upp.permission, Permission.ALL[0])
def test_edit_package(self):
response = self.client.post(reverse("packages"), {
response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package",
"package-description": "Just a Description",
})
},
)
self.assertEqual(response.status_code, 302)
package_url = response.url
self.client.post(package_url, {
self.client.post(
package_url,
{
"package-name": "New Name",
"package-description": "New Description",
})
},
)
p = Package.objects.get(url=package_url)
@ -108,10 +127,13 @@ class PackageViewTest(TestCase):
self.assertEqual(p.description, "New Description")
def test_edit_package_permissions(self):
response = self.client.post(reverse("packages"), {
response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package",
"package-description": "Just a Description",
})
},
)
self.assertEqual(response.status_code, 302)
package_url = response.url
p = Package.objects.get(url=package_url)
@ -119,57 +141,63 @@ class PackageViewTest(TestCase):
with self.assertRaises(UserPackagePermission.DoesNotExist):
UserPackagePermission.objects.get(package=p, user=self.user2)
self.client.post(package_url, {
self.client.post(
package_url,
{
"grantee": self.user2.url,
"read": "on",
"write": "on",
})
},
)
upp = UserPackagePermission.objects.get(package=p, user=self.user2)
self.assertEqual(upp.permission, Permission.WRITE[0])
def test_publish_package(self):
response = self.client.post(reverse("packages"), {
response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package",
"package-description": "Just a Description",
})
},
)
self.assertEqual(response.status_code, 302)
package_url = response.url
p = Package.objects.get(url=package_url)
self.client.post(package_url, {
"hidden": "publish-package"
})
self.client.post(package_url, {"hidden": "publish-package"})
self.assertEqual(Group.objects.filter(public=True).count(), 1)
g = Group.objects.get(public=True)
gpp = GroupPackagePermission.objects.get(package=p, group=g)
self.assertEqual(gpp.permission, Permission.READ[0])
def test_set_package_license(self):
response = self.client.post(reverse("packages"), {
response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package",
"package-description": "Just a Description",
})
},
)
package_url = response.url
p = Package.objects.get(url=package_url)
self.client.post(package_url, {
"license": "no-license"
})
self.client.post(package_url, {"license": "no-license"})
self.assertIsNone(p.license)
# TODO test others
def test_delete_package(self):
response = self.client.post(reverse("packages"), {
response = self.client.post(
reverse("packages"),
{
"package-name": "Test Package",
"package-description": "Just a Description",
})
},
)
package_url = response.url
p = Package.objects.get(url=package_url)
@ -182,11 +210,11 @@ class PackageViewTest(TestCase):
def test_delete_default_package(self):
self.client.force_login(self.user1)
# Try to delete the default package
response = self.client.post(self.user1.default_package.url, {
"hidden": "delete"
})
response = self.client.post(self.user1.default_package.url, {"hidden": "delete"})
self.assertEqual(response.status_code, 400)
self.assertTrue(f'You cannot delete the default package. '
f'If you want to delete this package you have to '
f'set another default package first' in response.content.decode())
self.assertTrue(
"You cannot delete the default package. "
"If you want to delete this package you have to "
"set another default package first" in response.content.decode()
)

View File

@ -5,6 +5,7 @@ from django.conf import settings as s
from epdb.logic import UserManager, PackageManager
from epdb.models import Pathway, Edge
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models")
class PathwayViewTest(TestCase):
fixtures = ["test_fixtures_incl_model.jsonl.gz"]
@ -12,41 +13,52 @@ class PathwayViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(PathwayViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe",
set_setting=True, add_to_group=True, is_active=True)
cls.user1 = UserManager.create_user(
"user1",
"user1@envipath.com",
"SuperSafe",
set_setting=True,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack')
cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self):
self.client.force_login(self.user1)
def test_predict_pathway(self):
response = self.client.post(reverse("pathways"), {
'name': 'Test Pathway',
'description': 'Just a Description',
'predict': 'predict',
'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO',
})
response = self.client.post(
reverse("pathways"),
{
"name": "Test Pathway",
"description": "Just a Description",
"predict": "predict",
"smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO",
},
)
self.assertEqual(response.status_code, 302)
pathway_url = response.url
pw = Pathway.objects.get(url=pathway_url)
self.assertEqual(self.user1_default_package, pw.package)
self.assertEqual(pw.name, 'Test Pathway')
self.assertEqual(pw.description, 'Just a Description')
self.assertEqual(pw.name, "Test Pathway")
self.assertEqual(pw.description, "Just a Description")
self.assertEqual(len(pw.root_nodes), 1)
self.assertEqual(pw.root_nodes.first().default_node_label.smiles, 'CCN(CC)C(=O)C1=CC(CO)=CC=C1')
self.assertEqual(
pw.root_nodes.first().default_node_label.smiles, "CCN(CC)C(=O)C1=CC(CO)=CC=C1"
)
first_level_nodes = {
# Edge 1
'CCN(CC)C(=O)C1=CC(C=O)=CC=C1',
"CCN(CC)C(=O)C1=CC(C=O)=CC=C1",
# Edge 2
'CCNC(=O)C1=CC(CO)=CC=C1',
'CC=O',
"CCNC(=O)C1=CC(CO)=CC=C1",
"CC=O",
# Edge 3
'CCNCC',
'O=C(O)C1=CC(CO)=CC=C1',
"CCNCC",
"O=C(O)C1=CC(CO)=CC=C1",
}
predicted_nodes = set()
@ -60,32 +72,36 @@ class PathwayViewTest(TestCase):
def test_predict_package_pathway(self):
response = self.client.post(
reverse("package pathway list", kwargs={'package_uuid': str(self.package.uuid)}), {
'name': 'Test Pathway',
'description': 'Just a Description',
'predict': 'predict',
'smiles': 'CCN(CC)C(=O)C1=CC(=CC=C1)CO',
})
reverse("package pathway list", kwargs={"package_uuid": str(self.package.uuid)}),
{
"name": "Test Pathway",
"description": "Just a Description",
"predict": "predict",
"smiles": "CCN(CC)C(=O)C1=CC(=CC=C1)CO",
},
)
self.assertEqual(response.status_code, 302)
pathway_url = response.url
pw = Pathway.objects.get(url=pathway_url)
self.assertEqual(self.package, pw.package)
self.assertEqual(pw.name, 'Test Pathway')
self.assertEqual(pw.description, 'Just a Description')
self.assertEqual(pw.name, "Test Pathway")
self.assertEqual(pw.description, "Just a Description")
self.assertEqual(len(pw.root_nodes), 1)
self.assertEqual(pw.root_nodes.first().default_node_label.smiles, 'CCN(CC)C(=O)C1=CC(CO)=CC=C1')
self.assertEqual(
pw.root_nodes.first().default_node_label.smiles, "CCN(CC)C(=O)C1=CC(CO)=CC=C1"
)
first_level_nodes = {
# Edge 1
'CCN(CC)C(=O)C1=CC(C=O)=CC=C1',
"CCN(CC)C(=O)C1=CC(C=O)=CC=C1",
# Edge 2
'CCNC(=O)C1=CC(CO)=CC=C1',
'CC=O',
"CCNC(=O)C1=CC(CO)=CC=C1",
"CC=O",
# Edge 3
'CCNCC',
'O=C(O)C1=CC(CO)=CC=C1',
"CCNCC",
"O=C(O)C1=CC(CO)=CC=C1",
}
predicted_nodes = set()

View File

@ -3,7 +3,7 @@ from django.urls import reverse
from envipy_additional_information import Temperature, Interval
from epdb.logic import UserManager, PackageManager
from epdb.models import Reaction, Scenario, ExternalIdentifier, ExternalDatabase
from epdb.models import Reaction, Scenario, ExternalDatabase
class ReactionViewTest(TestCase):
@ -12,21 +12,28 @@ class ReactionViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(ReactionViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe",
set_setting=False, add_to_group=True, is_active=True)
cls.user1 = UserManager.create_user(
"user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack')
cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self):
self.client.force_login(self.user1)
def test_create_reaction(self):
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -42,11 +49,12 @@ class ReactionViewTest(TestCase):
# Adding the same rule again should return the existing one, hence not increasing the number of rules
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.url, reaction_url)
@ -55,11 +63,12 @@ class ReactionViewTest(TestCase):
# Adding the same rule in a different package should create a new rule
response = self.client.post(
reverse("package reaction list", kwargs={'package_uuid': self.package.uuid}), {
reverse("package reaction list", kwargs={"package_uuid": self.package.uuid}),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -67,11 +76,12 @@ class ReactionViewTest(TestCase):
# adding another reaction should increase count
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0002",
"reaction-description": "Description for Eawag BBD reaction r0002",
"reaction-smirks": "C(CO)Cl>>C(C=O)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -80,11 +90,12 @@ class ReactionViewTest(TestCase):
# Edit
def test_edit_rule(self):
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -93,13 +104,17 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(self.user1_default_package.uuid),
'reaction_uuid': str(r.uuid)
}), {
reverse(
"package reaction detail",
kwargs={
"package_uuid": str(self.user1_default_package.uuid),
"reaction_uuid": str(r.uuid),
},
),
{
"reaction-name": "Test Reaction Adjusted",
"reaction-description": "New Description",
}
},
)
self.assertEqual(response.status_code, 302)
@ -119,7 +134,7 @@ class ReactionViewTest(TestCase):
"Test Desc",
"2025-10",
"soil",
[Temperature(interval=Interval(start=20, end=30))]
[Temperature(interval=Interval(start=20, end=30))],
)
s2 = Scenario.create(
@ -128,15 +143,16 @@ class ReactionViewTest(TestCase):
"Test Desc2",
"2025-10",
"soil",
[Temperature(interval=Interval(start=10, end=20))]
[Temperature(interval=Interval(start=10, end=20))],
)
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -144,47 +160,47 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(r.package.uuid),
'reaction_uuid': str(r.uuid)
}), {
"selected-scenarios": [s1.url, s2.url]
}
reverse(
"package reaction detail",
kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
),
{"selected-scenarios": [s1.url, s2.url]},
)
self.assertEqual(len(r.scenarios.all()), 2)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(r.package.uuid),
'reaction_uuid': str(r.uuid)
}), {
"selected-scenarios": [s1.url]
}
reverse(
"package reaction detail",
kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
),
{"selected-scenarios": [s1.url]},
)
self.assertEqual(len(r.scenarios.all()), 1)
self.assertEqual(r.scenarios.first().url, s1.url)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(r.package.uuid),
'reaction_uuid': str(r.uuid)
}), {
reverse(
"package reaction detail",
kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
),
{
# We have to set an empty string to avoid that the parameter is removed
"selected-scenarios": ""
}
},
)
self.assertEqual(len(r.scenarios.all()), 0)
def test_copy(self):
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -192,12 +208,13 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url)
response = self.client.post(
reverse("package detail", kwargs={
'package_uuid': str(self.package.uuid),
}), {
"hidden": "copy",
"object_to_copy": r.url
}
reverse(
"package detail",
kwargs={
"package_uuid": str(self.package.uuid),
},
),
{"hidden": "copy", "object_to_copy": r.url},
)
self.assertEqual(response.status_code, 200)
@ -211,44 +228,48 @@ class ReactionViewTest(TestCase):
# Copy to the same package should fail
response = self.client.post(
reverse("package detail", kwargs={
'package_uuid': str(r.package.uuid),
}), {
"hidden": "copy",
"object_to_copy": r.url
}
reverse(
"package detail",
kwargs={
"package_uuid": str(r.package.uuid),
},
),
{"hidden": "copy", "object_to_copy": r.url},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], f"Can't copy object {reaction_url} to the same package!")
self.assertEqual(
response.json()["error"], f"Can't copy object {reaction_url} to the same package!"
)
def test_references(self):
ext_db, _ = ExternalDatabase.objects.get_or_create(
name='KEGG Reaction',
name="KEGG Reaction",
defaults={
'full_name': 'KEGG Reaction Database',
'description': 'Database of biochemical reactions',
'base_url': 'https://www.genome.jp',
'url_pattern': 'https://www.genome.jp/entry/{id}'
}
"full_name": "KEGG Reaction Database",
"description": "Database of biochemical reactions",
"base_url": "https://www.genome.jp",
"url_pattern": "https://www.genome.jp/entry/{id}",
},
)
ext_db2, _ = ExternalDatabase.objects.get_or_create(
name='RHEA',
name="RHEA",
defaults={
'full_name': 'RHEA Reaction Database',
'description': 'Comprehensive resource of biochemical reactions',
'base_url': 'https://www.rhea-db.org',
'url_pattern': 'https://www.rhea-db.org/rhea/{id}'
"full_name": "RHEA Reaction Database",
"description": "Comprehensive resource of biochemical reactions",
"base_url": "https://www.rhea-db.org",
"url_pattern": "https://www.rhea-db.org/rhea/{id}",
},
)
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -256,45 +277,49 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(r.package.uuid),
'reaction_uuid': str(r.uuid),
}), {
'selected-database': ext_db.pk,
'identifier': 'C12345'
}
reverse(
"package reaction detail",
kwargs={
"package_uuid": str(r.package.uuid),
"reaction_uuid": str(r.uuid),
},
),
{"selected-database": ext_db.pk, "identifier": "C12345"},
)
self.assertEqual(r.external_identifiers.count(), 1)
self.assertEqual(r.external_identifiers.first().database, ext_db)
self.assertEqual(r.external_identifiers.first().identifier_value, 'C12345')
self.assertEqual(r.external_identifiers.first().identifier_value, "C12345")
# TODO Fixture contains old url template there the real test fails, use old value instead
# self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/C12345')
self.assertEqual(r.external_identifiers.first().url, 'https://www.genome.jp/entry/reaction+C12345')
self.assertEqual(
r.external_identifiers.first().url, "https://www.genome.jp/entry/reaction+C12345"
)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(r.package.uuid),
'reaction_uuid': str(r.uuid),
}), {
'selected-database': ext_db2.pk,
'identifier': '60116'
}
reverse(
"package reaction detail",
kwargs={
"package_uuid": str(r.package.uuid),
"reaction_uuid": str(r.uuid),
},
),
{"selected-database": ext_db2.pk, "identifier": "60116"},
)
self.assertEqual(r.external_identifiers.count(), 2)
self.assertEqual(r.external_identifiers.last().database, ext_db2)
self.assertEqual(r.external_identifiers.last().identifier_value, '60116')
self.assertEqual(r.external_identifiers.last().url, 'https://www.rhea-db.org/rhea/60116')
self.assertEqual(r.external_identifiers.last().identifier_value, "60116")
self.assertEqual(r.external_identifiers.last().url, "https://www.rhea-db.org/rhea/60116")
def test_delete(self):
response = self.client.post(
reverse("reactions"), {
reverse("reactions"),
{
"reaction-name": "Eawag BBD reaction r0001",
"reaction-description": "Description for Eawag BBD reaction r0001",
"reaction-smirks": "C(CCl)Cl>>C(CO)Cl",
}
},
)
self.assertEqual(response.status_code, 302)
@ -302,12 +327,11 @@ class ReactionViewTest(TestCase):
r = Reaction.objects.get(url=reaction_url)
response = self.client.post(
reverse("package reaction detail", kwargs={
'package_uuid': str(r.package.uuid),
'reaction_uuid': str(r.uuid)
}), {
"hidden": "delete"
}
reverse(
"package reaction detail",
kwargs={"package_uuid": str(r.package.uuid), "reaction_uuid": str(r.uuid)},
),
{"hidden": "delete"},
)
self.assertEqual(self.user1_default_package.reactions.count(), 0)

View File

@ -12,22 +12,29 @@ class RuleViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(RuleViewTest, cls).setUpClass()
cls.user1 = UserManager.create_user("user1", "user1@envipath.com", "SuperSafe",
set_setting=False, add_to_group=True, is_active=True)
cls.user1 = UserManager.create_user(
"user1",
"user1@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=True,
is_active=True,
)
cls.user1_default_package = cls.user1.default_package
cls.package = PackageManager.create_package(cls.user1, 'Test', 'Test Pack')
cls.package = PackageManager.create_package(cls.user1, "Test", "Test Pack")
def setUp(self):
self.client.force_login(self.user1)
def test_create_rule(self):
response = self.client.post(
reverse("rules"), {
reverse("rules"),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.status_code, 302)
@ -38,18 +45,21 @@ class RuleViewTest(TestCase):
self.assertEqual(r.package, self.user1_default_package)
self.assertEqual(r.name, "Test Rule")
self.assertEqual(r.description, "Just a Description")
self.assertEqual(r.smirks,
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]")
self.assertEqual(
r.smirks,
"[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
)
self.assertEqual(self.user1_default_package.rules.count(), 1)
# Adding the same rule again should return the existing one, hence not increasing the number of rules
response = self.client.post(
reverse("rules"), {
reverse("rules"),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.url, rule_url)
@ -58,12 +68,13 @@ class RuleViewTest(TestCase):
# Adding the same rule in a different package should create a new rule
response = self.client.post(
reverse("package rule list", kwargs={'package_uuid': self.package.uuid}), {
reverse("package rule list", kwargs={"package_uuid": self.package.uuid}),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.status_code, 302)
@ -72,12 +83,13 @@ class RuleViewTest(TestCase):
# Edit
def test_edit_rule(self):
response = self.client.post(
reverse("rules"), {
reverse("rules"),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.status_code, 302)
@ -86,13 +98,17 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url)
response = self.client.post(
reverse("package rule detail", kwargs={
'package_uuid': str(self.user1_default_package.uuid),
'rule_uuid': str(r.uuid)
}), {
reverse(
"package rule detail",
kwargs={
"package_uuid": str(self.user1_default_package.uuid),
"rule_uuid": str(r.uuid),
},
),
{
"rule-name": "Test Rule Adjusted",
"rule-description": "New Description",
}
},
)
self.assertEqual(response.status_code, 302)
@ -108,7 +124,7 @@ class RuleViewTest(TestCase):
"Test Desc",
"2025-10",
"soil",
[Temperature(interval=Interval(start=20, end=30))]
[Temperature(interval=Interval(start=20, end=30))],
)
s2 = Scenario.create(
@ -117,16 +133,17 @@ class RuleViewTest(TestCase):
"Test Desc2",
"2025-10",
"soil",
[Temperature(interval=Interval(start=10, end=20))]
[Temperature(interval=Interval(start=10, end=20))],
)
response = self.client.post(
reverse("rules"), {
reverse("rules"),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.status_code, 302)
@ -134,48 +151,48 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url)
response = self.client.post(
reverse("package rule detail", kwargs={
'package_uuid': str(r.package.uuid),
'rule_uuid': str(r.uuid)
}), {
"selected-scenarios": [s1.url, s2.url]
}
reverse(
"package rule detail",
kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
),
{"selected-scenarios": [s1.url, s2.url]},
)
self.assertEqual(len(r.scenarios.all()), 2)
response = self.client.post(
reverse("package rule detail", kwargs={
'package_uuid': str(r.package.uuid),
'rule_uuid': str(r.uuid)
}), {
"selected-scenarios": [s1.url]
}
reverse(
"package rule detail",
kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
),
{"selected-scenarios": [s1.url]},
)
self.assertEqual(len(r.scenarios.all()), 1)
self.assertEqual(r.scenarios.first().url, s1.url)
response = self.client.post(
reverse("package rule detail", kwargs={
'package_uuid': str(r.package.uuid),
'rule_uuid': str(r.uuid)
}), {
reverse(
"package rule detail",
kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
),
{
# We have to set an empty string to avoid that the parameter is removed
"selected-scenarios": ""
}
},
)
self.assertEqual(len(r.scenarios.all()), 0)
def test_copy(self):
response = self.client.post(
reverse("rules"), {
reverse("rules"),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.status_code, 302)
@ -183,12 +200,13 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url)
response = self.client.post(
reverse("package detail", kwargs={
'package_uuid': str(self.package.uuid),
}), {
"hidden": "copy",
"object_to_copy": r.url
}
reverse(
"package detail",
kwargs={
"package_uuid": str(self.package.uuid),
},
),
{"hidden": "copy", "object_to_copy": r.url},
)
self.assertEqual(response.status_code, 200)
@ -202,26 +220,29 @@ class RuleViewTest(TestCase):
# Copy to the same package should fail
response = self.client.post(
reverse("package detail", kwargs={
'package_uuid': str(r.package.uuid),
}), {
"hidden": "copy",
"object_to_copy": r.url
}
reverse(
"package detail",
kwargs={
"package_uuid": str(r.package.uuid),
},
),
{"hidden": "copy", "object_to_copy": r.url},
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['error'], f"Can't copy object {rule_url} to the same package!")
self.assertEqual(
response.json()["error"], f"Can't copy object {rule_url} to the same package!"
)
def test_delete(self):
response = self.client.post(
reverse("rules"), {
reverse("rules"),
{
"rule-name": "Test Rule",
"rule-description": "Just a Description",
"rule-smirks": "[H:5][C:1]([#6:6])([#1,#9,#17,#35,#53:4])[#9,#17,#35,#53]>>[H:5][C:1]([#6:6])([#8])[#1,#9,#17,#35,#53:4]",
"rule-type": "SimpleAmbitRule",
}
},
)
self.assertEqual(response.status_code, 302)
@ -229,12 +250,11 @@ class RuleViewTest(TestCase):
r = Rule.objects.get(url=rule_url)
response = self.client.post(
reverse("package rule detail", kwargs={
'package_uuid': str(r.package.uuid),
'rule_uuid': str(r.uuid)
}), {
"hidden": "delete"
}
reverse(
"package rule detail",
kwargs={"package_uuid": str(r.package.uuid), "rule_uuid": str(r.uuid)},
),
{"hidden": "delete"},
)
self.assertEqual(self.user1_default_package.rules.count(), 0)

View File

@ -11,70 +11,81 @@ class UserViewTest(TestCase):
@classmethod
def setUpClass(cls):
super(UserViewTest, cls).setUpClass()
cls.user = User.objects.get(username='anonymous')
cls.package = PackageManager.create_package(cls.user, 'Anon Test Package', 'No Desc')
cls.BBD_SUBSET = Package.objects.get(name='Fixtures')
cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_login_with_valid_credentials(self):
response = self.client.post(reverse("login"), {
response = self.client.post(
reverse("login"),
{
"username": "user0",
"password": 'SuperSafe',
})
"password": "SuperSafe",
},
)
self.assertRedirects(response, reverse("index"))
self.assertTrue(response.wsgi_request.user.is_authenticated)
def test_login_with_invalid_credentials(self):
response = self.client.post(reverse("login"), {
response = self.client.post(
reverse("login"),
{
"username": "user0",
"password": "wrongpassword",
})
},
)
self.assertEqual(response.status_code, 200)
self.assertFalse(response.wsgi_request.user.is_authenticated)
def test_register(self):
response = self.client.post(reverse("register"), {
response = self.client.post(
reverse("register"),
{
"username": "user1",
"email": "user1@envipath.com",
"password": "SuperSafe",
"rpassword": "SuperSafe",
})
},
)
self.assertEqual(response.status_code, 200)
# TODO currently fails as the fixture does not provide a global setting...
self.assertContains(response, "Registration failed!")
def test_register_password_mismatch(self):
response = self.client.post(reverse("register"), {
response = self.client.post(
reverse("register"),
{
"username": "user1",
"email": "user1@envipath.com",
"password": "SuperSafe",
"rpassword": "SuperSaf3",
})
},
)
self.assertEqual(response.status_code, 200)
self.assertContains(response, "Registration failed, provided passwords differ")
def test_logout(self):
response = self.client.post(reverse("login"), {
"username": "user0",
"password": 'SuperSafe',
"login": "true"
})
response = self.client.post(
reverse("login"), {"username": "user0", "password": "SuperSafe", "login": "true"}
)
self.assertTrue(response.wsgi_request.user.is_authenticated)
response = self.client.post(reverse('logout'), {
response = self.client.post(
reverse("logout"),
{
"logout": "true",
})
},
)
self.assertFalse(response.wsgi_request.user.is_authenticated)
def test_next_param_properly_handled(self):
response = self.client.get(reverse('packages'))
response = self.client.get(reverse("packages"))
self.assertRedirects(response, f"{reverse('login')}/?next=/package")
response = self.client.post(reverse('login'), {
"username": "user0",
"password": 'SuperSafe',
"login": "true",
"next": "/package"
})
response = self.client.post(
reverse("login"),
{"username": "user0", "password": "SuperSafe", "login": "true", "next": "/package"},
)
self.assertRedirects(response, reverse('packages'))
self.assertRedirects(response, reverse("packages"))

View File

@ -1,13 +0,0 @@
import abc
from enviPy.epdb import Pathway
class PredictionSchema(abc.ABC):
pass
class DFS(PredictionSchema):
def __init__(self, pw: Pathway, settings=None):
self.setting = settings or pw.prediction_settings
def predict(self):
pass

View File

@ -2,12 +2,11 @@ import logging
import re
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Dict
from typing import List, Optional, Dict, TYPE_CHECKING
from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer
from rdkit import Chem
from rdkit import RDLogger
from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
@ -15,9 +14,11 @@ from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.rdmolops import GetMolFrags
from rdkit.Contrib.IFG import ifg
logger = logging.getLogger(__name__)
RDLogger.DisableLog('rdApp.*')
if TYPE_CHECKING:
from epdb.models import Rule
logger = logging.getLogger(__name__)
rdBase.DisableLog("rdApp.*")
# from rdkit import rdBase
# rdBase.LogToPythonLogger()
@ -28,7 +29,6 @@ RDLogger.DisableLog('rdApp.*')
class ProductSet(object):
def __init__(self, product_set: List[str]):
self.product_set = product_set
@ -42,15 +42,18 @@ class ProductSet(object):
return iter(self.product_set)
def __eq__(self, other):
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(other.product_set)
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(
other.product_set
)
def __hash__(self):
return hash('-'.join(sorted(self.product_set)))
return hash("-".join(sorted(self.product_set)))
class PredictionResult(object):
def __init__(self, product_sets: List['ProductSet'], probability: float, rule: Optional['Rule'] = None):
def __init__(
self, product_sets: List["ProductSet"], probability: float, rule: Optional["Rule"] = None
):
self.product_sets = product_sets
self.probability = probability
self.rule = rule
@ -66,7 +69,6 @@ class PredictionResult(object):
class FormatConverter(object):
@staticmethod
def mass(smiles):
return Descriptors.MolWt(FormatConverter.from_smiles(smiles))
@ -127,7 +129,7 @@ class FormatConverter(object):
if kekulize:
try:
mol = Chem.Kekulize(mol)
except:
except Exception:
mol = Chem.Mol(mol.ToBinary())
if not mol.GetNumConformers():
@ -139,8 +141,8 @@ class FormatConverter(object):
opts.clearBackground = False
drawer.DrawMolecule(mol)
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace('svg:', '')
svg = re.sub("<\?xml.*\?>", '', svg)
svg = drawer.GetDrawingText().replace("svg:", "")
svg = re.sub("<\?xml.*\?>", "", svg)
return svg
@ -151,7 +153,7 @@ class FormatConverter(object):
if kekulize:
try:
Chem.Kekulize(mol)
except:
except Exception:
mc = Chem.Mol(mol.ToBinary())
if not mc.GetNumConformers():
@ -178,7 +180,7 @@ class FormatConverter(object):
smiles = tmp_smiles
if change is False:
print(f"nothing changed")
print("nothing changed")
return smiles
@ -198,7 +200,9 @@ class FormatConverter(object):
parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol)
# try to neutralize molecule
uncharger = rdMolStandardize.Uncharger() # annoying, but necessary as no convenience method exists
uncharger = (
rdMolStandardize.Uncharger()
) # annoying, but necessary as no convenience method exists
uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol)
# note that no attempt is made at reionization at this step
@ -239,17 +243,24 @@ class FormatConverter(object):
try:
rdChemReactions.ReactionFromSmarts(smirks)
return True
except:
except Exception:
return False
@staticmethod
def apply(smiles: str, smirks: str, preprocess_smiles: bool = True, bracketize: bool = True,
standardize: bool = True, kekulize: bool = True, remove_stereo: bool = True) -> List['ProductSet']:
logger.debug(f'Applying {smirks} on {smiles}')
def apply(
smiles: str,
smirks: str,
preprocess_smiles: bool = True,
bracketize: bool = True,
standardize: bool = True,
kekulize: bool = True,
remove_stereo: bool = True,
) -> List["ProductSet"]:
logger.debug(f"Applying {smirks} on {smiles}")
# If explicitly wanted or rule generates multiple products add brackets around products to capture all
if bracketize: # or "." in smirks:
smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")"
smirks = smirks.split(">>")[0] + ">>(" + smirks.split(">>")[1] + ")"
# List of ProductSet objects
pss = set()
@ -274,7 +285,9 @@ class FormatConverter(object):
Chem.SanitizeMol(product)
product = GetMolFrags(product, asMols=True)
for p in product:
p = FormatConverter.standardize(Chem.MolToSmiles(p), remove_stereo=remove_stereo)
p = FormatConverter.standardize(
Chem.MolToSmiles(p), remove_stereo=remove_stereo
)
prods.append(p)
# if kekulize:
@ -300,9 +313,8 @@ class FormatConverter(object):
# # bond.SetIsAromatic(False)
# Chem.Kekulize(product)
except ValueError as e:
logger.error(f'Sanitizing and converting failed:\n{e}')
logger.error(f"Sanitizing and converting failed:\n{e}")
continue
if len(prods):
@ -310,7 +322,7 @@ class FormatConverter(object):
pss.add(ps)
except Exception as e:
logger.error(f'Applying {smirks} on {smiles} failed:\n{e}')
logger.error(f"Applying {smirks} on {smiles} failed:\n{e}")
return pss
@ -340,22 +352,19 @@ class FormatConverter(object):
smi_p = Chem.MolToSmiles(mol, kekuleSmiles=True)
smi_p = Chem.CanonSmiles(smi_p)
if '~' in smi_p:
smi_p1 = smi_p.replace('~', '')
if "~" in smi_p:
smi_p1 = smi_p.replace("~", "")
parsed_smiles.append(smi_p1)
else:
parsed_smiles.append(smi_p)
except Exception as e:
except Exception:
errors += 1
pass
return parsed_smiles, errors
class Standardizer(ABC):
def __init__(self, name):
self.name = name
@ -364,7 +373,6 @@ class Standardizer(ABC):
class RuleStandardizer(Standardizer):
def __init__(self, name, smirks):
super().__init__(name)
self.smirks = smirks
@ -373,8 +381,8 @@ class RuleStandardizer(Standardizer):
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks)))
if len(standardized_smiles) > 1:
logger.warning(f'{self.smirks} generated more than 1 compound {standardized_smiles}')
print(f'{self.smirks} generated more than 1 compound {standardized_smiles}')
logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
standardized_smiles = standardized_smiles[:1]
if standardized_smiles:
@ -384,7 +392,6 @@ class RuleStandardizer(Standardizer):
class RegExStandardizer(Standardizer):
def __init__(self, name, replacements: dict):
super().__init__(name)
self.replacements = replacements
@ -404,28 +411,39 @@ class RegExStandardizer(Standardizer):
return super().standardize(smi)
FLATTEN = [
RegExStandardizer("Remove Stereo", {"@": ""})
]
FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]
UN_CIS_TRANS = [
RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})
]
UN_CIS_TRANS = [RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})]
BASIC = [
RuleStandardizer("ammoniumstandardization", "[H][N+:1]([H])([H])[#6:2]>>[H][#7:1]([H])-[#6:2]"),
RuleStandardizer("cyanate", "[H][#8:1][C:2]#[N:3]>>[#8-:1][C:2]#[N:3]"),
RuleStandardizer("deprotonatecarboxyls", "[H][#8:1]-[#6:2]=[O:3]>>[#8-:1]-[#6:2]=[O:3]"),
RuleStandardizer("forNOOH", "[H][#8:1]-[#7+:2](-[*:3])=[O:4]>>[#8-:1]-[#7+:2](-[*:3])=[O:4]"),
RuleStandardizer("Hydroxylprotonation", "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]"),
RuleStandardizer("phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"),
RuleStandardizer("PicricAcid",
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]"),
RuleStandardizer("Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"),
RuleStandardizer("Sulfate2",
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]"),
RuleStandardizer("Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"),
RuleStandardizer("Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"),
RuleStandardizer(
"Hydroxylprotonation",
"[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]",
),
RuleStandardizer(
"phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"
),
RuleStandardizer(
"PicricAcid",
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]",
),
RuleStandardizer(
"Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"
),
RuleStandardizer(
"Sulfate2",
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]",
),
RuleStandardizer(
"Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"
),
RuleStandardizer(
"Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"
),
]
ENHANCED = BASIC + [
@ -433,28 +451,30 @@ ENHANCED = BASIC + [
]
EXOTIC = ENHANCED + [
RuleStandardizer("ThioPhosphate1", "[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]")
RuleStandardizer(
"ThioPhosphate1",
"[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]",
)
]
COA_CUTTER = [
RuleStandardizer("CutCoEnzymeAOff",
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]")
RuleStandardizer(
"CutCoEnzymeAOff",
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]",
)
]
ENOL_KETO = [
RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")
]
ENOL_KETO = [RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")]
MATCH_STANDARDIZER = EXOTIC + FLATTEN + UN_CIS_TRANS + COA_CUTTER + ENOL_KETO
class IndigoUtils(object):
@staticmethod
def layout(mol_data):
i = Indigo()
try:
if mol_data.startswith('$RXN') or '>>' in mol_data:
if mol_data.startswith("$RXN") or ">>" in mol_data:
rxn = i.loadQueryReaction(mol_data)
rxn.layout()
return rxn.rxnfile()
@ -462,14 +482,14 @@ class IndigoUtils(object):
mol = i.loadQueryMolecule(mol_data)
mol.layout()
return mol.molfile()
except IndigoException as e:
except IndigoException:
try:
logger.info("layout() failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.layout()
return rxn.molfile()
except IndigoException as e2:
logger.error(f'layout() failed due to {e2}!')
logger.error(f"layout() failed due to {e2}!")
@staticmethod
def load_reaction_SMARTS(mol):
@ -479,7 +499,7 @@ class IndigoUtils(object):
def aromatize(mol_data, is_query):
i = Indigo()
try:
if mol_data.startswith('$RXN'):
if mol_data.startswith("$RXN"):
if is_query:
rxn = i.loadQueryReaction(mol_data)
else:
@ -495,20 +515,20 @@ class IndigoUtils(object):
mol.aromatize()
return mol.molfile()
except IndigoException as e:
except IndigoException:
try:
logger.info("Aromatizing failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.aromatize()
return rxn.molfile()
except IndigoException as e2:
logger.error(f'Aromatizing failed due to {e2}!')
logger.error(f"Aromatizing failed due to {e2}!")
@staticmethod
def dearomatize(mol_data, is_query):
i = Indigo()
try:
if mol_data.startswith('$RXN'):
if mol_data.startswith("$RXN"):
if is_query:
rxn = i.loadQueryReaction(mol_data)
else:
@ -524,14 +544,14 @@ class IndigoUtils(object):
mol.dearomatize()
return mol.molfile()
except IndigoException as e:
except IndigoException:
try:
logger.info("De-Aromatizing failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.dearomatize()
return rxn.molfile()
except IndigoException as e2:
logger.error(f'De-Aromatizing failed due to {e2}!')
logger.error(f"De-Aromatizing failed due to {e2}!")
@staticmethod
def sanitize_functional_group(functional_group: str):
@ -543,7 +563,7 @@ class IndigoUtils(object):
# special environment handling (amines, hydroxy, esters, ethers)
# the higher substituted should not contain H env.
if functional_group == '[C]=O':
if functional_group == "[C]=O":
functional_group = "[H][C](=O)[CX4,c]"
# aldamines
@ -577,15 +597,20 @@ class IndigoUtils(object):
functional_group = "[nH1,nX2](a)a" # pyrrole (with H) or pyridine (no other connections); currently overlaps with neighboring aromatic atoms
# substituted aromatic nitrogen
functional_group = functional_group.replace("N*(R)R",
"n(a)a") # substituent will be before N*; currently overlaps with neighboring aromatic atoms
functional_group = functional_group.replace(
"N*(R)R", "n(a)a"
) # substituent will be before N*; currently overlaps with neighboring aromatic atoms
# pyridinium
if functional_group == "RN*(R)(R)(R)R":
functional_group = "[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
functional_group = (
"[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
)
# N-oxide
if functional_group == "[H]ON*(R)(R)(R)R":
functional_group = "[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
functional_group = (
"[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
)
# other aromatic hetero atoms
functional_group = functional_group.replace("C*", "c")
@ -598,7 +623,9 @@ class IndigoUtils(object):
# other replacement, to accomodate for the standardization rules in enviPath
# This is not the perfect way to do it; there should be a way to replace substructure SMARTS in SMARTS?
# nitro groups are broken, due to charge handling. this SMARTS matches both forms (formal charges and hypervalent); Ertl-CDK still treats both forms separately...
functional_group = functional_group.replace("[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
functional_group = functional_group.replace(
"[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]"
)
functional_group = functional_group.replace("O=N(=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
# carboxylic acid: this SMARTS matches both neutral and anionic form; includes COOH in larger functional_groups
functional_group = functional_group.replace("[H]OC(=O)", "[OD1]C(=O)")
@ -616,7 +643,9 @@ class IndigoUtils(object):
return functional_group
@staticmethod
def _colorize(indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool):
def _colorize(
indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool
):
indigo.setOption("render-atom-color-property", "color")
indigo.setOption("aromaticity-model", "generic")
@ -646,7 +675,6 @@ class IndigoUtils(object):
for match in matcher.iterateMatches(query):
if match is not None:
for atom in query.iterateAtoms():
mappedAtom = match.mapAtom(atom)
if mappedAtom is None or mappedAtom.index() in environment:
@ -655,7 +683,7 @@ class IndigoUtils(object):
counts[mappedAtom.index()] = max(v, counts[mappedAtom.index()])
except IndigoException as e:
logger.debug(f'Colorizing failed due to {e}')
logger.debug(f"Colorizing failed due to {e}")
for k, v in counts.items():
if is_reaction:
@ -669,8 +697,9 @@ class IndigoUtils(object):
molecule.addDataSGroup([k], [], "color", color)
@staticmethod
def mol_to_svg(mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None):
def mol_to_svg(
mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None
):
if functional_groups is None:
functional_groups = {}
@ -682,7 +711,7 @@ class IndigoUtils(object):
i.setOption("render-image-size", width, height)
i.setOption("render-bond-line-width", 2.0)
if '~' in mol_data:
if "~" in mol_data:
mol = i.loadSmarts(mol_data)
else:
mol = i.loadMolecule(mol_data)
@ -690,11 +719,17 @@ class IndigoUtils(object):
if len(functional_groups.keys()) > 0:
IndigoUtils._colorize(i, mol, functional_groups, False)
return renderer.renderToBuffer(mol).decode('UTF-8')
return renderer.renderToBuffer(mol).decode("UTF-8")
@staticmethod
def smirks_to_svg(smirks: str, is_query_smirks, width: int = 0, height: int = 0,
educt_functional_groups: Dict[str, int] = None, product_functional_groups: Dict[str, int] = None):
def smirks_to_svg(
smirks: str,
is_query_smirks,
width: int = 0,
height: int = 0,
educt_functional_groups: Dict[str, int] = None,
product_functional_groups: Dict[str, int] = None,
):
if educt_functional_groups is None:
educt_functional_groups = {}
@ -721,18 +756,18 @@ class IndigoUtils(object):
for prod in obj.iterateProducts():
IndigoUtils._colorize(i, prod, product_functional_groups, True)
return renderer.renderToBuffer(obj).decode('UTF-8')
return renderer.renderToBuffer(obj).decode("UTF-8")
if __name__ == '__main__':
if __name__ == "__main__":
data = {
"struct": "\n Ketcher 2172510 12D 1 1.00000 0.00000 0\n\n 6 6 0 0 0 999 V2000\n 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 1 2 2 0 0 0 0\n 2 3 1 0 0 0 0\n 3 4 2 0 0 0 0\n 4 5 1 0 0 0 0\n 5 6 2 0 0 0 0\n 6 1 1 0 0 0 0\nM END\n",
"options": {
"smart-layout": True,
"ignore-stereochemistry-errors": True,
"mass-skip-error-on-pseudoatoms": False,
"gross-formula-add-rsites": True
}
"gross-formula-add-rsites": True,
},
}
print(IndigoUtils.aromatize(data['struct'], False))
print(IndigoUtils.aromatize(data["struct"], False))

View File

@ -1,83 +0,0 @@
import json
import requests
class AMBITResult:
def __init__(self, *args, **kwargs):
self.smiles = kwargs['smiles']
self.tps = []
for bt in kwargs['products']:
if len(bt['products']):
self.tps.append(bt)
self.probs = None
def __str__(self):
x = self.smiles + "\n"
total_bts = len(self.tps)
for i, tp in enumerate(self.tps):
prob = ""
if self.probs:
prob = f" (p={self.probs[tp['id']]})"
if i == total_bts - 1:
x += f"\t└── {tp['name']}{prob}\n"
else:
x += f"\t├── {tp['name']}{prob}\n"
total_products = len(tp['products'])
for j, p in enumerate(tp['products']):
if j == total_products - 1:
if i == total_bts - 1:
x += f"\t\t└── {p}"
else:
x += f"\t\t└── {p}\n"
else:
if i == total_bts - 1:
x += f"\t\t├── {p}\n"
else:
x += f"\t\t├── {p}\n"
return x
def set_probs(self, probs):
self.probs = probs
class AMBIT:
def __init__(self, host, rules=None):
self.host = host
self.rules = rules
self.ambit_params = {
'singlePos': True,
'split': False,
}
def batch_apply(self, smiles: list):
payload = {
'smiles': smiles,
'rules': self.rules,
}
payload.update(**self.ambit_params)
res = self._execute(payload)
tps = list()
for r in res['result']:
ar = AMBITResult(**r)
if len(ar.tps):
tps.append(ar)
else:
tps.append(None)
return tps
def apply(self, smiles: str):
return self.batch_apply([smiles])[0]
def _execute(self, payload):
res = requests.post(self.host + '/ambit', data=json.dumps(payload))
res.raise_for_status()
return res.json()

View File

@ -8,9 +8,9 @@ from epdb.models import Package
# Map HTTP methods to required permissions
DEFAULT_METHOD_PERMISSIONS = {
'GET': 'read',
'POST': 'write',
'DELETE': 'write',
"GET": "read",
"POST": "write",
"DELETE": "write",
}
@ -22,6 +22,7 @@ def package_permission_required(method_permissions=None):
@wraps(view_func)
def _wrapped_view(request, package_uuid, *args, **kwargs):
from epdb.views import _anonymous_or_real
user = _anonymous_or_real(request)
permission_required = method_permissions[request.method]
@ -30,11 +31,12 @@ def package_permission_required(method_permissions=None):
if not PackageManager.has_package_permission(user, package_uuid, permission_required):
from epdb.views import error
return error(
request,
"Operation failed!",
f"Couldn't perform the desired operation as {user.username} does not have the required permissions!",
code=403
code=403,
)
return view_func(request, package_uuid, *args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,35 @@
from __future__ import annotations
import copy
import numpy as np
from numpy.random import default_rng
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Set, Tuple
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
import networkx as nx
import numpy as np
from numpy.random import default_rng
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler
from utilities.chem import FormatConverter, PredictionResult
logger = logging.getLogger(__name__)
from dataclasses import dataclass, field
from utilities.chem import FormatConverter, PredictionResult
if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
class Dataset:
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns
self.num_labels: int = num_labels
@ -41,9 +39,9 @@ class Dataset:
self.data = data
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
self._triggered: Tuple[int, int] = self._block_indices('trig_')
self._observed: Tuple[int, int] = self._block_indices('obs_')
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
@ -62,7 +60,7 @@ class Dataset:
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f'trig_{rule_uuid}')
idx = self.columns.index(f"trig_{rule_uuid}")
times_triggered = 0
for row in self.data:
@ -89,12 +87,12 @@ class Dataset:
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
if isinstance(struct, str):
struct_id = None
struct_smiles = struct
@ -119,10 +117,14 @@ class Dataset:
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
return Dataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set()
for r in reactions:
@ -155,12 +157,11 @@ class Dataset:
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
@ -188,7 +189,7 @@ class Dataset:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
@ -224,19 +225,22 @@ class Dataset:
obs.append(0)
if ds is None:
header = ['structure_id'] + \
[f'feature_{i}' for i, _ in enumerate(feat)] \
+ [f'trig_{r.uuid}' for r in applicable_rules] \
+ [f'obs_{r.uuid}' for r in applicable_rules]
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
@ -247,14 +251,12 @@ class Dataset:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
@ -271,42 +273,50 @@ class Dataset:
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
else [row[i] for i in col_key] for row in rows]
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
def save(self, path: 'Path'):
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path') -> 'Dataset':
def load(path: "Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: 'Path'):
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
if c == 'structure_id':
if c == "structure_id":
arff += f"@attribute {c} string\n"
else:
arff += f"@attribute {c} {{0,1}}\n"
arff += f"\n@data\n"
arff += "\n@data\n"
for d in self.data:
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
arff += f'{ys},{xs}\n'
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
with open(path, "w") as fh:
fh.write(arff)
fh.flush()
def __repr__(self):
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -315,8 +325,11 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
Removes labels that are constant across all samples in training.
"""
def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42),
num_chains: int = 10):
def __init__(
self,
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
num_chains: int = 10,
):
self.base_clf = base_clf
self.num_chains = num_chains
@ -384,16 +397,16 @@ class BinaryRelevance:
if self.classifiers is None:
self.classifiers = []
for l in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, l])]
Y_l = (Y[~np.isnan(Y[:, l]), l])
for label in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, label])]
Y_l = Y[~np.isnan(Y[:, label]), label]
if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf = DummyClassifier(strategy="constant", constant=0)
clf.fit([X[0]], [0])
self.classifiers.append(clf)
continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
clf = DummyClassifier(strategy="most_frequent")
else:
clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l)
@ -439,17 +452,19 @@ class MissingValuesClassifierChain:
X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf = DummyClassifier(strategy="constant", constant=0)
self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
clf = DummyClassifier(strategy="most_frequent")
self.classifiers.append(clf.fit(X_p, Y_p))
else:
clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p]
pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
newcol[np.isnan(newcol)] = pred[
np.isnan(newcol)
] # fill in missing values with clf predictions
X = np.column_stack((X, newcol))
def predict(self, X):
@ -541,13 +556,10 @@ class RelativeReasoning:
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count and
countwin > countloose and
(
countloose <= self.max_count or
self.max_count < 0
) and
countboth == 0
countwin >= self.min_count
and countwin > countloose
and (countloose <= self.max_count or self.max_count < 0)
and countboth == 0
):
self.winmap[i].append(j)
@ -579,7 +591,6 @@ class RelativeReasoning:
class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler()
@ -587,7 +598,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: 'Dataset'):
def build(self, train_dataset: "Dataset"):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
@ -601,7 +612,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: 'Dataset'):
def is_applicable(self, classify_instances: "Dataset"):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []
@ -632,6 +643,7 @@ def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway
from epdb.logic import SPathway
graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
@ -645,7 +657,9 @@ def graph_from_pathway(data):
def get_sources_targets():
if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
return [n.node for n in edge.start_nodes.constrained_target.all()], [
n.node for n in edge.end_nodes.constrained_target.all()
]
elif isinstance(data, SPathway):
return edge.educts, edge.products
else:
@ -662,7 +676,7 @@ def graph_from_pathway(data):
def get_probability():
try:
if isinstance(data, Pathway):
return edge.kv.get('probability')
return edge.kv.get("probability")
elif isinstance(data, SPathway):
return edge.probability
else:
@ -680,17 +694,29 @@ def graph_from_pathway(data):
for source in sources:
source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph:
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
root=source_smiles in root_smiles)
graph.add_node(
source_smiles,
depth=source_depth,
smiles=source_smiles,
root=source_smiles in root_smiles,
)
else:
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
graph.nodes[source_smiles]["depth"] = min(
source_depth, graph.nodes[source_smiles]["depth"]
)
for target in targets:
target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2:
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
root=target_smiles in root_smiles)
graph.add_node(
target_smiles,
depth=target_depth,
smiles=target_smiles,
root=target_smiles in root_smiles,
)
elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
graph.nodes[target_smiles]["depth"] = min(
target_depth, graph.nodes[target_smiles]["depth"]
)
if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph
@ -710,7 +736,9 @@ def set_pathway_eval_weight(pathway):
node_eval_weights = {}
for node in pathway.nodes:
# Scale score according to depth level
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
node_eval_weights[node] = (
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
)
return node_eval_weights
@ -731,8 +759,11 @@ def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
shortest_path_list.append(shortest_path_nodes)
if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
shortest_path_node in intermediates)
num_ints = sum(
1
for shortest_path_node in shortest_path_nodes
if shortest_path_node in intermediates
)
pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway
@ -879,6 +910,11 @@ def pathway_edit_eval(data_pathway, pred_pathway):
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(data_pathway, pred_pathway,
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost, roots=roots)
return nx.graph_edit_distance(
data_pathway,
pred_pathway,
node_subst_cost=node_subst_cost,
node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost,
roots=roots,
)

View File

@ -23,11 +23,11 @@ def install_wheel(wheel_path):
def extract_package_name_from_wheel(wheel_filename):
# Example: my_plugin-0.1.0-py3-none-any.whl -> my_plugin
return wheel_filename.split('-')[0]
return wheel_filename.split("-")[0]
def ensure_plugins_installed():
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, '*.whl'))
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, "*.whl"))
for wheel_path in wheel_files:
wheel_filename = os.path.basename(wheel_path)
@ -45,7 +45,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
plugins = {}
for entry_point in importlib.metadata.entry_points(group='enviPy_plugins'):
for entry_point in importlib.metadata.entry_points(group="enviPy_plugins"):
try:
plugin_class = entry_point.load()
if _cls: