From 6ebd553da5aeeeaa940c07f4d2c18b0c4e19ac66 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 6 Nov 2024 13:32:46 -0800 Subject: [PATCH] fix routing tables look up key for memory bank (#383) Co-authored-by: Dinesh Yeduguru --- .../distribution/routers/routing_tables.py | 15 ++++++++++++++- llama_stack/providers/tests/memory/fixtures.py | 7 ++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1efd02c89..6297182bc 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable): objs = await self.dist_registry.get_all() return [obj for obj in objs if obj.type == type] + async def get_all_with_types( + self, types: List[str] + ) -> List[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + return [obj for obj in objs if obj.type in types] + class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDefWithProvider]: @@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - return await self.get_all_with_type("memory_bank") + return await self.get_all_with_types( + [ + MemoryBankType.vector.value, + MemoryBankType.keyvalue.value, + MemoryBankType.keyword.value, + MemoryBankType.graph.value, + ] + ) async def get_memory_bank( self, identifier: str diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index adeab8476..c5e41d32d 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import os +import tempfile import pytest import pytest_asyncio @@ -15,6 +16,7 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -26,12 +28,15 @@ def memory_remote() -> ProviderFixture: @pytest.fixture(scope="session") def memory_meta_reference() -> ProviderFixture: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( providers=[ Provider( provider_id="meta-reference", provider_type="meta-reference", - config=FaissImplConfig().model_dump(), + config=FaissImplConfig( + kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), + ).model_dump(), ) ], )