mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Add Kubernetes authentication (#1778)
# What does this PR do? This commit adds a new authentication system to the Llama Stack server with support for Kubernetes and custom authentication providers. Key changes include: - Implemented KubernetesAuthProvider for validating Kubernetes service account tokens - Implemented CustomAuthProvider for validating tokens against external endpoints - this is the same code that was already present. - Added test for Kubernetes - Updated server configuration to support authentication settings - Added documentation for authentication configuration and usage The authentication system supports: - Bearer token validation - Kubernetes service account token validation - Custom authentication endpoints ## Test Plan Setup a Kube cluster using Kind or Minikube. Run a server with: ``` server: port: 8321 auth: provider_type: kubernetes config: api_server_url: http://url ca_cert_path: path/to/cert (optional) ``` Run: ``` curl -s -L -H "Authorization: Bearer $(kubectl create token my-user)" http://127.0.0.1:8321/v1/providers ``` Or replace "my-user" with your service account. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
e6bbf8d20b
commit
79851d93aa
11 changed files with 886 additions and 154 deletions
136
.github/workflows/integration-auth-tests.yml
vendored
Normal file
136
.github/workflows/integration-auth-tests.yml
vendored
Normal file
|
@ -0,0 +1,136 @@
|
|||
name: Integration Auth Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'distributions/**'
|
||||
- 'llama_stack/**'
|
||||
- 'tests/integration/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test-matrix:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
auth-provider: [kubernetes]
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
uv sync --extra dev --extra test
|
||||
uv pip install -e .
|
||||
llama stack build --template ollama --image-type venv
|
||||
|
||||
- name: Install minikube
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
uses: medyagh/setup-minikube@latest
|
||||
|
||||
- name: Start minikube
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
run: |
|
||||
minikube start
|
||||
kubectl get pods -A
|
||||
|
||||
- name: Configure Kube Auth
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
run: |
|
||||
kubectl create namespace llama-stack
|
||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
|
||||
- name: Set Kubernetes Config
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
run: |
|
||||
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV
|
||||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||
|
||||
- name: Set Kube Auth Config and run server
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
run: |
|
||||
run_dir=$(mktemp -d)
|
||||
cat <<'EOF' > $run_dir/run.yaml
|
||||
version: '2'
|
||||
image_name: kube
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
server:
|
||||
port: 8321
|
||||
EOF
|
||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml
|
||||
cat $run_dir/run.yaml
|
||||
|
||||
source .venv/bin/activate
|
||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
echo "Waiting for Llama Stack server..."
|
||||
for i in {1..30}; do
|
||||
if curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://localhost:8321/v1/health | grep -q "OK"; then
|
||||
echo "Llama Stack server is up!"
|
||||
if grep -q "Enabling authentication with provider: ${{ matrix.auth-provider }}" server.log; then
|
||||
echo "Llama Stack server is configured to use ${{ matrix.auth-provider }} auth"
|
||||
exit 0
|
||||
else
|
||||
echo "Llama Stack server is not configured to use ${{ matrix.auth-provider }} auth"
|
||||
cat server.log
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Llama Stack server failed to start"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: Test auth
|
||||
run: |
|
||||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers|jq
|
|
@ -53,6 +53,13 @@ models:
|
|||
provider_id: ollama
|
||||
provider_model_id: null
|
||||
shields: []
|
||||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: "kubernetes"
|
||||
config:
|
||||
api_server_url: "https://kubernetes.default.svc"
|
||||
ca_cert_path: "/path/to/ca.crt"
|
||||
```
|
||||
|
||||
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
||||
|
@ -102,6 +109,105 @@ A Model is an instance of a "Resource" (see [Concepts](../concepts/index)) and i
|
|||
|
||||
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
|
||||
|
||||
## Server Configuration
|
||||
|
||||
The `server` section configures the HTTP server that serves the Llama Stack APIs:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
port: 8321 # Port to listen on (default: 8321)
|
||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||
auth: # Optional: Authentication configuration
|
||||
provider_type: "kubernetes" # Type of auth provider
|
||||
config: # Provider-specific configuration
|
||||
api_server_url: "https://kubernetes.default.svc"
|
||||
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||
```
|
||||
|
||||
### Authentication Configuration
|
||||
|
||||
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
||||
|
||||
```
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
The server supports multiple authentication providers:
|
||||
|
||||
#### Kubernetes Provider
|
||||
|
||||
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||
|
||||
```bash
|
||||
kubectl create namespace llama-stack
|
||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
```
|
||||
|
||||
Validates tokens against the Kubernetes API server:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "kubernetes"
|
||||
config:
|
||||
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server
|
||||
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
|
||||
```
|
||||
|
||||
The provider extracts user information from the JWT token:
|
||||
- Username from the `sub` claim becomes a role
|
||||
- Kubernetes groups become teams
|
||||
|
||||
You can easily validate a request by running:
|
||||
|
||||
```bash
|
||||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||
```
|
||||
|
||||
#### Custom Provider
|
||||
Validates tokens against a custom authentication endpoint:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "custom"
|
||||
config:
|
||||
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
||||
```
|
||||
|
||||
The custom endpoint receives a POST request with:
|
||||
```json
|
||||
{
|
||||
"api_key": "<token>",
|
||||
"request": {
|
||||
"path": "/api/v1/endpoint",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "curl/7.64.1"
|
||||
},
|
||||
"params": {
|
||||
"key": ["value"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
And must respond with:
|
||||
```json
|
||||
{
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Authentication successful"
|
||||
}
|
||||
```
|
||||
|
||||
If no access attributes are returned, the token is used as a namespace.
|
||||
|
||||
## Extending to handle Safety
|
||||
|
||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -235,10 +236,21 @@ class LoggingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class AuthProviderType(str, Enum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
endpoint: str = Field(
|
||||
provider_type: AuthProviderType = Field(
|
||||
...,
|
||||
description="Endpoint URL to validate authentication tokens",
|
||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||
)
|
||||
config: Dict[str, str] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -5,74 +5,29 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
"""Middleware that authenticates requests using configured authentication provider.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
2. Uses the configured auth provider to validate the token
|
||||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
||||
Authentication Request Format for Custom Auth Provider:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
|
@ -105,21 +60,26 @@ class AuthenticationMiddleware:
|
|||
}
|
||||
```
|
||||
|
||||
Token Validation:
|
||||
Each provider implements its own token validation logic:
|
||||
- Kubernetes: Uses TokenReview API to validate service account tokens
|
||||
- Custom: Sends token to custom endpoint for validation
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
The attributes returned by the auth provider are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
If the auth provider doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
def __init__(self, app, auth_config: AuthProviderConfig):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
|
@ -129,66 +89,34 @@ class AuthenticationMiddleware:
|
|||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
|
||||
api_key = auth_header.split("Bearer ", 1)[1]
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
# Validate token and get access attributes
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [api_key],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
access_attributes = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except ValueError as e:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, str(e))
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if access_attributes:
|
||||
user_attributes = access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to token by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
}
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
|
|
262
llama_stack/distribution/server/auth_providers.py
Normal file
262
llama_stack/distribution/server/auth_providers.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthProviderType(str, Enum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class AuthProviderConfig(BaseModel):
|
||||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
self.api_server_url = config["api_server_url"]
|
||||
self.ca_cert_path = config.get("ca_cert_path")
|
||||
self._client = None
|
||||
|
||||
async def _get_client(self):
|
||||
"""Get or create a Kubernetes client."""
|
||||
if self._client is None:
|
||||
# kubernetes-client has not async support, see:
|
||||
# https://github.com/kubernetes-client/python/issues/323
|
||||
from kubernetes import client
|
||||
from kubernetes.client import ApiClient
|
||||
|
||||
# Configure the client
|
||||
configuration = client.Configuration()
|
||||
configuration.host = self.api_server_url
|
||||
if self.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.ca_cert_path)
|
||||
|
||||
# Create API client
|
||||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
||||
# Set the token in the client
|
||||
client.set_default_header("Authorization", f"Bearer {token}")
|
||||
|
||||
# Make a request to validate the token
|
||||
# We use the /api endpoint which requires authentication
|
||||
from kubernetes.client import CoreV1Api
|
||||
|
||||
api = CoreV1Api(client)
|
||||
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
|
||||
|
||||
# If we get here, the token is valid
|
||||
# Extract user info from the token claims
|
||||
import base64
|
||||
|
||||
# Decode the token (without verification since we've already validated it)
|
||||
token_parts = token.split(".")
|
||||
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
|
||||
|
||||
# Extract user information from the token
|
||||
username = payload.get("sub", "")
|
||||
groups = payload.get("groups", [])
|
||||
|
||||
return AccessAttributes(
|
||||
roles=[username], # Use username as a role
|
||||
teams=groups, # Use Kubernetes groups as teams
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to validate Kubernetes token")
|
||||
raise ValueError("Invalid or expired token") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
self.endpoint = config["endpoint"]
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Authentication endpoint not configured")
|
||||
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
||||
headers = dict(scope.get("headers", []))
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=token,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
return auth_response.access_attributes
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
return auth_response.access_attributes
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during authentication")
|
||||
raise ValueError("Authentication service error") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "kubernetes":
|
||||
return KubernetesAuthProvider(config.config)
|
||||
elif provider_type == "custom":
|
||||
return CustomAuthProvider(config.config)
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
|
@ -419,9 +419,9 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth and config.server.auth.endpoint:
|
||||
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
|
|
|
@ -39,6 +39,7 @@ dependencies = [
|
|||
"tiktoken",
|
||||
"pillow",
|
||||
"h11>=0.16.0",
|
||||
"kubernetes",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -48,7 +49,7 @@ dev = [
|
|||
"pytest-cov",
|
||||
"pytest-html",
|
||||
"pytest-json-report",
|
||||
"nbval", # For notebook testing
|
||||
"nbval", # For notebook testing
|
||||
"black",
|
||||
"ruff",
|
||||
"types-requests",
|
||||
|
@ -56,7 +57,7 @@ dev = [
|
|||
"pre-commit",
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"ruamel.yaml", # needed for openapi generator
|
||||
"ruamel.yaml", # needed for openapi generator
|
||||
]
|
||||
# These are the dependencies required for running unit tests.
|
||||
unit = [
|
||||
|
@ -67,7 +68,7 @@ unit = [
|
|||
"pypdf",
|
||||
"chardet",
|
||||
"qdrant-client",
|
||||
"opentelemetry-exporter-otlp-proto-http"
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
|
|
|
@ -4,15 +4,18 @@ annotated-types==0.7.0
|
|||
anyio==4.8.0
|
||||
attrs==25.1.0
|
||||
blobfile==3.0.0
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
distro==1.9.0
|
||||
durationpy==0.9
|
||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||
filelock==3.17.0
|
||||
fire==0.7.0
|
||||
fsspec==2024.12.0
|
||||
google-auth==2.38.0
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
|
@ -22,18 +25,22 @@ jinja2==3.1.6
|
|||
jiter==0.8.2
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
kubernetes==32.0.1
|
||||
llama-stack-client==0.2.2
|
||||
lxml==5.3.1
|
||||
markdown-it-py==3.0.0
|
||||
markupsafe==3.0.2
|
||||
mdurl==0.1.2
|
||||
numpy==2.2.3
|
||||
oauthlib==3.2.2
|
||||
openai==1.71.0
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
pillow==11.1.0
|
||||
prompt-toolkit==3.0.50
|
||||
pyaml==25.1.0
|
||||
pyasn1==0.6.1
|
||||
pyasn1-modules==0.4.2
|
||||
pycryptodomex==3.21.0
|
||||
pydantic==2.10.6
|
||||
pydantic-core==2.27.2
|
||||
|
@ -45,8 +52,10 @@ pyyaml==6.0.2
|
|||
referencing==0.36.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
requests-oauthlib==2.0.0
|
||||
rich==13.9.4
|
||||
rpds-py==0.22.3
|
||||
rsa==4.9
|
||||
setuptools==75.8.0
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
|
@ -57,3 +66,4 @@ typing-extensions==4.12.2
|
|||
tzdata==2025.1
|
||||
urllib3==2.3.0
|
||||
wcwidth==0.2.13
|
||||
websocket-client==1.8.0
|
||||
|
|
|
@ -12,7 +12,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth import AccessAttributes
|
||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
|
|
@ -10,7 +10,9 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, AuthProviderType
|
||||
|
||||
|
||||
class MockResponse:
|
||||
|
@ -38,9 +40,23 @@ def invalid_api_key():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_auth_endpoint):
|
||||
def valid_token():
|
||||
return "valid.jwt.token"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_token():
|
||||
return "invalid.jwt.token"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint)
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -50,8 +66,29 @@ def app(mock_auth_endpoint):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_client(http_app):
|
||||
return TestClient(http_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_client(k8s_app):
|
||||
return TestClient(k8s_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -61,7 +98,7 @@ def mock_scope():
|
|||
"path": "/models/list",
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"authorization", b"Bearer test-api-key"),
|
||||
(b"authorization", b"Bearer test.jwt.token"),
|
||||
(b"user-agent", b"test-user-agent"),
|
||||
],
|
||||
"query_string": b"limit=100&offset=0",
|
||||
|
@ -69,13 +106,38 @@ def mock_scope():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_middleware(mock_auth_endpoint):
|
||||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
return AuthenticationMiddleware(mock_app, mock_auth_endpoint), mock_app
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
return MockResponse(200, {"message": "Authentication successful"})
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research", "production"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_post_failure(*args, **kwargs):
|
||||
|
@ -86,45 +148,46 @@ async def mock_post_exception(*args, **kwargs):
|
|||
raise Exception("Connection error")
|
||||
|
||||
|
||||
def test_missing_auth_header(client):
|
||||
response = client.get("/test")
|
||||
# HTTP Endpoint Tests
|
||||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(client):
|
||||
response = client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||
def test_valid_authentication(client, valid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
def test_valid_http_authentication(http_client, valid_api_key):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_failure)
|
||||
def test_invalid_authentication(client, invalid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
def test_invalid_http_authentication(http_client, invalid_api_key):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_exception)
|
||||
def test_auth_service_error(client, valid_api_key):
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
def test_http_auth_service_error(http_client, valid_api_key):
|
||||
response = http_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
|
||||
assert response.status_code == 401
|
||||
assert "Authentication service error" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||
def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoint):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(200, {"message": "Authentication successful"})
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client.get(
|
||||
http_client.get(
|
||||
"/test?param1=value1¶m2=value2",
|
||||
headers={
|
||||
"Authorization": f"Bearer {valid_api_key}",
|
||||
|
@ -149,40 +212,43 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_with_access_attributes(mock_middleware, mock_scope):
|
||||
middleware, mock_app = mock_middleware
|
||||
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
|
||||
"""Test HTTP middleware behavior with access attributes"""
|
||||
middleware, mock_app = mock_http_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"access_attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["project-x", "project-y"],
|
||||
}
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research", "production"],
|
||||
},
|
||||
},
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin", "user"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
assert mock_scope["user_attributes"]["projects"] == ["project-x", "project-y"]
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert attributes["roles"] == ["admin", "user"]
|
||||
assert attributes["teams"] == ["ml-team", "nlp-team"]
|
||||
assert attributes["projects"] == ["llama-3", "project-x"]
|
||||
assert attributes["namespaces"] == ["research", "production"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
||||
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_middleware
|
||||
middleware, mock_app = mock_http_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
@ -203,4 +269,104 @@ async def test_auth_middleware_no_attributes(mock_middleware, mock_scope):
|
|||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "namespaces" in attributes
|
||||
assert attributes["namespaces"] == ["test-api-key"]
|
||||
assert attributes["namespaces"] == ["test.jwt.token"]
|
||||
|
||||
|
||||
# Kubernetes Tests
|
||||
def test_missing_auth_header_k8s(k8s_client):
|
||||
response = k8s_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_k8s(k8s_client):
|
||||
response = k8s_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("kubernetes.client.ApiClient")
|
||||
def test_valid_k8s_authentication(mock_api_client, k8s_client, valid_token):
|
||||
# Mock the Kubernetes client
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock successful token validation
|
||||
mock_client.set_default_header = AsyncMock()
|
||||
|
||||
# Mock the token validation to return valid access attributes
|
||||
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
|
||||
mock_validate.return_value = AccessAttributes(
|
||||
roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"]
|
||||
)
|
||||
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("kubernetes.client.ApiClient")
|
||||
def test_invalid_k8s_authentication(mock_api_client, k8s_client, invalid_token):
|
||||
# Mock the Kubernetes client
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock failed token validation by raising an exception
|
||||
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
|
||||
mock_validate.side_effect = ValueError("Invalid or expired token")
|
||||
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid or expired token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_k8s_middleware_with_access_attributes(mock_k8s_middleware, mock_scope):
|
||||
middleware, mock_app = mock_k8s_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("kubernetes.client.ApiClient") as mock_api_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock successful token validation
|
||||
mock_client.set_default_header = AsyncMock()
|
||||
|
||||
# Mock token payload with access attributes
|
||||
mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiIsImdyb3VwcyI6WyJtbC10ZWFtIl19", "signature"]
|
||||
mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode())
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
assert mock_scope["user_attributes"]["roles"] == ["admin"]
|
||||
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_k8s_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("kubernetes.client.ApiClient") as mock_api_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_api_client.return_value = mock_client
|
||||
|
||||
# Mock successful token validation
|
||||
mock_client.set_default_header = AsyncMock()
|
||||
|
||||
# Mock token payload without access attributes
|
||||
mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiJ9", "signature"]
|
||||
mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode())
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "roles" in attributes
|
||||
assert attributes["roles"] == ["admin"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
|
111
uv.lock
generated
111
uv.lock
generated
|
@ -676,6 +676,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "durationpy"
|
||||
version = "0.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/31/e9/f49c4e7fccb77fa5c43c2480e09a857a78b41e7331a75e128ed5df45c56b/durationpy-0.9.tar.gz", hash = "sha256:fd3feb0a69a0057d582ef643c355c40d2fa1c942191f914d12203b1a01ac722a", size = 3186 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/a3/ac312faeceffd2d8f86bc6dcb5c401188ba5a01bc88e69bed97578a0dfcd/durationpy-0.9-py3-none-any.whl", hash = "sha256:e65359a7af5cedad07fb77a2dd3f390f8eb0b74cb845589fa6c057086834dd38", size = 3461 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.2"
|
||||
|
@ -842,6 +851,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-auth"
|
||||
version = "2.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cachetools" },
|
||||
{ name = "pyasn1-modules" },
|
||||
{ name = "rsa" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/eb/d504ba1daf190af6b204a9d4714d457462b486043744901a6eeea711f913/google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4", size = 270866 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/47/603554949a37bca5b7f894d51896a9c534b9eab808e2520a748e081669d0/google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a", size = 210770 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "googleapis-common-protos"
|
||||
version = "1.67.0"
|
||||
|
@ -1289,6 +1312,28 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kubernetes"
|
||||
version = "32.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "durationpy" },
|
||||
{ name = "google-auth" },
|
||||
{ name = "oauthlib" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "requests-oauthlib" },
|
||||
{ name = "six" },
|
||||
{ name = "urllib3" },
|
||||
{ name = "websocket-client" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b7/e8/0598f0e8b4af37cd9b10d8b87386cf3173cb8045d834ab5f6ec347a758b3/kubernetes-32.0.1.tar.gz", hash = "sha256:42f43d49abd437ada79a79a16bd48a604d3471a117a8347e87db693f2ba0ba28", size = 946691 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/10/9f8af3e6f569685ce3af7faab51c8dd9d93b9c38eba339ca31c746119447/kubernetes-32.0.1-py2.py3-none-any.whl", hash = "sha256:35282ab8493b938b08ab5526c7ce66588232df00ef5e1dbe88a419107dc10998", size = 1988070 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "levenshtein"
|
||||
version = "0.27.1"
|
||||
|
@ -1384,6 +1429,7 @@ dependencies = [
|
|||
{ name = "huggingface-hub" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "kubernetes" },
|
||||
{ name = "llama-stack-client" },
|
||||
{ name = "openai" },
|
||||
{ name = "pillow" },
|
||||
|
@ -1485,6 +1531,7 @@ requires-dist = [
|
|||
{ name = "jinja2", specifier = ">=3.1.6" },
|
||||
{ name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "kubernetes" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.2" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.1" },
|
||||
{ name = "mcp", marker = "extra == 'test'" },
|
||||
|
@ -2022,6 +2069,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/17/7f/d322a4125405920401450118dbdc52e0384026bd669939484670ce8b2ab9/numpy-2.2.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:783145835458e60fa97afac25d511d00a1eca94d4a8f3ace9fe2043003c678e4", size = 12839607 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oauthlib"
|
||||
version = "3.2.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.71.0"
|
||||
|
@ -2525,6 +2581,27 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/ed/bd/54907846383dcc7ee28772d7e646f6c34276a17da740002a5cefe90f04f7/pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:58d9397b2e273ef76264b45531e9d552d8ec8a6688b7390b5be44c02a37aade8", size = 42085744 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1-modules"
|
||||
version = "0.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyasn1" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
|
@ -3135,6 +3212,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-oauthlib"
|
||||
version = "2.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "oauthlib" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.9.4"
|
||||
|
@ -3234,6 +3324,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/9f/2e/c5c1689e80298d4e94c75b70faada4c25445739d91b94c211244a3ed7ed1/rpds_py-0.22.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:177c7c0fce2855833819c98e43c262007f42ce86651ffbb84f37883308cb0e7d", size = 233338 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "4.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyasn1" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruamel-yaml"
|
||||
version = "0.18.10"
|
||||
|
@ -4109,6 +4211,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "websocket-client"
|
||||
version = "1.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e6/30/fba0d96b4b5fbf5948ed3f4681f7da2f9f64512e1d303f94b4cc174c24a5/websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da", size = 54648 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "15.0"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue