forked from phoenix-oss/llama-stack-mirror
fix routing tables look up key for memory bank (#383)
Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
748606195b
commit
6ebd553da5
2 changed files with 20 additions and 2 deletions
|
@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
async def get_all_with_types(
|
||||
self, types: List[str]
|
||||
) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type in types]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
|
@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
return await self.get_all_with_type("memory_bank")
|
||||
return await self.get_all_with_types(
|
||||
[
|
||||
MemoryBankType.vector.value,
|
||||
MemoryBankType.keyvalue.value,
|
||||
MemoryBankType.keyword.value,
|
||||
MemoryBankType.graph.value,
|
||||
]
|
||||
)
|
||||
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
@ -15,6 +16,7 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
|||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
@ -26,12 +28,15 @@ def memory_remote() -> ProviderFixture:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_meta_reference() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=FaissImplConfig().model_dump(),
|
||||
config=FaissImplConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue