diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ef9e29af..44554f2ff 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 11:02:50.081698" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 21:05:58.323310" }, "servers": [ { @@ -429,6 +429,39 @@ } } }, + "/models/delete": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Models" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteModelRequest" + } + } + }, + "required": true + } + } + }, "/inference/embeddings": { "post": { "responses": { @@ -2225,6 +2258,46 @@ "required": true } } + }, + "/models/update": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Model" + } + } + } + } + }, + "tags": [ + "Models" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateModelRequest" + } + } + }, + "required": true + } + } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -4549,6 +4622,18 @@ "session_id" ] }, + "DeleteModelRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "model_id" + ] + }, "EmbeddingsRequest": { "type": "object", "properties": { @@ -7826,6 +7911,49 @@ "synthetic_data" ], "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." + }, + "UpdateModelRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + }, + "provider_model_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "model_id" + ] } }, "responses": {} @@ -7837,23 +7965,20 @@ ], "tags": [ { - "name": "Inspect" + "name": "Agents" + }, + { + "name": "DatasetIO" }, { "name": "Models" }, - { - "name": "Eval" - }, - { - "name": "EvalTasks" - }, - { - "name": "Scoring" - }, { "name": "Inference" }, + { + "name": "BatchInference" + }, { "name": "Memory" }, @@ -7861,35 +7986,38 @@ "name": "Safety" }, { - "name": "PostTraining" + "name": "Inspect" }, { - "name": "ScoringFunctions" + "name": "EvalTasks" }, { - "name": "Telemetry" - }, - { - "name": "Shields" - }, - { - "name": "BatchInference" - }, - { - "name": "MemoryBanks" + "name": "Scoring" }, { "name": "Datasets" }, + { + "name": "PostTraining" + }, + { + "name": "Eval" + }, + { + "name": "Shields" + }, + { + "name": "Telemetry" + }, + { + "name": "ScoringFunctions" + }, + { + "name": "MemoryBanks" + }, { "name": "SyntheticDataGeneration" }, - { - "name": "DatasetIO" - }, - { - "name": "Agents" - }, { "name": "BuiltinTool", "description": "" @@ -8142,6 +8270,10 @@ "name": "DeleteAgentsSessionRequest", "description": "" }, + { + "name": "DeleteModelRequest", + "description": "" + }, { "name": "EmbeddingsRequest", "description": "" @@ -8453,6 +8585,10 @@ { "name": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" + }, + { + "name": "UpdateModelRequest", + "description": "" } ], "x-tagGroups": [ @@ -8521,6 +8657,7 @@ "Dataset", "DeleteAgentsRequest", "DeleteAgentsSessionRequest", + "DeleteModelRequest", "DoraFinetuningConfig", "EmbeddingsRequest", "EmbeddingsResponse", @@ -8618,6 +8755,7 @@ "Turn", "URL", "UnstructuredLogEvent", + "UpdateModelRequest", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 14f87cf54..fc28405d7 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -867,6 +867,14 @@ components: - agent_id - session_id type: object + DeleteModelRequest: + additionalProperties: false + properties: + model_id: + type: string + required: + - model_id + type: object DoraFinetuningConfig: additionalProperties: false properties: @@ -3272,6 +3280,28 @@ components: - message - severity type: object + UpdateModelRequest: + additionalProperties: false + properties: + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + model_id: + type: string + provider_id: + type: string + provider_model_id: + type: string + required: + - model_id + type: object UserMessage: additionalProperties: false properties: @@ -3384,7 +3414,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-13 11:02:50.081698" + \ draft and subject to change.\n Generated at 2024-11-13 21:05:58.323310" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4186,6 +4216,27 @@ paths: responses: {} tags: - MemoryBanks + /models/delete: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DeleteModelRequest' + required: true + responses: + '200': + description: OK + tags: + - Models /models/get: get: parameters: @@ -4256,6 +4307,31 @@ paths: description: OK tags: - Models + /models/update: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateModelRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Model' + description: OK + tags: + - Models /post_training/job/artifacts: get: parameters: @@ -4748,24 +4824,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inspect +- name: Agents +- name: DatasetIO - name: Models -- name: Eval -- name: EvalTasks -- name: Scoring - name: Inference +- name: BatchInference - name: Memory - name: Safety -- name: PostTraining -- name: ScoringFunctions -- name: Telemetry -- name: Shields -- name: BatchInference -- name: MemoryBanks +- name: Inspect +- name: EvalTasks +- name: Scoring - name: Datasets +- name: PostTraining +- name: Eval +- name: Shields +- name: Telemetry +- name: ScoringFunctions +- name: MemoryBanks - name: SyntheticDataGeneration -- name: DatasetIO -- name: Agents - description: name: BuiltinTool - description: name: DeleteAgentsSessionRequest +- description: + name: DeleteModelRequest - description: name: EmbeddingsRequest @@ -5194,6 +5273,9 @@ tags: ' name: SyntheticDataGenerationResponse +- description: + name: UpdateModelRequest x-tagGroups: - name: Operations tags: @@ -5256,6 +5338,7 @@ x-tagGroups: - Dataset - DeleteAgentsRequest - DeleteAgentsSessionRequest + - DeleteModelRequest - DoraFinetuningConfig - EmbeddingsRequest - EmbeddingsResponse @@ -5353,6 +5436,7 @@ x-tagGroups: - Turn - URL - UnstructuredLogEvent + - UpdateModelRequest - UserMessage - VectorMemoryBank - VectorMemoryBankParams diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index d986828ee..aa63ca541 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -7,7 +7,7 @@ import asyncio import json -from typing import List, Optional +from typing import Any, Dict, List, Optional import fire import httpx @@ -61,6 +61,36 @@ class ModelsClient(Models): return None return Model(**j) + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: + async with httpx.AsyncClient() as client: + response = await client.put( + f"{self.base_url}/models/update", + json={ + "model_id": model_id, + "provider_model_id": provider_model_id, + "provider_id": provider_id, + "metadata": metadata, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return Model(**response.json()) + + async def delete_model(self, model_id: str) -> None: + async with httpx.AsyncClient() as client: + response = await client.delete( + f"{self.base_url}/models/delete", + params={"model_id": model_id}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + async def run_main(host: str, port: int, stream: bool): client = ModelsClient(f"http://{host}:{port}") diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 2cd12b4bc..5ffcde52f 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -54,3 +54,15 @@ class Models(Protocol): provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Model: ... + + @webmethod(route="/models/update", method="POST") + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: ... + + @webmethod(route="/models/delete", method="POST") + async def delete_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 8c1b0c1e7..861c830be 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -124,8 +124,8 @@ class CommonRoutingTableImpl(RoutingTable): apiname, objtype = apiname_object() # Get objects from disk registry - objects = self.dist_registry.get_cached(objtype, routing_key) - if not objects: + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" @@ -135,9 +135,8 @@ class CommonRoutingTableImpl(RoutingTable): f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) - for obj in objects: - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] raise ValueError(f"Provider not found for `{routing_key}`") @@ -145,26 +144,36 @@ class CommonRoutingTableImpl(RoutingTable): self, type: str, identifier: str ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry - objects = await self.dist_registry.get(type, identifier) - if not objects: + obj = await self.dist_registry.get(type, identifier) + if not obj: return None - assert len(objects) == 1 - return objects[0] + return obj + + async def delete_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + # TODO: delete from provider + + async def update_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: + registered_obj = await register_object_with_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) + return await self.dist_registry.update(registered_obj) async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: # Get existing objects from registry - existing_objects = await self.dist_registry.get(obj.type, obj.identifier) + existing_obj = await self.dist_registry.get(obj.type, obj.identifier) # Check for existing registration - for existing_obj in existing_objects: - if existing_obj.provider_id == obj.provider_id or not obj.provider_id: - print( - f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" - ) - return existing_obj + if existing_obj and existing_obj.provider_id == obj.provider_id: + print( + f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" + ) + return existing_obj # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: @@ -225,6 +234,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): registered_model = await self.register_object(model) return registered_model + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + + updated_model = Model( + identifier=model_id, + provider_resource_id=provider_model_id + or existing_model.provider_resource_id, + provider_id=provider_id or existing_model.provider_id, + metadata=metadata or existing_model.metadata, + ) + registered_model = await self.update_object(updated_model) + return registered_model + + async def delete_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.delete_object(existing_model) + class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[Shield]: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index bb87c81fa..b876ee756 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -26,19 +26,21 @@ class DistributionRegistry(Protocol): async def initialize(self) -> None: ... - async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... - def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ... + def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... + + async def update( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: ... - # The current data structure allows multiple objects with the same identifier but different providers. - # This is not ideal - we should have a single object that can be served by multiple providers, - # suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider]. - # The current approach could lead to inconsistencies if the same logical object has different data across providers. async def register(self, obj: RoutableObjectWithProvider) -> bool: ... + async def delete(self, type: str, identifier: str) -> None: ... + REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v1" +KEY_VERSION = "v2" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" @@ -52,19 +54,11 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: - try: - objects_data = json.loads(value) - objects = [ - pydantic.parse_obj_as( - RoutableObjectWithProvider, - json.loads(obj_str), - ) - for obj_str in objects_data - ] - all_objects.extend(objects) - except Exception as e: - print(f"Error parsing value: {e}") - traceback.print_exc() + obj = pydantic.parse_obj_as( + RoutableObjectWithProvider, + json.loads(value), + ) + all_objects.append(obj) return all_objects @@ -77,54 +71,60 @@ class DiskDistributionRegistry(DistributionRegistry): def get_cached( self, type: str, identifier: str - ) -> List[RoutableObjectWithProvider]: + ) -> Optional[RoutableObjectWithProvider]: # Disk registry does not have a cache - return [] + raise NotImplementedError("Disk registry does not have a cache") async def get_all(self) -> List[RoutableObjectWithProvider]: start_key, end_key = _get_registry_key_range() values = await self.kvstore.range(start_key, end_key) return _parse_registry_values(values) - async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + async def get( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: json_str = await self.kvstore.get( KEY_FORMAT.format(type=type, identifier=identifier) ) if not json_str: - return [] + return None objects_data = json.loads(json_str) - return [ - pydantic.parse_obj_as( + # Return only the first object if any exist + if objects_data: + return pydantic.parse_obj_as( RoutableObjectWithProvider, - json.loads(obj_str), + json.loads(objects_data), ) - for obj_str in objects_data - ] + return None - async def register(self, obj: RoutableObjectWithProvider) -> bool: - existing_objects = await self.get(obj.type, obj.identifier) - # dont register if the object's providerid already exists - for eobj in existing_objects: - if eobj.provider_id == obj.provider_id: - return False - - existing_objects.append(obj) - - objects_json = [ - obj.model_dump_json() for obj in existing_objects - ] # Fixed variable name + async def update(self, obj: RoutableObjectWithProvider) -> None: await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), - json.dumps(objects_json), + obj.model_dump_json(), + ) + return obj + + async def register(self, obj: RoutableObjectWithProvider) -> bool: + existing_obj = await self.get(obj.type, obj.identifier) + # dont register if the object's providerid already exists + if existing_obj and existing_obj.provider_id == obj.provider_id: + return False + + await self.kvstore.set( + KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), + obj.model_dump_json(), ) return True + async def delete(self, type: str, identifier: str) -> None: + await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier)) + class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {} + self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {} self._initialized = False self._initialize_lock = asyncio.Lock() self._cache_lock = asyncio.Lock() @@ -151,13 +151,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async with self._locked_cache() as cache: for obj in objects: cache_key = (obj.type, obj.identifier) - if cache_key not in cache: - cache[cache_key] = [] - if not any( - cached_obj.provider_id == obj.provider_id - for cached_obj in cache[cache_key] - ): - cache[cache_key].append(obj) + cache[cache_key] = obj self._initialized = True @@ -166,28 +160,22 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): def get_cached( self, type: str, identifier: str - ) -> List[RoutableObjectWithProvider]: - return self.cache.get((type, identifier), [])[:] # Return a copy + ) -> Optional[RoutableObjectWithProvider]: + return self.cache.get((type, identifier), None) async def get_all(self) -> List[RoutableObjectWithProvider]: await self._ensure_initialized() async with self._locked_cache() as cache: - return [item for sublist in cache.values() for item in sublist] + return list(cache.values()) - async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]: + async def get( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: await self._ensure_initialized() cache_key = (type, identifier) async with self._locked_cache() as cache: - if cache_key in cache: - return cache[cache_key][:] - - objects = await super().get(type, identifier) - if objects: - async with self._locked_cache() as cache: - cache[cache_key] = objects - - return objects + return cache.get(cache_key, None) async def register(self, obj: RoutableObjectWithProvider) -> bool: await self._ensure_initialized() @@ -196,16 +184,24 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): if success: cache_key = (obj.type, obj.identifier) async with self._locked_cache() as cache: - if cache_key not in cache: - cache[cache_key] = [] - if not any( - cached_obj.provider_id == obj.provider_id - for cached_obj in cache[cache_key] - ): - cache[cache_key].append(obj) + cache[cache_key] = obj return success + async def update(self, obj: RoutableObjectWithProvider) -> None: + await super().update(obj) + cache_key = (obj.type, obj.identifier) + async with self._locked_cache() as cache: + cache[cache_key] = obj + return obj + + async def delete(self, type: str, identifier: str) -> None: + await super().delete(type, identifier) + cache_key = (type, identifier) + async with self._locked_cache() as cache: + if cache_key in cache: + del cache[cache_key] + async def create_dist_registry( metadata_store: Optional[KVStoreConfig], diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 4b20e519c..97f0ac576 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -6,6 +6,8 @@ import pytest +from llama_models.datatypes import CoreModelId + # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py @@ -33,3 +35,23 @@ class TestModelRegistration: await models_impl.register_model( model_id="Llama3-NonExistent-Model", ) + + @pytest.mark.asyncio + async def test_update_model(self, inference_stack): + _, models_impl = inference_stack + + # Register a model to update + model_id = CoreModelId.llama3_1_8b_instruct.value + old_model = await models_impl.register_model(model_id=model_id) + + # Update the model + new_model_id = CoreModelId.llama3_2_3b_instruct.value + updated_model = await models_impl.update_model( + model_id=model_id, provider_model_id=new_model_id + ) + + # Retrieve the updated model to verify changes + assert updated_model.provider_resource_id != old_model.provider_resource_id + + # Cleanup + await models_impl.delete_model(model_id=model_id)