From 9e68ed3f36aafa3fcb40378d10a6ccb7a5a99a4e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 20:50:26 -0800 Subject: [PATCH] registery to handle updates and deletes --- docs/resources/llama-stack-spec.html | 48 +++---- docs/resources/llama-stack-spec.yaml | 24 ++-- llama_stack/apis/models/models.py | 2 +- .../distribution/routers/routing_tables.py | 39 +++--- llama_stack/distribution/store/registry.py | 120 ++++++++---------- 5 files changed, 113 insertions(+), 120 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7fb46a724..3cac93967 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 18:16:59.065989" }, "servers": [ { @@ -7965,10 +7965,7 @@ ], "tags": [ { - "name": "Datasets" - }, - { - "name": "Inference" + "name": "Shields" }, { "name": "ScoringFunctions" @@ -7977,20 +7974,38 @@ "name": "MemoryBanks" }, { - "name": "Telemetry" + "name": "Datasets" }, { - "name": "PostTraining" + "name": "Agents" }, { - "name": "Models" + "name": "DatasetIO" + }, + { + "name": "Inference" }, { "name": "Inspect" }, + { + "name": "Memory" + }, + { + "name": "Models" + }, + { + "name": "PostTraining" + }, { "name": "Safety" }, + { + "name": "SyntheticDataGeneration" + }, + { + "name": "EvalTasks" + }, { "name": "Scoring" }, @@ -8001,22 +8016,7 @@ "name": "Eval" }, { - "name": "SyntheticDataGeneration" - }, - { - "name": "EvalTasks" - }, - { - "name": "Shields" - }, - { - "name": "Memory" - }, - { - "name": "DatasetIO" - }, - { - "name": "Agents" + "name": "Telemetry" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 06a4afa85..5d2b91d84 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -3414,7 +3414,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" + \ draft and subject to change.\n Generated at 2024-11-13 18:16:59.065989" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4824,24 +4824,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Datasets -- name: Inference +- name: Shields - name: ScoringFunctions - name: MemoryBanks -- name: Telemetry -- name: PostTraining -- name: Models +- name: Datasets +- name: Agents +- name: DatasetIO +- name: Inference - name: Inspect +- name: Memory +- name: Models +- name: PostTraining - name: Safety +- name: SyntheticDataGeneration +- name: EvalTasks - name: Scoring - name: BatchInference - name: Eval -- name: SyntheticDataGeneration -- name: EvalTasks -- name: Shields -- name: Memory -- name: DatasetIO -- name: Agents +- name: Telemetry - description: name: BuiltinTool - description: Model: ... - @webmethod(route="/models/delete", method="DELETE") + @webmethod(route="/models/delete", method="POST") async def delete_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 32a341278..861c830be 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -124,8 +124,8 @@ class CommonRoutingTableImpl(RoutingTable): apiname, objtype = apiname_object() # Get objects from disk registry - objects = self.dist_registry.get_cached(objtype, routing_key) - if not objects: + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" @@ -135,9 +135,8 @@ class CommonRoutingTableImpl(RoutingTable): f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) - for obj in objects: - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] raise ValueError(f"Provider not found for `{routing_key}`") @@ -145,30 +144,36 @@ class CommonRoutingTableImpl(RoutingTable): self, type: str, identifier: str ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry - objects = await self.dist_registry.get(type, identifier) - if not objects: + obj = await self.dist_registry.get(type, identifier) + if not obj: return None - assert len(objects) == 1 - return objects[0] + return obj async def delete_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) # TODO: delete from provider + async def update_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: + registered_obj = await register_object_with_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) + return await self.dist_registry.update(registered_obj) + async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: # Get existing objects from registry - existing_objects = await self.dist_registry.get(obj.type, obj.identifier) + existing_obj = await self.dist_registry.get(obj.type, obj.identifier) # Check for existing registration - for existing_obj in existing_objects: - if existing_obj.provider_id == obj.provider_id or not obj.provider_id: - print( - f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" - ) - return existing_obj + if existing_obj and existing_obj.provider_id == obj.provider_id: + print( + f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" + ) + return existing_obj # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: @@ -247,7 +252,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id or existing_model.provider_id, metadata=metadata or existing_model.metadata, ) - registered_model = await self.register_object(updated_model) + registered_model = await self.update_object(updated_model) return registered_model async def delete_model(self, model_id: str) -> None: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 35276b439..d8a1a04e3 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -26,9 +26,13 @@ class DistributionRegistry(Protocol): async def initialize(self) -> None: ... - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... + + async def update( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: ... # The current data structure allows multiple objects with the same identifier but different providers. # This is not ideal - we should have a single object that can be served by multiple providers, @@ -40,7 +44,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v1" +KEY_VERSION = "v2" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" @@ -54,19 +58,11 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: - try: - objects_data = json.loads(value) - objects = [ - pydantic.parse_obj_as( - RoutableObjectWithProvider, - json.loads(obj_str), - ) - for obj_str in objects_data - ] - all_objects.extend(objects) - except Exception as e: - print(f"Error parsing value: {e}") - traceback.print_exc() + obj = pydantic.parse_obj_as( + RoutableObjectWithProvider, + json.loads(value), + ) + all_objects.append(obj) return all_objects @@ -79,46 +75,49 @@ class DiskDistributionRegistry(DistributionRegistry): def get_cached( self, type: str, identifier: str - ) -> List[RoutableObjectWithProvider]: + ) -> Optional[RoutableObjectWithProvider]: # Disk registry does not have a cache - return [] + raise NotImplementedError("Disk registry does not have a cache") async def get_all(self) -> List[RoutableObjectWithProvider]: start_key, end_key = _get_registry_key_range() values = await self.kvstore.range(start_key, end_key) return _parse_registry_values(values) - async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + async def get( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: json_str = await self.kvstore.get( KEY_FORMAT.format(type=type, identifier=identifier) ) if not json_str: - return [] + return None objects_data = json.loads(json_str) - return [ - pydantic.parse_obj_as( + # Return only the first object if any exist + if objects_data: + return pydantic.parse_obj_as( RoutableObjectWithProvider, - json.loads(obj_str), + json.loads(objects_data), ) - for obj_str in objects_data - ] + return None - async def register(self, obj: RoutableObjectWithProvider) -> bool: - existing_objects = await self.get(obj.type, obj.identifier) - # dont register if the object's providerid already exists - for eobj in existing_objects: - if eobj.provider_id == obj.provider_id: - return False - - existing_objects.append(obj) - - objects_json = [ - obj.model_dump_json() for obj in existing_objects - ] # Fixed variable name + async def update(self, obj: RoutableObjectWithProvider) -> None: await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), - json.dumps(objects_json), + obj.model_dump_json(), + ) + return obj + + async def register(self, obj: RoutableObjectWithProvider) -> bool: + existing_obj = await self.get(obj.type, obj.identifier) + # dont register if the object's providerid already exists + if existing_obj and existing_obj.provider_id == obj.provider_id: + return False + + await self.kvstore.set( + KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), + obj.model_dump_json(), ) return True @@ -129,7 +128,7 @@ class DiskDistributionRegistry(DistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} + self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {} self._initialized = False self._initialize_lock = asyncio.Lock() self._cache_lock = asyncio.Lock() @@ -156,13 +155,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async with self._locked_cache() as cache: for obj in objects: cache_key = (obj.type, obj.identifier) - if cache_key not in cache: - cache[cache_key] = [] - if not any( - cached_obj.provider_id == obj.provider_id - for cached_obj in cache[cache_key] - ): - cache[cache_key].append(obj) + cache[cache_key] = obj self._initialized = True @@ -171,28 +164,22 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): def get_cached( self, type: str, identifier: str - ) -> List[RoutableObjectWithProvider]: - return self.cache.get((type, identifier), [])[:] # Return a copy + ) -> Optional[RoutableObjectWithProvider]: + return self.cache.get((type, identifier), None) async def get_all(self) -> List[RoutableObjectWithProvider]: await self._ensure_initialized() async with self._locked_cache() as cache: - return [item for sublist in cache.values() for item in sublist] + return list(cache.values()) - async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + async def get( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: await self._ensure_initialized() cache_key = (type, identifier) async with self._locked_cache() as cache: - if cache_key in cache: - return cache[cache_key][:] - - objects = await super().get(type, identifier) - if objects: - async with self._locked_cache() as cache: - cache[cache_key] = objects - - return objects + return cache.get(cache_key, None) async def register(self, obj: RoutableObjectWithProvider) -> bool: await self._ensure_initialized() @@ -201,16 +188,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): if success: cache_key = (obj.type, obj.identifier) async with self._locked_cache() as cache: - if cache_key not in cache: - cache[cache_key] = [] - if not any( - cached_obj.provider_id == obj.provider_id - for cached_obj in cache[cache_key] - ): - cache[cache_key].append(obj) + cache[cache_key] = obj return success + async def update(self, obj: RoutableObjectWithProvider) -> None: + await super().update(obj) + cache_key = (obj.type, obj.identifier) + async with self._locked_cache() as cache: + cache[cache_key] = obj + return obj + async def delete(self, type: str, identifier: str) -> None: await super().delete(type, identifier) cache_key = (type, identifier)