From bf38d0aba0e2a526c91591268bc2ed4d4b3f90b3 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 21 Feb 2025 15:24:28 -0800 Subject: [PATCH] test: fix test_rag_agent test (#1215) Summary: Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py::test_rag_agent --safety-shield meta-llama/Llama-Guard-3-8B --- tests/client-sdk/agents/test_agents.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 23ae601e4..7ede5e517 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -496,10 +496,11 @@ def test_rag_agent(llama_stack_client, agent_config): stream=False, ) # rag is called - assert response.steps[0].tool_calls[0].tool_name == "query_from_memory" + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory" # document ids are present in metadata - assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"] - assert expected_kw in response.output_message.content + assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"] + assert expected_kw in response.output_message.content.lower() def test_rag_and_code_agent(llama_stack_client, agent_config):