From b3493ee94f2940172c75b0a45041c63bcd7dc3bb Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 14 May 2025 10:41:30 -0400 Subject: [PATCH] Update test_agents.py for Llama 4 models and remote-vllm This updates test_agents.py a bit after testing with Llama 4 Scout and the remote-vllm provider. The main difference here is a bit more verbose prompting to encourage tool calls because Llama 4 Scout likes to reply that polyjuice is fictional and has no boiling point vs calling our custom tool unless it's prodded a bit. Also, the remote-vllm distribution doesn't use input/output shields by default so test_multi_tool_calls was adjusted to only expect the shield results if shields are in use and otherwise not check for shield usage. Note that it requires changes to the vLLM pythonic tool parser to pass these tests - those are listed at https://gist.github.com/bbrowning/4734240ce96b4264340caa9584e47c9e With this change, all of the agent tests pass with Llama 4 Scout and remote-vllm except one of the RAG tests, that looks to be an unrelated (and pre-existing) failure. ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/agents/test_agents.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" ``` Signed-off-by: Ben Browning --- tests/integration/agents/test_agents.py | 40 ++++++++++++------- .../providers/inference/test_remote_vllm.py | 6 +-- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 63fd74f53..d4f47c837 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config): assert found_tool_execution +@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, @@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "What is the boiling point of the liquid polyjuice in celsius?", }, ], session_id=session_id, @@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "What is the boiling point of the liquid polyjuice in celsius?", }, ], session_id=session_id, @@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools): def test_multi_tool_calls(llama_stack_client, agent_config): - if "gpt" not in agent_config["model"]: - pytest.xfail("Only tested on GPT models") + if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower(): + pytest.xfail("Only tested on GPT and Llama 4 models") agent_config = { **agent_config, @@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?", + "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.", }, ], session_id=session_id, stream=False, ) steps = response.steps - assert len(steps) == 7 - assert steps[0].step_type == "shield_call" - assert steps[1].step_type == "inference" - assert steps[2].step_type == "shield_call" - assert steps[3].step_type == "tool_execution" - assert steps[4].step_type == "shield_call" - assert steps[5].step_type == "inference" - assert steps[6].step_type == "shield_call" - tool_execution_step = steps[3] + has_input_shield = True if agent_config.get("input_shields", None) else False + has_output_shield = True if agent_config.get("output_shields", None) else False + assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0) + if has_input_shield: + assert steps[0].step_type == "shield_call" + steps.pop(0) + assert steps[0].step_type == "inference" + if has_output_shield: + assert steps[1].step_type == "shield_call" + steps.pop(1) + assert steps[1].step_type == "tool_execution" + tool_execution_step = steps[1] + if has_input_shield: + assert steps[2].step_type == "shield_call" + steps.pop(2) + assert steps[2].step_type == "inference" + if has_output_shield: + assert steps[3].step_type == "shield_call" + steps.pop(3) + assert len(tool_execution_step.tool_calls) == 2 assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point") diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index f452d9fd9..f9eaee7d6 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -532,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_ yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name @@ -585,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name @@ -634,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name