diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index dd430dbcd..39de1e4df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -237,9 +237,17 @@ class TracingMiddleware: # Use the matched template or original path trace_path = route_template or path - await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + + async def send_with_trace_id(message): + if message["type"] == "http.response.start": + headers = message.get("headers", []) + headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) + message["headers"] = headers + await send(message) + try: - return await self.app(scope, receive, send) + return await self.app(scope, receive, send_with_trace_id) finally: await end_trace() diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 5ed586fce..e9a003db6 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -12,6 +12,7 @@ from datetime import datetime, timezone from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.trace import Span +from opentelemetry.trace.span import format_span_id, format_trace_id class SQLiteSpanProcessor(SpanProcessor): @@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor): conn = self._get_connection() cursor = conn.cursor() - trace_id = format(span.get_span_context().trace_id, "032x") - span_id = format(span.get_span_context().span_id, "016x") + trace_id = format_trace_id(span.get_span_context().trace_id) + span_id = format_span_id(span.get_span_context().span_id) service_name = span.resource.attributes.get("service.name", "unknown") parent_span_id = None parent_context = span.parent if parent_context: - parent_span_id = format(parent_context.span_id, "016x") + parent_span_id = format_span_id(parent_context.span_id) # Insert into traces cursor.execute( @@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor): ( trace_id, service_name, - (span_id if not parent_span_id else None), + (span_id if span.attributes.get("__root_span__") == "true" else None), datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(), datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(), ), diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 46a88a7b8..181bfda9b 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -54,16 +54,6 @@ _global_lock = threading.Lock() _TRACER_PROVIDER = None -def string_to_trace_id(s: str) -> int: - # Convert the string to bytes and then to an integer - return int.from_bytes(s.encode(), byteorder="big", signed=False) - - -def string_to_span_id(s: str) -> int: - # Use only the first 8 bytes (64 bits) for span ID - return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) - - def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: return span.is_recording() @@ -137,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: # Use global storage instead of instance storage - span_id = string_to_span_id(event.span_id) + span_id = event.span_id span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: @@ -197,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: - span_id = string_to_span_id(event.span_id) - trace_id = string_to_trace_id(event.trace_id) + span_id = int(event.span_id, 16) tracer = trace.get_tracer(__name__) if event.attributes is None: event.attributes = {} @@ -209,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if span_id in _GLOBAL_STORAGE["active_spans"]: return - parent_span = None + context = None if event.payload.parent_span_id: - parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - - context = trace.Context(trace_id=trace_id) - if parent_span: - context = trace.set_span_in_context(parent_span, context) + context = trace.set_span_in_context(parent_span) + else: + context = trace.set_span_in_context( + trace.NonRecordingSpan( + trace.SpanContext( + trace_id=int(event.trace_id, 16), + span_id=span_id, + is_remote=False, + trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), + ) + ) + ) + event.attributes["__root_span__"] = "true" span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 607d1a918..3d5c717d6 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -5,12 +5,11 @@ # the root directory of this source tree. import asyncio -import base64 import contextvars import logging import queue +import random import threading -import uuid from datetime import datetime, timezone from functools import wraps from typing import Any, Callable, Dict, List, Optional @@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value logger = get_logger(__name__, category="core") -def generate_short_uuid(len: int = 8): - full_uuid = uuid.uuid4() - uuid_bytes = full_uuid.bytes - encoded = base64.urlsafe_b64encode(uuid_bytes) - return encoded.rstrip(b"=").decode("ascii")[:len] +INVALID_SPAN_ID = 0x0000000000000000 +INVALID_TRACE_ID = 0x00000000000000000000000000000000 + + +def trace_id_to_str(trace_id: int) -> str: + """Convenience trace ID formatting method + Args: + trace_id: Trace ID int + + Returns: + The trace ID as 32-byte hexadecimal string + """ + return format(trace_id, "032x") + + +def span_id_to_str(span_id: int) -> str: + """Convenience span ID formatting method + Args: + span_id: Span ID int + + Returns: + The span ID as 16-byte hexadecimal string + """ + return format(span_id, "016x") + + +def generate_span_id() -> str: + span_id = random.getrandbits(64) + while span_id == INVALID_SPAN_ID: + span_id = random.getrandbits(64) + return span_id_to_str(span_id) + + +def generate_trace_id() -> str: + trace_id = random.getrandbits(128) + while trace_id == INVALID_TRACE_ID: + trace_id = random.getrandbits(128) + return trace_id_to_str(trace_id) CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) @@ -83,7 +115,7 @@ class TraceContext: def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( - span_id=generate_short_uuid(), + span_id=generate_span_id(), trace_id=self.trace_id, name=name, start_time=datetime.now(timezone.utc), @@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont logger.debug("No Telemetry implementation set. Skipping trace initialization...") return - trace_id = generate_short_uuid(16) + trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) context.push_span(name, {"__root__": True, **(attributes or {})})