fix:Tracing fixes for correct trace context propogation

This commit is contained in:
Dinesh Yeduguru 2025-03-10 11:26:58 -07:00
parent 0db3a2f511
commit 38805e0ca1
3 changed files with 54 additions and 35 deletions

View file

@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
async with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
@ -390,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -508,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference") as span:
async with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -685,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
async with tracing.span(
"tool_execution",
{
"tool_name": tool_name,

View file

@ -10,6 +10,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__)
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(
shield_id=identifier,
messages=messages,
)
for identifier in identifiers
]
)
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
for identifier, response in zip(identifiers, responses, strict=False):
if not response.violation:
continue

View file

@ -6,6 +6,7 @@
import asyncio
import base64
import contextvars
import logging
import queue
import threading
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
log = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8):
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
BACKGROUND_LOGGER = None
@ -51,7 +53,7 @@ class BackgroundLogger:
try:
self.log_queue.put_nowait(event)
except queue.Full:
log.error("Log queue is full, dropping event")
logger.error("Log queue is full, dropping event")
def _process_logs(self):
while True:
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
log.info("No Telemetry implementation set. Skipping trace initialization...")
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
CURRENT_TRACE_CONTEXT.set(context)
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
logger.debug("No trace context to end")
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT = None
CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity:
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
@ -218,16 +221,22 @@ class SpanContextManager:
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any):
if self.span:
@ -237,16 +246,22 @@ class SpanContextManager:
async def __aenter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable):
@wraps(func)
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context:
return context.get_current_span()
return None