mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 07:32:36 +00:00
add update and delete for memory banks
This commit is contained in:
parent
bba6edd06b
commit
9b75e92852
4 changed files with 240 additions and 3 deletions
|
|
@ -144,3 +144,15 @@ class MemoryBanks(Protocol):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory_banks/update", method="POST")
|
||||
async def update_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank: ...
|
||||
|
||||
@webmethod(route="/memory_banks/delete", method="POST")
|
||||
async def delete_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
registered_obj = await register_object_with_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
return await self.dist_registry.update(registered_obj)
|
||||
return await self.dist_registry.update(registered_obj or obj)
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
|
|
@ -333,6 +333,37 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
async def update_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
|
||||
updated_bank = parse_obj_as(
|
||||
MemoryBank,
|
||||
{
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id or existing_bank.provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id
|
||||
or existing_bank.provider_resource_id,
|
||||
**params.model_dump(),
|
||||
},
|
||||
)
|
||||
registered_bank = await self.update_object(updated_bank)
|
||||
return registered_bank
|
||||
|
||||
async def delete_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.delete_object(existing_bank)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[Dataset]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue