mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Propagate W3C trace context headers from clients (#2153)
# What does this PR do? This extracts the W3C trace context headers (traceparent and tracestate) from incoming requests, stuffs them as attributes on the spans we create, and uses them within the tracing provider implementation to actually wrap our spans in the proper context. What this means in practice is that when a client (such as an OpenAI client) is instrumented to create these traces, we'll continue that distributed trace within Llama Stack as opposed to creating our own root span that breaks the distributed trace between client and server. It's slightly awkward to do this in Llama Stack because our Tracing API knows nothing about opentelemetry, W3C trace headers, etc - that's only knowledge the specific provider implementation has. So, that's why the trace headers get extracted by in the server code but not actually used until the provider implementation to form the proper context. This also centralizes how we were adding the `__root__` and `__root_span__` attributes, as those two were being added in different parts of the code instead of from a single place. Closes #2097 ## Test Plan This was tested manually using the helpful scripts from #2097. I verified that Llama Stack properly joined the client's span when the client was instrumented for distributed tracing, and that Llama Stack properly started its own root span when the incoming request was not part of an existing trace. Here's an example of the joined spans:  Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
82778ecbb0
commit
6d20b720b8
3 changed files with 33 additions and 4 deletions
|
@ -280,7 +280,18 @@ class TracingMiddleware:
|
|||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
|
|
|
@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
|
|||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
|
@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
|||
)
|
||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
|
@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
|
@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
else:
|
||||
event.attributes["__root_span__"] = "true"
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
|
|
|
@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core")
|
|||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
|
@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont
|
|||
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {})
|
||||
context.push_span(name, attributes)
|
||||
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue