feat: enhance AccessDeniedError with detailed context and improve exception handling

• Enhanced AccessDeniedError class to include user, action, and resource context
  - Added constructor parameters for action, resource, and user
  - Generate detailed error messages showing user principal, attributes, and attempted resource
  - Backward compatible with existing usage (falls back to generic message)

• Updated exception handling in server.py
  - Import AccessDeniedError from access_control module
  - Return proper 403 status codes with detailed error messages
  - Separate handling for PermissionError (generic) vs AccessDeniedError (detailed)

• Enhanced error context at raise sites
  - Updated routing_tables/common.py to pass action, resource, and user context
  - Updated agents persistence to include context in access denied errors
  - Provides better debugging information for access control issues

• Added comprehensive unit tests
  - Created tests/unit/server/test_server.py with 13 test cases
  - Covers AccessDeniedError with and without context
  - Tests all exception types (ValidationError, BadRequestError, AuthenticationRequiredError, etc.)
  - Validates proper HTTP status codes and error message formats

Resolves access control error visibility issues where 500 errors were returned
instead of proper 403 responses with actionable error messages.

Signed-off-by: Akram Ben Aissi <<akram.benaissi@gmail.com>>
This commit is contained in:
Akram Ben Aissi 2025-07-03 10:26:48 +02:00
parent aa273944fd
commit 31f85076ad
5 changed files with 217 additions and 7 deletions

View file

@ -106,4 +106,21 @@ def is_action_allowed(
class AccessDeniedError(RuntimeError):
pass
def __init__(self, action: str | None = None, resource: ProtectedResource | None = None, user: User | None = None):
self.action = action
self.resource = resource
self.user = user
# Build detailed error message
if action and resource and user:
resource_info = f"{resource.type}::{resource.identifier}"
user_info = f"'{user.principal}'"
if user.attributes:
attrs = ", ".join([f"{k}={v}" for k, v in user.attributes.items()])
user_info += f" (attributes: {attrs})"
message = f"User {user_info} cannot perform action '{action}' on resource '{resource_info}'"
else:
message = "Insufficient permissions"
super().__init__(message)

View file

@ -175,8 +175,9 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
raise AccessDeniedError()
user = get_authenticated_user()
if not is_action_allowed(self.policy, "delete", obj, user):
raise AccessDeniedError("delete", obj, user)
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
@ -193,7 +194,7 @@ class CommonRoutingTableImpl(RoutingTable):
# If object supports access control but no attributes set, use creator's attributes
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError()
raise AccessDeniedError("create", obj, creator)
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")

View file

@ -9,6 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import os
import ssl
import sys
@ -31,6 +32,7 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
@ -116,7 +118,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError):
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
@ -236,7 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
result.url = route
return result
except Exception as e:
logger.exception(f"Error executing endpoint {route=} {method=}")
if logger.isEnabledFor(logging.DEBUG):
logger.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
sig = inspect.signature(func)

View file

@ -53,7 +53,7 @@ class AgentPersistence:
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError()
raise AccessDeniedError("create", session_info, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",

View file

@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
from fastapi import HTTPException
from openai import BadRequestError
from pydantic import ValidationError
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.server.server import translate_exception
class TestTranslateException:
"""Test cases for the translate_exception function."""
def test_translate_access_denied_error(self):
"""Test that AccessDeniedError is translated to 403 HTTP status."""
exc = AccessDeniedError()
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert result.detail == "Permission denied: Insufficient permissions"
def test_translate_access_denied_error_with_context(self):
"""Test that AccessDeniedError with context includes detailed information."""
from llama_stack.distribution.datatypes import User
# Create mock user and resource
user = User("test-user", {"roles": ["user"], "teams": ["dev"]})
# Create a simple mock object that implements the ProtectedResource protocol
class MockResource:
def __init__(self, type: str, identifier: str, owner=None):
self.type = type
self.identifier = identifier
self.owner = owner
resource = MockResource("vector_db", "test-db")
exc = AccessDeniedError("create", resource, user)
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert "test-user" in result.detail
assert "vector_db::test-db" in result.detail
assert "create" in result.detail
assert "roles=['user']" in result.detail
assert "teams=['dev']" in result.detail
def test_translate_permission_error(self):
"""Test that PermissionError is translated to 403 HTTP status."""
exc = PermissionError("Permission denied")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert result.detail == "Permission denied: Permission denied"
def test_translate_value_error(self):
"""Test that ValueError is translated to 400 HTTP status."""
exc = ValueError("Invalid input")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert result.detail == "Invalid value: Invalid input"
def test_translate_bad_request_error(self):
"""Test that BadRequestError is translated to 400 HTTP status."""
# Create a mock response for BadRequestError
mock_response = Mock()
mock_response.status_code = 400
mock_response.headers = {}
exc = BadRequestError("Bad request", response=mock_response, body="Bad request")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert result.detail == "Bad request"
def test_translate_authentication_required_error(self):
"""Test that AuthenticationRequiredError is translated to 401 HTTP status."""
exc = AuthenticationRequiredError("Authentication required")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 401
assert result.detail == "Authentication required: Authentication required"
def test_translate_timeout_error(self):
"""Test that TimeoutError is translated to 504 HTTP status."""
exc = TimeoutError("Operation timed out")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 504
assert result.detail == "Operation timed out: Operation timed out"
def test_translate_asyncio_timeout_error(self):
"""Test that asyncio.TimeoutError is translated to 504 HTTP status."""
exc = TimeoutError()
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 504
assert result.detail == "Operation timed out: "
def test_translate_not_implemented_error(self):
"""Test that NotImplementedError is translated to 501 HTTP status."""
exc = NotImplementedError("Not implemented")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 501
assert result.detail == "Not implemented: Not implemented"
def test_translate_validation_error(self):
"""Test that ValidationError is translated to 400 HTTP status with proper format."""
# Create a mock validation error using proper Pydantic error format
exc = ValidationError.from_exception_data(
"TestModel",
[
{
"loc": ("field", "nested"),
"msg": "field required",
"type": "missing",
}
],
)
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert "errors" in result.detail
assert len(result.detail["errors"]) == 1
assert result.detail["errors"][0]["loc"] == ["field", "nested"]
assert result.detail["errors"][0]["msg"] == "Field required"
assert result.detail["errors"][0]["type"] == "missing"
def test_translate_generic_exception(self):
"""Test that generic exceptions are translated to 500 HTTP status."""
exc = Exception("Unexpected error")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 500
assert result.detail == "Internal server error: An unexpected error occurred."
def test_translate_runtime_error(self):
"""Test that RuntimeError is translated to 500 HTTP status."""
exc = RuntimeError("Runtime error")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 500
assert result.detail == "Internal server error: An unexpected error occurred."
def test_multiple_access_denied_scenarios(self):
"""Test various scenarios that should result in 403 status codes."""
# Test AccessDeniedError (uses enhanced message)
exc1 = AccessDeniedError()
result1 = translate_exception(exc1)
assert isinstance(result1, HTTPException)
assert result1.status_code == 403
assert result1.detail == "Permission denied: Insufficient permissions"
# Test PermissionError (uses generic message)
exc2 = PermissionError("No permission")
result2 = translate_exception(exc2)
assert isinstance(result2, HTTPException)
assert result2.status_code == 403
assert result2.detail == "Permission denied: No permission"
exc3 = PermissionError("Access denied")
result3 = translate_exception(exc3)
assert isinstance(result3, HTTPException)
assert result3.status_code == 403
assert result3.detail == "Permission denied: Access denied"