mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat: api level request metrics via middleware
add RequestMetricsMiddleware which tracks key metrics related to each request the LLS server will recieve: 1. llama_stack_requests_total: tracks the total amount of requests the server has processed 2. llama_stack_request_duration_seconds: tracks the duration of each request 3. llama_stack_concurrent_requests: tracks concurrently processed requests by the server The usage of a middleware allows this to be done on the server level without having to add custom handling to each router like the inference router has today for its API specific metrics. Also, add some unit tests for this functionality resolves #2597 Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
dbfc15123e
commit
49b729b30a
4 changed files with 433 additions and 0 deletions
|
@ -37,6 +37,9 @@ The following metrics are automatically generated for each inference request:
|
|||
| `llama_stack_prompt_tokens_total` | Counter | `tokens` | Number of tokens in the input prompt | `model_id`, `provider_id` |
|
||||
| `llama_stack_completion_tokens_total` | Counter | `tokens` | Number of tokens in the generated response | `model_id`, `provider_id` |
|
||||
| `llama_stack_tokens_total` | Counter | `tokens` | Total tokens used (prompt + completion) | `model_id`, `provider_id` |
|
||||
| `llama_stack_requests_total` | Counter | `requests` | Total number of requests | `api`, `status` |
|
||||
| `llama_stack_request_duration_seconds` | Gauge | `seconds` | Request duration | `api`, `status` |
|
||||
| `llama_stack_concurrent_requests` | Gauge | `requests` | Number of concurrent requests | `api` |
|
||||
|
||||
#### Metric Generation Flow
|
||||
|
||||
|
|
153
llama_stack/core/server/metrics.py
Normal file
153
llama_stack/core/server/metrics.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class RequestMetricsMiddleware:
|
||||
"""
|
||||
middleware that tracks request-level metrics including:
|
||||
- Request counts by API and status
|
||||
- Request duration
|
||||
- Concurrent requests
|
||||
|
||||
Metrics are logged to the telemetry system and can be exported to Prometheus
|
||||
via OpenTelemetry.
|
||||
"""
|
||||
|
||||
def __init__(self, app, telemetry: Telemetry | None = None):
|
||||
self.app = app
|
||||
self.telemetry = telemetry
|
||||
self.concurrent_requests = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# FastAPI built-in paths that should be excluded from metrics
|
||||
self.excluded_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
def _extract_api_from_path(self, path: str) -> str:
|
||||
"""Extract the API name from the request path."""
|
||||
# Remove version prefix if present
|
||||
if path.startswith("/v1/"):
|
||||
path = path[4:]
|
||||
# Extract the first path segment as the API name
|
||||
segments = path.strip("/").split("/")
|
||||
if (
|
||||
segments and segments[0]
|
||||
): # Check that first segment is not empty, this will return the API rather than the action `["datasets", "list"]`
|
||||
return segments[0]
|
||||
return "unknown"
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
"""Check if the path should be excluded from metrics."""
|
||||
return any(path.startswith(excluded) for excluded in self.excluded_paths)
|
||||
|
||||
async def _log_request_metrics(self, api: str, status: str, duration: float, concurrent_count: int):
|
||||
"""Log request metrics to the telemetry system."""
|
||||
if not self.telemetry:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get current span if available
|
||||
span = get_current_span()
|
||||
trace_id = span.trace_id if span else ""
|
||||
span_id = span.span_id if span else ""
|
||||
|
||||
# Log request count (send increment of 1 for each request)
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
metric="llama_stack_requests_total",
|
||||
value=1, # Send increment instead of total so the provider handles incrementation
|
||||
unit="requests",
|
||||
attributes={
|
||||
"api": api,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Log request duration
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
metric="llama_stack_request_duration_seconds",
|
||||
value=duration,
|
||||
unit="seconds",
|
||||
attributes={"api": api, "status": status},
|
||||
)
|
||||
)
|
||||
|
||||
# Log concurrent requests (as a gauge)
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
metric="llama_stack_concurrent_requests",
|
||||
value=float(concurrent_count), # Convert to float for gauge
|
||||
unit="requests",
|
||||
attributes={"api": api},
|
||||
)
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to log request metrics: {e}")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Skip metrics for excluded paths
|
||||
if self._is_excluded_path(path):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
api = self._extract_api_from_path(path)
|
||||
start_time = time.time()
|
||||
status = 200
|
||||
|
||||
# Track concurrent requests
|
||||
async with self._lock:
|
||||
self.concurrent_requests += 1
|
||||
|
||||
# Create a wrapper to capture the response status
|
||||
async def send_wrapper(message):
|
||||
if message.get("type") == "http.response.start":
|
||||
nonlocal status
|
||||
status = message.get("status", 200)
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
|
||||
except Exception:
|
||||
# Set status to 500 for any unhandled exception
|
||||
status = 500
|
||||
raise
|
||||
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Capture concurrent count before decrementing
|
||||
async with self._lock:
|
||||
concurrent_count = self.concurrent_requests
|
||||
self.concurrent_requests -= 1
|
||||
|
||||
# Log metrics asynchronously to avoid blocking the response
|
||||
asyncio.create_task(self._log_request_metrics(api, str(status), duration, concurrent_count))
|
|
@ -76,6 +76,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .metrics import RequestMetricsMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
@ -536,6 +537,10 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
# Add request metrics middleware
|
||||
telemetry_impl = impls.get(Api.telemetry) if Api.telemetry in impls else None
|
||||
app.add_middleware(RequestMetricsMiddleware, telemetry=telemetry_impl)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
|
|
272
tests/unit/distribution/test_request_metrics.py
Normal file
272
tests/unit/distribution/test_request_metrics.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.core.server.metrics import RequestMetricsMiddleware
|
||||
|
||||
|
||||
class TestRequestMetricsMiddleware:
|
||||
@pytest.fixture
|
||||
def mock_telemetry(self):
|
||||
telemetry = AsyncMock(spec=Telemetry)
|
||||
return telemetry
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
app = AsyncMock()
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def middleware(self, mock_app, mock_telemetry):
|
||||
return RequestMetricsMiddleware(mock_app, mock_telemetry)
|
||||
|
||||
def test_extract_api_from_path(self, middleware):
|
||||
"""Test API extraction from various paths."""
|
||||
test_cases = [
|
||||
("/v1/inference/chat/completions", "inference"),
|
||||
("/v1/models/list", "models"),
|
||||
("/v1/providers", "providers"),
|
||||
("/", "unknown"),
|
||||
("", "unknown"),
|
||||
]
|
||||
|
||||
for path, expected_api in test_cases:
|
||||
assert middleware._extract_api_from_path(path) == expected_api
|
||||
|
||||
def test_is_excluded_path(self, middleware):
|
||||
"""Test path exclusion logic."""
|
||||
excluded_paths = [
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/favicon.ico",
|
||||
"/static/css/style.css",
|
||||
]
|
||||
|
||||
non_excluded_paths = [
|
||||
"/v1/inference/chat/completions",
|
||||
"/v1/models/list",
|
||||
"/health",
|
||||
]
|
||||
|
||||
for path in excluded_paths:
|
||||
assert middleware._is_excluded_path(path)
|
||||
|
||||
for path in non_excluded_paths:
|
||||
assert not middleware._is_excluded_path(path)
|
||||
|
||||
async def test_middleware_skips_excluded_paths(self, middleware, mock_app):
|
||||
"""Test that middleware skips metrics for excluded paths."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/docs",
|
||||
"method": "GET",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should call the app directly without tracking metrics
|
||||
mock_app.assert_called_once_with(scope, receive, send)
|
||||
# Should not log any metrics
|
||||
middleware.telemetry.log_event.assert_not_called()
|
||||
|
||||
async def test_middleware_tracks_metrics(self, middleware, mock_telemetry):
|
||||
"""Test that middleware tracks metrics for valid requests."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send_called = False
|
||||
|
||||
async def mock_send(message):
|
||||
nonlocal send_called
|
||||
send_called = True
|
||||
if message["type"] == "http.response.start":
|
||||
message["status"] = 200
|
||||
|
||||
# Mock the app to return successfully
|
||||
async def mock_app(scope, receive, send):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
await send({"type": "http.response.body", "body": b"ok"})
|
||||
|
||||
middleware.app = mock_app
|
||||
|
||||
await middleware(scope, receive, mock_send)
|
||||
|
||||
# Wait for async metric logging
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Should have logged metrics
|
||||
assert mock_telemetry.log_event.call_count >= 2
|
||||
|
||||
# Check that the right metrics were logged
|
||||
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||
|
||||
# Should have request count metric
|
||||
request_count_metric = next(
|
||||
(
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert request_count_metric is not None
|
||||
assert request_count_metric.value == 1
|
||||
assert request_count_metric.attributes["api"] == "inference"
|
||||
assert request_count_metric.attributes["status"] == "200"
|
||||
|
||||
# Should have duration metric
|
||||
duration_metric = next(
|
||||
(
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_request_duration_seconds"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert duration_metric is not None
|
||||
assert duration_metric.attributes["api"] == "inference"
|
||||
assert duration_metric.attributes["status"] == "200"
|
||||
|
||||
async def test_middleware_handles_errors(self, middleware, mock_telemetry):
|
||||
"""Test that middleware tracks metrics even when errors occur."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Mock the app to raise an exception
|
||||
async def mock_app(scope, receive, send):
|
||||
raise ValueError("Test error")
|
||||
|
||||
middleware.app = mock_app
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Wait for async metric logging
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Should have logged metrics with error status
|
||||
assert mock_telemetry.log_event.call_count >= 2
|
||||
|
||||
# Check that error metrics were logged
|
||||
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||
|
||||
request_count_metric = next(
|
||||
(
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert request_count_metric is not None
|
||||
assert request_count_metric.attributes["status"] == "500"
|
||||
|
||||
async def test_concurrent_requests_tracking(self, middleware, mock_telemetry):
|
||||
"""Test that concurrent requests are tracked correctly."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Mock the app to simulate a slow request
|
||||
async def mock_app(scope, receive, send):
|
||||
await asyncio.sleep(1) # Simulate processing time
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
|
||||
middleware.app = mock_app
|
||||
|
||||
# Start multiple concurrent requests
|
||||
tasks = []
|
||||
for _ in range(3):
|
||||
task = asyncio.create_task(middleware(scope, receive, send))
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all requests to complete
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Wait for async metric logging
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Should have logged metrics for all requests
|
||||
assert mock_telemetry.log_event.call_count >= 6 # 2 metrics per request * 3 requests
|
||||
|
||||
# Check concurrent requests metric
|
||||
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||
concurrent_metrics = [
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_concurrent_requests"
|
||||
]
|
||||
|
||||
assert len(concurrent_metrics) >= 3
|
||||
# The concurrent count should have been > 0 during the concurrent requests
|
||||
max_concurrent = max(m.value for m in concurrent_metrics)
|
||||
assert max_concurrent > 0
|
||||
|
||||
async def test_middleware_without_telemetry(self):
|
||||
"""Test that middleware works without telemetry configured."""
|
||||
mock_app = AsyncMock()
|
||||
middleware = RequestMetricsMiddleware(mock_app, telemetry=None)
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
async def mock_app_impl(scope, receive, send):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
|
||||
middleware.app = mock_app_impl
|
||||
|
||||
# Should not raise any exceptions
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should not try to log metrics
|
||||
# (no telemetry to call, so this is implicit)
|
||||
|
||||
async def test_non_http_requests_ignored(self, middleware, mock_telemetry):
|
||||
"""Test that non-HTTP requests are ignored."""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"path": "/",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should call the app directly without tracking metrics
|
||||
middleware.app.assert_called_once_with(scope, receive, send)
|
||||
# Should not log any metrics
|
||||
mock_telemetry.log_event.assert_not_called()
|
Loading…
Add table
Add a link
Reference in a new issue