mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
api access
- Create BaseServerMiddleware base class for server middleware - Refactor TracingMiddleware to extend BaseServerMiddleware - Consolidate route matching logic in base class - Update server.py to use user_from_scope utility - Add required_scope parameter to WebMethod in schema_utils.py - Create AccessControlMiddleware with simplified scope checking - Update telemetry API to use required_scope protection - Add comprehensive test coverage for access control logic - Integrate access control middleware into server setup - Rename AccessControlMiddleware to AuthorizationMiddleware for better clarity - Update imports and references in server.py and tests - Keep the same functionality and API - Merge authorization logic directly into AuthenticationMiddleware - Remove separate access_control.py file - Update middleware setup in server.py to use single middleware - Rename and update tests to test the merged functionality - AuthenticationMiddleware now handles both authentication and authorization
This commit is contained in:
parent
632cf9eb72
commit
ebea3c8277
7 changed files with 331 additions and 36 deletions
|
@ -41,7 +41,11 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.routes import (
|
||||
find_matching_route,
|
||||
|
@ -223,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal=principal, attributes=user_attributes)
|
||||
user = user_from_scope(request.scope)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
|
@ -437,10 +439,21 @@ def main(args: argparse.Namespace | None = None):
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
|
@ -471,18 +484,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue