diff --git a/docs/source/building_applications/telemetry.md b/docs/source/building_applications/telemetry.md index d93242f75..2106d3bec 100644 --- a/docs/source/building_applications/telemetry.md +++ b/docs/source/building_applications/telemetry.md @@ -37,6 +37,9 @@ The following metrics are automatically generated for each inference request: | `llama_stack_prompt_tokens_total` | Counter | `tokens` | Number of tokens in the input prompt | `model_id`, `provider_id` | | `llama_stack_completion_tokens_total` | Counter | `tokens` | Number of tokens in the generated response | `model_id`, `provider_id` | | `llama_stack_tokens_total` | Counter | `tokens` | Total tokens used (prompt + completion) | `model_id`, `provider_id` | +| `llama_stack_requests_total` | Counter | `requests` | Total number of requests | `api`, `status` | +| `llama_stack_request_duration_seconds` | Gauge | `seconds` | Request duration | `api`, `status` | +| `llama_stack_concurrent_requests` | Gauge | `requests` | Number of concurrent requests | `api` | #### Metric Generation Flow diff --git a/llama_stack/core/server/metrics.py b/llama_stack/core/server/metrics.py new file mode 100644 index 000000000..50fc676f0 --- /dev/null +++ b/llama_stack/core/server/metrics.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import time +from datetime import UTC, datetime + +from llama_stack.apis.telemetry import MetricEvent, Telemetry +from llama_stack.log import get_logger +from llama_stack.providers.utils.telemetry.tracing import get_current_span + +logger = get_logger(name=__name__, category="server") + + +class RequestMetricsMiddleware: + """ + middleware that tracks request-level metrics including: + - Request counts by API and status + - Request duration + - Concurrent requests + + Metrics are logged to the telemetry system and can be exported to Prometheus + via OpenTelemetry. + """ + + def __init__(self, app, telemetry: Telemetry | None = None): + self.app = app + self.telemetry = telemetry + self.concurrent_requests = 0 + self._lock = asyncio.Lock() + + # FastAPI built-in paths that should be excluded from metrics + self.excluded_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") + + def _extract_api_from_path(self, path: str) -> str: + """Extract the API name from the request path.""" + # Remove version prefix if present + if path.startswith("/v1/"): + path = path[4:] + # Extract the first path segment as the API name + segments = path.strip("/").split("/") + if ( + segments and segments[0] + ): # Check that first segment is not empty, this will return the API rather than the action `["datasets", "list"]` + return segments[0] + return "unknown" + + def _is_excluded_path(self, path: str) -> bool: + """Check if the path should be excluded from metrics.""" + return any(path.startswith(excluded) for excluded in self.excluded_paths) + + async def _log_request_metrics(self, api: str, status: str, duration: float, concurrent_count: int): + """Log request metrics to the telemetry system.""" + if not self.telemetry: + return + + try: + # Get current span if available + span = get_current_span() + trace_id = span.trace_id if span else "" + span_id = span.span_id if span else "" + + # Log request count (send increment of 1 for each request) + await self.telemetry.log_event( + MetricEvent( + trace_id=trace_id, + span_id=span_id, + timestamp=datetime.now(UTC), + metric="llama_stack_requests_total", + value=1, # Send increment instead of total so the provider handles incrementation + unit="requests", + attributes={ + "api": api, + "status": status, + }, + ) + ) + + # Log request duration + await self.telemetry.log_event( + MetricEvent( + trace_id=trace_id, + span_id=span_id, + timestamp=datetime.now(UTC), + metric="llama_stack_request_duration_seconds", + value=duration, + unit="seconds", + attributes={"api": api, "status": status}, + ) + ) + + # Log concurrent requests (as a gauge) + await self.telemetry.log_event( + MetricEvent( + trace_id=trace_id, + span_id=span_id, + timestamp=datetime.now(UTC), + metric="llama_stack_concurrent_requests", + value=float(concurrent_count), # Convert to float for gauge + unit="requests", + attributes={"api": api}, + ) + ) + + except ValueError as e: + logger.warning(f"Failed to log request metrics: {e}") + + async def __call__(self, scope, receive, send): + if scope.get("type") != "http": + return await self.app(scope, receive, send) + + path = scope.get("path", "") + + # Skip metrics for excluded paths + if self._is_excluded_path(path): + return await self.app(scope, receive, send) + + api = self._extract_api_from_path(path) + start_time = time.time() + status = 200 + + # Track concurrent requests + async with self._lock: + self.concurrent_requests += 1 + + # Create a wrapper to capture the response status + async def send_wrapper(message): + if message.get("type") == "http.response.start": + nonlocal status + status = message.get("status", 200) + await send(message) + + try: + return await self.app(scope, receive, send_wrapper) + + except Exception: + # Set status to 500 for any unhandled exception + status = 500 + raise + + finally: + duration = time.time() - start_time + + # Capture concurrent count before decrementing + async with self._lock: + concurrent_count = self.concurrent_requests + self.concurrent_requests -= 1 + + # Log metrics asynchronously to avoid blocking the response + asyncio.create_task(self._log_request_metrics(api, str(status), duration, concurrent_count)) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index e9d70fc8d..82880908f 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -77,6 +77,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from .auth import AuthenticationMiddleware +from .metrics import RequestMetricsMiddleware from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -540,6 +541,10 @@ def main(args: argparse.Namespace | None = None): app.__llama_stack_impls__ = impls app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) + # Add request metrics middleware + telemetry_impl = impls.get(Api.telemetry) if Api.telemetry in impls else None + app.add_middleware(RequestMetricsMiddleware, telemetry=telemetry_impl) + import uvicorn # Configure SSL if certificates are provided diff --git a/tests/unit/distribution/test_request_metrics.py b/tests/unit/distribution/test_request_metrics.py new file mode 100644 index 000000000..40ebf6b10 --- /dev/null +++ b/tests/unit/distribution/test_request_metrics.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.telemetry import MetricEvent, Telemetry +from llama_stack.core.server.metrics import RequestMetricsMiddleware + + +class TestRequestMetricsMiddleware: + @pytest.fixture + def mock_telemetry(self): + telemetry = AsyncMock(spec=Telemetry) + return telemetry + + @pytest.fixture + def mock_app(self): + app = AsyncMock() + return app + + @pytest.fixture + def middleware(self, mock_app, mock_telemetry): + return RequestMetricsMiddleware(mock_app, mock_telemetry) + + def test_extract_api_from_path(self, middleware): + """Test API extraction from various paths.""" + test_cases = [ + ("/v1/inference/chat/completions", "inference"), + ("/v1/models/list", "models"), + ("/v1/providers", "providers"), + ("/", "unknown"), + ("", "unknown"), + ] + + for path, expected_api in test_cases: + assert middleware._extract_api_from_path(path) == expected_api + + def test_is_excluded_path(self, middleware): + """Test path exclusion logic.""" + excluded_paths = [ + "/docs", + "/redoc", + "/openapi.json", + "/favicon.ico", + "/static/css/style.css", + ] + + non_excluded_paths = [ + "/v1/inference/chat/completions", + "/v1/models/list", + "/health", + ] + + for path in excluded_paths: + assert middleware._is_excluded_path(path) + + for path in non_excluded_paths: + assert not middleware._is_excluded_path(path) + + async def test_middleware_skips_excluded_paths(self, middleware, mock_app): + """Test that middleware skips metrics for excluded paths.""" + scope = { + "type": "http", + "path": "/docs", + "method": "GET", + } + + receive = AsyncMock() + send = AsyncMock() + + await middleware(scope, receive, send) + + # Should call the app directly without tracking metrics + mock_app.assert_called_once_with(scope, receive, send) + # Should not log any metrics + middleware.telemetry.log_event.assert_not_called() + + async def test_middleware_tracks_metrics(self, middleware, mock_telemetry): + """Test that middleware tracks metrics for valid requests.""" + scope = { + "type": "http", + "path": "/v1/inference/chat/completions", + "method": "POST", + } + + receive = AsyncMock() + send_called = False + + async def mock_send(message): + nonlocal send_called + send_called = True + if message["type"] == "http.response.start": + message["status"] = 200 + + # Mock the app to return successfully + async def mock_app(scope, receive, send): + await send({"type": "http.response.start", "status": 200}) + await send({"type": "http.response.body", "body": b"ok"}) + + middleware.app = mock_app + + await middleware(scope, receive, mock_send) + + # Wait for async metric logging + await asyncio.sleep(0.1) + + # Should have logged metrics + assert mock_telemetry.log_event.call_count >= 2 + + # Check that the right metrics were logged + call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list] + + # Should have request count metric + request_count_metric = next( + ( + call + for call in call_args + if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total" + ), + None, + ) + assert request_count_metric is not None + assert request_count_metric.value == 1 + assert request_count_metric.attributes["api"] == "inference" + assert request_count_metric.attributes["status"] == "200" + + # Should have duration metric + duration_metric = next( + ( + call + for call in call_args + if isinstance(call, MetricEvent) and call.metric == "llama_stack_request_duration_seconds" + ), + None, + ) + assert duration_metric is not None + assert duration_metric.attributes["api"] == "inference" + assert duration_metric.attributes["status"] == "200" + + async def test_middleware_handles_errors(self, middleware, mock_telemetry): + """Test that middleware tracks metrics even when errors occur.""" + scope = { + "type": "http", + "path": "/v1/inference/chat/completions", + "method": "POST", + } + + receive = AsyncMock() + send = AsyncMock() + + # Mock the app to raise an exception + async def mock_app(scope, receive, send): + raise ValueError("Test error") + + middleware.app = mock_app + + with pytest.raises(ValueError): + await middleware(scope, receive, send) + + # Wait for async metric logging + await asyncio.sleep(0.1) + + # Should have logged metrics with error status + assert mock_telemetry.log_event.call_count >= 2 + + # Check that error metrics were logged + call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list] + + request_count_metric = next( + ( + call + for call in call_args + if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total" + ), + None, + ) + assert request_count_metric is not None + assert request_count_metric.attributes["status"] == "500" + + async def test_concurrent_requests_tracking(self, middleware, mock_telemetry): + """Test that concurrent requests are tracked correctly.""" + scope = { + "type": "http", + "path": "/v1/inference/chat/completions", + "method": "POST", + } + + receive = AsyncMock() + send = AsyncMock() + + # Mock the app to simulate a slow request + async def mock_app(scope, receive, send): + await asyncio.sleep(1) # Simulate processing time + await send({"type": "http.response.start", "status": 200}) + + middleware.app = mock_app + + # Start multiple concurrent requests + tasks = [] + for _ in range(3): + task = asyncio.create_task(middleware(scope, receive, send)) + tasks.append(task) + + # Wait for all requests to complete + await asyncio.gather(*tasks) + + # Wait for async metric logging + await asyncio.sleep(0.2) + + # Should have logged metrics for all requests + assert mock_telemetry.log_event.call_count >= 6 # 2 metrics per request * 3 requests + + # Check concurrent requests metric + call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list] + concurrent_metrics = [ + call + for call in call_args + if isinstance(call, MetricEvent) and call.metric == "llama_stack_concurrent_requests" + ] + + assert len(concurrent_metrics) >= 3 + # The concurrent count should have been > 0 during the concurrent requests + max_concurrent = max(m.value for m in concurrent_metrics) + assert max_concurrent > 0 + + async def test_middleware_without_telemetry(self): + """Test that middleware works without telemetry configured.""" + mock_app = AsyncMock() + middleware = RequestMetricsMiddleware(mock_app, telemetry=None) + + scope = { + "type": "http", + "path": "/v1/inference/chat/completions", + "method": "POST", + } + + receive = AsyncMock() + send = AsyncMock() + + async def mock_app_impl(scope, receive, send): + await send({"type": "http.response.start", "status": 200}) + + middleware.app = mock_app_impl + + # Should not raise any exceptions + await middleware(scope, receive, send) + + # Should not try to log metrics + # (no telemetry to call, so this is implicit) + + async def test_non_http_requests_ignored(self, middleware, mock_telemetry): + """Test that non-HTTP requests are ignored.""" + scope = { + "type": "lifespan", + "path": "/", + } + + receive = AsyncMock() + send = AsyncMock() + + await middleware(scope, receive, send) + + # Should call the app directly without tracking metrics + middleware.app.assert_called_once_with(scope, receive, send) + # Should not log any metrics + mock_telemetry.log_event.assert_not_called()