forked from phoenix-oss/llama-stack-mirror
test: fix test_tool_choice (#1234)
Summary: Test Plan: --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1234). * __->__ #1234 * #1214
This commit is contained in:
parent
14c38acf97
commit
e8f4efba44
1 changed files with 16 additions and 15 deletions
|
@ -325,20 +325,16 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
|
||||
|
||||
def test_tool_choice(llama_stack_client, agent_config):
|
||||
data = [
|
||||
("required", '{"type": "function"'),
|
||||
("none", None),
|
||||
("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'),
|
||||
]
|
||||
client_tool = TestClientTool()
|
||||
for tool_choice, expected_tool in data:
|
||||
agent_config = {
|
||||
def run_agent(tool_choice):
|
||||
client_tool = TestClientTool()
|
||||
|
||||
test_agent_config = {
|
||||
**agent_config,
|
||||
"tool_config": {"tool_choice": tool_choice},
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -349,14 +345,19 @@ def test_tool_choice(llama_stack_client, agent_config):
|
|||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
if expected_tool:
|
||||
assert expected_tool in logs_str
|
||||
else:
|
||||
assert '{"type": "function"' not in logs_str
|
||||
return [step for step in response.steps if step.step_type == "tool_execution"]
|
||||
|
||||
tool_execution_steps = run_agent("required")
|
||||
assert len(tool_execution_steps) > 0
|
||||
|
||||
tool_execution_steps = run_agent("none")
|
||||
assert len(tool_execution_steps) == 0
|
||||
|
||||
tool_execution_steps = run_agent("get_boiling_point")
|
||||
assert len(tool_execution_steps) == 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
||||
|
||||
|
||||
# TODO: fix this flaky test
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue