diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8cc028769..e25bf0817 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -280,7 +280,18 @@ class TracingMiddleware: logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) - trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_context = await start_trace(trace_path, trace_attributes) async def send_with_trace_id(message): if message["type"] == "http.response.start": diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 67362dd36..1bc979894 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from llama_stack.apis.telemetry import ( Event, @@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor ) from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore +from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS from .config import TelemetryConfig, TelemetrySink @@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): event.attributes = {} event.attributes["__ttl__"] = ttl_seconds + # Extract these W3C trace context attributes so they are not written to + # underlying storage, as we just need them to propagate the trace context. + traceparent = event.attributes.pop("traceparent", None) + tracestate = event.attributes.pop("tracestate", None) + if traceparent: + # If we have a traceparent header value, we're not the root span. + for root_attribute in ROOT_SPAN_MARKERS: + event.attributes.pop(root_attribute, None) + if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates if span_id in _GLOBAL_STORAGE["active_spans"]: @@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) context = trace.set_span_in_context(parent_span) - else: - event.attributes["__root_span__"] = "true" + elif traceparent: + carrier = { + "traceparent": traceparent, + "tracestate": tracestate, + } + context = TraceContextTextMapPropagator().extract(carrier=carrier) span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 0f4fdd0d8..4edfa6516 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core") INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 +ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] + def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method @@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) - context.push_span(name, {"__root__": True, **(attributes or {})}) + attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {}) + context.push_span(name, attributes) CURRENT_TRACE_CONTEXT.set(context) return context