diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index b62227a84..7a42f503a 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -208,6 +208,80 @@ And must respond with: If no access attributes are returned, the token is used as a namespace. +### Quota Configuration + +The `quota` section allows you to enable server-side request throttling for both +authenticated and anonymous clients. This is useful for preventing abuse, enforcing +fairness across tenants, and controlling infrastructure costs without requiring +client-side rate limiting or external proxies. + +Quotas are disabled by default. When enabled, each client is tracked using either: + +* Their authenticated `client_id` (derived from the Bearer token), or +* Their IP address (fallback for anonymous requests) + +Quota state is stored in a SQLite-backed key-value store, and rate limits are applied +within a configurable time window (currently only `day` is supported). + +#### Example + +```yaml +server: + quota: + kvstore: + type: sqlite + db_path: ./quotas.db + anonymous_max_requests: 100 + authenticated_max_requests: 1000 + period: day +``` + +#### Configuration Options + +| Field | Description | +| ---------------------------- | -------------------------------------------------------------------------- | +| `kvstore` | Required. Backend storage config for tracking request counts. | +| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. | +| `kvstore.db_path` | File path to the SQLite database. | +| `anonymous_max_requests` | Max requests per period for unauthenticated clients. | +| `authenticated_max_requests` | Max requests per period for authenticated clients. | +| `period` | Time window for quota enforcement. Only `"day"` is supported. | + +> Note: if `authenticated_max_requests` is set but no authentication provider is +configured, the server will fall back to applying `anonymous_max_requests` to all +clients. + +#### Example with Authentication Enabled + +```yaml +server: + port: 8321 + auth: + provider_type: custom + config: + endpoint: https://auth.example.com/validate + quota: + kvstore: + type: sqlite + db_path: ./quotas.db + anonymous_max_requests: 100 + authenticated_max_requests: 1000 + period: day +``` + +If a client exceeds their limit, the server responds with: + +```http +HTTP/1.1 429 Too Many Requests +Content-Type: application/json + +{ + "error": { + "message": "Quota exceeded" + } +} +``` + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index be5629ba1..ca3664828 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -25,7 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.datatypes import Api, ProviderSpec -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -235,6 +235,19 @@ class AuthenticationConfig(BaseModel): ) +class QuotaPeriod(str, Enum): + DAY = "day" + + +class QuotaConfig(BaseModel): + kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") + anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period") + authenticated_max_requests: int = Field( + default=1000, description="Max requests for authenticated clients per period" + ) + period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -262,6 +275,10 @@ class ServerConfig(BaseModel): default=None, description="The host the server should listen on", ) + quota: QuotaConfig | None = Field( + default=None, + description="Per client quota request configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 83436c51f..67acffe3e 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -113,6 +113,10 @@ class AuthenticationMiddleware: "roles": [token], } + # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) + # can identify the requester and enforce per-client rate limits. + scope["authenticated_client_id"] = token + # Store attributes in request scope scope["user_attributes"] = user_attributes scope["principal"] = validation_result.principal diff --git a/llama_stack/distribution/server/quota.py b/llama_stack/distribution/server/quota.py new file mode 100644 index 000000000..ddbffae64 --- /dev/null +++ b/llama_stack/distribution/server/quota.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import time +from datetime import datetime, timedelta, timezone + +from starlette.types import ASGIApp, Receive, Scope, Send + +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl + +logger = get_logger(name=__name__, category="quota") + + +class QuotaMiddleware: + """ + ASGI middleware that enforces separate quotas for authenticated and anonymous clients + within a configurable time window. + + - For authenticated requests, it reads the client ID from the + `Authorization: Bearer ` header. + - For anonymous requests, it falls back to the IP address of the client. + Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned + once a client exceeds its quota. + """ + + def __init__( + self, + app: ASGIApp, + kv_config: KVStoreConfig, + anonymous_max_requests: int, + authenticated_max_requests: int, + window_seconds: int = 86400, + ): + self.app = app + self.kv_config = kv_config + self.kv: KVStore | None = None + self.anonymous_max_requests = anonymous_max_requests + self.authenticated_max_requests = authenticated_max_requests + self.window_seconds = window_seconds + + if isinstance(self.kv_config, SqliteKVStoreConfig): + logger.warning( + "QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. " + f"window_seconds={self.window_seconds}" + ) + + async def _get_kv(self) -> KVStore: + if self.kv is None: + self.kv = await kvstore_impl(self.kv_config) + return self.kv + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] == "http": + # pick key & limit based on auth + auth_id = scope.get("authenticated_client_id") + if auth_id: + key_id = auth_id + limit = self.authenticated_max_requests + else: + # fallback to IP + client = scope.get("client") + key_id = client[0] if client else "anonymous" + limit = self.anonymous_max_requests + + current_window = int(time.time() // self.window_seconds) + key = f"quota:{key_id}:{current_window}" + + try: + kv = await self._get_kv() + prev = await kv.get(key) or "0" + count = int(prev) + 1 + + if int(prev) == 0: + # Set with expiration datetime when it is the first request in the window. + expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds) + await kv.set(key, str(count), expiration=expiration) + else: + await kv.set(key, str(count)) + except Exception: + logger.exception("Failed to access KV store for quota") + return await self._send_error(send, 500, "Quota service error") + + if count > limit: + logger.warning( + "Quota exceeded for client %s: %d/%d", + key_id, + count, + limit, + ) + return await self._send_error(send, 429, "Quota exceeded") + + return await self.app(scope, receive, send) + + async def _send_error(self, send: Send, status: int, message: str): + await send( + { + "type": "http.response.start", + "status": status, + "headers": [[b"content-type", b"application/json"]], + } + ) + body = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": body}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e25bf0817..52f2b71b0 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -60,6 +60,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints +from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -434,6 +435,35 @@ def main(args: argparse.Namespace | None = None): if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) + else: + if config.server.quota: + quota = config.server.quota + logger.warning( + "Configured authenticated_max_requests (%d) but no auth is enabled; " + "falling back to anonymous_max_requests (%d) for all the requests", + quota.authenticated_max_requests, + quota.anonymous_max_requests, + ) + + if config.server.quota: + logger.info("Enabling quota middleware for authenticated and anonymous clients") + + quota = config.server.quota + anonymous_max_requests = quota.anonymous_max_requests + # if auth is disabled, use the anonymous max requests + authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests + + kv_config = quota.kvstore + window_map = {"day": 86400} + window_seconds = window_map[quota.period.value] + + app.add_middleware( + QuotaMiddleware, + kv_config=kv_config, + anonymous_max_requests=anonymous_max_requests, + authenticated_max_requests=authenticated_max_requests, + window_seconds=window_seconds, + ) try: impls = asyncio.run(construct_stack(config)) diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py new file mode 100644 index 000000000..763bf8e94 --- /dev/null +++ b/tests/unit/server/test_quota.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from starlette.middleware.base import BaseHTTPMiddleware + +from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod +from llama_stack.distribution.server.quota import QuotaMiddleware +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class InjectClientIDMiddleware(BaseHTTPMiddleware): + """ + Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware. + """ + + def __init__(self, app, client_id="client1"): + super().__init__(app) + self.client_id = client_id + + async def dispatch(self, request: Request, call_next): + request.scope["authenticated_client_id"] = self.client_id + return await call_next(request) + + +def build_quota_config(db_path) -> QuotaConfig: + return QuotaConfig( + kvstore=SqliteKVStoreConfig(db_path=str(db_path)), + anonymous_max_requests=1, + authenticated_max_requests=2, + period=QuotaPeriod.DAY, + ) + + +@pytest.fixture +def auth_app(tmp_path, request): + """ + FastAPI app with InjectClientIDMiddleware and QuotaMiddleware for authenticated testing. + Each test gets its own DB file. + """ + inner_app = FastAPI() + + @inner_app.get("/test") + async def test_endpoint(): + return {"message": "ok"} + + db_path = tmp_path / f"quota_{request.node.name}.db" + quota = build_quota_config(db_path) + + app = InjectClientIDMiddleware( + QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ), + client_id=f"client_{request.node.name}", + ) + return app + + +def test_authenticated_quota_allows_up_to_limit(auth_app): + client = TestClient(auth_app) + assert client.get("/test").status_code == 200 + assert client.get("/test").status_code == 200 + + +def test_authenticated_quota_blocks_after_limit(auth_app): + client = TestClient(auth_app) + client.get("/test") + client.get("/test") + resp = client.get("/test") + assert resp.status_code == 429 + assert resp.json()["error"]["message"] == "Quota exceeded" + + +def test_anonymous_quota_allows_up_to_limit(tmp_path, request): + inner_app = FastAPI() + + @inner_app.get("/test") + async def test_endpoint(): + return {"message": "ok"} + + db_path = tmp_path / f"quota_anon_{request.node.name}.db" + quota = build_quota_config(db_path) + + app = QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ) + + client = TestClient(app) + assert client.get("/test").status_code == 200 + + +def test_anonymous_quota_blocks_after_limit(tmp_path, request): + inner_app = FastAPI() + + @inner_app.get("/test") + async def test_endpoint(): + return {"message": "ok"} + + db_path = tmp_path / f"quota_anon_{request.node.name}.db" + quota = build_quota_config(db_path) + + app = QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ) + + client = TestClient(app) + client.get("/test") + resp = client.get("/test") + assert resp.status_code == 429 + assert resp.json()["error"]["message"] == "Quota exceeded"