From ee3c174bb3687150af577665b49172504a8a57a8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:40:39 -0800 Subject: [PATCH] add back 2/n --- .../agents/meta_reference/agent_instance.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2ae71ded6..a6fba6f04 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, @@ -227,13 +228,23 @@ class ChatAgent(ShieldRunnerMixin): ) await self.storage.add_turn_to_session(request.session_id, turn) - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) ) ) - ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + yield chunk async def continue_turn(self, request: AgentTurnContinueRequest) -> AsyncGenerator: