[memory refactor][4/n] Update the client-sdk test for RAG (#834)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Update client-sdk tests
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:15:19 -08:00 committed by GitHub
parent 1a7490470a
commit 63f37f9b7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 236 additions and 228 deletions

View file

@ -286,19 +286,16 @@ def test_rag_agent(llama_stack_client, agent_config):
)
for i, url in enumerate(urls)
]
memory_bank_id = "test-memory-bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
vector_db_id = "test-vector-db"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
llama_stack_client.tool_runtime.rag_tool.insert_documents(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
agent_config = {
**agent_config,
@ -306,7 +303,7 @@ def test_rag_agent(llama_stack_client, agent_config):
dict(
name="builtin::memory",
args={
"memory_bank_ids": [memory_bank_id],
"vector_db_ids": [vector_db_id],
},
)
],
@ -324,4 +321,4 @@ def test_rag_agent(llama_stack_client, agent_config):
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_memory" in logs_str
assert "Tool:rag_tool.query_context" in logs_str