From f4950f4ef0965d21443660d0e3da1a853c30c197 Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Thu, 3 Jul 2025 19:50:49 +0200 Subject: [PATCH] fix: AccessDeniedError leads to HTTP 500 instead of error 403 (#2595) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves access control error visibility issues where 500 errors were returned instead of proper 403 responses with actionable error messages. • Enhance AccessDeniedError with detailed context and improve exception handling • Enhanced AccessDeniedError class to include user, action, and resource context - Added constructor parameters for action, resource, and user - Generate detailed error messages showing user principal, attributes, and attempted resource - Backward compatible with existing usage (falls back to generic message) • Updated exception handling in server.py - Import AccessDeniedError from access_control module - Return proper 403 status codes with detailed error messages - Separate handling for PermissionError (generic) vs AccessDeniedError (detailed) • Enhanced error context at raise sites - Updated routing_tables/common.py to pass action, resource, and user context - Updated agents persistence to include context in access denied errors - Provides better debugging information for access control issues • Added comprehensive unit tests - Created tests/unit/server/test_server.py with 13 test cases - Covers AccessDeniedError with and without context - Tests all exception types (ValidationError, BadRequestError, AuthenticationRequiredError, etc.) - Validates proper HTTP status codes and error message formats # What does this PR do? ## Test Plan ``` server: port: 8321 access_policy: - permit: principal: admin actions: [create, read, delete] when: user with admin in groups - permit: actions: [read] when: user with system:authenticated in roles ``` then: ``` curl --request POST --url http://localhost:8321/v1/vector-dbs \ --header "Authorization: Bearer your-bearer" \ --data '{ "vector_db_id": "my_demo_vector_db", "embedding_model": "ibm-granite/granite-embedding-125m-english", "embedding_dimension": 768, "provider_id": "milvus" }' ``` depending if user is in group admin or not, you should get the `AccessDeniedError`. Before this PR, this was leading to an error 500 and `Traceback` displayed in the logs. After the PR, logs display a simpler error (unless DEBUG logging is set) and a 403 Forbidden error is returned on the HTTP side. --------- Signed-off-by: Akram Ben Aissi <> --- .../access_control/access_control.py | 24 ++- .../distribution/routing_tables/common.py | 7 +- llama_stack/distribution/server/server.py | 9 +- .../agents/meta_reference/persistence.py | 2 +- tests/unit/fixtures.py | 8 +- .../agents/test_persistence_access_control.py | 3 +- .../providers/vector_io/test_sqlite_vec.py | 2 +- tests/unit/server/test_access_control.py | 5 +- tests/unit/server/test_server.py | 187 ++++++++++++++++++ 9 files changed, 232 insertions(+), 15 deletions(-) create mode 100644 tests/unit/server/test_server.py diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/distribution/access_control/access_control.py index 84d506d8f..075152ce4 100644 --- a/llama_stack/distribution/access_control/access_control.py +++ b/llama_stack/distribution/access_control/access_control.py @@ -106,4 +106,26 @@ def is_action_allowed( class AccessDeniedError(RuntimeError): - pass + def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None): + self.action = action + self.resource = resource + self.user = user + + message = _build_access_denied_message(action, resource, user) + super().__init__(message) + + +def _build_access_denied_message(action: str | None, resource: ProtectedResource | None, user: User | None) -> str: + """Build detailed error message for access denied scenarios.""" + if action and resource and user: + resource_info = f"{resource.type}::{resource.identifier}" + user_info = f"'{user.principal}'" + if user.attributes: + attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()]) + user_info += f" (attributes: {attrs})" + + message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'" + else: + message = "Insufficient permissions" + + return message diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index b79c8a2a8..7f7de32fe 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -175,8 +175,9 @@ class CommonRoutingTableImpl(RoutingTable): return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: - if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()): - raise AccessDeniedError() + user = get_authenticated_user() + if not is_action_allowed(self.policy, "delete", obj, user): + raise AccessDeniedError("delete", obj, user) await self.dist_registry.delete(obj.type, obj.identifier) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) @@ -193,7 +194,7 @@ class CommonRoutingTableImpl(RoutingTable): # If object supports access control but no attributes set, use creator's attributes creator = get_authenticated_user() if not is_action_allowed(self.policy, "create", obj, creator): - raise AccessDeniedError() + raise AccessDeniedError("create", obj, creator) if creator: obj.owner = creator logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 83407a25f..681ab320d 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,6 +9,7 @@ import asyncio import functools import inspect import json +import logging import os import ssl import sys @@ -31,6 +32,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context @@ -116,7 +118,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): return HTTPException(status_code=400, detail=str(exc)) - elif isinstance(exc, PermissionError): + elif isinstance(exc, PermissionError | AccessDeniedError): return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, asyncio.TimeoutError | TimeoutError): return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") @@ -236,7 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: result.url = route return result except Exception as e: - logger.exception(f"Error executing endpoint {route=} {method=}") + if logger.isEnabledFor(logging.DEBUG): + logger.exception(f"Error executing endpoint {route=} {method=}") + else: + logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") raise translate_exception(e) from e sig = inspect.signature(func) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 717387008..cda535937 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -53,7 +53,7 @@ class AgentPersistence: identifier=name, # should this be qualified in any way? ) if not is_action_allowed(self.policy, "create", session_info, user): - raise AccessDeniedError() + raise AccessDeniedError("create", session_info, user) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index 7174d2e78..4e50c5e08 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest +import pytest_asyncio from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def sqlite_kvstore(tmp_path): db_path = tmp_path / "test_kv.db" kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix()) @@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path): yield kvstore -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def disk_dist_registry(sqlite_kvstore): registry = DiskDistributionRegistry(sqlite_kvstore) await registry.initialize() yield registry -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def cached_disk_dist_registry(sqlite_kvstore): registry = CachedDiskDistributionRegistry(sqlite_kvstore) await registry.initialize() diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py index d5b876a09..656d1e53c 100644 --- a/tests/unit/providers/agents/test_persistence_access_control.py +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -9,6 +9,7 @@ from datetime import datetime from unittest.mock import patch import pytest +import pytest_asyncio from llama_stack.apis.agents import Turn from llama_stack.apis.inference import CompletionMessage, StopReason @@ -16,7 +17,7 @@ from llama_stack.distribution.datatypes import User from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo -@pytest.fixture +@pytest_asyncio.fixture async def test_setup(sqlite_kvstore): agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={}) yield agent_persistence diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index bbac717c7..5d9d92cf3 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -148,7 +148,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!" -@pytest.fixture(scope="session") +@pytest_asyncio.fixture(scope="session") async def sqlite_vec_adapter(sqlite_connection): config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None) diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index f9ad47b0c..af03ddacb 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +import pytest_asyncio import yaml from pydantic import TypeAdapter, ValidationError @@ -26,7 +27,7 @@ def _return_model(model): return model -@pytest.fixture +@pytest_asyncio.fixture async def test_setup(cached_disk_dist_registry): mock_inference = Mock() mock_inference.__provider_spec__ = MagicMock() @@ -245,7 +246,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set assert model.identifier == "auto-access-model" -@pytest.fixture +@pytest_asyncio.fixture async def test_setup_with_access_policy(cached_disk_dist_registry): mock_inference = Mock() mock_inference.__provider_spec__ = MagicMock() diff --git a/tests/unit/server/test_server.py b/tests/unit/server/test_server.py new file mode 100644 index 000000000..d17d58b8a --- /dev/null +++ b/tests/unit/server/test_server.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +from fastapi import HTTPException +from openai import BadRequestError +from pydantic import ValidationError + +from llama_stack.distribution.access_control.access_control import AccessDeniedError +from llama_stack.distribution.datatypes import AuthenticationRequiredError +from llama_stack.distribution.server.server import translate_exception + + +class TestTranslateException: + """Test cases for the translate_exception function.""" + + def test_translate_access_denied_error(self): + """Test that AccessDeniedError is translated to 403 HTTP status.""" + exc = AccessDeniedError() + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 403 + assert result.detail == "Permission denied: Insufficient permissions" + + def test_translate_access_denied_error_with_context(self): + """Test that AccessDeniedError with context includes detailed information.""" + from llama_stack.distribution.datatypes import User + + # Create mock user and resource + user = User("test-user", {"roles": ["user"], "teams": ["dev"]}) + + # Create a simple mock object that implements the ProtectedResource protocol + class MockResource: + def __init__(self, type: str, identifier: str, owner=None): + self.type = type + self.identifier = identifier + self.owner = owner + + resource = MockResource("vector_db", "test-db") + + exc = AccessDeniedError("create", resource, user) + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 403 + assert "test-user" in result.detail + assert "vector_db::test-db" in result.detail + assert "create" in result.detail + assert "roles=['user']" in result.detail + assert "teams=['dev']" in result.detail + + def test_translate_permission_error(self): + """Test that PermissionError is translated to 403 HTTP status.""" + exc = PermissionError("Permission denied") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 403 + assert result.detail == "Permission denied: Permission denied" + + def test_translate_value_error(self): + """Test that ValueError is translated to 400 HTTP status.""" + exc = ValueError("Invalid input") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 400 + assert result.detail == "Invalid value: Invalid input" + + def test_translate_bad_request_error(self): + """Test that BadRequestError is translated to 400 HTTP status.""" + # Create a mock response for BadRequestError + mock_response = Mock() + mock_response.status_code = 400 + mock_response.headers = {} + + exc = BadRequestError("Bad request", response=mock_response, body="Bad request") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 400 + assert result.detail == "Bad request" + + def test_translate_authentication_required_error(self): + """Test that AuthenticationRequiredError is translated to 401 HTTP status.""" + exc = AuthenticationRequiredError("Authentication required") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 401 + assert result.detail == "Authentication required: Authentication required" + + def test_translate_timeout_error(self): + """Test that TimeoutError is translated to 504 HTTP status.""" + exc = TimeoutError("Operation timed out") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 504 + assert result.detail == "Operation timed out: Operation timed out" + + def test_translate_asyncio_timeout_error(self): + """Test that asyncio.TimeoutError is translated to 504 HTTP status.""" + exc = TimeoutError() + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 504 + assert result.detail == "Operation timed out: " + + def test_translate_not_implemented_error(self): + """Test that NotImplementedError is translated to 501 HTTP status.""" + exc = NotImplementedError("Not implemented") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 501 + assert result.detail == "Not implemented: Not implemented" + + def test_translate_validation_error(self): + """Test that ValidationError is translated to 400 HTTP status with proper format.""" + # Create a mock validation error using proper Pydantic error format + exc = ValidationError.from_exception_data( + "TestModel", + [ + { + "loc": ("field", "nested"), + "msg": "field required", + "type": "missing", + } + ], + ) + + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 400 + assert "errors" in result.detail + assert len(result.detail["errors"]) == 1 + assert result.detail["errors"][0]["loc"] == ["field", "nested"] + assert result.detail["errors"][0]["msg"] == "Field required" + assert result.detail["errors"][0]["type"] == "missing" + + def test_translate_generic_exception(self): + """Test that generic exceptions are translated to 500 HTTP status.""" + exc = Exception("Unexpected error") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 500 + assert result.detail == "Internal server error: An unexpected error occurred." + + def test_translate_runtime_error(self): + """Test that RuntimeError is translated to 500 HTTP status.""" + exc = RuntimeError("Runtime error") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 500 + assert result.detail == "Internal server error: An unexpected error occurred." + + def test_multiple_access_denied_scenarios(self): + """Test various scenarios that should result in 403 status codes.""" + # Test AccessDeniedError (uses enhanced message) + exc1 = AccessDeniedError() + result1 = translate_exception(exc1) + assert isinstance(result1, HTTPException) + assert result1.status_code == 403 + assert result1.detail == "Permission denied: Insufficient permissions" + + # Test PermissionError (uses generic message) + exc2 = PermissionError("No permission") + result2 = translate_exception(exc2) + assert isinstance(result2, HTTPException) + assert result2.status_code == 403 + assert result2.detail == "Permission denied: No permission" + + exc3 = PermissionError("Access denied") + result3 = translate_exception(exc3) + assert isinstance(result3, HTTPException) + assert result3.status_code == 403 + assert result3.detail == "Permission denied: Access denied"