add support for provider update and unregister for memory banks

This commit is contained in:
Dinesh Yeduguru 2024-11-14 16:08:24 -08:00
parent 9b75e92852
commit e8b699797c
11 changed files with 240 additions and 65 deletions

View file

@ -51,6 +51,24 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
raise ValueError(f"Unknown API {api} for registering object with provider")
async def update_object_with_provider(
obj: RoutableObject, p: Any
) -> Optional[RoutableObject]:
api = get_impl_api(p)
if api == Api.memory:
return await p.update_memory_bank(obj)
else:
raise ValueError(f"Update not supported for {api}")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
return await p.unregister_memory_bank(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]]
@ -148,14 +166,16 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
# TODO: delete from provider
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
async def update_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
registered_obj = await register_object_with_provider(
registered_obj = await update_object_with_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
return await self.dist_registry.update(registered_obj or obj)
@ -253,11 +273,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
registered_model = await self.update_object(updated_model)
return registered_model
async def delete_model(self, model_id: str) -> None:
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.delete_object(existing_model)
await self.unregister_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -358,11 +378,11 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
registered_bank = await self.update_object(updated_bank)
return registered_bank
async def delete_memory_bank(self, memory_bank_id: str) -> None:
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
await self.delete_object(existing_bank)
await self.unregister_object(existing_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):