Update test_agents.py for Llama 4 models and remote-vllm

This updates test_agents.py a bit after testing with Llama 4 Scout and
the remote-vllm provider. The main difference here is a bit more
verbose prompting to encourage tool calls because Llama 4 Scout likes
to reply that polyjuice is fictional and has no boiling point vs
calling our custom tool unless it's prodded a bit.

Also, the remote-vllm distribution doesn't use input/output shields by
default so test_multi_tool_calls was adjusted to only expect the
shield results if shields are in use and otherwise not check for
shield usage.

Note that it requires changes to the vLLM pythonic tool parser to pass
these tests - those are listed at
https://gist.github.com/bbrowning/4734240ce96b4264340caa9584e47c9e

With this change, all of the agent tests pass with Llama 4 Scout and
remote-vllm except one of the RAG tests, that looks to be an
unrelated (and pre-existing) failure.

```
VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/agents/test_agents.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic"
```

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-05-14 10:41:30 -04:00
parent 9f2a7e6a74
commit b3493ee94f
2 changed files with 29 additions and 17 deletions

View file

@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
assert found_tool_execution
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
"content": "What is the boiling point of the liquid polyjuice in celsius?",
},
],
session_id=session_id,
@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
"content": "What is the boiling point of the liquid polyjuice in celsius?",
},
],
session_id=session_id,
@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
def test_multi_tool_calls(llama_stack_client, agent_config):
if "gpt" not in agent_config["model"]:
pytest.xfail("Only tested on GPT models")
if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower():
pytest.xfail("Only tested on GPT and Llama 4 models")
agent_config = {
**agent_config,
@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?",
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.",
},
],
session_id=session_id,
stream=False,
)
steps = response.steps
assert len(steps) == 7
assert steps[0].step_type == "shield_call"
assert steps[1].step_type == "inference"
assert steps[2].step_type == "shield_call"
assert steps[3].step_type == "tool_execution"
assert steps[4].step_type == "shield_call"
assert steps[5].step_type == "inference"
assert steps[6].step_type == "shield_call"
tool_execution_step = steps[3]
has_input_shield = True if agent_config.get("input_shields", None) else False
has_output_shield = True if agent_config.get("output_shields", None) else False
assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0)
if has_input_shield:
assert steps[0].step_type == "shield_call"
steps.pop(0)
assert steps[0].step_type == "inference"
if has_output_shield:
assert steps[1].step_type == "shield_call"
steps.pop(1)
assert steps[1].step_type == "tool_execution"
tool_execution_step = steps[1]
if has_input_shield:
assert steps[2].step_type == "shield_call"
steps.pop(2)
assert steps[2].step_type == "inference"
if has_output_shield:
assert steps[3].step_type == "shield_call"
steps.pop(3)
assert len(tool_execution_step.tool_calls) == 2
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")

View file

@ -532,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -585,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -634,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name