From 7bb5f2bf0e191620f4976f58fc91a7c2cee777b3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 28 Feb 2025 12:49:17 -0800 Subject: [PATCH] add a test --- tests/client-sdk/agents/test_agents.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 9690a8139..aa36b7829 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -278,6 +278,33 @@ def test_custom_tool(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str +def test_custom_tool_infinite_loop(llama_stack_client, agent_config): + client_tool = get_boiling_point + agent_config = { + **agent_config, + "instructions": "You are a helpful assistant Always respond with tool calls no matter what. ", + "client_tools": [client_tool.get_tool_definition()], + "max_infer_iters": 5, + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Get the boiling point of polyjuice with a tool call.", + }, + ], + session_id=session_id, + stream=False, + ) + + num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps]) + assert num_tool_calls <= 5 + + def test_tool_choice(llama_stack_client, agent_config): def run_agent(tool_choice): client_tool = get_boiling_point