unregister for memory banks and remove update API (#458)

The semantics of an Update on resources is very tricky to reason about
especially for memory banks and models. The best way to go forward here
is for the user to unregister and register a new resource. We don't have
a compelling reason to support update APIs.


Tests:
pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m
"chroma" --env CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m
"pgvector" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0

$CONDA_PREFIX/bin/pytest -v -s -m "ollama"
llama_stack/providers/tests/inference/test_model_registration.py

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-14 17:12:11 -08:00 committed by GitHub
parent 2eab3b7ed9
commit 0850ad656a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 286 additions and 250 deletions

View file

@ -51,6 +51,16 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]]
@ -148,17 +158,11 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
# TODO: delete from provider
async def update_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
registered_obj = await register_object_with_provider(
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
return await self.dist_registry.update(registered_obj)
async def register_object(
self, obj: RoutableObjectWithProvider
@ -232,32 +236,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
registered_model = await self.register_object(model)
return registered_model
async def update_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
updated_model = Model(
identifier=model_id,
provider_resource_id=provider_model_id
or existing_model.provider_resource_id,
provider_id=provider_id or existing_model.provider_id,
metadata=metadata or existing_model.metadata,
)
registered_model = await self.update_object(updated_model)
return registered_model
async def delete_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.delete_object(existing_model)
await self.unregister_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -333,6 +316,12 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
await self.register_object(memory_bank)
return memory_bank
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
await self.unregister_object(existing_bank)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[Dataset]: