From 07c9222b6f77723405c7439a6850058e4bb5cc0d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 14:54:45 -0800 Subject: [PATCH] debug --- .../agents/meta_reference/agent_instance.py | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 779dcf74d..743b77adf 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -30,8 +30,8 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, - AgentTurnResponseTurnPendingPayload, AgentTurnResponseTurnStartPayload, Attachment, Document, @@ -225,26 +225,27 @@ class ChatAgent(ShieldRunnerMixin): completed_at=datetime.now(), steps=steps, ) - await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnPendingPayload( - turn=turn, + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) ) ) - ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) ) ) - ) + # only add to storage if turn is complete + await self.storage.add_turn_to_session(request.session_id, turn) - yield chunk + yield chunk async def run( self, @@ -626,11 +627,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message] else: log.info(f"{str(message)}") - tool_call = message.tool_calls[0] - if tool_call.tool_name in client_tools: - yield message - return - + # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -640,6 +637,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + + tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -654,6 +653,12 @@ class ChatAgent(ShieldRunnerMixin): ) ) + # If tool is a client tool, yield CompletionMessage and return + if tool_call.tool_name in client_tools: + yield message + return + + # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value