test: fix test_tool_choice (#1234)

Summary:

Test Plan:
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1234).
* __->__ #1234
* #1214
This commit is contained in:
ehhuang 2025-02-24 12:42:42 -08:00 committed by GitHub
parent 14c38acf97
commit e8f4efba44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -325,20 +325,16 @@ def test_custom_tool(llama_stack_client, agent_config):
def test_tool_choice(llama_stack_client, agent_config):
data = [
("required", '{"type": "function"'),
("none", None),
("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
]
client_tool = TestClientTool()
for tool_choice, expected_tool in data:
agent_config = {
def run_agent(tool_choice):
client_tool = TestClientTool()
test_agent_config = {
**agent_config,
"tool_config": {"tool_choice": tool_choice},
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -349,14 +345,19 @@ def test_tool_choice(llama_stack_client, agent_config):
},
],
session_id=session_id,
stream=False,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
if expected_tool:
assert expected_tool in logs_str
else:
assert '{"type": "function"' not in logs_str
return [step for step in response.steps if step.step_type == "tool_execution"]
tool_execution_steps = run_agent("required")
assert len(tool_execution_steps) > 0
tool_execution_steps = run_agent("none")
assert len(tool_execution_steps) == 0
tool_execution_steps = run_agent("get_boiling_point")
assert len(tool_execution_steps) == 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
# TODO: fix this flaky test