mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-15 22:47:59 +00:00
migrate memory banks to Resource and new registration
This commit is contained in:
parent
b4416b72fd
commit
c82f13bf9e
16 changed files with 178 additions and 104 deletions
|
@ -188,12 +188,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
async def get_all_with_types(
|
||||
self, types: List[str]
|
||||
) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type in types]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[Model]:
|
||||
|
@ -233,7 +227,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[Shield]:
|
||||
return await self.get_all_with_type("shield")
|
||||
return await self.get_all_with_type(ResourceType.shield.value)
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
|
@ -270,25 +264,29 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
return await self.get_all_with_types(
|
||||
[
|
||||
MemoryBankType.vector.value,
|
||||
MemoryBankType.keyvalue.value,
|
||||
MemoryBankType.keyword.value,
|
||||
MemoryBankType.graph.value,
|
||||
]
|
||||
)
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return await self.get_all_with_type(ResourceType.memory_bank.value)
|
||||
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_object_by_identifier(memory_bank_id)
|
||||
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
) -> None:
|
||||
self,
|
||||
request: RegistrationRequest,
|
||||
) -> MemoryBank:
|
||||
if request.provider_resource_id is None:
|
||||
request.provider_resource_id = request.memory_bank_id
|
||||
if request.provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
request.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
memory_bank = registration_request_to_memory_bank(request)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue