mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-24 16:57:21 +00:00 
			
		
		
		
	This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.
A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.
A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.
Some examples in yaml:
```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```
---------
Signed-off-by: Gordon Sim <gsim@redhat.com>
		
	
			
		
			
				
	
	
		
			131 lines
		
	
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			131 lines
		
	
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import json
 | |
| 
 | |
| import httpx
 | |
| 
 | |
| from llama_stack.distribution.datatypes import AuthenticationConfig
 | |
| from llama_stack.distribution.server.auth_providers import create_auth_provider
 | |
| from llama_stack.log import get_logger
 | |
| 
 | |
| logger = get_logger(name=__name__, category="auth")
 | |
| 
 | |
| 
 | |
| class AuthenticationMiddleware:
 | |
|     """Middleware that authenticates requests using configured authentication provider.
 | |
| 
 | |
|     This middleware:
 | |
|     1. Extracts the Bearer token from the Authorization header
 | |
|     2. Uses the configured auth provider to validate the token
 | |
|     3. Extracts user attributes from the provider's response
 | |
|     4. Makes these attributes available to the route handlers for access control
 | |
| 
 | |
|     The middleware supports multiple authentication providers through the AuthProvider interface:
 | |
|     - Kubernetes: Validates tokens against the Kubernetes API server
 | |
|     - Custom: Validates tokens against a custom endpoint
 | |
| 
 | |
|     Authentication Request Format for Custom Auth Provider:
 | |
|     ```json
 | |
|     {
 | |
|         "api_key": "the-api-key-extracted-from-auth-header",
 | |
|         "request": {
 | |
|             "path": "/models/list",
 | |
|             "headers": {
 | |
|                 "content-type": "application/json",
 | |
|                 "user-agent": "..."
 | |
|                 // All headers except Authorization
 | |
|             },
 | |
|             "params": {
 | |
|                 "limit": ["100"],
 | |
|                 "offset": ["0"]
 | |
|                 // Query parameters as key -> list of values
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     ```
 | |
| 
 | |
|     Expected Auth Endpoint Response Format:
 | |
|     ```json
 | |
|     {
 | |
|         "access_attributes": {    // Structured attribute format
 | |
|             "roles": ["admin", "user"],
 | |
|             "teams": ["ml-team", "nlp-team"],
 | |
|             "projects": ["llama-3", "project-x"],
 | |
|             "namespaces": ["research"]
 | |
|         },
 | |
|         "message": "Optional message about auth result"
 | |
|     }
 | |
|     ```
 | |
| 
 | |
|     Token Validation:
 | |
|     Each provider implements its own token validation logic:
 | |
|     - Kubernetes: Uses TokenReview API to validate service account tokens
 | |
|     - Custom: Sends token to custom endpoint for validation
 | |
| 
 | |
|     Attribute-Based Access Control:
 | |
|     The attributes returned by the auth provider are used to determine which
 | |
|     resources the user can access. Resources can specify required attributes
 | |
|     using the access_attributes field. For a user to access a resource:
 | |
| 
 | |
|     1. All attribute categories specified in the resource must be present in the user's attributes
 | |
|     2. For each category, the user must have at least one matching value
 | |
| 
 | |
|     If the auth provider doesn't return any attributes, the user will only be able to
 | |
|     access resources that don't have access_attributes defined.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, app, auth_config: AuthenticationConfig):
 | |
|         self.app = app
 | |
|         self.auth_provider = create_auth_provider(auth_config)
 | |
| 
 | |
|     async def __call__(self, scope, receive, send):
 | |
|         if scope["type"] == "http":
 | |
|             headers = dict(scope.get("headers", []))
 | |
|             auth_header = headers.get(b"authorization", b"").decode()
 | |
| 
 | |
|             if not auth_header or not auth_header.startswith("Bearer "):
 | |
|                 return await self._send_auth_error(send, "Missing or invalid Authorization header")
 | |
| 
 | |
|             token = auth_header.split("Bearer ", 1)[1]
 | |
| 
 | |
|             # Validate token and get access attributes
 | |
|             try:
 | |
|                 validation_result = await self.auth_provider.validate_token(token, scope)
 | |
|             except httpx.TimeoutException:
 | |
|                 logger.exception("Authentication request timed out")
 | |
|                 return await self._send_auth_error(send, "Authentication service timeout")
 | |
|             except ValueError as e:
 | |
|                 logger.exception("Error during authentication")
 | |
|                 return await self._send_auth_error(send, str(e))
 | |
|             except Exception:
 | |
|                 logger.exception("Error during authentication")
 | |
|                 return await self._send_auth_error(send, "Authentication service error")
 | |
| 
 | |
|             # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
 | |
|             # can identify the requester and enforce per-client rate limits.
 | |
|             scope["authenticated_client_id"] = token
 | |
| 
 | |
|             # Store attributes in request scope
 | |
|             scope["principal"] = validation_result.principal
 | |
|             if validation_result.attributes:
 | |
|                 scope["user_attributes"] = validation_result.attributes
 | |
|             logger.debug(
 | |
|                 f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
 | |
|             )
 | |
| 
 | |
|         return await self.app(scope, receive, send)
 | |
| 
 | |
|     async def _send_auth_error(self, send, message):
 | |
|         await send(
 | |
|             {
 | |
|                 "type": "http.response.start",
 | |
|                 "status": 401,
 | |
|                 "headers": [[b"content-type", b"application/json"]],
 | |
|             }
 | |
|         )
 | |
|         error_msg = json.dumps({"error": {"message": message}}).encode()
 | |
|         await send({"type": "http.response.body", "body": error_msg})
 |