forked from phoenix-oss/llama-stack-mirror
fix routing tables look up key for memory bank (#383)
Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
748606195b
commit
6ebd553da5
2 changed files with 20 additions and 2 deletions
|
@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
return [obj for obj in objs if obj.type == type]
|
return [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
async def get_all_with_types(
|
||||||
|
self, types: List[str]
|
||||||
|
) -> List[RoutableObjectWithProvider]:
|
||||||
|
objs = await self.dist_registry.get_all()
|
||||||
|
return [obj for obj in objs if obj.type in types]
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||||
|
@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||||
return await self.get_all_with_type("memory_bank")
|
return await self.get_all_with_types(
|
||||||
|
[
|
||||||
|
MemoryBankType.vector.value,
|
||||||
|
MemoryBankType.keyvalue.value,
|
||||||
|
MemoryBankType.keyword.value,
|
||||||
|
MemoryBankType.graph.value,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
async def get_memory_bank(
|
async def get_memory_bank(
|
||||||
self, identifier: str
|
self, identifier: str
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
@ -15,6 +16,7 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
@ -26,12 +28,15 @@ def memory_remote() -> ProviderFixture:
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_meta_reference() -> ProviderFixture:
|
def memory_meta_reference() -> ProviderFixture:
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
config=FaissImplConfig().model_dump(),
|
config=FaissImplConfig(
|
||||||
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||||
|
).model_dump(),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue