mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Merge 49b729b30a
into 8422bd102a
This commit is contained in:
commit
6d68ece4ef
4 changed files with 433 additions and 0 deletions
272
tests/unit/distribution/test_request_metrics.py
Normal file
272
tests/unit/distribution/test_request_metrics.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.core.server.metrics import RequestMetricsMiddleware
|
||||
|
||||
|
||||
class TestRequestMetricsMiddleware:
|
||||
@pytest.fixture
|
||||
def mock_telemetry(self):
|
||||
telemetry = AsyncMock(spec=Telemetry)
|
||||
return telemetry
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
app = AsyncMock()
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def middleware(self, mock_app, mock_telemetry):
|
||||
return RequestMetricsMiddleware(mock_app, mock_telemetry)
|
||||
|
||||
def test_extract_api_from_path(self, middleware):
|
||||
"""Test API extraction from various paths."""
|
||||
test_cases = [
|
||||
("/v1/inference/chat/completions", "inference"),
|
||||
("/v1/models/list", "models"),
|
||||
("/v1/providers", "providers"),
|
||||
("/", "unknown"),
|
||||
("", "unknown"),
|
||||
]
|
||||
|
||||
for path, expected_api in test_cases:
|
||||
assert middleware._extract_api_from_path(path) == expected_api
|
||||
|
||||
def test_is_excluded_path(self, middleware):
|
||||
"""Test path exclusion logic."""
|
||||
excluded_paths = [
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/favicon.ico",
|
||||
"/static/css/style.css",
|
||||
]
|
||||
|
||||
non_excluded_paths = [
|
||||
"/v1/inference/chat/completions",
|
||||
"/v1/models/list",
|
||||
"/health",
|
||||
]
|
||||
|
||||
for path in excluded_paths:
|
||||
assert middleware._is_excluded_path(path)
|
||||
|
||||
for path in non_excluded_paths:
|
||||
assert not middleware._is_excluded_path(path)
|
||||
|
||||
async def test_middleware_skips_excluded_paths(self, middleware, mock_app):
|
||||
"""Test that middleware skips metrics for excluded paths."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/docs",
|
||||
"method": "GET",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should call the app directly without tracking metrics
|
||||
mock_app.assert_called_once_with(scope, receive, send)
|
||||
# Should not log any metrics
|
||||
middleware.telemetry.log_event.assert_not_called()
|
||||
|
||||
async def test_middleware_tracks_metrics(self, middleware, mock_telemetry):
|
||||
"""Test that middleware tracks metrics for valid requests."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send_called = False
|
||||
|
||||
async def mock_send(message):
|
||||
nonlocal send_called
|
||||
send_called = True
|
||||
if message["type"] == "http.response.start":
|
||||
message["status"] = 200
|
||||
|
||||
# Mock the app to return successfully
|
||||
async def mock_app(scope, receive, send):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
await send({"type": "http.response.body", "body": b"ok"})
|
||||
|
||||
middleware.app = mock_app
|
||||
|
||||
await middleware(scope, receive, mock_send)
|
||||
|
||||
# Wait for async metric logging
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Should have logged metrics
|
||||
assert mock_telemetry.log_event.call_count >= 2
|
||||
|
||||
# Check that the right metrics were logged
|
||||
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||
|
||||
# Should have request count metric
|
||||
request_count_metric = next(
|
||||
(
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert request_count_metric is not None
|
||||
assert request_count_metric.value == 1
|
||||
assert request_count_metric.attributes["api"] == "inference"
|
||||
assert request_count_metric.attributes["status"] == "200"
|
||||
|
||||
# Should have duration metric
|
||||
duration_metric = next(
|
||||
(
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_request_duration_seconds"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert duration_metric is not None
|
||||
assert duration_metric.attributes["api"] == "inference"
|
||||
assert duration_metric.attributes["status"] == "200"
|
||||
|
||||
async def test_middleware_handles_errors(self, middleware, mock_telemetry):
|
||||
"""Test that middleware tracks metrics even when errors occur."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Mock the app to raise an exception
|
||||
async def mock_app(scope, receive, send):
|
||||
raise ValueError("Test error")
|
||||
|
||||
middleware.app = mock_app
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Wait for async metric logging
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Should have logged metrics with error status
|
||||
assert mock_telemetry.log_event.call_count >= 2
|
||||
|
||||
# Check that error metrics were logged
|
||||
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||
|
||||
request_count_metric = next(
|
||||
(
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert request_count_metric is not None
|
||||
assert request_count_metric.attributes["status"] == "500"
|
||||
|
||||
async def test_concurrent_requests_tracking(self, middleware, mock_telemetry):
|
||||
"""Test that concurrent requests are tracked correctly."""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Mock the app to simulate a slow request
|
||||
async def mock_app(scope, receive, send):
|
||||
await asyncio.sleep(1) # Simulate processing time
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
|
||||
middleware.app = mock_app
|
||||
|
||||
# Start multiple concurrent requests
|
||||
tasks = []
|
||||
for _ in range(3):
|
||||
task = asyncio.create_task(middleware(scope, receive, send))
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all requests to complete
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Wait for async metric logging
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Should have logged metrics for all requests
|
||||
assert mock_telemetry.log_event.call_count >= 6 # 2 metrics per request * 3 requests
|
||||
|
||||
# Check concurrent requests metric
|
||||
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||
concurrent_metrics = [
|
||||
call
|
||||
for call in call_args
|
||||
if isinstance(call, MetricEvent) and call.metric == "llama_stack_concurrent_requests"
|
||||
]
|
||||
|
||||
assert len(concurrent_metrics) >= 3
|
||||
# The concurrent count should have been > 0 during the concurrent requests
|
||||
max_concurrent = max(m.value for m in concurrent_metrics)
|
||||
assert max_concurrent > 0
|
||||
|
||||
async def test_middleware_without_telemetry(self):
|
||||
"""Test that middleware works without telemetry configured."""
|
||||
mock_app = AsyncMock()
|
||||
middleware = RequestMetricsMiddleware(mock_app, telemetry=None)
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/v1/inference/chat/completions",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
async def mock_app_impl(scope, receive, send):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
|
||||
middleware.app = mock_app_impl
|
||||
|
||||
# Should not raise any exceptions
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should not try to log metrics
|
||||
# (no telemetry to call, so this is implicit)
|
||||
|
||||
async def test_non_http_requests_ignored(self, middleware, mock_telemetry):
|
||||
"""Test that non-HTTP requests are ignored."""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"path": "/",
|
||||
}
|
||||
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should call the app directly without tracking metrics
|
||||
middleware.app.assert_called_once_with(scope, receive, send)
|
||||
# Should not log any metrics
|
||||
mock_telemetry.log_event.assert_not_called()
|
Loading…
Add table
Add a link
Reference in a new issue