mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
feat(quota): support per‑client and anonymous server‑side request quotas
Unrestricted API usage can lead to runaway costs and fragmented client-side throttling logic. This commit introduces a built-in quota mechanism at the server level, enabling operators to centrally enforce per-client and anonymous rate limits—without needing external proxies or client changes. This helps contain compute costs, enforces fair usage, and simplifies deployment and monitoring of Llama Stack services. Quotas are fully opt-in and have no effect unless explicitly configured. Currently, SQLite is the only supported KV store. If quotas are configured but authentication is disabled, authenticated limits will gracefully fall back to anonymous limits. Highlights: - Adds `QuotaMiddleware` to enforce request quotas: - Uses bearer token as client ID if present; otherwise falls back to IP address - Tracks requests in KV store with per-key TTL expiration - Returns HTTP 429 if a client exceeds their quota - Extends `ServerConfig` with a `quota` section: - `kvstore`: configuration for the backend (currently only SQLite) - `anonymous_max_requests`: per-period cap for unauthenticated clients - `authenticated_max_requests`: per-period cap for authenticated clients - `period`: duration of the quota window (currently only `day` is supported) - Adds full test coverage with FastAPI `TestClient` and custom middleware injection Behavior changes: - Quotas are disabled by default unless explicitly configured - Anonymous users get a conservative default quota; authenticated clients can be given more generous limits To enable per-client request quotas in `run.yaml`, add: ```yaml server: port: 8321 auth: provider_type: custom config: endpoint: https://auth.example.com/validate quota: kvstore: type: sqlite db_path: ./quotas.db anonymous_max_requests: 100 authenticated_max_requests: 1000 period: day ``` Signed-off-by: Wen Liang <wenliang@redhat.com>
This commit is contained in:
parent
ed7b4731aa
commit
dacd522f57
6 changed files with 363 additions and 1 deletions
|
@ -208,6 +208,80 @@ And must respond with:
|
|||
|
||||
If no access attributes are returned, the token is used as a namespace.
|
||||
|
||||
### Quota Configuration
|
||||
|
||||
The `quota` section allows you to enable server-side request throttling for both
|
||||
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
|
||||
fairness across tenants, and controlling infrastructure costs without requiring
|
||||
client-side rate limiting or external proxies.
|
||||
|
||||
Quotas are disabled by default. When enabled, each client is tracked using either:
|
||||
|
||||
* Their authenticated `client_id` (derived from the Bearer token), or
|
||||
* Their IP address (fallback for anonymous requests)
|
||||
|
||||
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
|
||||
within a configurable time window (currently only `day` is supported).
|
||||
|
||||
#### Example
|
||||
|
||||
```yaml
|
||||
server:
|
||||
quota:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ./quotas.db
|
||||
anonymous_max_requests: 100
|
||||
authenticated_max_requests: 1000
|
||||
period: day
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
| Field | Description |
|
||||
| ---------------------------- | -------------------------------------------------------------------------- |
|
||||
| `kvstore` | Required. Backend storage config for tracking request counts. |
|
||||
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
|
||||
| `kvstore.db_path` | File path to the SQLite database. |
|
||||
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
|
||||
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
|
||||
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
|
||||
|
||||
> Note: if `authenticated_max_requests` is set but no authentication provider is
|
||||
configured, the server will fall back to applying `anonymous_max_requests` to all
|
||||
clients.
|
||||
|
||||
#### Example with Authentication Enabled
|
||||
|
||||
```yaml
|
||||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: custom
|
||||
config:
|
||||
endpoint: https://auth.example.com/validate
|
||||
quota:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ./quotas.db
|
||||
anonymous_max_requests: 100
|
||||
authenticated_max_requests: 1000
|
||||
period: day
|
||||
```
|
||||
|
||||
If a client exceeds their limit, the server responds with:
|
||||
|
||||
```http
|
||||
HTTP/1.1 429 Too Many Requests
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"error": {
|
||||
"message": "Quota exceeded"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Extending to handle Safety
|
||||
|
||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
|||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
@ -235,6 +235,19 @@ class AuthenticationConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class QuotaPeriod(str, Enum):
|
||||
DAY = "day"
|
||||
|
||||
|
||||
class QuotaConfig(BaseModel):
|
||||
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
||||
authenticated_max_requests: int = Field(
|
||||
default=1000, description="Max requests for authenticated clients per period"
|
||||
)
|
||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
@ -262,6 +275,10 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="The host the server should listen on",
|
||||
)
|
||||
quota: QuotaConfig | None = Field(
|
||||
default=None,
|
||||
description="Per client quota request configuration",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
|
@ -113,6 +113,10 @@ class AuthenticationMiddleware:
|
|||
"roles": [token],
|
||||
}
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
# can identify the requester and enforce per-client rate limits.
|
||||
scope["authenticated_client_id"] = token
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
scope["principal"] = validation_result.principal
|
||||
|
|
110
llama_stack/distribution/server/quota.py
Normal file
110
llama_stack/distribution/server/quota.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="quota")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
"""
|
||||
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
||||
within a configurable time window.
|
||||
|
||||
- For authenticated requests, it reads the client ID from the
|
||||
`Authorization: Bearer <client_id>` header.
|
||||
- For anonymous requests, it falls back to the IP address of the client.
|
||||
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
||||
once a client exceeds its quota.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
kv_config: KVStoreConfig,
|
||||
anonymous_max_requests: int,
|
||||
authenticated_max_requests: int,
|
||||
window_seconds: int = 86400,
|
||||
):
|
||||
self.app = app
|
||||
self.kv_config = kv_config
|
||||
self.kv: KVStore | None = None
|
||||
self.anonymous_max_requests = anonymous_max_requests
|
||||
self.authenticated_max_requests = authenticated_max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||
logger.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
|
||||
async def _get_kv(self) -> KVStore:
|
||||
if self.kv is None:
|
||||
self.kv = await kvstore_impl(self.kv_config)
|
||||
return self.kv
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] == "http":
|
||||
# pick key & limit based on auth
|
||||
auth_id = scope.get("authenticated_client_id")
|
||||
if auth_id:
|
||||
key_id = auth_id
|
||||
limit = self.authenticated_max_requests
|
||||
else:
|
||||
# fallback to IP
|
||||
client = scope.get("client")
|
||||
key_id = client[0] if client else "anonymous"
|
||||
limit = self.anonymous_max_requests
|
||||
|
||||
current_window = int(time.time() // self.window_seconds)
|
||||
key = f"quota:{key_id}:{current_window}"
|
||||
|
||||
try:
|
||||
kv = await self._get_kv()
|
||||
prev = await kv.get(key) or "0"
|
||||
count = int(prev) + 1
|
||||
|
||||
if int(prev) == 0:
|
||||
# Set with expiration datetime when it is the first request in the window.
|
||||
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
|
||||
await kv.set(key, str(count), expiration=expiration)
|
||||
else:
|
||||
await kv.set(key, str(count))
|
||||
except Exception:
|
||||
logger.exception("Failed to access KV store for quota")
|
||||
return await self._send_error(send, 500, "Quota service error")
|
||||
|
||||
if count > limit:
|
||||
logger.warning(
|
||||
"Quota exceeded for client %s: %d/%d",
|
||||
key_id,
|
||||
count,
|
||||
limit,
|
||||
)
|
||||
return await self._send_error(send, 429, "Quota exceeded")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_error(self, send: Send, status: int, message: str):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
body = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": body})
|
|
@ -60,6 +60,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .endpoints import get_all_api_endpoints
|
||||
from .quota import QuotaMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
@ -434,6 +435,35 @@ def main(args: argparse.Namespace | None = None):
|
|||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
logger.warning(
|
||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||
quota.authenticated_max_requests,
|
||||
quota.anonymous_max_requests,
|
||||
)
|
||||
|
||||
if config.server.quota:
|
||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||
|
||||
quota = config.server.quota
|
||||
anonymous_max_requests = quota.anonymous_max_requests
|
||||
# if auth is disabled, use the anonymous max requests
|
||||
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
||||
|
||||
kv_config = quota.kvstore
|
||||
window_map = {"day": 86400}
|
||||
window_seconds = window_map[quota.period.value]
|
||||
|
||||
app.add_middleware(
|
||||
QuotaMiddleware,
|
||||
kv_config=kv_config,
|
||||
anonymous_max_requests=anonymous_max_requests,
|
||||
authenticated_max_requests=authenticated_max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
|
|
127
tests/unit/server/test_quota.py
Normal file
127
tests/unit/server/test_quota.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.distribution.server.quota import QuotaMiddleware
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
|
||||
"""
|
||||
|
||||
def __init__(self, app, client_id="client1"):
|
||||
super().__init__(app)
|
||||
self.client_id = client_id
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
request.scope["authenticated_client_id"] = self.client_id
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def build_quota_config(db_path) -> QuotaConfig:
|
||||
return QuotaConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||
anonymous_max_requests=1,
|
||||
authenticated_max_requests=2,
|
||||
period=QuotaPeriod.DAY,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_app(tmp_path, request):
|
||||
"""
|
||||
FastAPI app with InjectClientIDMiddleware and QuotaMiddleware for authenticated testing.
|
||||
Each test gets its own DB file.
|
||||
"""
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "ok"}
|
||||
|
||||
db_path = tmp_path / f"quota_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
app = InjectClientIDMiddleware(
|
||||
QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
),
|
||||
client_id=f"client_{request.node.name}",
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||
client = TestClient(auth_app)
|
||||
assert client.get("/test").status_code == 200
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_authenticated_quota_blocks_after_limit(auth_app):
|
||||
client = TestClient(auth_app)
|
||||
client.get("/test")
|
||||
client.get("/test")
|
||||
resp = client.get("/test")
|
||||
assert resp.status_code == 429
|
||||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||
|
||||
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "ok"}
|
||||
|
||||
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
app = QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "ok"}
|
||||
|
||||
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
app = QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
client.get("/test")
|
||||
resp = client.get("/test")
|
||||
assert resp.status_code == 429
|
||||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
Loading…
Add table
Add a link
Reference in a new issue