mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 06:07:43 +00:00
feat(auth): API access control (#2822)
# What does this PR do? - Added ability to specify `required_scope` when declaring an API. This is part of the `@webmethod` decorator. - If auth is enabled, a user can access an API only if `user.attributes['scope']` includes the `required_scope` - We add `required_scope='telemetry.read'` to the telemetry read APIs. ## Test Plan CI with added tests 1. Enable server.auth with github token 2. Observe `client.telemetry.query_traces()` returns 403
This commit is contained in:
parent
7cc4819e90
commit
21bae296f2
7 changed files with 331 additions and 36 deletions
|
@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import (
|
|||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_impls():
|
||||
"""Mock implementations for scope testing"""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
|
||||
|
||||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||
|
||||
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||
|
||||
# Mock the route finding logic
|
||||
def mock_find_matching_route(method, path, route_impls):
|
||||
if method == "POST" and path == "/test/scoped":
|
||||
return None, {}, "/test/scoped", scoped_webmethod
|
||||
elif method == "GET" and path == "/test/public":
|
||||
return None, {}, "/test/public", public_webmethod
|
||||
else:
|
||||
raise ValueError("No matching route")
|
||||
|
||||
import llama_stack.distribution.server.auth
|
||||
|
||||
llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route
|
||||
llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {}
|
||||
|
||||
return middleware, mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
|
@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs):
|
|||
raise Exception("Connection error")
|
||||
|
||||
|
||||
async def mock_post_success_with_scope(*args, **kwargs):
|
||||
"""Mock auth response for user with test.read scope"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-user",
|
||||
"attributes": {
|
||||
"scopes": ["test.read", "other.scope"],
|
||||
"roles": ["user"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_post_success_no_scope(*args, **kwargs):
|
||||
"""Mock auth response for user without required scope"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-user",
|
||||
"attributes": {
|
||||
"scopes": ["other.scope"],
|
||||
"roles": ["user"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# HTTP Endpoint Tests
|
||||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
|
@ -252,7 +326,7 @@ def oauth2_app():
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token():
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
# Scope-based authorization tests
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that user with required scope can access protected endpoint"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should call the downstream app (no 403 error sent)
|
||||
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that user without required scope gets 403 access denied"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should send 403 error, not call downstream app
|
||||
mock_app.assert_not_called()
|
||||
assert mock_send.call_count == 2 # start + body
|
||||
|
||||
# Check the response
|
||||
start_call = mock_send.call_args_list[0][0][0]
|
||||
assert start_call["status"] == 403
|
||||
|
||||
body_call = mock_send.call_args_list[1][0][0]
|
||||
body_text = body_call["body"].decode()
|
||||
assert "Access denied" in body_text
|
||||
assert "test.read" in body_text
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that public endpoints work without specific scopes"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/public",
|
||||
"method": "GET",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should call the downstream app (no error)
|
||||
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [], # No authorization header
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should send 401 auth error, not call downstream app
|
||||
mock_app.assert_not_called()
|
||||
assert mock_send.call_count == 2 # start + body
|
||||
|
||||
# Check the response
|
||||
start_call = mock_send.call_args_list[0][0][0]
|
||||
assert start_call["status"] == 401
|
||||
|
||||
body_call = mock_send.call_args_list[1][0][0]
|
||||
body_text = body_call["body"].decode()
|
||||
assert "Authentication required" in body_text
|
||||
|
||||
|
||||
def test_has_required_scope_function():
|
||||
"""Test the _has_required_scope function directly"""
|
||||
# Test user with required scope
|
||||
user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]})
|
||||
assert _has_required_scope("test.read", user_with_scope)
|
||||
|
||||
# Test user without required scope
|
||||
user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]})
|
||||
assert not _has_required_scope("test.read", user_without_scope)
|
||||
|
||||
# Test user with no scopes attribute
|
||||
user_no_scopes = User(principal="test-user", attributes={})
|
||||
assert not _has_required_scope("test.read", user_no_scopes)
|
||||
|
||||
# Test no user (auth disabled)
|
||||
assert _has_required_scope("test.read", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue