feat: api level request metrics via middleware

add RequestMetricsMiddleware which tracks key metrics related to each request the LLS server will recieve:

1. llama_stack_requests_total: tracks the total amount of requests the server has processed
2. llama_stack_request_duration_seconds: tracks the duration of each request
3. llama_stack_concurrent_requests: tracks concurrently processed requests by the server

The usage of a middleware allows this to be done on the server level without having to add custom handling to each router like the inference router has today for its API specific metrics.

Also, add some unit tests for this functionality

resolves #2597

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-11 20:52:32 -04:00
parent dbfc15123e
commit 49b729b30a
4 changed files with 433 additions and 0 deletions

View file

@ -37,6 +37,9 @@ The following metrics are automatically generated for each inference request:
| `llama_stack_prompt_tokens_total` | Counter | `tokens` | Number of tokens in the input prompt | `model_id`, `provider_id` |
| `llama_stack_completion_tokens_total` | Counter | `tokens` | Number of tokens in the generated response | `model_id`, `provider_id` |
| `llama_stack_tokens_total` | Counter | `tokens` | Total tokens used (prompt + completion) | `model_id`, `provider_id` |
| `llama_stack_requests_total` | Counter | `requests` | Total number of requests | `api`, `status` |
| `llama_stack_request_duration_seconds` | Gauge | `seconds` | Request duration | `api`, `status` |
| `llama_stack_concurrent_requests` | Gauge | `requests` | Number of concurrent requests | `api` |
#### Metric Generation Flow

View file

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import time
from datetime import UTC, datetime
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="server")
class RequestMetricsMiddleware:
"""
middleware that tracks request-level metrics including:
- Request counts by API and status
- Request duration
- Concurrent requests
Metrics are logged to the telemetry system and can be exported to Prometheus
via OpenTelemetry.
"""
def __init__(self, app, telemetry: Telemetry | None = None):
self.app = app
self.telemetry = telemetry
self.concurrent_requests = 0
self._lock = asyncio.Lock()
# FastAPI built-in paths that should be excluded from metrics
self.excluded_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
def _extract_api_from_path(self, path: str) -> str:
"""Extract the API name from the request path."""
# Remove version prefix if present
if path.startswith("/v1/"):
path = path[4:]
# Extract the first path segment as the API name
segments = path.strip("/").split("/")
if (
segments and segments[0]
): # Check that first segment is not empty, this will return the API rather than the action `["datasets", "list"]`
return segments[0]
return "unknown"
def _is_excluded_path(self, path: str) -> bool:
"""Check if the path should be excluded from metrics."""
return any(path.startswith(excluded) for excluded in self.excluded_paths)
async def _log_request_metrics(self, api: str, status: str, duration: float, concurrent_count: int):
"""Log request metrics to the telemetry system."""
if not self.telemetry:
return
try:
# Get current span if available
span = get_current_span()
trace_id = span.trace_id if span else ""
span_id = span.span_id if span else ""
# Log request count (send increment of 1 for each request)
await self.telemetry.log_event(
MetricEvent(
trace_id=trace_id,
span_id=span_id,
timestamp=datetime.now(UTC),
metric="llama_stack_requests_total",
value=1, # Send increment instead of total so the provider handles incrementation
unit="requests",
attributes={
"api": api,
"status": status,
},
)
)
# Log request duration
await self.telemetry.log_event(
MetricEvent(
trace_id=trace_id,
span_id=span_id,
timestamp=datetime.now(UTC),
metric="llama_stack_request_duration_seconds",
value=duration,
unit="seconds",
attributes={"api": api, "status": status},
)
)
# Log concurrent requests (as a gauge)
await self.telemetry.log_event(
MetricEvent(
trace_id=trace_id,
span_id=span_id,
timestamp=datetime.now(UTC),
metric="llama_stack_concurrent_requests",
value=float(concurrent_count), # Convert to float for gauge
unit="requests",
attributes={"api": api},
)
)
except ValueError as e:
logger.warning(f"Failed to log request metrics: {e}")
async def __call__(self, scope, receive, send):
if scope.get("type") != "http":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Skip metrics for excluded paths
if self._is_excluded_path(path):
return await self.app(scope, receive, send)
api = self._extract_api_from_path(path)
start_time = time.time()
status = 200
# Track concurrent requests
async with self._lock:
self.concurrent_requests += 1
# Create a wrapper to capture the response status
async def send_wrapper(message):
if message.get("type") == "http.response.start":
nonlocal status
status = message.get("status", 200)
await send(message)
try:
return await self.app(scope, receive, send_wrapper)
except Exception:
# Set status to 500 for any unhandled exception
status = 500
raise
finally:
duration = time.time() - start_time
# Capture concurrent count before decrementing
async with self._lock:
concurrent_count = self.concurrent_requests
self.concurrent_requests -= 1
# Log metrics asynchronously to avoid blocking the response
asyncio.create_task(self._log_request_metrics(api, str(status), duration, concurrent_count))

View file

