forked from phoenix-oss/llama-stack-mirror
unregister for memory banks and remove update API (#458)
The semantics of an Update on resources is very tricky to reason about especially for memory banks and models. The best way to go forward here is for the user to unregister and register a new resource. We don't have a compelling reason to support update APIs. Tests: pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "chroma" --env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "pgvector" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 $CONDA_PREFIX/bin/pytest -v -s -m "ollama" llama_stack/providers/tests/inference/test_model_registration.py --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
2eab3b7ed9
commit
0850ad656a
18 changed files with 286 additions and 250 deletions
|
@ -51,6 +51,16 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.memory:
|
||||
return await p.unregister_memory_bank(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
|
@ -148,17 +158,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
return obj
|
||||
|
||||
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
# TODO: delete from provider
|
||||
|
||||
async def update_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
registered_obj = await register_object_with_provider(
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
return await self.dist_registry.update(registered_obj)
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
|
@ -232,32 +236,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
async def update_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Model:
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
updated_model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id
|
||||
or existing_model.provider_resource_id,
|
||||
provider_id=provider_id or existing_model.provider_id,
|
||||
metadata=metadata or existing_model.metadata,
|
||||
)
|
||||
registered_model = await self.update_object(updated_model)
|
||||
return registered_model
|
||||
|
||||
async def delete_model(self, model_id: str) -> None:
|
||||
existing_model = await self.get_model(model_id)
|
||||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.delete_object(existing_model)
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
@ -333,6 +316,12 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
existing_bank = await self.get_memory_bank(memory_bank_id)
|
||||
if existing_bank is None:
|
||||
raise ValueError(f"Memory bank {memory_bank_id} not found")
|
||||
await self.unregister_object(existing_bank)
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[Dataset]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue