feat: Propagate W3C trace context headers from clients (#2153)

# What does this PR do?

This extracts the W3C trace context headers (traceparent and tracestate)
from incoming requests, stuffs them as attributes on the spans we
create, and uses them within the tracing provider implementation to
actually wrap our spans in the proper context.

What this means in practice is that when a client (such as an OpenAI
client) is instrumented to create these traces, we'll continue that
distributed trace within Llama Stack as opposed to creating our own root
span that breaks the distributed trace between client and server.

It's slightly awkward to do this in Llama Stack because our Tracing API
knows nothing about opentelemetry, W3C trace headers, etc - that's only
knowledge the specific provider implementation has. So, that's why the
trace headers get extracted by in the server code but not actually used
until the provider implementation to form the proper context.

This also centralizes how we were adding the `__root__` and
`__root_span__` attributes, as those two were being added in different
parts of the code instead of from a single place.

Closes #2097

## Test Plan

This was tested manually using the helpful scripts from #2097. I
verified that Llama Stack properly joined the client's span when the
client was instrumented for distributed tracing, and that Llama Stack
properly started its own root span when the incoming request was not
part of an existing trace.

Here's an example of the joined spans:

![Screenshot 2025-05-13 at 8 46
09 AM](https://github.com/user-attachments/assets/dbefda28-9faa-4339-a08d-1441efefc149)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-05-19 21:56:54 -04:00 committed by GitHub
parent 82778ecbb0
commit 6d20b720b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 33 additions and 4 deletions

View file

@ -280,7 +280,18 @@ class TracingMiddleware:
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":

View file

@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from llama_stack.apis.telemetry import (
Event,
@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
)
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
from .config import TelemetryConfig, TelemetrySink
@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
# Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None)
if traceparent:
# If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None)
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
else:
event.attributes["__root_span__"] = "true"
elif traceparent:
carrier = {
"traceparent": traceparent,
"tracestate": tracestate,
}
context = TraceContextTextMapPropagator().extract(carrier=carrier)
span = tracer.start_span(
name=event.payload.name,

View file

@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core")
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
def trace_id_to_str(trace_id: int) -> str:
"""Convenience trace ID formatting method
@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {})
context.push_span(name, attributes)
CURRENT_TRACE_CONTEXT.set(context)
return context