feat: Add Kubernetes authentication (#1778)

# What does this PR do?

This commit adds a new authentication system to the Llama Stack server
with support for Kubernetes and custom authentication providers. Key
changes include:

- Implemented KubernetesAuthProvider for validating Kubernetes service
account tokens
- Implemented CustomAuthProvider for validating tokens against external
endpoints - this is the same code that was already present.
- Added test for Kubernetes
- Updated server configuration to support authentication settings
- Added documentation for authentication configuration and usage

The authentication system supports:
- Bearer token validation
- Kubernetes service account token validation
- Custom authentication endpoints

## Test Plan

Setup a Kube cluster using Kind or Minikube.

Run a server with:

```
server:
  port: 8321
  auth:
    provider_type: kubernetes
    config:
      api_server_url: http://url
      ca_cert_path: path/to/cert (optional)
```

Run:

```
curl -s -L -H "Authorization: Bearer $(kubectl create token my-user)" http://127.0.0.1:8321/v1/providers
```

Or replace "my-user" with your service account.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-04-28 22:24:58 +02:00 committed by GitHub
parent e6bbf8d20b
commit 79851d93aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 886 additions and 154 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
@ -235,10 +236,21 @@ class LoggingConfig(BaseModel):
)
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
class AuthenticationConfig(BaseModel):
endpoint: str = Field(
provider_type: AuthProviderType = Field(
...,
description="Endpoint URL to validate authentication tokens",
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
)
config: Dict[str, str] = Field(
...,
description="Provider-specific configuration",
)

View file

@ -5,74 +5,29 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
"""Middleware that authenticates requests using configured authentication provider.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
2. Uses the configured auth provider to validate the token
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
Authentication Request Format for Custom Auth Provider:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
@ -105,21 +60,26 @@ class AuthenticationMiddleware:
}
```
Token Validation:
Each provider implements its own token validation logic:
- Kubernetes: Uses TokenReview API to validate service account tokens
- Custom: Sends token to custom endpoint for validation
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
The attributes returned by the auth provider are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
If the auth provider doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint):
def __init__(self, app, auth_config: AuthProviderConfig):
self.app = app
self.auth_endpoint = auth_endpoint
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
@ -129,66 +89,34 @@ class AuthenticationMiddleware:
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
api_key = auth_header.split("Bearer ", 1)[1]
token = auth_header.split("Bearer ", 1)[1]
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=api_key,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
# Validate token and get access attributes
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
access_attributes = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"namespaces": [token],
}
# Store attributes in request scope
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):

View file

@ -0,0 +1,262 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
"""Validate a token and return access attributes."""
pass
@abstractmethod
async def close(self):
"""Clean up any resources."""
pass
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: Dict[str, str]):
self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path")
self._client = None
async def _get_client(self):
"""Get or create a Kubernetes client."""
if self._client is None:
# kubernetes-client has not async support, see:
# https://github.com/kubernetes-client/python/issues/323
from kubernetes import client
from kubernetes.client import ApiClient
# Configure the client
configuration = client.Configuration()
configuration.host = self.api_server_url
if self.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path)
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
# Set the token in the client
client.set_default_header("Authorization", f"Bearer {token}")
# Make a request to validate the token
# We use the /api endpoint which requires authentication
from kubernetes.client import CoreV1Api
api = CoreV1Api(client)
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
# If we get here, the token is valid
# Extract user info from the token claims
import base64
# Decode the token (without verification since we've already validated it)
token_parts = token.split(".")
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
# Extract user information from the token
username = payload.get("sub", "")
groups = payload.get("groups", [])
return AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
)
except Exception as e:
logger.exception("Failed to validate Kubernetes token")
raise ValueError("Invalid or expired token") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
self._client.close()
self._client = None
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: Dict[str, str]):
self.endpoint = config["endpoint"]
self._client = None
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
"""Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")
if scope is None:
scope = {}
headers = dict(scope.get("headers", []))
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=token,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config)
elif provider_type == "custom":
return CustomAuthProvider(config.config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -419,9 +419,9 @@ def main(args: Optional[argparse.Namespace] = None):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth and config.server.auth.endpoint:
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
try:
impls = asyncio.run(construct_stack(config))