[Feature] Create API Key Authenticaton for v1 API (#327)

Add API key authentication to v1 API
Also includes:
- management command to create keys for users
- Improvements to API tests

Minor:
- more robust way to start docker dev container.

Reviewed-on: enviPath/enviPy#327
Co-authored-by: Tobias O <tobias.olenyi@envipath.com>
Co-committed-by: Tobias O <tobias.olenyi@envipath.com>
This commit is contained in:
2026-02-11 02:29:54 +13:00
committed by jebus
parent c0cfdb9255
commit 5789f20e7f
15 changed files with 282 additions and 165 deletions

View File

@ -51,47 +51,30 @@ class ValidationErrorUtilityTests(TestCase):
self.assertIn("active", formatted)
self.assertIn("inactive", formatted)
def test_format_string_type_error(self):
"""Test formatting of string type validation error."""
def test_format_type_errors(self):
"""Test formatting of type validation errors (string, int, float)."""
test_cases = [
# (field_type, invalid_value, expected_message)
# Note: We don't check exact error_type as Pydantic may use different types
# (e.g., int_type vs int_parsing) but we verify the formatted message is correct
(str, 123, "Please enter a valid string"),
(int, "not_a_number", "Please enter a valid int"),
(float, "not_a_float", "Please enter a valid float"),
]
class TestModel(BaseModel):
name: str
for field_type, invalid_value, expected_message in test_cases:
with self.subTest(field_type=field_type.__name__):
try:
TestModel(name=123)
except ValidationError as e:
errors = e.errors()
self.assertEqual(len(errors), 1)
formatted = format_validation_error(errors[0])
self.assertEqual(formatted, "Please enter a valid string")
class TestModel(BaseModel):
field: field_type
def test_format_int_type_error(self):
"""Test formatting of integer type validation error."""
class TestModel(BaseModel):
count: int
try:
TestModel(count="not_a_number")
except ValidationError as e:
errors = e.errors()
self.assertEqual(len(errors), 1)
formatted = format_validation_error(errors[0])
self.assertEqual(formatted, "Please enter a valid int")
def test_format_float_type_error(self):
"""Test formatting of float type validation error."""
class TestModel(BaseModel):
value: float
try:
TestModel(value="not_a_float")
except ValidationError as e:
errors = e.errors()
self.assertEqual(len(errors), 1)
formatted = format_validation_error(errors[0])
self.assertEqual(formatted, "Please enter a valid float")
try:
TestModel(field=invalid_value)
except ValidationError as e:
errors = e.errors()
self.assertEqual(len(errors), 1)
formatted = format_validation_error(errors[0])
self.assertEqual(formatted, expected_message)
def test_format_value_error(self):
"""Test formatting of value error from custom validator."""
@ -114,6 +97,19 @@ class ValidationErrorUtilityTests(TestCase):
formatted = format_validation_error(errors[0])
self.assertEqual(formatted, "Age must be positive")
def test_format_unknown_error_type_fallback(self):
"""Test that unknown error types fall back to default formatting."""
# Mock an error with an unknown type
mock_error = {
"type": "unknown_custom_type",
"msg": "Input should be a valid email address",
"ctx": {},
}
formatted = format_validation_error(mock_error)
# Should use the else branch which does replacements on the message
self.assertEqual(formatted, "Please enter a valid email address")
def test_handle_validation_error_structure(self):
"""Test that handle_validation_error raises HttpError with correct structure."""

View File

@ -380,13 +380,6 @@ class AdditionalInformationAPITests(TestCase):
self.assertEqual(items[0]["uuid"], item1_uuid)
self.assertEqual(items[0]["data"]["interval"]["start"], 15)
def test_unauthenticated_access_returns_401(self):
"""Test that unauthenticated requests return 401."""
# Don't log in
response = self.client.get(f"/api/v1/scenario/{self.scenario.uuid}/information/")
self.assertEqual(response.status_code, 401)
def test_list_info_denied_without_permission(self):
"""User cannot list info for scenario in package they don't have access to"""
self.client.force_login(self.user)
@ -445,16 +438,6 @@ class AdditionalInformationAPITests(TestCase):
)
self.assertEqual(response.status_code, 403)
def test_anonymous_user_denied(self):
"""Anonymous users cannot access private scenarios"""
# Ensure scenario package is not reviewed (private)
self.scenario.package.reviewed = False
self.scenario.package.save()
# Unauthenticated users get 401 from the auth layer
response = self.client.get(f"/api/v1/scenario/{self.scenario.uuid}/information/")
self.assertEqual(response.status_code, 401)
def test_nonexistent_scenario_returns_404(self):
"""Test operations on non-existent scenario return 404."""
self.client.force_login(self.user)

View File

@ -261,13 +261,6 @@ class GlobalCompoundListPermissionTest(APIPermissionTestBase):
self.assertEqual(response.status_code, 200)
payload = response.json()
# user2 should see compounds from:
# - reviewed_package (public)
# - unreviewed_package_read (READ permission)
# - unreviewed_package_write (WRITE permission)
# - unreviewed_package_all (ALL permission)
# - group_package (via group membership)
# Total: 5 compounds
self.assertEqual(payload["total_items"], 5)
visible_uuids = {item["uuid"] for item in payload["items"]}
@ -303,54 +296,6 @@ class GlobalCompoundListPermissionTest(APIPermissionTestBase):
# user1 owns all packages, so sees all compounds
self.assertEqual(payload["total_items"], 7)
def test_read_permission_allows_viewing(self):
"""READ permission allows viewing compounds."""
self.client.force_login(self.user2)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
payload = response.json()
# Check that read_compound is included
uuids = [item["uuid"] for item in payload["items"]]
self.assertIn(str(self.read_compound.uuid), uuids)
def test_write_permission_allows_viewing(self):
"""WRITE permission also allows viewing compounds."""
self.client.force_login(self.user2)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
payload = response.json()
# Check that write_compound is included
uuids = [item["uuid"] for item in payload["items"]]
self.assertIn(str(self.write_compound.uuid), uuids)
def test_all_permission_allows_viewing(self):
"""ALL permission allows viewing compounds."""
self.client.force_login(self.user2)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
payload = response.json()
# Check that all_compound is included
uuids = [item["uuid"] for item in payload["items"]]
self.assertIn(str(self.all_compound.uuid), uuids)
def test_group_permission_allows_viewing(self):
"""Group membership grants access to group-permitted packages."""
self.client.force_login(self.user2)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
payload = response.json()
# Check that group_compound is included
uuids = [item["uuid"] for item in payload["items"]]
self.assertIn(str(self.group_compound.uuid), uuids)
@tag("api", "end2end")
class PackageScopedCompoundListPermissionTest(APIPermissionTestBase):

View File

@ -134,7 +134,7 @@ class BaseTestAPIGetPaginated:
f"({self.total_reviewed} <= {self.default_page_size})"
)
response = self.client.get(self.global_endpoint, {"page": 2})
response = self.client.get(self.global_endpoint, {"page": 2, "review_status": True})
self.assertEqual(response.status_code, 200)
payload = response.json()

View File

@ -119,7 +119,7 @@ class ScenarioCreationAPITests(TestCase):
"scenario_type": "biodegradation",
"additional_information": [
{
"type": "SomeValidType",
"type": "invalid_type_name",
"data": None, # This should cause a validation error
}
],
@ -131,8 +131,8 @@ class ScenarioCreationAPITests(TestCase):
content_type="application/json",
)
# Should return 400 for validation errors
self.assertIn(response.status_code, [400, 422])
# Should return 422 for validation errors
self.assertEqual(response.status_code, 422)
def test_create_scenario_success(self):
"""Test that valid scenario creation returns 200."""
@ -162,30 +162,6 @@ class ScenarioCreationAPITests(TestCase):
self.assertEqual(scenario.package, self.package)
self.assertEqual(scenario.scenario_type, "biodegradation")
def test_create_scenario_with_additional_information(self):
"""Test creating scenario with valid additional information models."""
self.client.force_login(self.user)
# This test will succeed if the registry has valid models
# For now, test with empty additional_information
payload = {
"name": "Scenario with AI",
"description": "Test with additional information",
"scenario_date": "2024-01-01",
"scenario_type": "biodegradation",
"additional_information": [],
}
response = self.client.post(
f"/api/v1/package/{self.package.uuid}/scenario/",
data=json.dumps(payload),
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertEqual(data["name"], "Scenario with AI")
def test_create_scenario_auto_name(self):
"""Test that empty name triggers auto-generation."""
self.client.force_login(self.user)
@ -265,7 +241,7 @@ class ScenarioCreationAPITests(TestCase):
"scenario_type": "biodegradation",
"additional_information": [
{
"type": "SomeType",
"type": "invalid_type_name",
"data": "string instead of dict", # Wrong type
}
],
@ -277,8 +253,8 @@ class ScenarioCreationAPITests(TestCase):
content_type="application/json",
)
# Should return 400 for validation errors
self.assertIn(response.status_code, [400, 422])
# Should return 422 for validation errors
self.assertEqual(response.status_code, 422)
def test_create_scenario_default_values(self):
"""Test that default values are applied correctly."""

View File

@ -0,0 +1,94 @@
from datetime import timedelta
from django.test import TestCase, tag
from django.utils import timezone
from epdb.logic import PackageManager, UserManager
from epdb.models import APIToken
@tag("api", "auth")
class BearerTokenAuthTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.user = UserManager.create_user(
"token-user",
"token-user@envipath.com",
"SuperSafe",
set_setting=False,
add_to_group=False,
is_active=True,
)
default_pkg = cls.user.default_package
cls.user.default_package = None
cls.user.save()
if default_pkg:
default_pkg.delete()
cls.unreviewed_package = PackageManager.create_package(
cls.user, "Token Auth Package", "Package for token auth tests"
)
def _auth_header(self, raw_token):
return {"HTTP_AUTHORIZATION": f"Bearer {raw_token}"}
def test_valid_token_allows_access(self):
_, raw_token = APIToken.create_token(self.user, name="Valid Token", expires_days=1)
response = self.client.get("/api/v1/compounds/", **self._auth_header(raw_token))
self.assertEqual(response.status_code, 200)
def test_expired_token_rejected(self):
token, raw_token = APIToken.create_token(self.user, name="Expired Token", expires_days=1)
token.expires_at = timezone.now() - timedelta(days=1)
token.save(update_fields=["expires_at"])
response = self.client.get("/api/v1/compounds/", **self._auth_header(raw_token))
self.assertEqual(response.status_code, 401)
def test_inactive_token_rejected(self):
token, raw_token = APIToken.create_token(self.user, name="Inactive Token", expires_days=1)
token.is_active = False
token.save(update_fields=["is_active"])
response = self.client.get("/api/v1/compounds/", **self._auth_header(raw_token))
self.assertEqual(response.status_code, 401)
def test_invalid_token_rejected(self):
response = self.client.get("/api/v1/compounds/", HTTP_AUTHORIZATION="Bearer invalid-token")
self.assertEqual(response.status_code, 401)
def test_no_token_rejected(self):
self.client.logout()
response = self.client.get("/api/v1/compounds/")
self.assertEqual(response.status_code, 401)
def test_bearer_populates_request_user_for_packages(self):
response = self.client.get("/api/v1/packages/")
self.assertEqual(response.status_code, 200)
payload = response.json()
uuids = {item["uuid"] for item in payload["items"]}
self.assertNotIn(str(self.unreviewed_package.uuid), uuids)
_, raw_token = APIToken.create_token(self.user, name="Package Token", expires_days=1)
response = self.client.get("/api/v1/packages/", **self._auth_header(raw_token))
self.assertEqual(response.status_code, 200)
payload = response.json()
uuids = {item["uuid"] for item in payload["items"]}
self.assertIn(str(self.unreviewed_package.uuid), uuids)
def test_session_auth_still_works_without_bearer(self):
self.client.force_login(self.user)
response = self.client.get("/api/v1/packages/")
self.assertEqual(response.status_code, 200)
payload = response.json()
uuids = {item["uuid"] for item in payload["items"]}
self.assertIn(str(self.unreviewed_package.uuid), uuids)

View File

@ -1,8 +1,34 @@
import hashlib
from ninja.security import HttpBearer
from ninja.errors import HttpError
from epdb.models import APIToken
class BearerTokenAuth(HttpBearer):
def authenticate(self, request, token):
# FIXME: placeholder; implement it in O(1) time
raise HttpError(401, "Invalid or expired token")
if token is None:
return None
hashed_token = hashlib.sha256(token.encode()).hexdigest()
user = APIToken.authenticate(hashed_token, hashed=True)
if not user:
raise HttpError(401, "Invalid or expired token")
request.user = user
return user
class OptionalBearerTokenAuth:
"""Bearer auth that allows unauthenticated access.
Validates the Bearer token if present (401 on invalid token),
otherwise lets the request through for anonymous/session access.
"""
def __init__(self):
self._bearer = BearerTokenAuth()
def __call__(self, request):
return self._bearer(request) or request.user

View File

@ -3,6 +3,7 @@ from ninja import Router
from ninja_extra.pagination import paginate
import logging
from ..auth import OptionalBearerTokenAuth
from ..dal import get_user_packages_for_read
from ..pagination import EnhancedPageNumberPagination
from ..schemas import PackageOutSchema, SelfReviewStatusFilter
@ -11,7 +12,11 @@ router = Router()
logger = logging.getLogger(__name__)
@router.get("/packages/", response=EnhancedPageNumberPagination.Output[PackageOutSchema], auth=None)
@router.get(
"/packages/",
response=EnhancedPageNumberPagination.Output[PackageOutSchema],
auth=OptionalBearerTokenAuth(),
)
@paginate(
EnhancedPageNumberPagination,
page_size=s.API_PAGINATION_DEFAULT_PAGE_SIZE,

View File

@ -2,20 +2,12 @@ from typing import List
from django.contrib.auth import get_user_model
from ninja import Router, Schema, Field
from ninja.errors import HttpError
from ninja.pagination import paginate
from ninja.security import HttpBearer
from epapi.v1.auth import BearerTokenAuth
from .logic import PackageManager
from .models import User, Compound, APIToken
class BearerTokenAuth(HttpBearer):
def authenticate(self, request, token):
for token_obj in APIToken.objects.select_related("user").all():
if token_obj.check_token(token) and token_obj.is_valid():
return token_obj.user
raise HttpError(401, "Invalid or expired token")
from .models import User, Compound
def _anonymous_or_real(request):

View File

@ -679,7 +679,7 @@ class PackageManager(object):
ai_data = json.loads(res.model_dump_json())
ai_data["uuid"] = f"{uuid4()}"
new_add_inf[res_cls_name].append(ai_data)
except ValidationError:
except (ValidationError, ValueError):
logger.error(f"Failed to convert {name} with {addinf_data}")
scen.additional_information = new_add_inf

View File

@ -0,0 +1,92 @@
from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError
from epdb.models import APIToken
class Command(BaseCommand):
help = "Create an API token for a user"
def add_arguments(self, parser):
parser.add_argument(
"--username",
required=True,
help="Username of the user who will own the token",
)
parser.add_argument(
"--name",
required=True,
help="Descriptive name for the token",
)
parser.add_argument(
"--expires-days",
type=int,
default=90,
help="Days until expiration (0 for no expiration)",
)
parser.add_argument(
"--inactive",
action="store_true",
help="Create the token as inactive",
)
parser.add_argument(
"--curl",
action="store_true",
help="Print a curl example using the token",
)
parser.add_argument(
"--base-url",
default=None,
help="Base URL for curl example (default SERVER_URL or http://localhost:8000)",
)
parser.add_argument(
"--endpoint",
default="/api/v1/compounds/",
help="Endpoint path for curl example",
)
def handle(self, *args, **options):
username = options["username"]
name = options["name"]
expires_days = options["expires_days"]
if expires_days < 0:
raise CommandError("--expires-days must be >= 0")
if expires_days == 0:
expires_days = None
user_model = get_user_model()
try:
user = user_model.objects.get(username=username)
except user_model.DoesNotExist as exc:
raise CommandError(f"User not found for username '{username}'") from exc
token, raw_token = APIToken.create_token(user, name=name, expires_days=expires_days)
if options["inactive"]:
token.is_active = False
token.save(update_fields=["is_active"])
self.stdout.write(f"User: {user.username} ({user.email})")
self.stdout.write(f"Token name: {token.name}")
self.stdout.write(f"Token id: {token.id}")
if token.expires_at:
self.stdout.write(f"Expires at: {token.expires_at.isoformat()}")
else:
self.stdout.write("Expires at: never")
self.stdout.write(f"Active: {token.is_active}")
self.stdout.write("Raw token:")
self.stdout.write(raw_token)
if options["curl"]:
base_url = (
options["base_url"] or getattr(s, "SERVER_URL", None) or "http://localhost:8000"
)
endpoint = options["endpoint"]
endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}"
url = f"{base_url.rstrip('/')}{endpoint}"
curl_cmd = f'curl -H "Authorization: Bearer {raw_token}" "{url}"'
self.stdout.write("Curl:")
self.stdout.write(curl_cmd)

