add model update and delete

This commit is contained in:
Dinesh Yeduguru 2024-11-13 15:30:17 -08:00
parent 4253cfcd7f
commit 4b1b196251
6 changed files with 356 additions and 49 deletions

View file

@ -152,6 +152,10 @@ class CommonRoutingTableImpl(RoutingTable):
assert len(objects) == 1
return objects[0]
async def delete_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
# TODO: delete from provider
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
@ -225,6 +229,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
registered_model = await self.register_object(model)
return registered_model
async def update_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
updated_model = Model(
identifier=model_id,
provider_resource_id=provider_model_id
or existing_model.provider_resource_id,
provider_id=provider_id or existing_model.provider_id,
metadata=metadata or existing_model.metadata,
)
registered_model = await self.register_object(updated_model)
return registered_model
async def delete_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.delete_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[Shield]:

View file

@ -36,6 +36,8 @@ class DistributionRegistry(Protocol):
# The current approach could lead to inconsistencies if the same logical object has different data across providers.
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
async def delete(self, type: str, identifier: str) -> None: ...
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v1"
@ -120,6 +122,9 @@ class DiskDistributionRegistry(DistributionRegistry):
)
return True
async def delete(self, type: str, identifier: str) -> None:
await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier))
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
@ -206,6 +211,13 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
return success
async def delete(self, type: str, identifier: str) -> None:
await super().delete(type, identifier)
cache_key = (type, identifier)
async with self._locked_cache() as cache:
if cache_key in cache:
del cache[cache_key]
async def create_dist_registry(
metadata_store: Optional[KVStoreConfig],