diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5606b50b..27a69c90a 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -325,20 +325,16 @@ def test_custom_tool(llama_stack_client, agent_config): def test_tool_choice(llama_stack_client, agent_config): - data = [ - ("required", '{"type": "function"'), - ("none", None), - ("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'), - ] - client_tool = TestClientTool() - for tool_choice, expected_tool in data: - agent_config = { + def run_agent(tool_choice): + client_tool = TestClientTool() + + test_agent_config = { **agent_config, "tool_config": {"tool_choice": tool_choice}, "client_tools": [client_tool.get_tool_definition()], } - agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( @@ -349,14 +345,19 @@ def test_tool_choice(llama_stack_client, agent_config): }, ], session_id=session_id, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - if expected_tool: - assert expected_tool in logs_str - else: - assert '{"type": "function"' not in logs_str + return [step for step in response.steps if step.step_type == "tool_execution"] + + tool_execution_steps = run_agent("required") + assert len(tool_execution_steps) > 0 + + tool_execution_steps = run_agent("none") + assert len(tool_execution_steps) == 0 + + tool_execution_steps = run_agent("get_boiling_point") + assert len(tool_execution_steps) == 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" # TODO: fix this flaky test