View File

@ -170,17 +170,18 @@ class APIToken(TimeStampedModel):
return token, raw_key
@classmethod
def authenticate(cls, raw_key: str) -> Optional[User]:
def authenticate(cls, token: str, *, hashed: bool = False) -> Optional[User]:
"""
Authenticate a user using an API token.
Args:
raw_key: Raw token key
token: Raw token key or SHA-256 hash (when hashed=True)
hashed: Whether the token is already hashed
Returns:
User if token is valid, None otherwise
"""
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
hashed_key = token if hashed else hashlib.sha256(token.encode()).hexdigest()
try:
token = cls.objects.select_related("user").get(hashed_key=hashed_key)

View File

@ -2600,9 +2600,11 @@ def user(request, user_uuid):
if is_hidden_method and request.POST["hidden"] == "request-api-token":
name = request.POST.get("name", "No Name")
valid_for = min(max(int(request.POST.get("valid-for", 90)), 1), 90)
expires_days = min(max(int(request.POST.get("valid-for", 90)), 1), 90)
token, raw_token = APIToken.create_token(request.user, name=name, valid_for=valid_for)
token, raw_token = APIToken.create_token(
request.user, name=name, expires_days=expires_days
)
return JsonResponse(
{"raw_token": raw_token, "token": {"id": token.id, "name": token.name}}

View File

@ -86,8 +86,8 @@ build = { sequence = [
], help = "Build frontend assets and collect static files" }
# Database tasks
db-up = { cmd = "docker compose -f docker-compose.dev.yml up -d", help = "Start PostgreSQL database using Docker Compose" }
db-down = { cmd = "docker compose -f docker-compose.dev.yml down", help = "Stop PostgreSQL database" }
db-up = { cmd = "docker compose -p envipath -f docker-compose.dev.yml up -d", help = "Start PostgreSQL database using Docker Compose" }
db-down = { cmd = "docker compose -p envipath -f docker-compose.dev.yml down", help = "Stop PostgreSQL database" }
# Celery tasks
celery = { cmd = "celery -A envipath worker -l INFO -Q predict,model,background", help = "Start Celery worker for async task processing" }

View File

@ -11,6 +11,8 @@ import signal
import subprocess
import sys
import time
import os
import dotenv
def find_pnpm():
@ -65,6 +67,7 @@ class DevServerManager:
bufsize=1,
)
self.processes.append((process, description))
print(" ".join(command))
print(f"✓ Started {description} (PID: {process.pid})")
return process
except Exception as e:
@ -146,6 +149,7 @@ class DevServerManager:
def main():
"""Main entry point."""
dotenv.load_dotenv()
manager = DevServerManager()
manager.register_cleanup()
@ -174,9 +178,10 @@ def main():
time.sleep(1)
# Start Django dev server
port = os.environ.get("DJANGO_PORT", "8000")
django_process = manager.start_process(
["uv", "run", "python", "manage.py", "runserver"],
"Django server",
["uv", "run", "python", "manage.py", "runserver", f"0:{port}"],
f"Django server on port {port}",
shell=False,
)