mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat: api level request metrics via middleware
add RequestMetricsMiddleware which tracks key metrics related to each request the LLS server will recieve: 1. llama_stack_requests_total: tracks the total amount of requests the server has processed 2. llama_stack_request_duration_seconds: tracks the duration of each request 3. llama_stack_concurrent_requests: tracks concurrently processed requests by the server The usage of a middleware allows this to be done on the server level without having to add custom handling to each router like the inference router has today for its API specific metrics. Also, add some unit tests for this functionality resolves #2597 Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
dbfc15123e
commit
49b729b30a
4 changed files with 433 additions and 0 deletions
153
llama_stack/core/server/metrics.py
Normal file
153
llama_stack/core/server/metrics.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class RequestMetricsMiddleware:
|
||||
"""
|
||||
middleware that tracks request-level metrics including:
|
||||
- Request counts by API and status
|
||||
- Request duration
|
||||
- Concurrent requests
|
||||
|
||||
Metrics are logged to the telemetry system and can be exported to Prometheus
|
||||
via OpenTelemetry.
|
||||
"""
|
||||
|
||||
def __init__(self, app, telemetry: Telemetry | None = None):
|
||||
self.app = app
|
||||
self.telemetry = telemetry
|
||||
self.concurrent_requests = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# FastAPI built-in paths that should be excluded from metrics
|
||||
self.excluded_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
def _extract_api_from_path(self, path: str) -> str:
|
||||
"""Extract the API name from the request path."""
|
||||
# Remove version prefix if present
|
||||
if path.startswith("/v1/"):
|
||||
path = path[4:]
|
||||
# Extract the first path segment as the API name
|
||||
segments = path.strip("/").split("/")
|
||||
if (
|
||||
segments and segments[0]
|
||||
): # Check that first segment is not empty, this will return the API rather than the action `["datasets", "list"]`
|
||||
return segments[0]
|
||||
return "unknown"
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
"""Check if the path should be excluded from metrics."""
|
||||
return any(path.startswith(excluded) for excluded in self.excluded_paths)
|
||||
|
||||
async def _log_request_metrics(self, api: str, status: str, duration: float, concurrent_count: int):
|
||||
"""Log request metrics to the telemetry system."""
|
||||
if not self.telemetry:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get current span if available
|
||||
span = get_current_span()
|
||||
trace_id = span.trace_id if span else ""
|
||||
span_id = span.span_id if span else ""
|
||||
|
||||
# Log request count (send increment of 1 for each request)
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
metric="llama_stack_requests_total",
|
||||
value=1, # Send increment instead of total so the provider handles incrementation
|
||||
unit="requests",
|
||||
attributes={
|
||||
"api": api,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Log request duration
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
metric="llama_stack_request_duration_seconds",
|
||||
value=duration,
|
||||
unit="seconds",
|
||||
attributes={"api": api, "status": status},
|
||||
)
|
||||
)
|
||||
|
||||
# Log concurrent requests (as a gauge)
|
||||
await self.telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
metric="llama_stack_concurrent_requests",
|
||||
value=float(concurrent_count), # Convert to float for gauge
|
||||
unit="requests",
|
||||
attributes={"api": api},
|
||||
)
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to log request metrics: {e}")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Skip metrics for excluded paths
|
||||
if self._is_excluded_path(path):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
api = self._extract_api_from_path(path)
|
||||
start_time = time.time()
|
||||
status = 200
|
||||
|
||||
# Track concurrent requests
|
||||
async with self._lock:
|
||||
self.concurrent_requests += 1
|
||||
|
||||
# Create a wrapper to capture the response status
|
||||
async def send_wrapper(message):
|
||||
if message.get("type") == "http.response.start":
|
||||
nonlocal status
|
||||
status = message.get("status", 200)
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_wrapper)
|
||||
|
||||
except Exception:
|
||||
# Set status to 500 for any unhandled exception
|
||||
status = 500
|
||||
raise
|
||||
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Capture concurrent count before decrementing
|
||||
async with self._lock:
|
||||
concurrent_count = self.concurrent_requests
|
||||
self.concurrent_requests -= 1
|
||||
|
||||
# Log metrics asynchronously to avoid blocking the response
|
||||
asyncio.create_task(self._log_request_metrics(api, str(status), duration, concurrent_count))
|
|
@ -76,6 +76,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .metrics import RequestMetricsMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
@ -536,6 +537,10 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
# Add request metrics middleware
|
||||
telemetry_impl = impls.get(Api.telemetry) if Api.telemetry in impls else None
|
||||
app.add_middleware(RequestMetricsMiddleware, telemetry=telemetry_impl)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue