forked from phoenix-oss/llama-stack-mirror
feat(quota): add server‑side per‑client request quotas (requires auth) (#2096)
# What does this PR do? feat(quota): add server‑side per‑client request quotas (requires auth) Unrestricted usage can lead to runaway costs and fragmented client-side workarounds. This commit introduces a native quota mechanism to the server, giving operators a unified, centrally managed throttle for per-client requests—without needing extra proxies or custom client logic. This helps contain cloud-compute expenses, enables fine-grained usage control, and simplifies deployment and monitoring of Llama Stack services. Quotas are fully opt-in and have no effect unless explicitly configured. Notice that Quotas are fully opt-in and require authentication to be enabled. The 'sqlite' is the only supported quota `type` at this time, any other `type` will be rejected. And the only supported `period` is 'day'. Highlights: - Adds `QuotaMiddleware` to enforce per-client request quotas: - Uses `Authorization: Bearer <client_id>` (from AuthenticationMiddleware) - Tracks usage via a SQLite-based KV store - Returns 429 when the quota is exceeded - Extends `ServerConfig` with a `quota` section (type + config) - Enforces strict coupling: quotas require authentication or the server will fail to start Behavior changes: - Quotas are disabled by default unless explicitly configured - SQLite defaults to `./quotas.db` if no DB path is set - The server requires authentication when quotas are enabled To enable per-client request quotas in `run.yaml`, add: ``` server: port: 8321 auth: provider_type: "custom" config: endpoint: "https://auth.example.com/validate" quota: type: sqlite config: db_path: ./quotas.db limit: max_requests: 1000 period: day [//]: # (If resolving an issue, uncomment and update the line below) Closes #2093 ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: Wen Liang <wenliang@redhat.com> Co-authored-by: Wen Liang <wenliang@redhat.com>
This commit is contained in:
parent
5a3d777b20
commit
2890243107
6 changed files with 363 additions and 1 deletions
|
@ -208,6 +208,80 @@ And must respond with:
|
||||||
|
|
||||||
If no access attributes are returned, the token is used as a namespace.
|
If no access attributes are returned, the token is used as a namespace.
|
||||||
|
|
||||||
|
### Quota Configuration
|
||||||
|
|
||||||
|
The `quota` section allows you to enable server-side request throttling for both
|
||||||
|
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
|
||||||
|
fairness across tenants, and controlling infrastructure costs without requiring
|
||||||
|
client-side rate limiting or external proxies.
|
||||||
|
|
||||||
|
Quotas are disabled by default. When enabled, each client is tracked using either:
|
||||||
|
|
||||||
|
* Their authenticated `client_id` (derived from the Bearer token), or
|
||||||
|
* Their IP address (fallback for anonymous requests)
|
||||||
|
|
||||||
|
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
|
||||||
|
within a configurable time window (currently only `day` is supported).
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
quota:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./quotas.db
|
||||||
|
anonymous_max_requests: 100
|
||||||
|
authenticated_max_requests: 1000
|
||||||
|
period: day
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description |
|
||||||
|
| ---------------------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `kvstore` | Required. Backend storage config for tracking request counts. |
|
||||||
|
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
|
||||||
|
| `kvstore.db_path` | File path to the SQLite database. |
|
||||||
|
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
|
||||||
|
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
|
||||||
|
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
|
||||||
|
|
||||||
|
> Note: if `authenticated_max_requests` is set but no authentication provider is
|
||||||
|
configured, the server will fall back to applying `anonymous_max_requests` to all
|
||||||
|
clients.
|
||||||
|
|
||||||
|
#### Example with Authentication Enabled
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
auth:
|
||||||
|
provider_type: custom
|
||||||
|
config:
|
||||||
|
endpoint: https://auth.example.com/validate
|
||||||
|
quota:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./quotas.db
|
||||||
|
anonymous_max_requests: 100
|
||||||
|
authenticated_max_requests: 1000
|
||||||
|
period: day
|
||||||
|
```
|
||||||
|
|
||||||
|
If a client exceeds their limit, the server responds with:
|
||||||
|
|
||||||
|
```http
|
||||||
|
HTTP/1.1 429 Too Many Requests
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "Quota exceeded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
@ -235,6 +235,19 @@ class AuthenticationConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaPeriod(str, Enum):
|
||||||
|
DAY = "day"
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaConfig(BaseModel):
|
||||||
|
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||||
|
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
||||||
|
authenticated_max_requests: int = Field(
|
||||||
|
default=1000, description="Max requests for authenticated clients per period"
|
||||||
|
)
|
||||||
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -262,6 +275,10 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="The host the server should listen on",
|
description="The host the server should listen on",
|
||||||
)
|
)
|
||||||
|
quota: QuotaConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Per client quota request configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
@ -113,6 +113,10 @@ class AuthenticationMiddleware:
|
||||||
"roles": [token],
|
"roles": [token],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
|
# can identify the requester and enforce per-client rate limits.
|
||||||
|
scope["authenticated_client_id"] = token
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
scope["user_attributes"] = user_attributes
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
|
|
110
llama_stack/distribution/server/quota.py
Normal file
110
llama_stack/distribution/server/quota.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="quota")
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaMiddleware:
|
||||||
|
"""
|
||||||
|
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
||||||
|
within a configurable time window.
|
||||||
|
|
||||||
|
- For authenticated requests, it reads the client ID from the
|
||||||
|
`Authorization: Bearer <client_id>` header.
|
||||||
|
- For anonymous requests, it falls back to the IP address of the client.
|
||||||
|
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
||||||
|
once a client exceeds its quota.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
kv_config: KVStoreConfig,
|
||||||
|
anonymous_max_requests: int,
|
||||||
|
authenticated_max_requests: int,
|
||||||
|
window_seconds: int = 86400,
|
||||||
|
):
|
||||||
|
self.app = app
|
||||||
|
self.kv_config = kv_config
|
||||||
|
self.kv: KVStore | None = None
|
||||||
|
self.anonymous_max_requests = anonymous_max_requests
|
||||||
|
self.authenticated_max_requests = authenticated_max_requests
|
||||||
|
self.window_seconds = window_seconds
|
||||||
|
|
||||||
|
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||||
|
logger.warning(
|
||||||
|
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||||
|
f"window_seconds={self.window_seconds}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_kv(self) -> KVStore:
|
||||||
|
if self.kv is None:
|
||||||
|
self.kv = await kvstore_impl(self.kv_config)
|
||||||
|
return self.kv
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
# pick key & limit based on auth
|
||||||
|
auth_id = scope.get("authenticated_client_id")
|
||||||
|
if auth_id:
|
||||||
|
key_id = auth_id
|
||||||
|
limit = self.authenticated_max_requests
|
||||||
|
else:
|
||||||
|
# fallback to IP
|
||||||
|
client = scope.get("client")
|
||||||
|
key_id = client[0] if client else "anonymous"
|
||||||
|
limit = self.anonymous_max_requests
|
||||||
|
|
||||||
|
current_window = int(time.time() // self.window_seconds)
|
||||||
|
key = f"quota:{key_id}:{current_window}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
kv = await self._get_kv()
|
||||||
|
prev = await kv.get(key) or "0"
|
||||||
|
count = int(prev) + 1
|
||||||
|
|
||||||
|
if int(prev) == 0:
|
||||||
|
# Set with expiration datetime when it is the first request in the window.
|
||||||
|
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
|
||||||
|
await kv.set(key, str(count), expiration=expiration)
|
||||||
|
else:
|
||||||
|
await kv.set(key, str(count))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to access KV store for quota")
|
||||||
|
return await self._send_error(send, 500, "Quota service error")
|
||||||
|
|
||||||
|
if count > limit:
|
||||||
|
logger.warning(
|
||||||
|
"Quota exceeded for client %s: %d/%d",
|
||||||
|
key_id,
|
||||||
|
count,
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
return await self._send_error(send, 429, "Quota exceeded")
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
async def _send_error(self, send: Send, status: int, message: str):
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": status,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
body = json.dumps({"error": {"message": message}}).encode()
|
||||||
|
await send({"type": "http.response.body", "body": body})
|
|
@ -60,6 +60,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
@ -434,6 +435,35 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||||
|
else:
|
||||||
|
if config.server.quota:
|
||||||
|
quota = config.server.quota
|
||||||
|
logger.warning(
|
||||||
|
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||||
|
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||||
|
quota.authenticated_max_requests,
|
||||||
|
quota.anonymous_max_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.server.quota:
|
||||||
|
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||||
|
|
||||||
|
quota = config.server.quota
|
||||||
|
anonymous_max_requests = quota.anonymous_max_requests
|
||||||
|
# if auth is disabled, use the anonymous max requests
|
||||||
|
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
||||||
|
|
||||||
|
kv_config = quota.kvstore
|
||||||
|
window_map = {"day": 86400}
|
||||||
|
window_seconds = window_map[quota.period.value]
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
QuotaMiddleware,
|
||||||
|
kv_config=kv_config,
|
||||||
|
anonymous_max_requests=anonymous_max_requests,
|
||||||
|
authenticated_max_requests=authenticated_max_requests,
|
||||||
|
window_seconds=window_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
|
127
tests/unit/server/test_quota.py
Normal file
127
tests/unit/server/test_quota.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod
|
||||||
|
from llama_stack.distribution.server.quota import QuotaMiddleware
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, client_id="client1"):
|
||||||
|
super().__init__(app)
|
||||||
|
self.client_id = client_id
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
request.scope["authenticated_client_id"] = self.client_id
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
def build_quota_config(db_path) -> QuotaConfig:
|
||||||
|
return QuotaConfig(
|
||||||
|
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||||
|
anonymous_max_requests=1,
|
||||||
|
authenticated_max_requests=2,
|
||||||
|
period=QuotaPeriod.DAY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_app(tmp_path, request):
|
||||||
|
"""
|
||||||
|
FastAPI app with InjectClientIDMiddleware and QuotaMiddleware for authenticated testing.
|
||||||
|
Each test gets its own DB file.
|
||||||
|
"""
|
||||||
|
inner_app = FastAPI()
|
||||||
|
|
||||||
|
@inner_app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
db_path = tmp_path / f"quota_{request.node.name}.db"
|
||||||
|
quota = build_quota_config(db_path)
|
||||||
|
|
||||||
|
app = InjectClientIDMiddleware(
|
||||||
|
QuotaMiddleware(
|
||||||
|
inner_app,
|
||||||
|
kv_config=quota.kvstore,
|
||||||
|
anonymous_max_requests=quota.anonymous_max_requests,
|
||||||
|
authenticated_max_requests=quota.authenticated_max_requests,
|
||||||
|
window_seconds=86400,
|
||||||
|
),
|
||||||
|
client_id=f"client_{request.node.name}",
|
||||||
|
)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||||
|
client = TestClient(auth_app)
|
||||||
|
assert client.get("/test").status_code == 200
|
||||||
|
assert client.get("/test").status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_authenticated_quota_blocks_after_limit(auth_app):
|
||||||
|
client = TestClient(auth_app)
|
||||||
|
client.get("/test")
|
||||||
|
client.get("/test")
|
||||||
|
resp = client.get("/test")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||||
|
inner_app = FastAPI()
|
||||||
|
|
||||||
|
@inner_app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
|
||||||
|
quota = build_quota_config(db_path)
|
||||||
|
|
||||||
|
app = QuotaMiddleware(
|
||||||
|
inner_app,
|
||||||
|
kv_config=quota.kvstore,
|
||||||
|
anonymous_max_requests=quota.anonymous_max_requests,
|
||||||
|
authenticated_max_requests=quota.authenticated_max_requests,
|
||||||
|
window_seconds=86400,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
assert client.get("/test").status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||||
|
inner_app = FastAPI()
|
||||||
|
|
||||||
|
@inner_app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
|
||||||
|
quota = build_quota_config(db_path)
|
||||||
|
|
||||||
|
app = QuotaMiddleware(
|
||||||
|
inner_app,
|
||||||
|
kv_config=quota.kvstore,
|
||||||
|
anonymous_max_requests=quota.anonymous_max_requests,
|
||||||
|
authenticated_max_requests=quota.authenticated_max_requests,
|
||||||
|
window_seconds=86400,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/test")
|
||||||
|
resp = client.get("/test")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert resp.json()["error"]["message"] == "Quota exceeded"
|
Loading…
Add table
Add a link
Reference in a new issue