From fa528138f1d34cdd91efe9efdc497c1878994e0c Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 30 Jan 2025 15:07:02 -0800 Subject: [PATCH] check for exact response in rag --- tests/client-sdk/agents/test_agents.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 969f926ff..e0f86e3d7 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -285,7 +285,8 @@ def test_rag_agent(llama_stack_client, agent_config): llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, - chunk_size_in_tokens=512, + # small chunks help to get specific info out of the docs + chunk_size_in_tokens=128, ) agent_config = { **agent_config, @@ -301,9 +302,13 @@ def test_rag_agent(llama_stack_client, agent_config): rag_agent = Agent(llama_stack_client, agent_config) session_id = rag_agent.create_session(f"test-session-{uuid4()}") user_prompts = [ - "What are the top 5 topics that were explained? Only list succinct bullet points.", + ( + "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", + "grouped-query", + ), + ("What command to use to get access to Llama3-8B-Instruct ?", "tune download"), ] - for prompt in user_prompts: + for prompt, expected_kw in user_prompts: print(f"User> {prompt}") response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], @@ -312,6 +317,7 @@ def test_rag_agent(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_from_memory" in logs_str + assert expected_kw in logs_str.lower() def test_rag_and_code_agent(llama_stack_client, agent_config):