mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Merge 49b729b30a
into 5b312a80b9
This commit is contained in:
commit
204195f688
4 changed files with 433 additions and 0 deletions
|
@ -37,6 +37,9 @@ The following metrics are automatically generated for each inference request:
|
||||||
| `llama_stack_prompt_tokens_total` | Counter | `tokens` | Number of tokens in the input prompt | `model_id`, `provider_id` |
|
| `llama_stack_prompt_tokens_total` | Counter | `tokens` | Number of tokens in the input prompt | `model_id`, `provider_id` |
|
||||||
| `llama_stack_completion_tokens_total` | Counter | `tokens` | Number of tokens in the generated response | `model_id`, `provider_id` |
|
| `llama_stack_completion_tokens_total` | Counter | `tokens` | Number of tokens in the generated response | `model_id`, `provider_id` |
|
||||||
| `llama_stack_tokens_total` | Counter | `tokens` | Total tokens used (prompt + completion) | `model_id`, `provider_id` |
|
| `llama_stack_tokens_total` | Counter | `tokens` | Total tokens used (prompt + completion) | `model_id`, `provider_id` |
|
||||||
|
| `llama_stack_requests_total` | Counter | `requests` | Total number of requests | `api`, `status` |
|
||||||
|
| `llama_stack_request_duration_seconds` | Gauge | `seconds` | Request duration | `api`, `status` |
|
||||||
|
| `llama_stack_concurrent_requests` | Gauge | `requests` | Number of concurrent requests | `api` |
|
||||||
|
|
||||||
#### Metric Generation Flow
|
#### Metric Generation Flow
|
||||||
|
|
||||||
|
|
153
llama_stack/core/server/metrics.py
Normal file
153
llama_stack/core/server/metrics.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="server")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestMetricsMiddleware:
|
||||||
|
"""
|
||||||
|
middleware that tracks request-level metrics including:
|
||||||
|
- Request counts by API and status
|
||||||
|
- Request duration
|
||||||
|
- Concurrent requests
|
||||||
|
|
||||||
|
Metrics are logged to the telemetry system and can be exported to Prometheus
|
||||||
|
via OpenTelemetry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, telemetry: Telemetry | None = None):
|
||||||
|
self.app = app
|
||||||
|
self.telemetry = telemetry
|
||||||
|
self.concurrent_requests = 0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# FastAPI built-in paths that should be excluded from metrics
|
||||||
|
self.excluded_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||||
|
|
||||||
|
def _extract_api_from_path(self, path: str) -> str:
|
||||||
|
"""Extract the API name from the request path."""
|
||||||
|
# Remove version prefix if present
|
||||||
|
if path.startswith("/v1/"):
|
||||||
|
path = path[4:]
|
||||||
|
# Extract the first path segment as the API name
|
||||||
|
segments = path.strip("/").split("/")
|
||||||
|
if (
|
||||||
|
segments and segments[0]
|
||||||
|
): # Check that first segment is not empty, this will return the API rather than the action `["datasets", "list"]`
|
||||||
|
return segments[0]
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _is_excluded_path(self, path: str) -> bool:
|
||||||
|
"""Check if the path should be excluded from metrics."""
|
||||||
|
return any(path.startswith(excluded) for excluded in self.excluded_paths)
|
||||||
|
|
||||||
|
async def _log_request_metrics(self, api: str, status: str, duration: float, concurrent_count: int):
|
||||||
|
"""Log request metrics to the telemetry system."""
|
||||||
|
if not self.telemetry:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get current span if available
|
||||||
|
span = get_current_span()
|
||||||
|
trace_id = span.trace_id if span else ""
|
||||||
|
span_id = span.span_id if span else ""
|
||||||
|
|
||||||
|
# Log request count (send increment of 1 for each request)
|
||||||
|
await self.telemetry.log_event(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=trace_id,
|
||||||
|
span_id=span_id,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
metric="llama_stack_requests_total",
|
||||||
|
value=1, # Send increment instead of total so the provider handles incrementation
|
||||||
|
unit="requests",
|
||||||
|
attributes={
|
||||||
|
"api": api,
|
||||||
|
"status": status,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log request duration
|
||||||
|
await self.telemetry.log_event(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=trace_id,
|
||||||
|
span_id=span_id,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
metric="llama_stack_request_duration_seconds",
|
||||||
|
value=duration,
|
||||||
|
unit="seconds",
|
||||||
|
attributes={"api": api, "status": status},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log concurrent requests (as a gauge)
|
||||||
|
await self.telemetry.log_event(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=trace_id,
|
||||||
|
span_id=span_id,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
metric="llama_stack_concurrent_requests",
|
||||||
|
value=float(concurrent_count), # Convert to float for gauge
|
||||||
|
unit="requests",
|
||||||
|
attributes={"api": api},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Failed to log request metrics: {e}")
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope.get("type") != "http":
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
path = scope.get("path", "")
|
||||||
|
|
||||||
|
# Skip metrics for excluded paths
|
||||||
|
if self._is_excluded_path(path):
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
api = self._extract_api_from_path(path)
|
||||||
|
start_time = time.time()
|
||||||
|
status = 200
|
||||||
|
|
||||||
|
# Track concurrent requests
|
||||||
|
async with self._lock:
|
||||||
|
self.concurrent_requests += 1
|
||||||
|
|
||||||
|
# Create a wrapper to capture the response status
|
||||||
|
async def send_wrapper(message):
|
||||||
|
if message.get("type") == "http.response.start":
|
||||||
|
nonlocal status
|
||||||
|
status = message.get("status", 200)
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.app(scope, receive, send_wrapper)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Set status to 500 for any unhandled exception
|
||||||
|
status = 500
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
|
||||||
|
# Capture concurrent count before decrementing
|
||||||
|
async with self._lock:
|
||||||
|
concurrent_count = self.concurrent_requests
|
||||||
|
self.concurrent_requests -= 1
|
||||||
|
|
||||||
|
# Log metrics asynchronously to avoid blocking the response
|
||||||
|
asyncio.create_task(self._log_request_metrics(api, str(status), duration, concurrent_count))
|
|
@ -77,6 +77,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
|
from .metrics import RequestMetricsMiddleware
|
||||||
from .quota import QuotaMiddleware
|
from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
@ -540,6 +541,10 @@ def main(args: argparse.Namespace | None = None):
|
||||||
app.__llama_stack_impls__ = impls
|
app.__llama_stack_impls__ = impls
|
||||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||||
|
|
||||||
|
# Add request metrics middleware
|
||||||
|
telemetry_impl = impls.get(Api.telemetry) if Api.telemetry in impls else None
|
||||||
|
app.add_middleware(RequestMetricsMiddleware, telemetry=telemetry_impl)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Configure SSL if certificates are provided
|
# Configure SSL if certificates are provided
|
||||||
|
|
272
tests/unit/distribution/test_request_metrics.py
Normal file
272
tests/unit/distribution/test_request_metrics.py
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||||
|
from llama_stack.core.server.metrics import RequestMetricsMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestMetricsMiddleware:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_telemetry(self):
|
||||||
|
telemetry = AsyncMock(spec=Telemetry)
|
||||||
|
return telemetry
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_app(self):
|
||||||
|
app = AsyncMock()
|
||||||
|
return app
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def middleware(self, mock_app, mock_telemetry):
|
||||||
|
return RequestMetricsMiddleware(mock_app, mock_telemetry)
|
||||||
|
|
||||||
|
def test_extract_api_from_path(self, middleware):
|
||||||
|
"""Test API extraction from various paths."""
|
||||||
|
test_cases = [
|
||||||
|
("/v1/inference/chat/completions", "inference"),
|
||||||
|
("/v1/models/list", "models"),
|
||||||
|
("/v1/providers", "providers"),
|
||||||
|
("/", "unknown"),
|
||||||
|
("", "unknown"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for path, expected_api in test_cases:
|
||||||
|
assert middleware._extract_api_from_path(path) == expected_api
|
||||||
|
|
||||||
|
def test_is_excluded_path(self, middleware):
|
||||||
|
"""Test path exclusion logic."""
|
||||||
|
excluded_paths = [
|
||||||
|
"/docs",
|
||||||
|
"/redoc",
|
||||||
|
"/openapi.json",
|
||||||
|
"/favicon.ico",
|
||||||
|
"/static/css/style.css",
|
||||||
|
]
|
||||||
|
|
||||||
|
non_excluded_paths = [
|
||||||
|
"/v1/inference/chat/completions",
|
||||||
|
"/v1/models/list",
|
||||||
|
"/health",
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in excluded_paths:
|
||||||
|
assert middleware._is_excluded_path(path)
|
||||||
|
|
||||||
|
for path in non_excluded_paths:
|
||||||
|
assert not middleware._is_excluded_path(path)
|
||||||
|
|
||||||
|
async def test_middleware_skips_excluded_paths(self, middleware, mock_app):
|
||||||
|
"""Test that middleware skips metrics for excluded paths."""
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/docs",
|
||||||
|
"method": "GET",
|
||||||
|
}
|
||||||
|
|
||||||
|
receive = AsyncMock()
|
||||||
|
send = AsyncMock()
|
||||||
|
|
||||||
|
await middleware(scope, receive, send)
|
||||||
|
|
||||||
|
# Should call the app directly without tracking metrics
|
||||||
|
mock_app.assert_called_once_with(scope, receive, send)
|
||||||
|
# Should not log any metrics
|
||||||
|
middleware.telemetry.log_event.assert_not_called()
|
||||||
|
|
||||||
|
async def test_middleware_tracks_metrics(self, middleware, mock_telemetry):
|
||||||
|
"""Test that middleware tracks metrics for valid requests."""
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/v1/inference/chat/completions",
|
||||||
|
"method": "POST",
|
||||||
|
}
|
||||||
|
|
||||||
|
receive = AsyncMock()
|
||||||
|
send_called = False
|
||||||
|
|
||||||
|
async def mock_send(message):
|
||||||
|
nonlocal send_called
|
||||||
|
send_called = True
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
message["status"] = 200
|
||||||
|
|
||||||
|
# Mock the app to return successfully
|
||||||
|
async def mock_app(scope, receive, send):
|
||||||
|
await send({"type": "http.response.start", "status": 200})
|
||||||
|
await send({"type": "http.response.body", "body": b"ok"})
|
||||||
|
|
||||||
|
middleware.app = mock_app
|
||||||
|
|
||||||
|
await middleware(scope, receive, mock_send)
|
||||||
|
|
||||||
|
# Wait for async metric logging
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Should have logged metrics
|
||||||
|
assert mock_telemetry.log_event.call_count >= 2
|
||||||
|
|
||||||
|
# Check that the right metrics were logged
|
||||||
|
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||||
|
|
||||||
|
# Should have request count metric
|
||||||
|
request_count_metric = next(
|
||||||
|
(
|
||||||
|
call
|
||||||
|
for call in call_args
|
||||||
|
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert request_count_metric is not None
|
||||||
|
assert request_count_metric.value == 1
|
||||||
|
assert request_count_metric.attributes["api"] == "inference"
|
||||||
|
assert request_count_metric.attributes["status"] == "200"
|
||||||
|
|
||||||
|
# Should have duration metric
|
||||||
|
duration_metric = next(
|
||||||
|
(
|
||||||
|
call
|
||||||
|
for call in call_args
|
||||||
|
if isinstance(call, MetricEvent) and call.metric == "llama_stack_request_duration_seconds"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert duration_metric is not None
|
||||||
|
assert duration_metric.attributes["api"] == "inference"
|
||||||
|
assert duration_metric.attributes["status"] == "200"
|
||||||
|
|
||||||
|
async def test_middleware_handles_errors(self, middleware, mock_telemetry):
|
||||||
|
"""Test that middleware tracks metrics even when errors occur."""
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/v1/inference/chat/completions",
|
||||||
|
"method": "POST",
|
||||||
|
}
|
||||||
|
|
||||||
|
receive = AsyncMock()
|
||||||
|
send = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the app to raise an exception
|
||||||
|
async def mock_app(scope, receive, send):
|
||||||
|
raise ValueError("Test error")
|
||||||
|
|
||||||
|
middleware.app = mock_app
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await middleware(scope, receive, send)
|
||||||
|
|
||||||
|
# Wait for async metric logging
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Should have logged metrics with error status
|
||||||
|
assert mock_telemetry.log_event.call_count >= 2
|
||||||
|
|
||||||
|
# Check that error metrics were logged
|
||||||
|
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||||
|
|
||||||
|
request_count_metric = next(
|
||||||
|
(
|
||||||
|
call
|
||||||
|
for call in call_args
|
||||||
|
if isinstance(call, MetricEvent) and call.metric == "llama_stack_requests_total"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert request_count_metric is not None
|
||||||
|
assert request_count_metric.attributes["status"] == "500"
|
||||||
|
|
||||||
|
async def test_concurrent_requests_tracking(self, middleware, mock_telemetry):
|
||||||
|
"""Test that concurrent requests are tracked correctly."""
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/v1/inference/chat/completions",
|
||||||
|
"method": "POST",
|
||||||
|
}
|
||||||
|
|
||||||
|
receive = AsyncMock()
|
||||||
|
send = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the app to simulate a slow request
|
||||||
|
async def mock_app(scope, receive, send):
|
||||||
|
await asyncio.sleep(1) # Simulate processing time
|
||||||
|
await send({"type": "http.response.start", "status": 200})
|
||||||
|
|
||||||
|
middleware.app = mock_app
|
||||||
|
|
||||||
|
# Start multiple concurrent requests
|
||||||
|
tasks = []
|
||||||
|
for _ in range(3):
|
||||||
|
task = asyncio.create_task(middleware(scope, receive, send))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all requests to complete
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Wait for async metric logging
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# Should have logged metrics for all requests
|
||||||
|
assert mock_telemetry.log_event.call_count >= 6 # 2 metrics per request * 3 requests
|
||||||
|
|
||||||
|
# Check concurrent requests metric
|
||||||
|
call_args = [call.args[0] for call in mock_telemetry.log_event.call_args_list]
|
||||||
|
concurrent_metrics = [
|
||||||
|
call
|
||||||
|
for call in call_args
|
||||||
|
if isinstance(call, MetricEvent) and call.metric == "llama_stack_concurrent_requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(concurrent_metrics) >= 3
|
||||||
|
# The concurrent count should have been > 0 during the concurrent requests
|
||||||
|
max_concurrent = max(m.value for m in concurrent_metrics)
|
||||||
|
assert max_concurrent > 0
|
||||||
|
|
||||||
|
async def test_middleware_without_telemetry(self):
|
||||||
|
"""Test that middleware works without telemetry configured."""
|
||||||
|
mock_app = AsyncMock()
|
||||||
|
middleware = RequestMetricsMiddleware(mock_app, telemetry=None)
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/v1/inference/chat/completions",
|
||||||
|
"method": "POST",
|
||||||
|
}
|
||||||
|
|
||||||
|
receive = AsyncMock()
|
||||||
|
send = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_app_impl(scope, receive, send):
|
||||||
|
await send({"type": "http.response.start", "status": 200})
|
||||||
|
|
||||||
|
middleware.app = mock_app_impl
|
||||||
|
|
||||||
|
# Should not raise any exceptions
|
||||||
|
await middleware(scope, receive, send)
|
||||||
|
|
||||||
|
# Should not try to log metrics
|
||||||
|
# (no telemetry to call, so this is implicit)
|
||||||
|
|
||||||
|
async def test_non_http_requests_ignored(self, middleware, mock_telemetry):
|
||||||
|
"""Test that non-HTTP requests are ignored."""
|
||||||
|
scope = {
|
||||||
|
"type": "lifespan",
|
||||||
|
"path": "/",
|
||||||
|
}
|
||||||
|
|
||||||
|
receive = AsyncMock()
|
||||||
|
send = AsyncMock()
|
||||||
|
|
||||||
|
await middleware(scope, receive, send)
|
||||||
|
|
||||||
|
# Should call the app directly without tracking metrics
|
||||||
|
middleware.app.assert_called_once_with(scope, receive, send)
|
||||||
|
# Should not log any metrics
|
||||||
|
mock_telemetry.log_event.assert_not_called()
|
Loading…
Add table
Add a link
Reference in a new issue