@ -76,6 +76,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from .auth import AuthenticationMiddleware
from .metrics import RequestMetricsMiddleware
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -536,6 +537,10 @@ def main(args: argparse.Namespace | None = None):
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
# Add request metrics middleware
telemetry_impl = impls.get(Api.telemetry) if Api.telemetry in impls else None
app.add_middleware(RequestMetricsMiddleware, telemetry=telemetry_impl)
import uvicorn
# Configure SSL if certificates are provided

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.core.server.metrics import RequestMetricsMiddleware
class TestRequestMetricsMiddleware:
@pytest.fixture
def mock_telemetry(self):
telemetry = AsyncMock(spec=Telemetry)
return telemetry
@pytest.fixture
def mock_app(self):
app = AsyncMock()
return app
@pytest.fixture
def middleware(self, mock_app, mock_telemetry):
return RequestMetricsMiddleware(mock_app, mock_telemetry)
def test_extract_api_from_path(self, middleware):
"""Test API extraction from various paths."""
test_cases = [
("/v1/inference/chat/completions", "inference"),
("/v1/models/list", "models"),
("/v1/providers", "providers"),
("/", "unknown"),
("", "unknown"),
]
for path, expected_api in test_cases:
assert middleware._extract_api_from_path(path) == expected_api
def test_is_excluded_path(self, middleware):
"""Test path exclusion logic."""
excluded_paths = [
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico",
"/static/css/style.css",
]
non_excluded_paths = [
"/v1/inference/chat/completions",
"/v1/models/list",
"/health",
]
for path in excluded_paths:
assert middleware._is_excluded_path(path)
for path in non_excluded_paths:
assert not middleware._is_excluded_path(path)
async def test_middleware_skips_excluded_paths(self, middleware, mock_app):
"""Test that middleware skips metrics for excluded paths."""
scope = {
"type": "http",
"path": "/docs",
"method": "GET",
}
receive = AsyncMock()
send = AsyncMock()
await middleware(scope, receive, send)
# Should call the app directly without tracking metrics
mock_app.assert_called_once_with(scope, receive, send)
# Should not log any metrics
middleware.telemetry.log_event.assert_not_called()
async def test_middleware_tracks_metrics(self, middleware, mock_telemetry):
"""Test that middleware tracks metrics for valid requests."""
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send_called = False
async def mock_send(message):
nonlocal send_called
send_called = True
if message["type"] == "http.response.start":
message["status"] = 200
# Mock the app to return successfully
async def mock_app(scope, receive, send):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"ok"})
middleware.app = mock_app
await middleware(scope, receive, mock_send)
# Wait for async metric logging
await asyncio.sleep(0.1)
# Should have logged metrics
assert mock_telemetry.log_event.call_count >= 2
# Check that the right metrics were logged
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
# Should have request count metric
request_count_metric = next(
(
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
),
None,
)
assert request_count_metric is not None
assert request_count_metric.value == 1
assert request_count_metric.attributes["api"] == "inference"
assert request_count_metric.attributes["status"] == "200"
# Should have duration metric
duration_metric = next(
(
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_request_duration_seconds"
),
None,
)
assert duration_metric is not None
assert duration_metric.attributes["api"] == "inference"
assert duration_metric.attributes["status"] == "200"
async def test_middleware_handles_errors(self, middleware, mock_telemetry):
"""Test that middleware tracks metrics even when errors occur."""
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send = AsyncMock()
# Mock the app to raise an exception
async def mock_app(scope, receive, send):
raise ValueError("Test error")
middleware.app = mock_app
with pytest.raises(ValueError):
await middleware(scope, receive, send)
# Wait for async metric logging
await asyncio.sleep(0.1)
# Should have logged metrics with error status
assert mock_telemetry.log_event.call_count >= 2
# Check that error metrics were logged
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
request_count_metric = next(
(
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
),
None,
)
assert request_count_metric is not None
assert request_count_metric.attributes["status"] == "500"
async def test_concurrent_requests_tracking(self, middleware, mock_telemetry):
"""Test that concurrent requests are tracked correctly."""
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send = AsyncMock()
# Mock the app to simulate a slow request
async def mock_app(scope, receive, send):
await asyncio.sleep(1) # Simulate processing time
await send({"type": "http.response.start", "status": 200})
middleware.app = mock_app
# Start multiple concurrent requests
tasks = []
for _ in range(3):
task = asyncio.create_task(middleware(scope, receive, send))
tasks.append(task)
# Wait for all requests to complete
await asyncio.gather(*tasks)
# Wait for async metric logging
await asyncio.sleep(0.2)
# Should have logged metrics for all requests
assert mock_telemetry.log_event.call_count >= 6 # 2 metrics per request * 3 requests
# Check concurrent requests metric
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
concurrent_metrics = [
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_concurrent_requests"
]
assert len(concurrent_metrics) >= 3
# The concurrent count should have been > 0 during the concurrent requests
max_concurrent = max(m.value for m in concurrent_metrics)
assert max_concurrent > 0
async def test_middleware_without_telemetry(self):
"""Test that middleware works without telemetry configured."""
mock_app = AsyncMock()
middleware = RequestMetricsMiddleware(mock_app, telemetry=None)
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send = AsyncMock()
async def mock_app_impl(scope, receive, send):
await send({"type": "http.response.start", "status": 200})
middleware.app = mock_app_impl
# Should not raise any exceptions
await middleware(scope, receive, send)
# Should not try to log metrics
# (no telemetry to call, so this is implicit)
async def test_non_http_requests_ignored(self, middleware, mock_telemetry):
"""Test that non-HTTP requests are ignored."""
scope = {
"type": "lifespan",
"path": "/",
}
receive = AsyncMock()
send = AsyncMock()
await middleware(scope, receive, send)
# Should call the app directly without tracking metrics
middleware.app.assert_called_once_with(scope, receive, send)
# Should not log any metrics
mock_telemetry.log_event.assert_not_called()