diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 63fd74f53..d4f47c837 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config): assert found_tool_execution +@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, @@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "What is the boiling point of the liquid polyjuice in celsius?", }, ], session_id=session_id, @@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "What is the boiling point of the liquid polyjuice in celsius?", }, ], session_id=session_id, @@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools): def test_multi_tool_calls(llama_stack_client, agent_config): - if "gpt" not in agent_config["model"]: - pytest.xfail("Only tested on GPT models") + if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower(): + pytest.xfail("Only tested on GPT and Llama 4 models") agent_config = { **agent_config, @@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?", + "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.", }, ], session_id=session_id, stream=False, ) steps = response.steps - assert len(steps) == 7 - assert steps[0].step_type == "shield_call" - assert steps[1].step_type == "inference" - assert steps[2].step_type == "shield_call" - assert steps[3].step_type == "tool_execution" - assert steps[4].step_type == "shield_call" - assert steps[5].step_type == "inference" - assert steps[6].step_type == "shield_call" - tool_execution_step = steps[3] + has_input_shield = True if agent_config.get("input_shields", None) else False + has_output_shield = True if agent_config.get("output_shields", None) else False + assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0) + if has_input_shield: + assert steps[0].step_type == "shield_call" + steps.pop(0) + assert steps[0].step_type == "inference" + if has_output_shield: + assert steps[1].step_type == "shield_call" + steps.pop(1) + assert steps[1].step_type == "tool_execution" + tool_execution_step = steps[1] + if has_input_shield: + assert steps[2].step_type == "shield_call" + steps.pop(2) + assert steps[2].step_type == "inference" + if has_output_shield: + assert steps[3].step_type == "shield_call" + steps.pop(3) + assert len(tool_execution_step.tool_calls) == 2 assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point") diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index f452d9fd9..f9eaee7d6 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -532,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_ yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name @@ -585,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name @@ -634,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name