mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 09:29:47 +00:00
add support for provider update and unregister for memory banks
This commit is contained in:
parent
9b75e92852
commit
e8b699797c
11 changed files with 240 additions and 65 deletions
|
|
@ -51,6 +51,24 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
async def update_object_with_provider(
|
||||
obj: RoutableObject, p: Any
|
||||
) -> Optional[RoutableObject]:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.update_memory_bank(obj)
|
||||
else:
|
||||
raise ValueError(f"Update not supported for {api}")
|
||||
|
||||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.unregister_memory_bank(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
|
|
@ -148,14 +166,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
return obj
|
||||
|
||||
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
# TODO: delete from provider
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
|
||||
async def update_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
registered_obj = await register_object_with_provider(
|
||||
registered_obj = await update_object_with_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
return await self.dist_registry.update(registered_obj or obj)
|
||||
|
|
@ -253,11 +273,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
registered_model = await self.update_object(updated_model)
|
||||
return registered_model
|
||||
|
||||
async def delete_model(self, model_id: str) -> None:
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.delete_object(existing_model)
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
|
@ -358,11 +378,11 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
registered_bank = await self.update_object(updated_bank)
|
||||
return registered_bank
|
||||
|
||||
async def delete_memory_bank(self, memory_bank_id: str) -> None:
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.delete_object(existing_bank)
|
||||
await self.unregister_object(existing_bank)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue