feat: api level request metrics via middleware

add RequestMetricsMiddleware which tracks key metrics related to each request the LLS server will recieve:

1. llama_stack_requests_total: tracks the total amount of requests the server has processed
2. llama_stack_request_duration_seconds: tracks the duration of each request
3. llama_stack_concurrent_requests: tracks concurrently processed requests by the server

The usage of a middleware allows this to be done on the server level without having to add custom handling to each router like the inference router has today for its API specific metrics.

Also, add some unit tests for this functionality

resolves #2597

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-07-11 20:52:32 -04:00
parent dbfc15123e
commit 49b729b30a
4 changed files with 433 additions and 0 deletions

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.core.server.metrics import RequestMetricsMiddleware
class TestRequestMetricsMiddleware:
@pytest.fixture
def mock_telemetry(self):
telemetry = AsyncMock(spec=Telemetry)
return telemetry
@pytest.fixture
def mock_app(self):
app = AsyncMock()
return app
@pytest.fixture
def middleware(self, mock_app, mock_telemetry):
return RequestMetricsMiddleware(mock_app, mock_telemetry)
def test_extract_api_from_path(self, middleware):
"""Test API extraction from various paths."""
test_cases = [
("/v1/inference/chat/completions", "inference"),
("/v1/models/list", "models"),
("/v1/providers", "providers"),
("/", "unknown"),
("", "unknown"),
]
for path, expected_api in test_cases:
assert middleware._extract_api_from_path(path) == expected_api
def test_is_excluded_path(self, middleware):
"""Test path exclusion logic."""
excluded_paths = [
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico",
"/static/css/style.css",
]
non_excluded_paths = [
"/v1/inference/chat/completions",
"/v1/models/list",
"/health",
]
for path in excluded_paths:
assert middleware._is_excluded_path(path)
for path in non_excluded_paths:
assert not middleware._is_excluded_path(path)
async def test_middleware_skips_excluded_paths(self, middleware, mock_app):
"""Test that middleware skips metrics for excluded paths."""
scope = {
"type": "http",
"path": "/docs",
"method": "GET",
}
receive = AsyncMock()
send = AsyncMock()
await middleware(scope, receive, send)
# Should call the app directly without tracking metrics
mock_app.assert_called_once_with(scope, receive, send)
# Should not log any metrics
middleware.telemetry.log_event.assert_not_called()
async def test_middleware_tracks_metrics(self, middleware, mock_telemetry):
"""Test that middleware tracks metrics for valid requests."""
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send_called = False
async def mock_send(message):
nonlocal send_called
send_called = True
if message["type"] == "http.response.start":
message["status"] = 200
# Mock the app to return successfully
async def mock_app(scope, receive, send):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"ok"})
middleware.app = mock_app
await middleware(scope, receive, mock_send)
# Wait for async metric logging
await asyncio.sleep(0.1)
# Should have logged metrics
assert mock_telemetry.log_event.call_count >= 2
# Check that the right metrics were logged
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
# Should have request count metric
request_count_metric = next(
(
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
),
None,
)
assert request_count_metric is not None
assert request_count_metric.value == 1
assert request_count_metric.attributes["api"] == "inference"
assert request_count_metric.attributes["status"] == "200"
# Should have duration metric
duration_metric = next(
(
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_request_duration_seconds"
),
None,
)
assert duration_metric is not None
assert duration_metric.attributes["api"] == "inference"
assert duration_metric.attributes["status"] == "200"
async def test_middleware_handles_errors(self, middleware, mock_telemetry):
"""Test that middleware tracks metrics even when errors occur."""
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send = AsyncMock()
# Mock the app to raise an exception
async def mock_app(scope, receive, send):
raise ValueError("Test error")
middleware.app = mock_app
with pytest.raises(ValueError):
await middleware(scope, receive, send)
# Wait for async metric logging
await asyncio.sleep(0.1)
# Should have logged metrics with error status
assert mock_telemetry.log_event.call_count >= 2
# Check that error metrics were logged
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
request_count_metric = next(
(
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
),
None,
)
assert request_count_metric is not None
assert request_count_metric.attributes["status"] == "500"
async def test_concurrent_requests_tracking(self, middleware, mock_telemetry):
"""Test that concurrent requests are tracked correctly."""
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send = AsyncMock()
# Mock the app to simulate a slow request
async def mock_app(scope, receive, send):
await asyncio.sleep(1) # Simulate processing time
await send({"type": "http.response.start", "status": 200})
middleware.app = mock_app
# Start multiple concurrent requests
tasks = []
for _ in range(3):
task = asyncio.create_task(middleware(scope, receive, send))
tasks.append(task)
# Wait for all requests to complete
await asyncio.gather(*tasks)
# Wait for async metric logging
await asyncio.sleep(0.2)
# Should have logged metrics for all requests
assert mock_telemetry.log_event.call_count >= 6 # 2 metrics per request * 3 requests
# Check concurrent requests metric
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
concurrent_metrics = [
call
for call in call_args
if isinstance(call, MetricEvent) and call.metric == "llama_stack_concurrent_requests"
]
assert len(concurrent_metrics) >= 3
# The concurrent count should have been > 0 during the concurrent requests
max_concurrent = max(m.value for m in concurrent_metrics)
assert max_concurrent > 0
async def test_middleware_without_telemetry(self):
"""Test that middleware works without telemetry configured."""
mock_app = AsyncMock()
middleware = RequestMetricsMiddleware(mock_app, telemetry=None)
scope = {
"type": "http",
"path": "/v1/inference/chat/completions",
"method": "POST",
}
receive = AsyncMock()
send = AsyncMock()
async def mock_app_impl(scope, receive, send):
await send({"type": "http.response.start", "status": 200})
middleware.app = mock_app_impl
# Should not raise any exceptions
await middleware(scope, receive, send)
# Should not try to log metrics
# (no telemetry to call, so this is implicit)
async def test_non_http_requests_ignored(self, middleware, mock_telemetry):
"""Test that non-HTTP requests are ignored."""
scope = {
"type": "lifespan",
"path": "/",
}
receive = AsyncMock()
send = AsyncMock()
await middleware(scope, receive, send)
# Should call the app directly without tracking metrics
middleware.app.assert_called_once_with(scope, receive, send)
# Should not log any metrics
mock_telemetry.log_event.assert_not_called()