diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 3619b3f67..fedd695c1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin): return messages async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: - with tracing.span("create_and_execute_turn") as span: + async with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) span.set_attribute("request", request.model_dump_json()) @@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin): yield chunk async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: - with tracing.span("resume_turn") as span: + async with tracing.span("resume_turn") as span: span.set_attribute("agent_id", self.agent_id) span.set_attribute("session_id", request.session_id) span.set_attribute("turn_id", request.turn_id) @@ -390,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin): shields: List[str], touchpoint: str, ) -> AsyncGenerator: - with tracing.span("run_shields") as span: + async with tracing.span("run_shields") as span: span.set_attribute("input", [m.model_dump_json() for m in messages]) if len(shields) == 0: span.set_attribute("output", "no shields") @@ -508,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin): content = "" stop_reason = None - with tracing.span("inference") as span: + async with tracing.span("inference") as span: async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -685,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin): tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value - with tracing.span( + async with tracing.span( "tool_execution", { "tool_name": tool_name, diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 2497be070..bef16eaba 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -10,6 +10,7 @@ from typing import List from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel +from llama_stack.providers.utils.telemetry import tracing log = logging.getLogger(__name__) @@ -32,15 +33,14 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None: - responses = await asyncio.gather( - *[ - self.safety_api.run_shield( + async def run_shield_with_span(identifier: str): + async with tracing.span(f"run_shield_{identifier}"): + return await self.safety_api.run_shield( shield_id=identifier, messages=messages, ) - for identifier in identifiers - ] - ) + + responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers]) for identifier, response in zip(identifiers, responses, strict=False): if not response.violation: continue diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index d84024941..bef229080 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -6,6 +6,7 @@ import asyncio import base64 +import contextvars import logging import queue import threading @@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import ( Telemetry, UnstructuredLogEvent, ) +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value -log = logging.getLogger(__name__) +logger = get_logger(__name__, category="core") def generate_short_uuid(len: int = 8): @@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8): return encoded.rstrip(b"=").decode("ascii")[:len] -CURRENT_TRACE_CONTEXT = None +CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) BACKGROUND_LOGGER = None @@ -51,7 +53,7 @@ class BackgroundLogger: try: self.log_queue.put_nowait(event) except queue.Full: - log.error("Log queue is full, dropping event") + logger.error("Log queue is full, dropping event") def _process_logs(self): while True: @@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO): if BACKGROUND_LOGGER is None: BACKGROUND_LOGGER = BackgroundLogger(api) - logger = logging.getLogger() - logger.setLevel(level) - logger.addHandler(TelemetryHandler()) + root_logger = logging.getLogger() + root_logger.setLevel(level) + root_logger.addHandler(TelemetryHandler()) async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: - log.info("No Telemetry implementation set. Skipping trace initialization...") + logger.debug("No Telemetry implementation set. Skipping trace initialization...") return trace_id = generate_short_uuid(16) context = TraceContext(BACKGROUND_LOGGER, trace_id) context.push_span(name, {"__root__": True, **(attributes or {})}) - CURRENT_TRACE_CONTEXT = context + CURRENT_TRACE_CONTEXT.set(context) return context async def end_trace(status: SpanStatus = SpanStatus.OK): global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT.get() if context is None: + logger.debug("No trace context to end") return context.pop_span(status) - CURRENT_TRACE_CONTEXT = None + CURRENT_TRACE_CONTEXT.set(None) def severity(levelname: str) -> LogSeverity: @@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler): if BACKGROUND_LOGGER is None: raise RuntimeError("Telemetry API not initialized") - context = CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT.get() if context is None: return @@ -218,16 +221,22 @@ class SpanContextManager: def __enter__(self): global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT - if context: - self.span = context.push_span(self.name, self.attributes) + context = CURRENT_TRACE_CONTEXT.get() + if not context: + logger.debug("No trace context to push span") + return self + + self.span = context.push_span(self.name, self.attributes) return self def __exit__(self, exc_type, exc_value, traceback): global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT - if context: - context.pop_span() + context = CURRENT_TRACE_CONTEXT.get() + if not context: + logger.debug("No trace context to pop span") + return + + context.pop_span() def set_attribute(self, key: str, value: Any): if self.span: @@ -237,16 +246,22 @@ class SpanContextManager: async def __aenter__(self): global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT - if context: - self.span = context.push_span(self.name, self.attributes) + context = CURRENT_TRACE_CONTEXT.get() + if not context: + logger.debug("No trace context to push span") + return self + + self.span = context.push_span(self.name, self.attributes) return self async def __aexit__(self, exc_type, exc_value, traceback): global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT - if context: - context.pop_span() + context = CURRENT_TRACE_CONTEXT.get() + if not context: + logger.debug("No trace context to pop span") + return + + context.pop_span() def __call__(self, func: Callable): @wraps(func) @@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None): def get_current_span() -> Optional[Span]: global CURRENT_TRACE_CONTEXT - context = CURRENT_TRACE_CONTEXT + if CURRENT_TRACE_CONTEXT is None: + logger.debug("No trace context to get current span") + return None + + context = CURRENT_TRACE_CONTEXT.get() if context: return context.get_current_span() return None