mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat(quota): add server‑side per‑client request quotas (requires auth) (#2096)
# What does this PR do?
feat(quota): add server‑side per‑client request quotas (requires auth)
Unrestricted usage can lead to runaway costs and fragmented client-side
workarounds. This commit introduces a native quota mechanism to the
server, giving operators a unified, centrally managed throttle for
per-client requests—without needing extra proxies or custom client
logic. This helps contain cloud-compute expenses, enables fine-grained
usage control, and simplifies deployment and monitoring of Llama Stack
services. Quotas are fully opt-in and have no effect unless explicitly
configured.
Notice that Quotas are fully opt-in and require authentication to be
enabled. The 'sqlite' is the only supported quota `type` at this time,
any other `type` will be rejected. And the only supported `period` is
'day'.
Highlights:
- Adds `QuotaMiddleware` to enforce per-client request quotas:
- Uses `Authorization: Bearer <client_id>` (from
AuthenticationMiddleware)
- Tracks usage via a SQLite-based KV store
- Returns 429 when the quota is exceeded
- Extends `ServerConfig` with a `quota` section (type + config)
- Enforces strict coupling: quotas require authentication or the server
will fail to start
Behavior changes:
- Quotas are disabled by default unless explicitly configured
- SQLite defaults to `./quotas.db` if no DB path is set
- The server requires authentication when quotas are enabled
To enable per-client request quotas in `run.yaml`, add:
```
server:
port: 8321
auth:
provider_type: "custom"
config:
endpoint: "https://auth.example.com/validate"
quota:
type: sqlite
config:
db_path: ./quotas.db
limit:
max_requests: 1000
period: day
[//]: # (If resolving an issue, uncomment and update the line below)
Closes #2093
## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]
[//]: # (## Documentation)
Signed-off-by: Wen Liang <wenliang@redhat.com>
Co-authored-by: Wen Liang <wenliang@redhat.com>
This commit is contained in:
parent
5a3d777b20
commit
2890243107
6 changed files with 363 additions and 1 deletions
127
tests/unit/server/test_quota.py
Normal file
127
tests/unit/server/test_quota.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod
|
||||
from llama_stack.distribution.server.quota import QuotaMiddleware
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class InjectClientIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
|
||||
"""
|
||||
|
||||
def __init__(self, app, client_id="client1"):
|
||||
super().__init__(app)
|
||||
self.client_id = client_id
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
request.scope["authenticated_client_id"] = self.client_id
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def build_quota_config(db_path) -> QuotaConfig:
|
||||
return QuotaConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
|
||||
anonymous_max_requests=1,
|
||||
authenticated_max_requests=2,
|
||||
period=QuotaPeriod.DAY,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_app(tmp_path, request):
|
||||
"""
|
||||
FastAPI app with InjectClientIDMiddleware and QuotaMiddleware for authenticated testing.
|
||||
Each test gets its own DB file.
|
||||
"""
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "ok"}
|
||||
|
||||
db_path = tmp_path / f"quota_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
app = InjectClientIDMiddleware(
|
||||
QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
),
|
||||
client_id=f"client_{request.node.name}",
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||
client = TestClient(auth_app)
|
||||
assert client.get("/test").status_code == 200
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_authenticated_quota_blocks_after_limit(auth_app):
|
||||
client = TestClient(auth_app)
|
||||
client.get("/test")
|
||||
client.get("/test")
|
||||
resp = client.get("/test")
|
||||
assert resp.status_code == 429
|
||||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||
|
||||
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "ok"}
|
||||
|
||||
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
app = QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "ok"}
|
||||
|
||||
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
app = QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
client.get("/test")
|
||||
resp = client.get("/test")
|
||||
assert resp.status_code == 429
|
||||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||
Loading…
Add table
Add a link
Reference in a new issue