From 5cdcdbe074c0a3f9ad7c922d60ae4232a4ba5f46 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 8 Nov 2024 17:00:15 -0800 Subject: [PATCH] fix tests --- .../distribution/routers/routing_tables.py | 13 ++++--- .../distribution/utils/memory_bank_utils.py | 8 ++--- .../providers/tests/memory/test_memory.py | 36 ++++++++++++------- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 676ce14f6..41e84a3ac 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -274,9 +274,10 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def register_memory_bank( self, memory_bank_id: str, - provider_id: str, - provider_memorybank_id: str, - params: BankParams, + memory_bank_type: MemoryBankType, + provider_id: Optional[str] = None, + provider_memorybank_id: Optional[str] = None, + params: Optional[BankParams] = None, ) -> MemoryBank: if provider_memorybank_id is None: provider_memorybank_id = memory_bank_id @@ -289,7 +290,11 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): "No provider specified and multiple providers available. Please specify a provider_id." ) memory_bank = build_memory_bank( - memory_bank_id, params.type, provider_id, provider_memorybank_id, params + memory_bank_id, + memory_bank_type, + provider_id, + provider_memorybank_id, + params, ) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/utils/memory_bank_utils.py b/llama_stack/distribution/utils/memory_bank_utils.py index b2d698257..e55977e28 100644 --- a/llama_stack/distribution/utils/memory_bank_utils.py +++ b/llama_stack/distribution/utils/memory_bank_utils.py @@ -25,7 +25,7 @@ def build_memory_bank( provider_memorybank_id: str, params: Optional[BankParams] = None, ) -> MemoryBank: - if memory_bank_type == MemoryBankType.vector: + if memory_bank_type == MemoryBankType.vector.value: assert isinstance(params, VectorMemoryBankParams) memory_bank = VectorMemoryBank( identifier=memory_bank_id, @@ -36,21 +36,21 @@ def build_memory_bank( chunk_size_in_tokens=params.chunk_size_in_tokens, overlap_size_in_tokens=params.overlap_size_in_tokens, ) - elif memory_bank_type == MemoryBankType.keyvalue: + elif memory_bank_type == MemoryBankType.keyvalue.value: memory_bank = KeyValueMemoryBank( identifier=memory_bank_id, provider_id=provider_id, provider_resource_id=provider_memorybank_id, memory_bank_type=memory_bank_type, ) - elif memory_bank_type == MemoryBankType.keyword: + elif memory_bank_type == MemoryBankType.keyword.value: memory_bank = KeywordMemoryBank( identifier=memory_bank_id, provider_id=provider_id, provider_resource_id=provider_memorybank_id, memory_bank_type=memory_bank_type, ) - elif memory_bank_type == MemoryBankType.graph: + elif memory_bank_type == MemoryBankType.graph.value: memory_bank = GraphMemoryBank( identifier=memory_bank_id, provider_id=provider_id, diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 512077b03..1cefd1d4a 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -8,6 +8,7 @@ import pytest from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams # How to run this test: # @@ -44,13 +45,14 @@ def sample_documents(): async def register_memory_bank(banks_impl: MemoryBanks): - await banks_impl.register_memory_bank( - VectorRegistration( - memory_bank_id="test_bank", + return await banks_impl.register_memory_bank( + memory_bank_id="test_bank", + memory_bank_type="vector", + params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - ) + ), ) @@ -69,20 +71,30 @@ class TestMemory: # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack - bank = VectorRegistration( - memory_bank_id="test_bank_no_provider", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - await banks_impl.register_memory_bank(bank) + bank = await banks_impl.register_memory_bank( + memory_bank_id="test_bank_no_provider", + memory_bank_type="vector", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) response = await banks_impl.list_memory_banks() assert isinstance(response, list) assert len(response) == 1 # register same memory bank with same id again will fail - await banks_impl.register_memory_bank(bank) + await banks_impl.register_memory_bank( + memory_bank_id="test_bank_no_provider", + memory_bank_type="vector", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) response = await banks_impl.list_memory_banks() assert isinstance(response, list) assert len(response) == 1