feat: use same trace ids in stack and otel

This commit is contained in:
Dinesh Yeduguru 2025-03-21 14:45:10 -07:00
parent baf68c665c
commit a7de2f3ce4
4 changed files with 73 additions and 34 deletions

View file

@ -5,12 +5,11 @@
# the root directory of this source tree.
import asyncio
import base64
import contextvars
import logging
import queue
import random
import threading
import uuid
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
return encoded.rstrip(b"=").decode("ascii")[:len]
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
def format_trace_id(trace_id: int) -> str:
"""Convenience trace ID formatting method
Args:
trace_id: Trace ID int
Returns:
The trace ID as 32-byte hexadecimal string
"""
return format(trace_id, "032x")
def format_span_id(span_id: int) -> str:
"""Convenience span ID formatting method
Args:
span_id: Span ID int
Returns:
The span ID as 16-byte hexadecimal string
"""
return format(span_id, "016x")
def generate_span_id() -> str:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return format_span_id(span_id)
def generate_trace_id() -> str:
trace_id = random.getrandbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = random.getrandbits(128)
return format_trace_id(trace_id)
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
@ -83,7 +115,7 @@ class TraceContext:
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(timezone.utc),
@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid(16)
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})