llama-stack-mirror/llama_stack/telemetry/middleware.py
2025-10-07 17:25:54 -04:00

120 lines
4.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from opentelemetry import trace
from opentelemetry.metrics import Counter, Histogram
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="instrumentation::otel")
class StreamingMetricsMiddleware:
"""
ASGI middleware to track streaming response metrics.
:param app: The ASGI app to wrap
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
logger.debug(f"StreamingMetricsMiddleware called for {scope.get('method')} {scope.get('path')}")
start_time = time.time()
is_streaming = False
async def send_wrapper(message: Message):
nonlocal is_streaming
# Detect streaming responses by headers
if message["type"] == "http.response.start":
headers = message.get("headers", [])
for name, value in headers:
if name == b"content-type" and b"text/event-stream" in value:
is_streaming = True
# Add streaming attribute to current span
current_span = trace.get_current_span()
if current_span and current_span.is_recording():
current_span.set_attribute("http.response.is_streaming", True)
break
# Record total duration when response body completes
elif message["type"] == "http.response.body" and not message.get("more_body", False):
if is_streaming:
current_span = trace.get_current_span()
if current_span and current_span.is_recording():
total_duration_ms = (time.time() - start_time) * 1000
current_span.set_attribute("http.streaming.total_duration_ms", total_duration_ms)
await send(message)
await self.app(scope, receive, send_wrapper)
class MetricsSpanExporter(SpanExporter):
"""
Records additional custom HTTP metrics during otel span export.
:param request_duration: Histogram to record request duration
:param streaming_duration: Histogram to record streaming duration
:param streaming_requests: Counter to record streaming requests
:param request_count: Counter to record request count
"""
def __init__(
self,
request_duration: Histogram,
streaming_duration: Histogram,
streaming_requests: Counter,
request_count: Counter,
):
self.request_duration = request_duration
self.streaming_duration = streaming_duration
self.streaming_requests = streaming_requests
self.request_count = request_count
def export(self, spans):
for span in spans:
if not span.attributes or not span.attributes.get("http.method"):
continue
logger.debug(f"Processing span: {span.name}")
if span.end_time is None or span.start_time is None:
continue
duration_ms = (span.end_time - span.start_time) / 1_000_000
is_streaming = span.attributes.get("http.response.is_streaming", False)
attributes = {
"http.method": str(span.attributes.get("http.method", "UNKNOWN")),
"http.route": str(span.attributes.get("http.route", span.attributes.get("http.target", "/"))),
"http.status_code": str(span.attributes.get("http.status_code", 0)),
"trace_id": str(span.attributes.get("trace_id", "")),
"span_id": str(span.attributes.get("span_id", "")),
}
# Record request count and duration
logger.debug(f"Recording metrics: duration={duration_ms}ms, attributes={attributes}")
self.request_count.add(1, attributes)
self.request_duration.record(duration_ms, attributes)
if is_streaming:
logger.debug(f"MetricsSpanExporter: Recording streaming metrics for {span.name}")
self.streaming_requests.add(1, attributes)
stream_duration = span.attributes.get("http.streaming.total_duration_ms")
if stream_duration and isinstance(stream_duration, (int | float)):
self.streaming_duration.record(float(stream_duration), attributes)
return SpanExportResult.SUCCESS