mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
api access
- Create BaseServerMiddleware base class for server middleware - Refactor TracingMiddleware to extend BaseServerMiddleware - Consolidate route matching logic in base class - Update server.py to use user_from_scope utility - Add required_scope parameter to WebMethod in schema_utils.py - Create AccessControlMiddleware with simplified scope checking - Update telemetry API to use required_scope protection - Add comprehensive test coverage for access control logic - Integrate access control middleware into server setup - Rename AccessControlMiddleware to AuthorizationMiddleware for better clarity - Update imports and references in server.py and tests - Keep the same functionality and API - Merge authorization logic directly into AuthenticationMiddleware - Remove separate access_control.py file - Update middleware setup in server.py to use single middleware - Rename and update tests to test the merged functionality - AuthenticationMiddleware now handles both authentication and authorization
This commit is contained in:
parent
632cf9eb72
commit
ebea3c8277
7 changed files with 331 additions and 36 deletions
|
@ -504,6 +504,47 @@ created by users sharing a team with them:
|
||||||
description: any user has read access to any resource created by a user with the same team
|
description: any user has read access to any resource created by a user with the same team
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### API Endpoint Authorization with Scopes
|
||||||
|
|
||||||
|
In addition to resource-based access control, Llama Stack supports endpoint-level authorization using OAuth 2.0 style scopes. When authentication is enabled, specific API endpoints require users to have particular scopes in their authentication token.
|
||||||
|
|
||||||
|
**Scope-Gated APIs:**
|
||||||
|
The following APIs are currently gated by scopes:
|
||||||
|
|
||||||
|
- **Telemetry API** (scope: `telemetry.read`):
|
||||||
|
- `POST /telemetry/traces` - Query traces
|
||||||
|
- `GET /telemetry/traces/{trace_id}` - Get trace by ID
|
||||||
|
- `GET /telemetry/traces/{trace_id}/spans/{span_id}` - Get span by ID
|
||||||
|
- `POST /telemetry/spans/{span_id}/tree` - Get span tree
|
||||||
|
- `POST /telemetry/spans` - Query spans
|
||||||
|
- `POST /telemetry/metrics/{metric_name}` - Query metrics
|
||||||
|
|
||||||
|
**Authentication Configuration:**
|
||||||
|
|
||||||
|
For **JWT/OAuth2 providers**, scopes should be included in the JWT's claims:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sub": "user123",
|
||||||
|
"scope": "telemetry.read",
|
||||||
|
"aud": "llama-stack"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For **custom authentication providers**, the endpoint must return user attributes including the `scopes` array:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"principal": "user123",
|
||||||
|
"attributes": {
|
||||||
|
"scopes": ["telemetry.read"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Behavior:**
|
||||||
|
- Users without the required scope receive a 403 Forbidden response
|
||||||
|
- When authentication is disabled, scope checks are bypassed
|
||||||
|
- Endpoints without `required_scope` work normally for all authenticated users
|
||||||
|
|
||||||
### Quota Configuration
|
### Quota Configuration
|
||||||
|
|
||||||
The `quota` section allows you to enable server-side request throttling for both
|
The `quota` section allows you to enable server-side request throttling for both
|
||||||
|
|
|
@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
||||||
# Add this constant near the top of the file, after the imports
|
# Add this constant near the top of the file, after the imports
|
||||||
DEFAULT_TTL_DAYS = 7
|
DEFAULT_TTL_DAYS = 7
|
||||||
|
|
||||||
|
REQUIRED_SCOPE = "telemetry.read"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanStatus(Enum):
|
class SpanStatus(Enum):
|
||||||
|
@ -259,7 +261,7 @@ class Telemetry(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces", method="POST")
|
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_filters: list[QueryCondition] | None = None,
|
attribute_filters: list[QueryCondition] | None = None,
|
||||||
|
@ -277,7 +279,7 @@ class Telemetry(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
|
||||||
async def get_trace(self, trace_id: str) -> Trace:
|
async def get_trace(self, trace_id: str) -> Trace:
|
||||||
"""Get a trace by its ID.
|
"""Get a trace by its ID.
|
||||||
|
|
||||||
|
@ -286,7 +288,9 @@ class Telemetry(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
@webmethod(
|
||||||
|
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
|
||||||
|
)
|
||||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||||
"""Get a span by its ID.
|
"""Get a span by its ID.
|
||||||
|
|
||||||
|
@ -296,7 +300,7 @@ class Telemetry(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
|
@ -312,7 +316,7 @@ class Telemetry(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans", method="POST")
|
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
|
||||||
async def query_spans(
|
async def query_spans(
|
||||||
self,
|
self,
|
||||||
attribute_filters: list[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
|
@ -345,7 +349,7 @@ class Telemetry(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
|
||||||
async def query_metrics(
|
async def query_metrics(
|
||||||
self,
|
self,
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
|
|
|
@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
|
||||||
if not provider_data:
|
if not provider_data:
|
||||||
return None
|
return None
|
||||||
return provider_data.get("__authenticated_user")
|
return provider_data.get("__authenticated_user")
|
||||||
|
|
||||||
|
|
||||||
|
def user_from_scope(scope: dict) -> User | None:
|
||||||
|
"""Create a User object from ASGI scope data (set by authentication middleware)"""
|
||||||
|
user_attributes = scope.get("user_attributes", {})
|
||||||
|
principal = scope.get("principal", "")
|
||||||
|
|
||||||
|
# auth not enabled
|
||||||
|
if not principal and not user_attributes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return User(principal=principal, attributes=user_attributes)
|
||||||
|
|
|
@ -7,9 +7,12 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from aiohttp import hdrs
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
from llama_stack.distribution.datatypes import AuthenticationConfig, User
|
||||||
|
from llama_stack.distribution.request_headers import user_from_scope
|
||||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||||
|
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
|
||||||
access resources that don't have access_attributes defined.
|
access resources that don't have access_attributes defined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
def __init__(self, app, auth_config: AuthenticationConfig, impls):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
self.impls = impls
|
||||||
self.auth_provider = create_auth_provider(auth_config)
|
self.auth_provider = create_auth_provider(auth_config)
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
if scope["type"] == "http":
|
if scope["type"] == "http":
|
||||||
|
# First, handle authentication
|
||||||
headers = dict(scope.get("headers", []))
|
headers = dict(scope.get("headers", []))
|
||||||
auth_header = headers.get(b"authorization", b"").decode()
|
auth_header = headers.get(b"authorization", b"").decode()
|
||||||
|
|
||||||
|
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
|
||||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Scope-based API access control
|
||||||
|
path = scope.get("path", "")
|
||||||
|
method = scope.get("method", hdrs.METH_GET)
|
||||||
|
|
||||||
|
if not hasattr(self, "route_impls"):
|
||||||
|
self.route_impls = initialize_route_impls(self.impls)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||||
|
except ValueError:
|
||||||
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
if webmethod.required_scope:
|
||||||
|
user = user_from_scope(scope)
|
||||||
|
if not _has_required_scope(webmethod.required_scope, user):
|
||||||
|
return await self._send_auth_error(
|
||||||
|
send,
|
||||||
|
f"Access denied: user does not have required scope: {webmethod.required_scope}",
|
||||||
|
status=403,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
async def _send_auth_error(self, send, message):
|
async def _send_auth_error(self, send, message, status=401):
|
||||||
await send(
|
await send(
|
||||||
{
|
{
|
||||||
"type": "http.response.start",
|
"type": "http.response.start",
|
||||||
"status": 401,
|
"status": status,
|
||||||
"headers": [[b"content-type", b"application/json"]],
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
error_key = "message" if status == 401 else "detail"
|
||||||
|
error_msg = json.dumps({"error": {error_key: message}}).encode()
|
||||||
await send({"type": "http.response.body", "body": error_msg})
|
await send({"type": "http.response.body", "body": error_msg})
|
||||||
|
|
||||||
|
|
||||||
|
def _has_required_scope(required_scope: str, user: User | None) -> bool:
|
||||||
|
# if no user, assume auth is not enabled
|
||||||
|
if not user:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not user.attributes:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_scopes = user.attributes.get("scopes", [])
|
||||||
|
return required_scope in user_scopes
|
||||||
|
|
|
@ -41,7 +41,11 @@ from llama_stack.distribution.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
from llama_stack.distribution.request_headers import (
|
||||||
|
PROVIDER_DATA_VAR,
|
||||||
|
request_provider_data_context,
|
||||||
|
user_from_scope,
|
||||||
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.routes import (
|
from llama_stack.distribution.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
|
@ -223,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def route_handler(request: Request, **kwargs):
|
async def route_handler(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user = user_from_scope(request.scope)
|
||||||
principal = request.scope.get("principal", "")
|
|
||||||
user = User(principal=principal, attributes=user_attributes)
|
|
||||||
|
|
||||||
await log_request_pre_validation(request)
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
|
@ -437,10 +439,21 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
# Add authentication middleware if configured
|
try:
|
||||||
|
# Create and set the event loop that will be used for both construction and server runtime
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# Construct the stack in the persistent event loop
|
||||||
|
impls = loop.run_until_complete(construct_stack(config))
|
||||||
|
|
||||||
|
except InvalidProviderError as e:
|
||||||
|
logger.error(f"Error: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||||
else:
|
else:
|
||||||
if config.server.quota:
|
if config.server.quota:
|
||||||
quota = config.server.quota
|
quota = config.server.quota
|
||||||
|
@ -471,18 +484,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
# Create and set the event loop that will be used for both construction and server runtime
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# Construct the stack in the persistent event loop
|
|
||||||
impls = loop.run_until_complete(construct_stack(config))
|
|
||||||
|
|
||||||
except InvalidProviderError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,6 +22,7 @@ class WebMethod:
|
||||||
# A descriptive name of the corresponding span created by tracing
|
# A descriptive name of the corresponding span created by tracing
|
||||||
descriptive_name: str | None = None
|
descriptive_name: str | None = None
|
||||||
experimental: bool | None = False
|
experimental: bool | None = False
|
||||||
|
required_scope: str | None = None
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable[..., Any])
|
T = TypeVar("T", bound=Callable[..., Any])
|
||||||
|
@ -36,6 +37,7 @@ def webmethod(
|
||||||
raw_bytes_request_body: bool | None = False,
|
raw_bytes_request_body: bool | None = False,
|
||||||
descriptive_name: str | None = None,
|
descriptive_name: str | None = None,
|
||||||
experimental: bool | None = False,
|
experimental: bool | None = False,
|
||||||
|
required_scope: str | None = None,
|
||||||
) -> Callable[[T], T]:
|
) -> Callable[[T], T]:
|
||||||
"""
|
"""
|
||||||
Decorator that supplies additional metadata to an endpoint operation function.
|
Decorator that supplies additional metadata to an endpoint operation function.
|
||||||
|
@ -45,6 +47,7 @@ def webmethod(
|
||||||
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
||||||
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
||||||
:param experimental: True if the operation is experimental and subject to change.
|
:param experimental: True if the operation is experimental and subject to change.
|
||||||
|
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(func: T) -> T:
|
def wrap(func: T) -> T:
|
||||||
|
@ -57,6 +60,7 @@ def webmethod(
|
||||||
raw_bytes_request_body=raw_bytes_request_body,
|
raw_bytes_request_body=raw_bytes_request_body,
|
||||||
descriptive_name=descriptive_name,
|
descriptive_name=descriptive_name,
|
||||||
experimental=experimental,
|
experimental=experimental,
|
||||||
|
required_scope=required_scope,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import (
|
||||||
OAuth2JWKSConfig,
|
OAuth2JWKSConfig,
|
||||||
OAuth2TokenAuthConfig,
|
OAuth2TokenAuthConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
from llama_stack.distribution.request_headers import User
|
||||||
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope
|
||||||
from llama_stack.distribution.server.auth_providers import (
|
from llama_stack.distribution.server.auth_providers import (
|
||||||
get_attributes_from_claims,
|
get_attributes_from_claims,
|
||||||
)
|
)
|
||||||
|
@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint):
|
||||||
),
|
),
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint):
|
||||||
),
|
),
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_impls():
|
||||||
|
"""Mock implementations for scope testing"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||||
|
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||||
|
mock_app = AsyncMock()
|
||||||
|
auth_config = AuthenticationConfig(
|
||||||
|
provider_config=CustomAuthConfig(
|
||||||
|
type=AuthProviderType.CUSTOM,
|
||||||
|
endpoint=mock_auth_endpoint,
|
||||||
|
),
|
||||||
|
access_policy=[],
|
||||||
|
)
|
||||||
|
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
|
||||||
|
|
||||||
|
# Mock the route_impls to simulate finding routes with required scopes
|
||||||
|
from llama_stack.schema_utils import WebMethod
|
||||||
|
|
||||||
|
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||||
|
|
||||||
|
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||||
|
|
||||||
|
# Mock the route finding logic
|
||||||
|
def mock_find_matching_route(method, path, route_impls):
|
||||||
|
if method == "POST" and path == "/test/scoped":
|
||||||
|
return None, {}, "/test/scoped", scoped_webmethod
|
||||||
|
elif method == "GET" and path == "/test/public":
|
||||||
|
return None, {}, "/test/public", public_webmethod
|
||||||
|
else:
|
||||||
|
raise ValueError("No matching route")
|
||||||
|
|
||||||
|
import llama_stack.distribution.server.auth
|
||||||
|
|
||||||
|
llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route
|
||||||
|
llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {}
|
||||||
|
|
||||||
|
return middleware, mock_app
|
||||||
|
|
||||||
|
|
||||||
async def mock_post_success(*args, **kwargs):
|
async def mock_post_success(*args, **kwargs):
|
||||||
|
@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs):
|
||||||
raise Exception("Connection error")
|
raise Exception("Connection error")
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_post_success_with_scope(*args, **kwargs):
|
||||||
|
"""Mock auth response for user with test.read scope"""
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"message": "Authentication successful",
|
||||||
|
"principal": "test-user",
|
||||||
|
"attributes": {
|
||||||
|
"scopes": ["test.read", "other.scope"],
|
||||||
|
"roles": ["user"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_post_success_no_scope(*args, **kwargs):
|
||||||
|
"""Mock auth response for user without required scope"""
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"message": "Authentication successful",
|
||||||
|
"principal": "test-user",
|
||||||
|
"attributes": {
|
||||||
|
"scopes": ["other.scope"],
|
||||||
|
"roles": ["user"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# HTTP Endpoint Tests
|
# HTTP Endpoint Tests
|
||||||
def test_missing_auth_header(http_client):
|
def test_missing_auth_header(http_client):
|
||||||
response = http_client.get("/test")
|
response = http_client.get("/test")
|
||||||
|
@ -252,7 +326,7 @@ def oauth2_app():
|
||||||
),
|
),
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token():
|
||||||
),
|
),
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint):
|
||||||
),
|
),
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||||
),
|
),
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"message": "Authentication successful"}
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
# Scope-based authorization tests
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||||
|
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||||
|
"""Test that user with required scope can access protected endpoint"""
|
||||||
|
middleware, mock_app = scope_middleware_with_mocks
|
||||||
|
mock_receive = AsyncMock()
|
||||||
|
mock_send = AsyncMock()
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/test/scoped",
|
||||||
|
"method": "POST",
|
||||||
|
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||||
|
}
|
||||||
|
|
||||||
|
await middleware(scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
# Should call the downstream app (no 403 error sent)
|
||||||
|
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||||
|
mock_send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||||
|
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||||
|
"""Test that user without required scope gets 403 access denied"""
|
||||||
|
middleware, mock_app = scope_middleware_with_mocks
|
||||||
|
mock_receive = AsyncMock()
|
||||||
|
mock_send = AsyncMock()
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/test/scoped",
|
||||||
|
"method": "POST",
|
||||||
|
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||||
|
}
|
||||||
|
|
||||||
|
await middleware(scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
# Should send 403 error, not call downstream app
|
||||||
|
mock_app.assert_not_called()
|
||||||
|
assert mock_send.call_count == 2 # start + body
|
||||||
|
|
||||||
|
# Check the response
|
||||||
|
start_call = mock_send.call_args_list[0][0][0]
|
||||||
|
assert start_call["status"] == 403
|
||||||
|
|
||||||
|
body_call = mock_send.call_args_list[1][0][0]
|
||||||
|
body_text = body_call["body"].decode()
|
||||||
|
assert "Access denied" in body_text
|
||||||
|
assert "test.read" in body_text
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||||
|
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||||
|
"""Test that public endpoints work without specific scopes"""
|
||||||
|
middleware, mock_app = scope_middleware_with_mocks
|
||||||
|
mock_receive = AsyncMock()
|
||||||
|
mock_send = AsyncMock()
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/test/public",
|
||||||
|
"method": "GET",
|
||||||
|
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||||
|
}
|
||||||
|
|
||||||
|
await middleware(scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
# Should call the downstream app (no error)
|
||||||
|
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||||
|
mock_send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||||
|
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||||
|
middleware, mock_app = scope_middleware_with_mocks
|
||||||
|
mock_receive = AsyncMock()
|
||||||
|
mock_send = AsyncMock()
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/test/scoped",
|
||||||
|
"method": "POST",
|
||||||
|
"headers": [], # No authorization header
|
||||||
|
}
|
||||||
|
|
||||||
|
await middleware(scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
# Should send 401 auth error, not call downstream app
|
||||||
|
mock_app.assert_not_called()
|
||||||
|
assert mock_send.call_count == 2 # start + body
|
||||||
|
|
||||||
|
# Check the response
|
||||||
|
start_call = mock_send.call_args_list[0][0][0]
|
||||||
|
assert start_call["status"] == 401
|
||||||
|
|
||||||
|
body_call = mock_send.call_args_list[1][0][0]
|
||||||
|
body_text = body_call["body"].decode()
|
||||||
|
assert "Authentication required" in body_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_required_scope_function():
|
||||||
|
"""Test the _has_required_scope function directly"""
|
||||||
|
# Test user with required scope
|
||||||
|
user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]})
|
||||||
|
assert _has_required_scope("test.read", user_with_scope)
|
||||||
|
|
||||||
|
# Test user without required scope
|
||||||
|
user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]})
|
||||||
|
assert not _has_required_scope("test.read", user_without_scope)
|
||||||
|
|
||||||
|
# Test user with no scopes attribute
|
||||||
|
user_no_scopes = User(principal="test-user", attributes={})
|
||||||
|
assert not _has_required_scope("test.read", user_no_scopes)
|
||||||
|
|
||||||
|
# Test no user (auth disabled)
|
||||||
|
assert _has_required_scope("test.read", None)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue