llama-stack-mirror/llama_stack/providers/inline/telemetry/otel/otel.py

301 lines
13 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import time
from fastapi import FastAPI
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.metrics import Counter, Histogram
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
SimpleSpanProcessor,
SpanExporter,
SpanExportResult,
)
from sqlalchemy import Engine
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from llama_stack.core.telemetry.telemetry import TelemetryProvider
from llama_stack.log import get_logger
from .config import OTelTelemetryConfig
logger = get_logger(name=__name__, category="telemetry::otel")
class StreamingMetricsMiddleware:
"""
Pure ASGI middleware to track streaming response metrics.
This follows Starlette best practices by implementing pure ASGI,
which is more efficient and less prone to bugs than BaseHTTPMiddleware.
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
logger.debug(f"StreamingMetricsMiddleware called for {scope.get('method')} {scope.get('path')}")
start_time = time.time()
# Track if this is a streaming response
is_streaming = False
async def send_wrapper(message: Message):
nonlocal is_streaming
# Detect streaming responses by headers
if message["type"] == "http.response.start":
headers = message.get("headers", [])
for name, value in headers:
if name == b"content-type" and b"text/event-stream" in value:
is_streaming = True
# Add streaming attribute to current span
current_span = trace.get_current_span()
if current_span and current_span.is_recording():
current_span.set_attribute("http.response.is_streaming", True)
break
# Record total duration when response body completes
elif message["type"] == "http.response.body" and not message.get("more_body", False):
if is_streaming:
current_span = trace.get_current_span()
if current_span and current_span.is_recording():
total_duration_ms = (time.time() - start_time) * 1000
current_span.set_attribute("http.streaming.total_duration_ms", total_duration_ms)
await send(message)
await self.app(scope, receive, send_wrapper)
class MetricsSpanExporter(SpanExporter):
"""Records HTTP metrics from span data."""
def __init__(
self,
request_duration: Histogram,
streaming_duration: Histogram,
streaming_requests: Counter,
request_count: Counter,
):
self.request_duration = request_duration
self.streaming_duration = streaming_duration
self.streaming_requests = streaming_requests
self.request_count = request_count
def export(self, spans):
logger.debug(f"MetricsSpanExporter.export called with {len(spans)} spans")
for span in spans:
if not span.attributes or not span.attributes.get("http.method"):
continue
logger.debug(f"Processing span: {span.name}")
if span.end_time is None or span.start_time is None:
continue
# Calculate time-to-first-byte duration
duration_ns = span.end_time - span.start_time
duration_ms = duration_ns / 1_000_000
# Check if this was a streaming response
is_streaming = span.attributes.get("http.response.is_streaming", False)
attributes = {
"http.method": str(span.attributes.get("http.method", "UNKNOWN")),
"http.route": str(span.attributes.get("http.route", span.attributes.get("http.target", "/"))),
"http.status_code": str(span.attributes.get("http.status_code", 0)),
}
# set distributed trace attributes
if span.attributes.get("trace_id"):
attributes["trace_id"] = str(span.attributes.get("trace_id"))
if span.attributes.get("span_id"):
attributes["span_id"] = str(span.attributes.get("span_id"))
# Record request count and duration
logger.debug(f"Recording metrics: duration={duration_ms}ms, attributes={attributes}")
self.request_count.add(1, attributes)
self.request_duration.record(duration_ms, attributes)
logger.debug("Metrics recorded successfully")
# For streaming, record separately
if is_streaming:
logger.debug(f"MetricsSpanExporter: Recording streaming metrics for {span.name}")
self.streaming_requests.add(1, attributes)
# If full streaming duration is available
stream_total_duration = span.attributes.get("http.streaming.total_duration_ms")
if stream_total_duration and isinstance(stream_total_duration, int | float):
logger.debug(f"MetricsSpanExporter: Recording streaming duration: {stream_total_duration}ms")
self.streaming_duration.record(float(stream_total_duration), attributes)
else:
logger.warning(
"MetricsSpanExporter: Streaming span has no http.streaming.total_duration_ms attribute"
)
return SpanExportResult.SUCCESS
def shutdown(self):
pass
# NOTE: DO NOT ALLOW LLM TO MODIFY THIS WITHOUT TESTING AND SUPERVISION: it frequently breaks otel integrations
class OTelTelemetryProvider(TelemetryProvider):
"""
A simple Open Telemetry native telemetry provider.
"""
config: OTelTelemetryConfig
def model_post_init(self, __context):
"""Initialize provider after Pydantic validation."""
# Do not fail the application, but warn the user if the endpoints are not set properly.
if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if not os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"):
logger.warning(
"OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_TRACES_ENDPOINT is not set. Traces will not be exported."
)
if not os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"):
logger.warning(
"OTEL_EXPORTER_OTLP_ENDPOINT or OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is not set. Metrics will not be exported."
)
# Respect OTEL design standards where environment variables get highest precedence
service_name = os.environ.get("OTEL_SERVICE_NAME")
if not service_name:
service_name = self.config.service_name
# Create resource with service name
resource = Resource.create({"service.name": service_name})
# Configure the tracer provider (always, since llama stack run spawns subprocess without opentelemetry-instrument)
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider)
# Configure OTLP span exporter
otlp_span_exporter = OTLPSpanExporter()
# Add span processor (simple for immediate export, batch for performance)
span_processor_type = os.environ.get("OTEL_SPAN_PROCESSOR", "batch")
if span_processor_type == "batch":
tracer_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
else:
tracer_provider.add_span_processor(SimpleSpanProcessor(otlp_span_exporter))
# Configure meter provider with OTLP exporter for metrics
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
logger.info(
f"Initialized OpenTelemetry provider with service.name={service_name}, span_processor={span_processor_type}"
)
def fastapi_middleware(self, app: FastAPI):
"""
Instrument FastAPI with OTel for automatic tracing and metrics.
Captures telemetry for both regular and streaming HTTP requests:
- Distributed traces (via FastAPIInstrumentor)
- HTTP request metrics (count, duration, status)
- Streaming-specific metrics (time-to-first-byte, total stream duration)
"""
# Create meter for HTTP metrics
meter = metrics.get_meter("llama_stack.http.server")
# HTTP Metrics following OTel semantic conventions
# https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
request_duration = meter.create_histogram(
"http.server.request.duration",
unit="ms",
description="Duration of HTTP requests (time-to-first-byte for streaming)",
)
streaming_duration = meter.create_histogram(
"http.server.streaming.duration",
unit="ms",
description="Total duration of streaming responses (from start to stream completion)",
)
request_count = meter.create_counter(
"http.server.request.count", unit="requests", description="Total number of HTTP requests"
)
streaming_requests = meter.create_counter(
"http.server.streaming.count", unit="requests", description="Number of streaming requests"
)
# Hook to enrich spans and record initial metrics
def server_request_hook(span, scope):
"""
Called by FastAPIInstrumentor for each request.
This only reads from scope (ASGI dict), never touches request body.
Safe to use without interfering with body parsing.
"""
method = scope.get("method", "UNKNOWN")
path = scope.get("path", "/")
# Add custom attributes
span.set_attribute("service.component", "llama-stack-api")
span.set_attribute("http.request", path)
span.set_attribute("http.method", method)
attributes = {
"http.request": path,
"http.method": method,
"trace_id": span.attributes.get("trace_id", ""),
"span_id": span.attributes.get("span_id", ""),
}
request_count.add(1, attributes)
logger.debug(f"server_request_hook: recorded request_count for {method} {path}, attributes={attributes}")
# NOTE: This is called BEFORE routes are added to the app
# FastAPIInstrumentor.instrument_app() patches build_middleware_stack(),
# which will be called on first request (after routes are added), so hooks should work.
logger.debug("Instrumenting FastAPI (routes will be added later)")
FastAPIInstrumentor.instrument_app(
app,
server_request_hook=server_request_hook,
)
logger.debug(f"FastAPI instrumented: {getattr(app, '_is_instrumented_by_opentelemetry', False)}")
# Add pure ASGI middleware for streaming metrics (always add, regardless of instrumentation)
app.add_middleware(StreamingMetricsMiddleware)
# Add metrics span processor
provider = trace.get_tracer_provider()
logger.debug(f"TracerProvider: {provider}")
if isinstance(provider, TracerProvider):
metrics_exporter = MetricsSpanExporter(
request_duration=request_duration,
streaming_duration=streaming_duration,
streaming_requests=streaming_requests,
request_count=request_count,
)
provider.add_span_processor(BatchSpanProcessor(metrics_exporter))
logger.debug("Added MetricsSpanExporter as BatchSpanProcessor")
else:
logger.warning(
f"TracerProvider is not TracerProvider instance, it's {type(provider)}. MetricsSpanExporter not added."
)