From 4667c1f542cba66d1fad527d118e4049f053552e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 17 Oct 2024 16:43:14 -0700 Subject: [PATCH] memory test --- .../tests/memory/provider_config_example.yaml | 4 ++-- .../providers/tests/memory/test_memory.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml index cac1adde5..5b5440f8d 100644 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -2,8 +2,8 @@ providers: - provider_id: test-faiss provider_type: meta-reference config: {} - - provider_id: test-chroma - provider_type: remote::chroma + - provider_id: test-chromadb + provider_type: remote::chromadb config: host: localhost port: 6001 diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index c5ebdf9c7..d92feaba8 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -89,6 +89,30 @@ async def test_banks_list(memory_settings): assert len(response) == 0 +@pytest.mark.asyncio +async def test_banks_register(memory_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + banks_impl = memory_settings["memory_banks_impl"] + bank = VectorMemoryBankDef( + identifier="test_bank_no_provider", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + # register same memory bank with same id again will fail + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + @pytest.mark.asyncio async def test_query_documents(memory_settings, sample_documents): memory_impl = memory_settings["memory_impl"]