diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index be7a6bf3c..34a9db630 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -540,7 +540,8 @@ class ChatAgent(ShieldRunnerMixin): output_attachments = [] - n_iter = 0 + n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 + # Build a map of custom tools to their definitions for faster lookup client_tools = {} for tool in self.agent_config.client_tools: @@ -627,6 +628,9 @@ class ChatAgent(ShieldRunnerMixin): ) span.set_attribute("output", output_attr) + n_iter += 1 + await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter) + stop_reason = stop_reason or StopReason.out_of_tokens # If tool calls are parsed successfully, @@ -662,6 +666,9 @@ class ChatAgent(ShieldRunnerMixin): if n_iter >= self.agent_config.max_infer_iters: log.info("Done with MAX iterations, exiting.") + # NOTE: mark end_of_turn to indicate to client that we are done with the turn + # Do not continue the tool call loop after this point + message.stop_reason = StopReason.end_of_turn yield message break @@ -711,6 +718,9 @@ class ChatAgent(ShieldRunnerMixin): # If tool is a client tool, yield CompletionMessage and return if tool_call.tool_name in client_tools: + # NOTE: mark end_of_message to indicate to client that it may + # call the tool and continue the conversation with the tool's response. + message.stop_reason = StopReason.end_of_message await self.storage.set_in_progress_tool_call_step( session_id, turn_id, @@ -796,8 +806,6 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message, result_message] - n_iter += 1 - async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None ) -> Tuple[List[ToolDefinition], Dict[str, str]]: diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 9a2663567..2c04305fd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -105,3 +105,15 @@ class AgentPersistence: key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", ) return ToolExecutionStep(**json.loads(value)) if value else None + + async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): + await self.kvstore.set( + key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", + value=str(num_infer_iters), + ) + + async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: + value = await self.kvstore.get( + key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", + ) + return int(value) if value else None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 201dcc1d4..4d808da23 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -278,6 +278,33 @@ def test_custom_tool(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str +def test_custom_tool_infinite_loop(llama_stack_client, agent_config): + client_tool = get_boiling_point + agent_config = { + **agent_config, + "instructions": "You are a helpful assistant Always respond with tool calls no matter what. ", + "client_tools": [client_tool.get_tool_definition()], + "max_infer_iters": 5, + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Get the boiling point of polyjuice with a tool call.", + }, + ], + session_id=session_id, + stream=False, + ) + + num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps]) + assert num_tool_calls <= 5 + + def test_tool_choice(llama_stack_client, agent_config): def run_agent(tool_choice): client_tool = get_boiling_point