fix routing tables look up key for memory bank (#383)

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-06 13:32:46 -08:00 committed by GitHub
parent 748606195b
commit 6ebd553da5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 2 deletions

View file

@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable):
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
async def get_all_with_types(
self, types: List[str]
) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type in types]
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDefWithProvider]:
@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
return await self.get_all_with_type("memory_bank")
return await self.get_all_with_types(
[
MemoryBankType.vector.value,
MemoryBankType.keyvalue.value,
MemoryBankType.keyword.value,
MemoryBankType.graph.value,
]
)
async def get_memory_bank(
self, identifier: str

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
@ -15,6 +16,7 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@ -26,12 +28,15 @@ def memory_remote() -> ProviderFixture:
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=FaissImplConfig().model_dump(),
config=FaissImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)