feat(auth): API access control (#2822)

# What does this PR do?
- Added ability to specify `required_scope` when declaring an API. This
is part of the `@webmethod` decorator.
- If auth is enabled, a user can access an API only if
`user.attributes['scope']` includes the `required_scope`
- We add `required_scope='telemetry.read'` to the telemetry read APIs.

## Test Plan
CI with added tests

1. Enable server.auth with github token
2. Observe `client.telemetry.query_traces()` returns 403
This commit is contained in:
ehhuang 2025-07-24 15:30:48 -07:00 committed by GitHub
parent 7cc4819e90
commit 21bae296f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 331 additions and 36 deletions

View file

@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
REQUIRED_SCOPE = "telemetry.read"
@json_schema_type
class SpanStatus(Enum):
@ -259,7 +261,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces", method="POST")
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,
@ -277,7 +279,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
@ -286,7 +288,9 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
@webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
)
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
@ -296,7 +300,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
async def get_span_tree(
self,
span_id: str,
@ -312,7 +316,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans", method="POST")
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
async def query_spans(
self,
attribute_filters: list[QueryCondition],
@ -345,7 +349,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
async def query_metrics(
self,
metric_name: str,

View file

@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
if not provider_data:
return None
return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

View file

@ -7,9 +7,12 @@
import json
import httpx
from aiohttp import hdrs
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.datatypes import AuthenticationConfig, User
from llama_stack.distribution.request_headers import user_from_scope
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthenticationConfig):
def __init__(self, app, auth_config: AuthenticationConfig, impls):
self.app = app
self.impls = impls
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(
send,
f"Access denied: user does not have required scope: {webmethod.required_scope}",
status=403,
)
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
async def _send_auth_error(self, send, message, status=401):
await send(
{
"type": "http.response.start",
"status": 401,
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
error_key = "message" if status == 401 else "detail"
error_msg = json.dumps({"error": {error_key: message}}).encode()
await send({"type": "http.response.body", "body": error_msg})
def _has_required_scope(required_scope: str, user: User | None) -> bool:
# if no user, assume auth is not enabled
if not user:
return True
if not user.attributes:
return False
user_scopes = user.attributes.get("scopes", [])
return required_scope in user_scopes

View file

@ -41,7 +41,11 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
find_matching_route,
@ -223,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal=principal, attributes=user_attributes)
user = user_from_scope(request.scope)
await log_request_pre_validation(request)
@ -437,10 +439,21 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
@ -471,18 +484,6 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:

View file

@ -22,6 +22,7 @@ class WebMethod:
# A descriptive name of the corresponding span created by tracing
descriptive_name: str | None = None
experimental: bool | None = False
required_scope: str | None = None
T = TypeVar("T", bound=Callable[..., Any])
@ -36,6 +37,7 @@ def webmethod(
raw_bytes_request_body: bool | None = False,
descriptive_name: str | None = None,
experimental: bool | None = False,
required_scope: str | None = None,
) -> Callable[[T], T]:
"""
Decorator that supplies additional metadata to an endpoint operation function.
@ -45,6 +47,7 @@ def webmethod(
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
:param experimental: True if the operation is experimental and subject to change.
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
"""
def wrap(func: T) -> T:
@ -57,6 +60,7 @@ def webmethod(
raw_bytes_request_body=raw_bytes_request_body,
descriptive_name=descriptive_name,
experimental=experimental,
required_scope=required_scope,
)
return func