fix routing tables look up key for memory bank

This commit is contained in:
Dinesh Yeduguru 2024-11-06 13:24:24 -08:00
parent d289afdbde
commit 79740d5e54
2 changed files with 20 additions and 2 deletions

View file

@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable):
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type == type]
async def get_all_with_types(
self, types: List[str]
) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
return [obj for obj in objs if obj.type in types]
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDefWithProvider]:
@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
return await self.get_all_with_type("memory_bank")
return await self.get_all_with_types(
[
MemoryBankType.vector.value,
MemoryBankType.keyvalue.value,
MemoryBankType.keyword.value,
MemoryBankType.graph.value,
]
)
async def get_memory_bank(
self, identifier: str