diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 7d56098c8..b403b9203 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin): async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - with tracing.SpanContextManager("create_and_execute_turn") as span: + with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) span.set_attribute("request", request.model_dump_json()) @@ -279,7 +279,7 @@ class ChatAgent(ShieldRunnerMixin): shields: List[str], touchpoint: str, ) -> AsyncGenerator: - with tracing.SpanContextManager("run_shields") as span: + with tracing.span("run_shields") as span: span.set_attribute("input", [m.model_dump_json() for m in messages]) if len(shields) == 0: span.set_attribute("output", "no shields") @@ -359,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - with tracing.SpanContextManager("retrieve_rag_context") as span: + with tracing.span("retrieve_rag_context") as span: rag_context, bank_ids = await self._retrieve_context( session_id, input_messages, attachments ) @@ -419,7 +419,7 @@ class ChatAgent(ShieldRunnerMixin): content = "" stop_reason = None - with tracing.SpanContextManager("inference") as span: + with tracing.span("inference") as span: async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -558,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) - with tracing.SpanContextManager( + with tracing.span( "tool_execution", { "tool_name": tool_call.tool_name, @@ -707,7 +707,7 @@ class ChatAgent(ShieldRunnerMixin): ) for a in attachments ] - with tracing.SpanContextManager("insert_documents"): + with tracing.span("insert_documents"): await self.memory_api.insert_documents(bank_id, documents) else: session_info = await self.storage.get_session_info(session_id) diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 57f58d50f..938d333fa 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -81,9 +81,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]: self, *args, **kwargs ) - with tracing.SpanContextManager( - f"{class_name}.{method_name}", span_attributes - ) as span: + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: try: count = 0 async for item in method(self, *args, **kwargs): @@ -98,9 +96,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]: self, *args, **kwargs ) - with tracing.SpanContextManager( - f"{class_name}.{method_name}", span_attributes - ) as span: + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: try: result = await method(self, *args, **kwargs) span.set_attribute("output", serialize_value(result)) @@ -115,9 +111,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]: self, *args, **kwargs ) - with tracing.SpanContextManager( - f"{class_name}.{method_name}", span_attributes - ) as span: + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: try: result = method(self, *args, **kwargs) span.set_attribute("output", serialize_value(result)) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 326a7c023..54558afdc 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -259,6 +259,10 @@ class SpanContextManager: return wrapper +def span(name: str, attributes: Dict[str, Any] = None): + return SpanContextManager(name, attributes) + + def get_current_span() -> Optional[Span]: global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT