diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 6e9fc7c63..7d949603e 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence from .rag.context_retriever import generate_rag_query @@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + @tracing.span("create_and_execute_turn") async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: @@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin): yield final_response + @tracing.span("run_shields") async def run_multiple_shields_wrapper( self, turn_id: str, @@ -348,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - rag_context, bank_ids = await self._retrieve_context( - session_id, input_messages, attachments - ) + with tracing.span("retrieve_rag_context"): + rag_context, bank_ids = await self._retrieve_context( + session_id, input_messages, attachments + ) step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -403,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin): tool_calls = [] content = "" stop_reason = None - async for chunk in self.inference_api.chat_completion( - self.agent_config.model, - input_messages, - tools=self._get_tools(), - tool_prompt_format=self.agent_config.tool_prompt_format, - stream=True, - sampling_params=sampling_params, - ): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: - continue - elif event.event_type == ChatCompletionResponseEventType.complete: - stop_reason = StopReason.end_of_turn - continue - delta = event.delta - if isinstance(delta, ToolCallDelta): - if delta.parse_status == ToolCallParseStatus.success: - tool_calls.append(delta.content) + with tracing.span("inference"): + async for chunk in self.inference_api.chat_completion( + self.agent_config.model, + input_messages, + tools=self._get_tools(), + tool_prompt_format=self.agent_config.tool_prompt_format, + stream=True, + sampling_params=sampling_params, + ): + event = chunk.event + if event.event_type == ChatCompletionResponseEventType.start: + continue + elif event.event_type == ChatCompletionResponseEventType.complete: + stop_reason = StopReason.end_of_turn + continue - if stream: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, - step_id=step_id, - model_response_text_delta="", - tool_call_delta=delta, + delta = event.delta + if isinstance(delta, ToolCallDelta): + if delta.parse_status == ToolCallParseStatus.success: + tool_calls.append(delta.content) + + if stream: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.inference.value, + step_id=step_id, + model_response_text_delta="", + tool_call_delta=delta, + ) ) ) - ) - elif isinstance(delta, str): - content += delta - if stream and event.stop_reason is None: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, - step_id=step_id, - model_response_text_delta=event.delta, + elif isinstance(delta, str): + content += delta + if stream and event.stop_reason is None: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.inference.value, + step_id=step_id, + model_response_text_delta=event.delta, + ) ) ) - ) - else: - raise ValueError(f"Unexpected delta type {type(delta)}") + else: + raise ValueError(f"Unexpected delta type {type(delta)}") - if event.stop_reason is not None: - stop_reason = event.stop_reason + if event.stop_reason is not None: + stop_reason = event.stop_reason stop_reason = stop_reason or StopReason.out_of_tokens message = CompletionMessage( @@ -528,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin): ) ) - result_messages = await execute_tool_call_maybe( - self.tools_dict, - [message], - ) - assert ( - len(result_messages) == 1 - ), "Currently not supporting multiple messages" - result_message = result_messages[0] + with tracing.span("tool_execution"): + result_messages = await execute_tool_call_maybe( + self.tools_dict, + [message], + ) + assert ( + len(result_messages) == 1 + ), "Currently not supporting multiple messages" + result_message = result_messages[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -669,7 +676,8 @@ class ChatAgent(ShieldRunnerMixin): ) for a in attachments ] - await self.memory_api.insert_documents(bank_id, documents) + with tracing.span("insert_documents"): + await self.memory_api.insert_documents(bank_id, documents) else: session_info = await self.storage.get_session_info(session_id) if session_info.memory_bank_id: diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 5284dfac0..45868b408 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -12,7 +12,7 @@ import threading import uuid from datetime import datetime from functools import wraps -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from llama_stack.apis.telemetry import * # noqa: F403 @@ -196,33 +196,42 @@ class TelemetryHandler(logging.Handler): pass -def span(name: str, attributes: Dict[str, Any] = None): - def decorator(func): +class SpanContextManager: + def __init__(self, name: str, attributes: Dict[str, Any] = None): + self.name = name + self.attributes = attributes + + def __enter__(self): + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.push_span(self.name, self.attributes) + return self + + def __exit__(self, exc_type, exc_value, traceback): + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.pop_span() + + async def __aenter__(self): + return self.__enter__() + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + def __call__(self, func: Callable): @wraps(func) def sync_wrapper(*args, **kwargs): - try: - global CURRENT_TRACE_CONTEXT - - context = CURRENT_TRACE_CONTEXT - if context: - context.push_span(name, attributes) - result = func(*args, **kwargs) - finally: - context.pop_span() - return result + print("sync wrapper") + with self: + return func(*args, **kwargs) @wraps(func) async def async_wrapper(*args, **kwargs): - try: - global CURRENT_TRACE_CONTEXT - - context = CURRENT_TRACE_CONTEXT - if context: - context.push_span(name, attributes) - result = await func(*args, **kwargs) - finally: - context.pop_span() - return result + print("async wrapper") + async with self: + return await func(*args, **kwargs) @wraps(func) def wrapper(*args, **kwargs): @@ -233,4 +242,6 @@ def span(name: str, attributes: Dict[str, Any] = None): return wrapper - return decorator + +def span(name: str, attributes: Dict[str, Any] = None): + return SpanContextManager(name, attributes)