mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
[memory refactor][4/n] Update the client-sdk test for RAG (#834)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Update client-sdk tests
This commit is contained in:
parent
1a7490470a
commit
63f37f9b7c
3 changed files with 236 additions and 228 deletions
|
@ -286,19 +286,16 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
memory_bank_id = "test-memory-bank"
|
vector_db_id = "test-vector-db"
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
memory_bank_id=memory_bank_id,
|
vector_db_id=vector_db_id,
|
||||||
params={
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
"memory_bank_type": "vector",
|
embedding_dimension=384,
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
|
||||||
"chunk_size_in_tokens": 512,
|
|
||||||
"overlap_size_in_tokens": 64,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
llama_stack_client.memory.insert(
|
llama_stack_client.tool_runtime.rag_tool.insert_documents(
|
||||||
bank_id=memory_bank_id,
|
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
|
@ -306,7 +303,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
dict(
|
dict(
|
||||||
name="builtin::memory",
|
name="builtin::memory",
|
||||||
args={
|
args={
|
||||||
"memory_bank_ids": [memory_bank_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -324,4 +321,4 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
)
|
)
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "Tool:query_memory" in logs_str
|
assert "Tool:rag_tool.query_context" in logs_str
|
||||||
|
|
180
tests/client-sdk/tool_runtime/test_rag_tool.py
Normal file
180
tests/client-sdk/tool_runtime/test_rag_tool.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack_client.types.tool_runtime import DocumentParam
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def empty_vector_db_registry(llama_stack_client):
|
||||||
|
vector_dbs = [
|
||||||
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
|
]
|
||||||
|
for vector_db_id in vector_dbs:
|
||||||
|
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
|
||||||
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
|
llama_stack_client.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
embedding_dimension=384,
|
||||||
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
vector_dbs = [
|
||||||
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
|
]
|
||||||
|
return vector_dbs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sample_documents():
|
||||||
|
return [
|
||||||
|
DocumentParam(
|
||||||
|
document_id="test-doc-1",
|
||||||
|
content="Python is a high-level programming language.",
|
||||||
|
metadata={"category": "programming", "difficulty": "beginner"},
|
||||||
|
),
|
||||||
|
DocumentParam(
|
||||||
|
document_id="test-doc-2",
|
||||||
|
content="Machine learning is a subset of artificial intelligence.",
|
||||||
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
),
|
||||||
|
DocumentParam(
|
||||||
|
document_id="test-doc-3",
|
||||||
|
content="Data structures are fundamental to computer science.",
|
||||||
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||||
|
),
|
||||||
|
DocumentParam(
|
||||||
|
document_id="test-doc-4",
|
||||||
|
content="Neural networks are inspired by biological neural networks.",
|
||||||
|
metadata={"category": "AI", "difficulty": "advanced"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_valid_response(response):
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert len(response.scores) > 0
|
||||||
|
assert len(response.chunks) == len(response.scores)
|
||||||
|
for chunk in response.chunks:
|
||||||
|
assert isinstance(chunk.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_db_insert_inline_and_query(
|
||||||
|
llama_stack_client, single_entry_vector_db_registry, sample_documents
|
||||||
|
):
|
||||||
|
vector_db_id = single_entry_vector_db_registry[0]
|
||||||
|
llama_stack_client.tool_runtime.rag_tool.insert_documents(
|
||||||
|
documents=sample_documents,
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with a direct match
|
||||||
|
query1 = "programming language"
|
||||||
|
response1 = llama_stack_client.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query=query1,
|
||||||
|
)
|
||||||
|
assert_valid_response(response1)
|
||||||
|
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||||
|
|
||||||
|
# Query with semantic similarity
|
||||||
|
query2 = "AI and brain-inspired computing"
|
||||||
|
response2 = llama_stack_client.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query=query2,
|
||||||
|
)
|
||||||
|
assert_valid_response(response2)
|
||||||
|
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
|
||||||
|
|
||||||
|
# Query with limit on number of results (max_chunks=2)
|
||||||
|
query3 = "computer"
|
||||||
|
response3 = llama_stack_client.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query=query3,
|
||||||
|
params={"max_chunks": 2},
|
||||||
|
)
|
||||||
|
assert_valid_response(response3)
|
||||||
|
assert len(response3.chunks) <= 2
|
||||||
|
|
||||||
|
# Query with threshold on similarity score
|
||||||
|
query4 = "computer"
|
||||||
|
response4 = llama_stack_client.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query=query4,
|
||||||
|
params={"score_threshold": 0.01},
|
||||||
|
)
|
||||||
|
assert_valid_response(response4)
|
||||||
|
assert all(score >= 0.01 for score in response4.scores)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_db_insert_from_url_and_query(
|
||||||
|
llama_stack_client, empty_vector_db_registry
|
||||||
|
):
|
||||||
|
providers = [p for p in llama_stack_client.providers.list() if p.api == "vector_io"]
|
||||||
|
assert len(providers) > 0
|
||||||
|
|
||||||
|
vector_db_id = "test_vector_db"
|
||||||
|
|
||||||
|
llama_stack_client.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
embedding_dimension=384,
|
||||||
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
|
||||||
|
# list to check memory bank is successfully registered
|
||||||
|
available_vector_dbs = [
|
||||||
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
|
]
|
||||||
|
assert vector_db_id in available_vector_dbs
|
||||||
|
|
||||||
|
# URLs of documents to insert
|
||||||
|
# TODO: Move to test/memory/resources then update the url to
|
||||||
|
# https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url}
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
DocumentParam(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
llama_stack_client.tool_runtime.rag_tool.insert_documents(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query for the name of method
|
||||||
|
response1 = llama_stack_client.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query="What's the name of the fine-tunning method used?",
|
||||||
|
)
|
||||||
|
assert_valid_response(response1)
|
||||||
|
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
||||||
|
|
||||||
|
# Query for the name of model
|
||||||
|
response2 = llama_stack_client.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query="Which Llama model is mentioned?",
|
||||||
|
)
|
||||||
|
assert_valid_response(response2)
|
||||||
|
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
|
@ -8,251 +8,82 @@ import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.memory import MemoryBankDocument
|
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def empty_memory_bank_registry(llama_stack_client):
|
def empty_vector_db_registry(llama_stack_client):
|
||||||
memory_banks = [
|
vector_dbs = [
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
]
|
]
|
||||||
for memory_bank_id in memory_banks:
|
for vector_db_id in vector_dbs:
|
||||||
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_registry):
|
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
|
||||||
memory_bank_id = f"test_bank_{random.randint(1000, 9999)}"
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
memory_bank_id=memory_bank_id,
|
vector_db_id=vector_db_id,
|
||||||
params={
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
"memory_bank_type": "vector",
|
embedding_dimension=384,
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
|
||||||
"chunk_size_in_tokens": 512,
|
|
||||||
"overlap_size_in_tokens": 64,
|
|
||||||
},
|
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
memory_banks = [
|
vector_dbs = [
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
]
|
]
|
||||||
return memory_banks
|
return vector_dbs
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
def test_vector_db_retrieve(llama_stack_client, empty_vector_db_registry):
|
||||||
def sample_documents():
|
|
||||||
return [
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="test-doc-1",
|
|
||||||
content="Python is a high-level programming language.",
|
|
||||||
metadata={"category": "programming", "difficulty": "beginner"},
|
|
||||||
),
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="test-doc-2",
|
|
||||||
content="Machine learning is a subset of artificial intelligence.",
|
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
|
||||||
),
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="test-doc-3",
|
|
||||||
content="Data structures are fundamental to computer science.",
|
|
||||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
|
||||||
),
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="test-doc-4",
|
|
||||||
content="Neural networks are inspired by biological neural networks.",
|
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response):
|
|
||||||
assert len(response.chunks) > 0
|
|
||||||
assert len(response.scores) > 0
|
|
||||||
assert len(response.chunks) == len(response.scores)
|
|
||||||
for chunk in response.chunks:
|
|
||||||
assert isinstance(chunk.content, str)
|
|
||||||
assert chunk.document_id is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
|
|
||||||
# Register a memory bank first
|
# Register a memory bank first
|
||||||
memory_bank_id = f"test_bank_{random.randint(1000, 9999)}"
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
memory_bank_id=memory_bank_id,
|
vector_db_id=vector_db_id,
|
||||||
params={
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
"memory_bank_type": "vector",
|
embedding_dimension=384,
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
|
||||||
"chunk_size_in_tokens": 512,
|
|
||||||
"overlap_size_in_tokens": 64,
|
|
||||||
},
|
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retrieve the memory bank and validate its properties
|
# Retrieve the memory bank and validate its properties
|
||||||
response = llama_stack_client.memory_banks.retrieve(memory_bank_id=memory_bank_id)
|
response = llama_stack_client.vector_dbs.retrieve(vector_db_id=vector_db_id)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.identifier == memory_bank_id
|
assert response.identifier == vector_db_id
|
||||||
assert response.type == "memory_bank"
|
|
||||||
assert response.memory_bank_type == "vector"
|
|
||||||
assert response.embedding_model == "all-MiniLM-L6-v2"
|
assert response.embedding_model == "all-MiniLM-L6-v2"
|
||||||
assert response.chunk_size_in_tokens == 512
|
|
||||||
assert response.overlap_size_in_tokens == 64
|
|
||||||
assert response.provider_id == "faiss"
|
assert response.provider_id == "faiss"
|
||||||
assert response.provider_resource_id == memory_bank_id
|
assert response.provider_resource_id == vector_db_id
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
|
def test_vector_db_list(llama_stack_client, empty_vector_db_registry):
|
||||||
memory_banks_after_register = [
|
vector_dbs_after_register = [
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
]
|
]
|
||||||
assert len(memory_banks_after_register) == 0
|
assert len(vector_dbs_after_register) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
|
def test_vector_db_register(llama_stack_client, empty_vector_db_registry):
|
||||||
memory_provider_id = "faiss"
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
memory_bank_id = f"test_bank_{random.randint(1000, 9999)}"
|
llama_stack_client.vector_dbs.register(
|
||||||
llama_stack_client.memory_banks.register(
|
vector_db_id=vector_db_id,
|
||||||
memory_bank_id=memory_bank_id,
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
params={
|
embedding_dimension=384,
|
||||||
"memory_bank_type": "vector",
|
provider_id="faiss",
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
|
||||||
"chunk_size_in_tokens": 512,
|
|
||||||
"overlap_size_in_tokens": 64,
|
|
||||||
},
|
|
||||||
provider_id=memory_provider_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_banks_after_register = [
|
vector_dbs_after_register = [
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
]
|
]
|
||||||
assert memory_banks_after_register == [memory_bank_id]
|
assert vector_dbs_after_register == [vector_db_id]
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
|
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry):
|
||||||
memory_banks = [
|
vector_dbs = [
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
]
|
]
|
||||||
assert len(memory_banks) == 1
|
assert len(vector_dbs) == 1
|
||||||
|
|
||||||
memory_bank_id = memory_banks[0]
|
vector_db_id = vector_dbs[0]
|
||||||
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
|
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
|
||||||
memory_banks = [
|
vector_dbs = [
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()
|
||||||
]
|
]
|
||||||
assert len(memory_banks) == 0
|
assert len(vector_dbs) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_insert_inline_and_query(
|
|
||||||
llama_stack_client, single_entry_memory_bank_registry, sample_documents
|
|
||||||
):
|
|
||||||
memory_bank_id = single_entry_memory_bank_registry[0]
|
|
||||||
llama_stack_client.memory.insert(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
documents=sample_documents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query with a direct match
|
|
||||||
query1 = "programming language"
|
|
||||||
response1 = llama_stack_client.memory.query(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
query=query1,
|
|
||||||
)
|
|
||||||
assert_valid_response(response1)
|
|
||||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
|
||||||
|
|
||||||
# Query with semantic similarity
|
|
||||||
query2 = "AI and brain-inspired computing"
|
|
||||||
response2 = llama_stack_client.memory.query(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
query=query2,
|
|
||||||
)
|
|
||||||
assert_valid_response(response2)
|
|
||||||
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
|
|
||||||
|
|
||||||
# Query with limit on number of results (max_chunks=2)
|
|
||||||
query3 = "computer"
|
|
||||||
response3 = llama_stack_client.memory.query(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
query=query3,
|
|
||||||
params={"max_chunks": 2},
|
|
||||||
)
|
|
||||||
assert_valid_response(response3)
|
|
||||||
assert len(response3.chunks) <= 2
|
|
||||||
|
|
||||||
# Query with threshold on similarity score
|
|
||||||
query4 = "computer"
|
|
||||||
response4 = llama_stack_client.memory.query(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
query=query4,
|
|
||||||
params={"score_threshold": 0.01},
|
|
||||||
)
|
|
||||||
assert_valid_response(response4)
|
|
||||||
assert all(score >= 0.01 for score in response4.scores)
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_bank_insert_from_url_and_query(
|
|
||||||
llama_stack_client, empty_memory_bank_registry
|
|
||||||
):
|
|
||||||
providers = [p for p in llama_stack_client.providers.list() if p.api == "memory"]
|
|
||||||
assert len(providers) > 0
|
|
||||||
|
|
||||||
memory_provider_id = providers[0].provider_id
|
|
||||||
memory_bank_id = "test_bank"
|
|
||||||
|
|
||||||
llama_stack_client.memory_banks.register(
|
|
||||||
memory_bank_id=memory_bank_id,
|
|
||||||
params={
|
|
||||||
"memory_bank_type": "vector",
|
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
|
||||||
"chunk_size_in_tokens": 512,
|
|
||||||
"overlap_size_in_tokens": 64,
|
|
||||||
},
|
|
||||||
provider_id=memory_provider_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# list to check memory bank is successfully registered
|
|
||||||
available_memory_banks = [
|
|
||||||
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
|
|
||||||
]
|
|
||||||
assert memory_bank_id in available_memory_banks
|
|
||||||
|
|
||||||
# URLs of documents to insert
|
|
||||||
# TODO: Move to test/memory/resources then update the url to
|
|
||||||
# https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url}
|
|
||||||
urls = [
|
|
||||||
"memory_optimizations.rst",
|
|
||||||
"chat.rst",
|
|
||||||
"llama3.rst",
|
|
||||||
]
|
|
||||||
documents = [
|
|
||||||
Document(
|
|
||||||
document_id=f"num-{i}",
|
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
||||||
mime_type="text/plain",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
for i, url in enumerate(urls)
|
|
||||||
]
|
|
||||||
|
|
||||||
llama_stack_client.memory.insert(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
documents=documents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query for the name of method
|
|
||||||
response1 = llama_stack_client.memory.query(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
query="What's the name of the fine-tunning method used?",
|
|
||||||
)
|
|
||||||
assert_valid_response(response1)
|
|
||||||
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
|
||||||
|
|
||||||
# Query for the name of model
|
|
||||||
response2 = llama_stack_client.memory.query(
|
|
||||||
bank_id=memory_bank_id,
|
|
||||||
query="Which Llama model is mentioned?",
|
|
||||||
)
|
|
||||||
assert_valid_response(response1)
|
|
||||||
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue