fix rag tests (#918)

make more deterministic
This commit is contained in:
Hardik Shah 2025-01-31 15:29:29 -08:00 committed by GitHub
parent 216cde5ee8
commit 589a6911ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -276,7 +276,7 @@ def test_rag_agent(llama_stack_client, agent_config):
) )
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
vector_db_id = "test-vector-db" vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register( llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
@ -286,7 +286,7 @@ def test_rag_agent(llama_stack_client, agent_config):
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
# small chunks help to get specific info out of the docs # small chunks help to get specific info out of the docs
chunk_size_in_tokens=128, chunk_size_in_tokens=256,
) )
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -304,9 +304,12 @@ def test_rag_agent(llama_stack_client, agent_config):
user_prompts = [ user_prompts = [
( (
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?", "Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped-query", "grouped",
),
(
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
"download",
), ),
("What command to use to get access to Llama3-8B-Instruct ?", "tune download"),
] ]
for prompt, expected_kw in user_prompts: for prompt, expected_kw in user_prompts:
print(f"User> {prompt}") print(f"User> {prompt}")