mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 18:32:41 +00:00
registery to handle updates and deletes
This commit is contained in:
parent
4b1b196251
commit
9e68ed3f36
5 changed files with 113 additions and 120 deletions
|
|
@ -124,8 +124,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
apiname, objtype = apiname_object()
|
||||
|
||||
# Get objects from disk registry
|
||||
objects = self.dist_registry.get_cached(objtype, routing_key)
|
||||
if not objects:
|
||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||
if not obj:
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if len(provider_ids) > 1:
|
||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||
|
|
@ -135,9 +135,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||
)
|
||||
|
||||
for obj in objects:
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
|
|
@ -145,30 +144,36 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
# Get from disk registry
|
||||
objects = await self.dist_registry.get(type, identifier)
|
||||
if not objects:
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
assert len(objects) == 1
|
||||
return objects[0]
|
||||
return obj
|
||||
|
||||
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
# TODO: delete from provider
|
||||
|
||||
async def update_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
registered_obj = await register_object_with_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
return await self.dist_registry.update(registered_obj)
|
||||
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# Get existing objects from registry
|
||||
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
# Check for existing registration
|
||||
for existing_obj in existing_objects:
|
||||
if existing_obj.provider_id == obj.provider_id or not obj.provider_id:
|
||||
print(
|
||||
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||
)
|
||||
return existing_obj
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
print(
|
||||
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||
)
|
||||
return existing_obj
|
||||
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
|
|
@ -247,7 +252,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id or existing_model.provider_id,
|
||||
metadata=metadata or existing_model.metadata,
|
||||
)
|
||||
registered_model = await self.register_object(updated_model)
|
||||
registered_model = await self.update_object(updated_model)
|
||||
return registered_model
|
||||
|
||||
async def delete_model(self, model_id: str) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue