check for exact response in rag

This commit is contained in:
Hardik Shah 2025-01-30 15:07:02 -08:00
parent e2bfaf9cd3
commit fa528138f1

View file

@ -285,7 +285,8 @@ def test_rag_agent(llama_stack_client, agent_config):
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
# small chunks help to get specific info out of the docs
chunk_size_in_tokens=128,
)
agent_config = {
**agent_config,
@ -301,9 +302,13 @@ def test_rag_agent(llama_stack_client, agent_config):
rag_agent = Agent(llama_stack_client, agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
"What are the top 5 topics that were explained? Only list succinct bullet points.",
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped-query",
),
("What command to use to get access to Llama3-8B-Instruct ?", "tune download"),
]
for prompt in user_prompts:
for prompt, expected_kw in user_prompts:
print(f"User> {prompt}")
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
@ -312,6 +317,7 @@ def test_rag_agent(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_from_memory" in logs_str
assert expected_kw in logs_str.lower()
def test_rag_and_code_agent(llama_stack_client, agent_config):