llama-stack-mirror/tests/unit/server/test_quota.py
liangwen12year 2890243107
feat(quota): add server‑side per‑client request quotas (requires auth) (#2096)
# What does this PR do?
feat(quota): add server‑side per‑client request quotas (requires auth)
    
Unrestricted usage can lead to runaway costs and fragmented client-side
    workarounds. This commit introduces a native quota mechanism to the
    server, giving operators a unified, centrally managed throttle for
    per-client requests—without needing extra proxies or custom client
logic. This helps contain cloud-compute expenses, enables fine-grained
usage control, and simplifies deployment and monitoring of Llama Stack
services. Quotas are fully opt-in and have no effect unless explicitly
    configured.
    
    Notice that Quotas are fully opt-in and require authentication to be
enabled. The 'sqlite' is the only supported quota `type` at this time,
any other `type` will be rejected. And the only supported `period` is
    'day'.
    
    Highlights:
    
    - Adds `QuotaMiddleware` to enforce per-client request quotas:
      - Uses `Authorization: Bearer <client_id>` (from
        AuthenticationMiddleware)
      - Tracks usage via a SQLite-based KV store
      - Returns 429 when the quota is exceeded
    
    - Extends `ServerConfig` with a `quota` section (type + config)
    
- Enforces strict coupling: quotas require authentication or the server
      will fail to start
    
    Behavior changes:
    - Quotas are disabled by default unless explicitly configured
    - SQLite defaults to `./quotas.db` if no DB path is set
    - The server requires authentication when quotas are enabled
    
    To enable per-client request quotas in `run.yaml`, add:
    ```
    server:
      port: 8321
      auth:
        provider_type: "custom"
        config:
          endpoint: "https://auth.example.com/validate"
      quota:
        type: sqlite
        config:
          db_path: ./quotas.db
          limit:
            max_requests: 1000
            period: day

[//]: # (If resolving an issue, uncomment and update the line below)
Closes #2093

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)

Signed-off-by: Wen Liang <wenliang@redhat.com>
Co-authored-by: Wen Liang <wenliang@redhat.com>
2025-05-21 10:58:45 +02:00

127 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from starlette.middleware.base import BaseHTTPMiddleware
from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod
from llama_stack.distribution.server.quota import QuotaMiddleware
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class InjectClientIDMiddleware(BaseHTTPMiddleware):
"""
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
"""
def __init__(self, app, client_id="client1"):
super().__init__(app)
self.client_id = client_id
async def dispatch(self, request: Request, call_next):
request.scope["authenticated_client_id"] = self.client_id
return await call_next(request)
def build_quota_config(db_path) -> QuotaConfig:
return QuotaConfig(
kvstore=SqliteKVStoreConfig(db_path=str(db_path)),
anonymous_max_requests=1,
authenticated_max_requests=2,
period=QuotaPeriod.DAY,
)
@pytest.fixture
def auth_app(tmp_path, request):
"""
FastAPI app with InjectClientIDMiddleware and QuotaMiddleware for authenticated testing.
Each test gets its own DB file.
"""
inner_app = FastAPI()
@inner_app.get("/test")
async def test_endpoint():
return {"message": "ok"}
db_path = tmp_path / f"quota_{request.node.name}.db"
quota = build_quota_config(db_path)
app = InjectClientIDMiddleware(
QuotaMiddleware(
inner_app,
kv_config=quota.kvstore,
anonymous_max_requests=quota.anonymous_max_requests,
authenticated_max_requests=quota.authenticated_max_requests,
window_seconds=86400,
),
client_id=f"client_{request.node.name}",
)
return app
def test_authenticated_quota_allows_up_to_limit(auth_app):
client = TestClient(auth_app)
assert client.get("/test").status_code == 200
assert client.get("/test").status_code == 200
def test_authenticated_quota_blocks_after_limit(auth_app):
client = TestClient(auth_app)
client.get("/test")
client.get("/test")
resp = client.get("/test")
assert resp.status_code == 429
assert resp.json()["error"]["message"] == "Quota exceeded"
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
inner_app = FastAPI()
@inner_app.get("/test")
async def test_endpoint():
return {"message": "ok"}
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
quota = build_quota_config(db_path)
app = QuotaMiddleware(
inner_app,
kv_config=quota.kvstore,
anonymous_max_requests=quota.anonymous_max_requests,
authenticated_max_requests=quota.authenticated_max_requests,
window_seconds=86400,
)
client = TestClient(app)
assert client.get("/test").status_code == 200
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
inner_app = FastAPI()
@inner_app.get("/test")
async def test_endpoint():
return {"message": "ok"}
db_path = tmp_path / f"quota_anon_{request.node.name}.db"
quota = build_quota_config(db_path)
app = QuotaMiddleware(
inner_app,
kv_config=quota.kvstore,
anonymous_max_requests=quota.anonymous_max_requests,
authenticated_max_requests=quota.authenticated_max_requests,
window_seconds=86400,
)
client = TestClient(app)
client.get("/test")
resp = client.get("/test")
assert resp.status_code == 429
assert resp.json()["error"]["message"] == "Quota exceeded"