mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-27 21:30:25 +00:00
feat(quota): support per‑client and anonymous server‑side request quotas
Unrestricted API usage can lead to runaway costs and fragmented client-side
throttling logic. This commit introduces a built-in quota mechanism at the
server level, enabling operators to centrally enforce per-client and anonymous
rate limits—without needing external proxies or client changes.
This helps contain compute costs, enforces fair usage, and simplifies deployment
and monitoring of Llama Stack services. Quotas are fully opt-in and have no
effect unless explicitly configured.
Currently, SQLite is the only supported KV store. If quotas are
configured but authentication is disabled, authenticated limits will
gracefully fall back to anonymous limits.
Highlights:
- Adds `QuotaMiddleware` to enforce request quotas:
- Uses bearer token as client ID if present; otherwise falls back to IP address
- Tracks requests in KV store with per-key TTL expiration
- Returns HTTP 429 if a client exceeds their quota
- Extends `ServerConfig` with a `quota` section:
- `kvstore`: configuration for the backend (currently only SQLite)
- `anonymous_max_requests`: per-period cap for unauthenticated clients
- `authenticated_max_requests`: per-period cap for authenticated clients
- `period`: duration of the quota window (currently only `day` is supported)
- Adds full test coverage with FastAPI `TestClient` and custom middleware injection
Behavior changes:
- Quotas are disabled by default unless explicitly configured
- Anonymous users get a conservative default quota; authenticated clients can be given more generous limits
To enable per-client request quotas in `run.yaml`, add:
```yaml
server:
port: 8321
auth:
provider_type: custom
config:
endpoint: https://auth.example.com/validate
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```
Signed-off-by: Wen Liang <wenliang@redhat.com>
This commit is contained in:
parent
ed7b4731aa
commit
dacd522f57
6 changed files with 363 additions and 1 deletions
110
llama_stack/distribution/server/quota.py
Normal file
110
llama_stack/distribution/server/quota.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="quota")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
"""
|
||||
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
||||
within a configurable time window.
|
||||
|
||||
- For authenticated requests, it reads the client ID from the
|
||||
`Authorization: Bearer <client_id>` header.
|
||||
- For anonymous requests, it falls back to the IP address of the client.
|
||||
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
||||
once a client exceeds its quota.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
kv_config: KVStoreConfig,
|
||||
anonymous_max_requests: int,
|
||||
authenticated_max_requests: int,
|
||||
window_seconds: int = 86400,
|
||||
):
|
||||
self.app = app
|
||||
self.kv_config = kv_config
|
||||
self.kv: KVStore | None = None
|
||||
self.anonymous_max_requests = anonymous_max_requests
|
||||
self.authenticated_max_requests = authenticated_max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||
logger.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
|
||||
async def _get_kv(self) -> KVStore:
|
||||
if self.kv is None:
|
||||
self.kv = await kvstore_impl(self.kv_config)
|
||||
return self.kv
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] == "http":
|
||||
# pick key & limit based on auth
|
||||
auth_id = scope.get("authenticated_client_id")
|
||||
if auth_id:
|
||||
key_id = auth_id
|
||||
limit = self.authenticated_max_requests
|
||||
else:
|
||||
# fallback to IP
|
||||
client = scope.get("client")
|
||||
key_id = client[0] if client else "anonymous"
|
||||
limit = self.anonymous_max_requests
|
||||
|
||||
current_window = int(time.time() // self.window_seconds)
|
||||
key = f"quota:{key_id}:{current_window}"
|
||||
|
||||
try:
|
||||
kv = await self._get_kv()
|
||||
prev = await kv.get(key) or "0"
|
||||
count = int(prev) + 1
|
||||
|
||||
if int(prev) == 0:
|
||||
# Set with expiration datetime when it is the first request in the window.
|
||||
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
|
||||
await kv.set(key, str(count), expiration=expiration)
|
||||
else:
|
||||
await kv.set(key, str(count))
|
||||
except Exception:
|
||||
logger.exception("Failed to access KV store for quota")
|
||||
return await self._send_error(send, 500, "Quota service error")
|
||||
|
||||
if count > limit:
|
||||
logger.warning(
|
||||
"Quota exceeded for client %s: %d/%d",
|
||||
key_id,
|
||||
count,
|
||||
limit,
|
||||
)
|
||||
return await self._send_error(send, 429, "Quota exceeded")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_error(self, send: Send, status: int, message: str):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
body = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
Loading…
Add table
Add a link
Reference in a new issue