From e8f4efba44526b6bbe94eaaac3f5ab03b5684f16 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 24 Feb 2025 12:42:42 -0800 Subject: [PATCH] test: fix test_tool_choice (#1234) Summary: Test Plan: --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1234). * __->__ #1234 * #1214 --- tests/client-sdk/agents/test_agents.py | 31 +++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5606b50b..27a69c90a 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -325,20 +325,16 @@ def test_custom_tool(llama_stack_client, agent_config): def test_tool_choice(llama_stack_client, agent_config): - data = [ - ("required", '{"type": "function"'), - ("none", None), - ("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'), - ] - client_tool = TestClientTool() - for tool_choice, expected_tool in data: - agent_config = { + def run_agent(tool_choice): + client_tool = TestClientTool() + + test_agent_config = { **agent_config, "tool_config": {"tool_choice": tool_choice}, "client_tools": [client_tool.get_tool_definition()], } - agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( @@ -349,14 +345,19 @@ def test_tool_choice(llama_stack_client, agent_config): }, ], session_id=session_id, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - if expected_tool: - assert expected_tool in logs_str - else: - assert '{"type": "function"' not in logs_str + return [step for step in response.steps if step.step_type == "tool_execution"] + + tool_execution_steps = run_agent("required") + assert len(tool_execution_steps) > 0 + + tool_execution_steps = run_agent("none") + assert len(tool_execution_steps) == 0 + + tool_execution_steps = run_agent("get_boiling_point") + assert len(tool_execution_steps) == 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" # TODO: fix this flaky test