api access

- Create BaseServerMiddleware base class for server middleware
- Refactor TracingMiddleware to extend BaseServerMiddleware
- Consolidate route matching logic in base class
- Update server.py to use user_from_scope utility
- Add required_scope parameter to WebMethod in schema_utils.py
- Create AccessControlMiddleware with simplified scope checking
- Update telemetry API to use required_scope protection
- Add comprehensive test coverage for access control logic
- Integrate access control middleware into server setup
- Rename AccessControlMiddleware to AuthorizationMiddleware for better clarity
- Update imports and references in server.py and tests
- Keep the same functionality and API
- Merge authorization logic directly into AuthenticationMiddleware
- Remove separate access_control.py file
- Update middleware setup in server.py to use single middleware
- Rename and update tests to test the merged functionality
- AuthenticationMiddleware now handles both authentication and authorization
This commit is contained in:
Eric Huang 2025-07-24 14:56:17 -07:00
parent 632cf9eb72
commit ebea3c8277
7 changed files with 331 additions and 36 deletions

View file

@ -41,7 +41,11 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
find_matching_route,
@ -223,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal=principal, attributes=user_attributes)
user = user_from_scope(request.scope)
await log_request_pre_validation(request)
@ -437,10 +439,21 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
@ -471,18 +484,6 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else: