add update and delete for memory banks

This commit is contained in:
Dinesh Yeduguru 2024-11-14 14:44:37 -08:00
parent bba6edd06b
commit 9b75e92852
4 changed files with 240 additions and 3 deletions

View file

@ -144,3 +144,15 @@ class MemoryBanks(Protocol):
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/update", method="POST")
async def update_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/delete", method="POST")
async def delete_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -158,7 +158,7 @@ class CommonRoutingTableImpl(RoutingTable):
registered_obj = await register_object_with_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
return await self.dist_registry.update(registered_obj)
return await self.dist_registry.update(registered_obj or obj)
async def register_object(
self, obj: RoutableObjectWithProvider
@ -333,6 +333,37 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
await self.register_object(memory_bank)
return memory_bank
async def update_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
updated_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id or existing_bank.provider_id,
"provider_resource_id": provider_memory_bank_id
or existing_bank.provider_resource_id,
**params.model_dump(),
},
)
registered_bank = await self.update_object(updated_bank)
return registered_bank
async def delete_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
await self.delete_object(existing_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]: