llama-stack-mirror/llama_stack/distribution/server/auth.py
liangwen12year 2890243107
feat(quota): add server‑side per‑client request quotas (requires auth) (#2096)
# What does this PR do?
feat(quota): add server‑side per‑client request quotas (requires auth)
    
Unrestricted usage can lead to runaway costs and fragmented client-side
    workarounds. This commit introduces a native quota mechanism to the
    server, giving operators a unified, centrally managed throttle for
    per-client requests—without needing extra proxies or custom client
logic. This helps contain cloud-compute expenses, enables fine-grained
usage control, and simplifies deployment and monitoring of Llama Stack
services. Quotas are fully opt-in and have no effect unless explicitly
    configured.
    
    Notice that Quotas are fully opt-in and require authentication to be
enabled. The 'sqlite' is the only supported quota `type` at this time,
any other `type` will be rejected. And the only supported `period` is
    'day'.
    
    Highlights:
    
    - Adds `QuotaMiddleware` to enforce per-client request quotas:
      - Uses `Authorization: Bearer <client_id>` (from
        AuthenticationMiddleware)
      - Tracks usage via a SQLite-based KV store
      - Returns 429 when the quota is exceeded
    
    - Extends `ServerConfig` with a `quota` section (type + config)
    
- Enforces strict coupling: quotas require authentication or the server
      will fail to start
    
    Behavior changes:
    - Quotas are disabled by default unless explicitly configured
    - SQLite defaults to `./quotas.db` if no DB path is set
    - The server requires authentication when quotas are enabled
    
    To enable per-client request quotas in `run.yaml`, add:
    ```
    server:
      port: 8321
      auth:
        provider_type: "custom"
        config:
          endpoint: "https://auth.example.com/validate"
      quota:
        type: sqlite
        config:
          db_path: ./quotas.db
          limit:
            max_requests: 1000
            period: day

[//]: # (If resolving an issue, uncomment and update the line below)
Closes #2093

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)

Signed-off-by: Wen Liang <wenliang@redhat.com>
Co-authored-by: Wen Liang <wenliang@redhat.com>
2025-05-21 10:58:45 +02:00

138 lines
5.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import httpx
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware:
"""Middleware that authenticates requests using configured authentication provider.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Uses the configured auth provider to validate the token
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
Authentication Request Format for Custom Auth Provider:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
"request": {
"path": "/models/list",
"headers": {
"content-type": "application/json",
"user-agent": "..."
// All headers except Authorization
},
"params": {
"limit": ["100"],
"offset": ["0"]
// Query parameters as key -> list of values
}
}
}
```
Expected Auth Endpoint Response Format:
```json
{
"access_attributes": { // Structured attribute format
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
"namespaces": ["research"]
},
"message": "Optional message about auth result"
}
```
Token Validation:
Each provider implements its own token validation logic:
- Kubernetes: Uses TokenReview API to validate service account tokens
- Custom: Sends token to custom endpoint for validation
Attribute-Based Access Control:
The attributes returned by the auth provider are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth provider doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthProviderConfig):
self.app = app
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
token = auth_header.split("Bearer ", 1)[1]
# Validate token and get access attributes
try:
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"roles": [token],
}
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope
scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
)
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
await send(
{
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": error_msg})