From 6d20b720b872d120ef4ade58b66b56bafb2f7302 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 19 May 2025 21:56:54 -0400 Subject: [PATCH] feat: Propagate W3C trace context headers from clients (#2153) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This extracts the W3C trace context headers (traceparent and tracestate) from incoming requests, stuffs them as attributes on the spans we create, and uses them within the tracing provider implementation to actually wrap our spans in the proper context. What this means in practice is that when a client (such as an OpenAI client) is instrumented to create these traces, we'll continue that distributed trace within Llama Stack as opposed to creating our own root span that breaks the distributed trace between client and server. It's slightly awkward to do this in Llama Stack because our Tracing API knows nothing about opentelemetry, W3C trace headers, etc - that's only knowledge the specific provider implementation has. So, that's why the trace headers get extracted by in the server code but not actually used until the provider implementation to form the proper context. This also centralizes how we were adding the `__root__` and `__root_span__` attributes, as those two were being added in different parts of the code instead of from a single place. Closes #2097 ## Test Plan This was tested manually using the helpful scripts from #2097. I verified that Llama Stack properly joined the client's span when the client was instrumented for distributed tracing, and that Llama Stack properly started its own root span when the incoming request was not part of an existing trace. Here's an example of the joined spans: ![Screenshot 2025-05-13 at 8 46 09 AM](https://github.com/user-attachments/assets/dbefda28-9faa-4339-a08d-1441efefc149) Signed-off-by: Ben Browning --- llama_stack/distribution/server/server.py | 13 ++++++++++++- .../telemetry/meta_reference/telemetry.py | 19 +++++++++++++++++-- .../providers/utils/telemetry/tracing.py | 5 ++++- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8cc028769..e25bf0817 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -280,7 +280,18 @@ class TracingMiddleware: logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) - trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_context = await start_trace(trace_path, trace_attributes) async def send_with_trace_id(message): if message["type"] == "http.response.start": diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 67362dd36..1bc979894 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from llama_stack.apis.telemetry import ( Event, @@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor ) from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore +from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS from .config import TelemetryConfig, TelemetrySink @@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): event.attributes = {} event.attributes["__ttl__"] = ttl_seconds + # Extract these W3C trace context attributes so they are not written to + # underlying storage, as we just need them to propagate the trace context. + traceparent = event.attributes.pop("traceparent", None) + tracestate = event.attributes.pop("tracestate", None) + if traceparent: + # If we have a traceparent header value, we're not the root span. + for root_attribute in ROOT_SPAN_MARKERS: + event.attributes.pop(root_attribute, None) + if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates if span_id in _GLOBAL_STORAGE["active_spans"]: @@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) context = trace.set_span_in_context(parent_span) - else: - event.attributes["__root_span__"] = "true" + elif traceparent: + carrier = { + "traceparent": traceparent, + "tracestate": tracestate, + } + context = TraceContextTextMapPropagator().extract(carrier=carrier) span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 0f4fdd0d8..4edfa6516 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core") INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 +ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] + def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method @@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) - context.push_span(name, {"__root__": True, **(attributes or {})}) + attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {}) + context.push_span(name, attributes) CURRENT_TRACE_CONTEXT.set(context) return context