diff --git a/epapi/tests/utils/test_validation_errors.py b/epapi/tests/utils/test_validation_errors.py index c2a98be6..4404a6be 100644 --- a/epapi/tests/utils/test_validation_errors.py +++ b/epapi/tests/utils/test_validation_errors.py @@ -51,47 +51,30 @@ class ValidationErrorUtilityTests(TestCase): self.assertIn("active", formatted) self.assertIn("inactive", formatted) - def test_format_string_type_error(self): - """Test formatting of string type validation error.""" + def test_format_type_errors(self): + """Test formatting of type validation errors (string, int, float).""" + test_cases = [ + # (field_type, invalid_value, expected_message) + # Note: We don't check exact error_type as Pydantic may use different types + # (e.g., int_type vs int_parsing) but we verify the formatted message is correct + (str, 123, "Please enter a valid string"), + (int, "not_a_number", "Please enter a valid int"), + (float, "not_a_float", "Please enter a valid float"), + ] - class TestModel(BaseModel): - name: str + for field_type, invalid_value, expected_message in test_cases: + with self.subTest(field_type=field_type.__name__): - try: - TestModel(name=123) - except ValidationError as e: - errors = e.errors() - self.assertEqual(len(errors), 1) - formatted = format_validation_error(errors[0]) - self.assertEqual(formatted, "Please enter a valid string") + class TestModel(BaseModel): + field: field_type - def test_format_int_type_error(self): - """Test formatting of integer type validation error.""" - - class TestModel(BaseModel): - count: int - - try: - TestModel(count="not_a_number") - except ValidationError as e: - errors = e.errors() - self.assertEqual(len(errors), 1) - formatted = format_validation_error(errors[0]) - self.assertEqual(formatted, "Please enter a valid int") - - def test_format_float_type_error(self): - """Test formatting of float type validation error.""" - - class TestModel(BaseModel): - value: float - - try: - TestModel(value="not_a_float") - except ValidationError as e: - errors = e.errors() - self.assertEqual(len(errors), 1) - formatted = format_validation_error(errors[0]) - self.assertEqual(formatted, "Please enter a valid float") + try: + TestModel(field=invalid_value) + except ValidationError as e: + errors = e.errors() + self.assertEqual(len(errors), 1) + formatted = format_validation_error(errors[0]) + self.assertEqual(formatted, expected_message) def test_format_value_error(self): """Test formatting of value error from custom validator.""" @@ -114,6 +97,19 @@ class ValidationErrorUtilityTests(TestCase): formatted = format_validation_error(errors[0]) self.assertEqual(formatted, "Age must be positive") + def test_format_unknown_error_type_fallback(self): + """Test that unknown error types fall back to default formatting.""" + # Mock an error with an unknown type + mock_error = { + "type": "unknown_custom_type", + "msg": "Input should be a valid email address", + "ctx": {}, + } + + formatted = format_validation_error(mock_error) + # Should use the else branch which does replacements on the message + self.assertEqual(formatted, "Please enter a valid email address") + def test_handle_validation_error_structure(self): """Test that handle_validation_error raises HttpError with correct structure.""" diff --git a/epapi/tests/v1/test_additional_information.py b/epapi/tests/v1/test_additional_information.py index 3439b52e..8f66250e 100644 --- a/epapi/tests/v1/test_additional_information.py +++ b/epapi/tests/v1/test_additional_information.py @@ -380,13 +380,6 @@ class AdditionalInformationAPITests(TestCase): self.assertEqual(items[0]["uuid"], item1_uuid) self.assertEqual(items[0]["data"]["interval"]["start"], 15) - def test_unauthenticated_access_returns_401(self): - """Test that unauthenticated requests return 401.""" - # Don't log in - response = self.client.get(f"/api/v1/scenario/{self.scenario.uuid}/information/") - - self.assertEqual(response.status_code, 401) - def test_list_info_denied_without_permission(self): """User cannot list info for scenario in package they don't have access to""" self.client.force_login(self.user) @@ -445,16 +438,6 @@ class AdditionalInformationAPITests(TestCase): ) self.assertEqual(response.status_code, 403) - def test_anonymous_user_denied(self): - """Anonymous users cannot access private scenarios""" - # Ensure scenario package is not reviewed (private) - self.scenario.package.reviewed = False - self.scenario.package.save() - - # Unauthenticated users get 401 from the auth layer - response = self.client.get(f"/api/v1/scenario/{self.scenario.uuid}/information/") - self.assertEqual(response.status_code, 401) - def test_nonexistent_scenario_returns_404(self): """Test operations on non-existent scenario return 404.""" self.client.force_login(self.user) diff --git a/epapi/tests/v1/test_api_permissions.py b/epapi/tests/v1/test_api_permissions.py index 200a761f..57de6215 100644 --- a/epapi/tests/v1/test_api_permissions.py +++ b/epapi/tests/v1/test_api_permissions.py @@ -261,13 +261,6 @@ class GlobalCompoundListPermissionTest(APIPermissionTestBase): self.assertEqual(response.status_code, 200) payload = response.json() - # user2 should see compounds from: - # - reviewed_package (public) - # - unreviewed_package_read (READ permission) - # - unreviewed_package_write (WRITE permission) - # - unreviewed_package_all (ALL permission) - # - group_package (via group membership) - # Total: 5 compounds self.assertEqual(payload["total_items"], 5) visible_uuids = {item["uuid"] for item in payload["items"]} @@ -303,54 +296,6 @@ class GlobalCompoundListPermissionTest(APIPermissionTestBase): # user1 owns all packages, so sees all compounds self.assertEqual(payload["total_items"], 7) - def test_read_permission_allows_viewing(self): - """READ permission allows viewing compounds.""" - self.client.force_login(self.user2) - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, 200) - payload = response.json() - - # Check that read_compound is included - uuids = [item["uuid"] for item in payload["items"]] - self.assertIn(str(self.read_compound.uuid), uuids) - - def test_write_permission_allows_viewing(self): - """WRITE permission also allows viewing compounds.""" - self.client.force_login(self.user2) - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, 200) - payload = response.json() - - # Check that write_compound is included - uuids = [item["uuid"] for item in payload["items"]] - self.assertIn(str(self.write_compound.uuid), uuids) - - def test_all_permission_allows_viewing(self): - """ALL permission allows viewing compounds.""" - self.client.force_login(self.user2) - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, 200) - payload = response.json() - - # Check that all_compound is included - uuids = [item["uuid"] for item in payload["items"]] - self.assertIn(str(self.all_compound.uuid), uuids) - - def test_group_permission_allows_viewing(self): - """Group membership grants access to group-permitted packages.""" - self.client.force_login(self.user2) - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, 200) - payload = response.json() - - # Check that group_compound is included - uuids = [item["uuid"] for item in payload["items"]] - self.assertIn(str(self.group_compound.uuid), uuids) - @tag("api", "end2end") class PackageScopedCompoundListPermissionTest(APIPermissionTestBase): diff --git a/epapi/tests/v1/test_contract_get_entities.py b/epapi/tests/v1/test_contract_get_entities.py index 3bdeed59..0ec515f8 100644 --- a/epapi/tests/v1/test_contract_get_entities.py +++ b/epapi/tests/v1/test_contract_get_entities.py @@ -134,7 +134,7 @@ class BaseTestAPIGetPaginated: f"({self.total_reviewed} <= {self.default_page_size})" ) - response = self.client.get(self.global_endpoint, {"page": 2}) + response = self.client.get(self.global_endpoint, {"page": 2, "review_status": True}) self.assertEqual(response.status_code, 200) payload = response.json() diff --git a/epapi/tests/v1/test_scenario_creation.py b/epapi/tests/v1/test_scenario_creation.py index 3cd16579..735bae12 100644 --- a/epapi/tests/v1/test_scenario_creation.py +++ b/epapi/tests/v1/test_scenario_creation.py @@ -119,7 +119,7 @@ class ScenarioCreationAPITests(TestCase): "scenario_type": "biodegradation", "additional_information": [ { - "type": "SomeValidType", + "type": "invalid_type_name", "data": None, # This should cause a validation error } ], @@ -131,8 +131,8 @@ class ScenarioCreationAPITests(TestCase): content_type="application/json", ) - # Should return 400 for validation errors - self.assertIn(response.status_code, [400, 422]) + # Should return 422 for validation errors + self.assertEqual(response.status_code, 422) def test_create_scenario_success(self): """Test that valid scenario creation returns 200.""" @@ -162,30 +162,6 @@ class ScenarioCreationAPITests(TestCase): self.assertEqual(scenario.package, self.package) self.assertEqual(scenario.scenario_type, "biodegradation") - def test_create_scenario_with_additional_information(self): - """Test creating scenario with valid additional information models.""" - self.client.force_login(self.user) - - # This test will succeed if the registry has valid models - # For now, test with empty additional_information - payload = { - "name": "Scenario with AI", - "description": "Test with additional information", - "scenario_date": "2024-01-01", - "scenario_type": "biodegradation", - "additional_information": [], - } - - response = self.client.post( - f"/api/v1/package/{self.package.uuid}/scenario/", - data=json.dumps(payload), - content_type="application/json", - ) - - self.assertEqual(response.status_code, 200) - data = response.json() - self.assertEqual(data["name"], "Scenario with AI") - def test_create_scenario_auto_name(self): """Test that empty name triggers auto-generation.""" self.client.force_login(self.user) @@ -265,7 +241,7 @@ class ScenarioCreationAPITests(TestCase): "scenario_type": "biodegradation", "additional_information": [ { - "type": "SomeType", + "type": "invalid_type_name", "data": "string instead of dict", # Wrong type } ], @@ -277,8 +253,8 @@ class ScenarioCreationAPITests(TestCase): content_type="application/json", ) - # Should return 400 for validation errors - self.assertIn(response.status_code, [400, 422]) + # Should return 422 for validation errors + self.assertEqual(response.status_code, 422) def test_create_scenario_default_values(self): """Test that default values are applied correctly.""" diff --git a/epapi/tests/v1/test_token_auth.py b/epapi/tests/v1/test_token_auth.py new file mode 100644 index 00000000..0e220895 --- /dev/null +++ b/epapi/tests/v1/test_token_auth.py @@ -0,0 +1,94 @@ +from datetime import timedelta + +from django.test import TestCase, tag +from django.utils import timezone + +from epdb.logic import PackageManager, UserManager +from epdb.models import APIToken + + +@tag("api", "auth") +class BearerTokenAuthTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.user = UserManager.create_user( + "token-user", + "token-user@envipath.com", + "SuperSafe", + set_setting=False, + add_to_group=False, + is_active=True, + ) + + default_pkg = cls.user.default_package + cls.user.default_package = None + cls.user.save() + if default_pkg: + default_pkg.delete() + + cls.unreviewed_package = PackageManager.create_package( + cls.user, "Token Auth Package", "Package for token auth tests" + ) + + def _auth_header(self, raw_token): + return {"HTTP_AUTHORIZATION": f"Bearer {raw_token}"} + + def test_valid_token_allows_access(self): + _, raw_token = APIToken.create_token(self.user, name="Valid Token", expires_days=1) + + response = self.client.get("/api/v1/compounds/", **self._auth_header(raw_token)) + + self.assertEqual(response.status_code, 200) + + def test_expired_token_rejected(self): + token, raw_token = APIToken.create_token(self.user, name="Expired Token", expires_days=1) + token.expires_at = timezone.now() - timedelta(days=1) + token.save(update_fields=["expires_at"]) + + response = self.client.get("/api/v1/compounds/", **self._auth_header(raw_token)) + + self.assertEqual(response.status_code, 401) + + def test_inactive_token_rejected(self): + token, raw_token = APIToken.create_token(self.user, name="Inactive Token", expires_days=1) + token.is_active = False + token.save(update_fields=["is_active"]) + + response = self.client.get("/api/v1/compounds/", **self._auth_header(raw_token)) + + self.assertEqual(response.status_code, 401) + + def test_invalid_token_rejected(self): + response = self.client.get("/api/v1/compounds/", HTTP_AUTHORIZATION="Bearer invalid-token") + + self.assertEqual(response.status_code, 401) + + def test_no_token_rejected(self): + self.client.logout() + response = self.client.get("/api/v1/compounds/") + + self.assertEqual(response.status_code, 401) + + def test_bearer_populates_request_user_for_packages(self): + response = self.client.get("/api/v1/packages/") + self.assertEqual(response.status_code, 200) + payload = response.json() + uuids = {item["uuid"] for item in payload["items"]} + self.assertNotIn(str(self.unreviewed_package.uuid), uuids) + + _, raw_token = APIToken.create_token(self.user, name="Package Token", expires_days=1) + response = self.client.get("/api/v1/packages/", **self._auth_header(raw_token)) + + self.assertEqual(response.status_code, 200) + payload = response.json() + uuids = {item["uuid"] for item in payload["items"]} + self.assertIn(str(self.unreviewed_package.uuid), uuids) + + def test_session_auth_still_works_without_bearer(self): + self.client.force_login(self.user) + response = self.client.get("/api/v1/packages/") + + self.assertEqual(response.status_code, 200) + payload = response.json() + uuids = {item["uuid"] for item in payload["items"]} + self.assertIn(str(self.unreviewed_package.uuid), uuids) diff --git a/epapi/v1/auth.py b/epapi/v1/auth.py index 72569ce0..6e0daed7 100644 --- a/epapi/v1/auth.py +++ b/epapi/v1/auth.py @@ -1,8 +1,34 @@ +import hashlib + from ninja.security import HttpBearer from ninja.errors import HttpError +from epdb.models import APIToken + class BearerTokenAuth(HttpBearer): def authenticate(self, request, token): - # FIXME: placeholder; implement it in O(1) time - raise HttpError(401, "Invalid or expired token") + if token is None: + return None + + hashed_token = hashlib.sha256(token.encode()).hexdigest() + user = APIToken.authenticate(hashed_token, hashed=True) + if not user: + raise HttpError(401, "Invalid or expired token") + + request.user = user + return user + + +class OptionalBearerTokenAuth: + """Bearer auth that allows unauthenticated access. + + Validates the Bearer token if present (401 on invalid token), + otherwise lets the request through for anonymous/session access. + """ + + def __init__(self): + self._bearer = BearerTokenAuth() + + def __call__(self, request): + return self._bearer(request) or request.user diff --git a/epapi/v1/endpoints/packages.py b/epapi/v1/endpoints/packages.py index ea08569c..a251b1bf 100644 --- a/epapi/v1/endpoints/packages.py +++ b/epapi/v1/endpoints/packages.py @@ -3,6 +3,7 @@ from ninja import Router from ninja_extra.pagination import paginate import logging +from ..auth import OptionalBearerTokenAuth from ..dal import get_user_packages_for_read from ..pagination import EnhancedPageNumberPagination from ..schemas import PackageOutSchema, SelfReviewStatusFilter @@ -11,7 +12,11 @@ router = Router() logger = logging.getLogger(__name__) -@router.get("/packages/", response=EnhancedPageNumberPagination.Output[PackageOutSchema], auth=None) +@router.get( + "/packages/", + response=EnhancedPageNumberPagination.Output[PackageOutSchema], + auth=OptionalBearerTokenAuth(), +) @paginate( EnhancedPageNumberPagination, page_size=s.API_PAGINATION_DEFAULT_PAGE_SIZE, diff --git a/epdb/api.py b/epdb/api.py index 646d873f..4e7191a7 100644 --- a/epdb/api.py +++ b/epdb/api.py @@ -2,20 +2,12 @@ from typing import List from django.contrib.auth import get_user_model from ninja import Router, Schema, Field -from ninja.errors import HttpError from ninja.pagination import paginate -from ninja.security import HttpBearer + +from epapi.v1.auth import BearerTokenAuth from .logic import PackageManager -from .models import User, Compound, APIToken - - -class BearerTokenAuth(HttpBearer): - def authenticate(self, request, token): - for token_obj in APIToken.objects.select_related("user").all(): - if token_obj.check_token(token) and token_obj.is_valid(): - return token_obj.user - raise HttpError(401, "Invalid or expired token") +from .models import User, Compound def _anonymous_or_real(request): diff --git a/epdb/logic.py b/epdb/logic.py index edd84a2f..9718309b 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -679,7 +679,7 @@ class PackageManager(object): ai_data = json.loads(res.model_dump_json()) ai_data["uuid"] = f"{uuid4()}" new_add_inf[res_cls_name].append(ai_data) - except ValidationError: + except (ValidationError, ValueError): logger.error(f"Failed to convert {name} with {addinf_data}") scen.additional_information = new_add_inf diff --git a/epdb/management/commands/create_api_token.py b/epdb/management/commands/create_api_token.py new file mode 100644 index 00000000..1bf7b880 --- /dev/null +++ b/epdb/management/commands/create_api_token.py @@ -0,0 +1,92 @@ +from django.conf import settings as s +from django.contrib.auth import get_user_model +from django.core.management.base import BaseCommand, CommandError + +from epdb.models import APIToken + + +class Command(BaseCommand): + help = "Create an API token for a user" + + def add_arguments(self, parser): + parser.add_argument( + "--username", + required=True, + help="Username of the user who will own the token", + ) + parser.add_argument( + "--name", + required=True, + help="Descriptive name for the token", + ) + parser.add_argument( + "--expires-days", + type=int, + default=90, + help="Days until expiration (0 for no expiration)", + ) + parser.add_argument( + "--inactive", + action="store_true", + help="Create the token as inactive", + ) + parser.add_argument( + "--curl", + action="store_true", + help="Print a curl example using the token", + ) + parser.add_argument( + "--base-url", + default=None, + help="Base URL for curl example (default SERVER_URL or http://localhost:8000)", + ) + parser.add_argument( + "--endpoint", + default="/api/v1/compounds/", + help="Endpoint path for curl example", + ) + + def handle(self, *args, **options): + username = options["username"] + name = options["name"] + expires_days = options["expires_days"] + + if expires_days < 0: + raise CommandError("--expires-days must be >= 0") + + if expires_days == 0: + expires_days = None + + user_model = get_user_model() + try: + user = user_model.objects.get(username=username) + except user_model.DoesNotExist as exc: + raise CommandError(f"User not found for username '{username}'") from exc + + token, raw_token = APIToken.create_token(user, name=name, expires_days=expires_days) + + if options["inactive"]: + token.is_active = False + token.save(update_fields=["is_active"]) + + self.stdout.write(f"User: {user.username} ({user.email})") + self.stdout.write(f"Token name: {token.name}") + self.stdout.write(f"Token id: {token.id}") + if token.expires_at: + self.stdout.write(f"Expires at: {token.expires_at.isoformat()}") + else: + self.stdout.write("Expires at: never") + self.stdout.write(f"Active: {token.is_active}") + self.stdout.write("Raw token:") + self.stdout.write(raw_token) + + if options["curl"]: + base_url = ( + options["base_url"] or getattr(s, "SERVER_URL", None) or "http://localhost:8000" + ) + endpoint = options["endpoint"] + endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}" + url = f"{base_url.rstrip('/')}{endpoint}" + curl_cmd = f'curl -H "Authorization: Bearer {raw_token}" "{url}"' + self.stdout.write("Curl:") + self.stdout.write(curl_cmd) diff --git a/epdb/models.py b/epdb/models.py index cdbcd64c..20f3348b 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -170,17 +170,18 @@ class APIToken(TimeStampedModel): return token, raw_key @classmethod - def authenticate(cls, raw_key: str) -> Optional[User]: + def authenticate(cls, token: str, *, hashed: bool = False) -> Optional[User]: """ Authenticate a user using an API token. Args: - raw_key: Raw token key + token: Raw token key or SHA-256 hash (when hashed=True) + hashed: Whether the token is already hashed Returns: User if token is valid, None otherwise """ - hashed_key = hashlib.sha256(raw_key.encode()).hexdigest() + hashed_key = token if hashed else hashlib.sha256(token.encode()).hexdigest() try: token = cls.objects.select_related("user").get(hashed_key=hashed_key) diff --git a/epdb/views.py b/epdb/views.py index 612ed2b3..fadbb1f4 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -2600,9 +2600,11 @@ def user(request, user_uuid): if is_hidden_method and request.POST["hidden"] == "request-api-token": name = request.POST.get("name", "No Name") - valid_for = min(max(int(request.POST.get("valid-for", 90)), 1), 90) + expires_days = min(max(int(request.POST.get("valid-for", 90)), 1), 90) - token, raw_token = APIToken.create_token(request.user, name=name, valid_for=valid_for) + token, raw_token = APIToken.create_token( + request.user, name=name, expires_days=expires_days + ) return JsonResponse( {"raw_token": raw_token, "token": {"id": token.id, "name": token.name}} diff --git a/pyproject.toml b/pyproject.toml index c2b14a82..fc486267 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,8 +86,8 @@ build = { sequence = [ ], help = "Build frontend assets and collect static files" } # Database tasks -db-up = { cmd = "docker compose -f docker-compose.dev.yml up -d", help = "Start PostgreSQL database using Docker Compose" } -db-down = { cmd = "docker compose -f docker-compose.dev.yml down", help = "Stop PostgreSQL database" } +db-up = { cmd = "docker compose -p envipath -f docker-compose.dev.yml up -d", help = "Start PostgreSQL database using Docker Compose" } +db-down = { cmd = "docker compose -p envipath -f docker-compose.dev.yml down", help = "Stop PostgreSQL database" } # Celery tasks celery = { cmd = "celery -A envipath worker -l INFO -Q predict,model,background", help = "Start Celery worker for async task processing" } diff --git a/scripts/dev_server.py b/scripts/dev_server.py index fc853151..45a39fa6 100755 --- a/scripts/dev_server.py +++ b/scripts/dev_server.py @@ -11,6 +11,8 @@ import signal import subprocess import sys import time +import os +import dotenv def find_pnpm(): @@ -65,6 +67,7 @@ class DevServerManager: bufsize=1, ) self.processes.append((process, description)) + print(" ".join(command)) print(f"✓ Started {description} (PID: {process.pid})") return process except Exception as e: @@ -146,6 +149,7 @@ class DevServerManager: def main(): """Main entry point.""" + dotenv.load_dotenv() manager = DevServerManager() manager.register_cleanup() @@ -174,9 +178,10 @@ def main(): time.sleep(1) # Start Django dev server + port = os.environ.get("DJANGO_PORT", "8000") django_process = manager.start_process( - ["uv", "run", "python", "manage.py", "runserver"], - "Django server", + ["uv", "run", "python", "manage.py", "runserver", f"0:{port}"], + f"Django server on port {port}", shell=False, )