mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Update test_agents.py for Llama 4 models and remote-vllm
This updates test_agents.py a bit after testing with Llama 4 Scout and the remote-vllm provider. The main difference here is a bit more verbose prompting to encourage tool calls because Llama 4 Scout likes to reply that polyjuice is fictional and has no boiling point vs calling our custom tool unless it's prodded a bit. Also, the remote-vllm distribution doesn't use input/output shields by default so test_multi_tool_calls was adjusted to only expect the shield results if shields are in use and otherwise not check for shield usage. Note that it requires changes to the vLLM pythonic tool parser to pass these tests - those are listed at https://gist.github.com/bbrowning/4734240ce96b4264340caa9584e47c9e With this change, all of the agent tests pass with Llama 4 Scout and remote-vllm except one of the RAG tests, that looks to be an unrelated (and pre-existing) failure. ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/agents/test_agents.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" ``` Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
9f2a7e6a74
commit
b3493ee94f
2 changed files with 29 additions and 17 deletions
|
@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
|||
assert found_tool_execution
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
"content": "What is the boiling point of the liquid polyjuice in celsius?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
"content": "What is the boiling point of the liquid polyjuice in celsius?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
|
|||
|
||||
|
||||
def test_multi_tool_calls(llama_stack_client, agent_config):
|
||||
if "gpt" not in agent_config["model"]:
|
||||
pytest.xfail("Only tested on GPT models")
|
||||
if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower():
|
||||
pytest.xfail("Only tested on GPT and Llama 4 models")
|
||||
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?",
|
||||
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
steps = response.steps
|
||||
assert len(steps) == 7
|
||||
assert steps[0].step_type == "shield_call"
|
||||
assert steps[1].step_type == "inference"
|
||||
assert steps[2].step_type == "shield_call"
|
||||
assert steps[3].step_type == "tool_execution"
|
||||
assert steps[4].step_type == "shield_call"
|
||||
assert steps[5].step_type == "inference"
|
||||
assert steps[6].step_type == "shield_call"
|
||||
|
||||
tool_execution_step = steps[3]
|
||||
has_input_shield = True if agent_config.get("input_shields", None) else False
|
||||
has_output_shield = True if agent_config.get("output_shields", None) else False
|
||||
assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0)
|
||||
if has_input_shield:
|
||||
assert steps[0].step_type == "shield_call"
|
||||
steps.pop(0)
|
||||
assert steps[0].step_type == "inference"
|
||||
if has_output_shield:
|
||||
assert steps[1].step_type == "shield_call"
|
||||
steps.pop(1)
|
||||
assert steps[1].step_type == "tool_execution"
|
||||
tool_execution_step = steps[1]
|
||||
if has_input_shield:
|
||||
assert steps[2].step_type == "shield_call"
|
||||
steps.pop(2)
|
||||
assert steps[2].step_type == "inference"
|
||||
if has_output_shield:
|
||||
assert steps[3].step_type == "shield_call"
|
||||
steps.pop(3)
|
||||
|
||||
assert len(tool_execution_step.tool_calls) == 2
|
||||
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")
|
||||
|
|
|
@ -532,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
|
@ -585,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
|
@ -634,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
|||
yield chunk
|
||||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks) == 3
|
||||
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue