mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix routing tables look up key for memory bank (#383)
Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
748606195b
commit
6ebd553da5
2 changed files with 20 additions and 2 deletions
|
@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
async def get_all_with_types(
|
||||
self, types: List[str]
|
||||
) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type in types]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
|
@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
return await self.get_all_with_type("memory_bank")
|
||||
return await self.get_all_with_types(
|
||||
[
|
||||
MemoryBankType.vector.value,
|
||||
MemoryBankType.keyvalue.value,
|
||||
MemoryBankType.keyword.value,
|
||||
MemoryBankType.graph.value,
|
||||
]
|
||||
)
|
||||
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue