From 4b1b1962511329c838a589fc8fd41f97083ebd65 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 15:30:17 -0800 Subject: [PATCH] add model update and delete --- docs/resources/llama-stack-spec.html | 204 +++++++++++++++--- docs/resources/llama-stack-spec.yaml | 114 ++++++++-- llama_stack/apis/models/client.py | 32 ++- llama_stack/apis/models/models.py | 12 ++ .../distribution/routers/routing_tables.py | 31 +++ llama_stack/distribution/store/registry.py | 12 ++ 6 files changed, 356 insertions(+), 49 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ef9e29af..7fb46a724 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 11:02:50.081698" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" }, "servers": [ { @@ -429,6 +429,39 @@ } } }, + "/models/delete": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Models" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteModelRequest" + } + } + }, + "required": true + } + } + }, "/inference/embeddings": { "post": { "responses": { @@ -2225,6 +2258,46 @@ "required": true } } + }, + "/models/update": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Model" + } + } + } + } + }, + "tags": [ + "Models" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateModelRequest" + } + } + }, + "required": true + } + } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -4549,6 +4622,18 @@ "session_id" ] }, + "DeleteModelRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "model_id" + ] + }, "EmbeddingsRequest": { "type": "object", "properties": { @@ -7826,6 +7911,49 @@ "synthetic_data" ], "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." + }, + "UpdateModelRequest": { + "type": "object", + "properties": { + "model_id": { + "type": "string" + }, + "provider_model_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "model_id" + ] } }, "responses": {} @@ -7837,53 +7965,53 @@ ], "tags": [ { - "name": "Inspect" - }, - { - "name": "Models" - }, - { - "name": "Eval" - }, - { - "name": "EvalTasks" - }, - { - "name": "Scoring" + "name": "Datasets" }, { "name": "Inference" }, - { - "name": "Memory" - }, - { - "name": "Safety" - }, - { - "name": "PostTraining" - }, { "name": "ScoringFunctions" }, - { - "name": "Telemetry" - }, - { - "name": "Shields" - }, - { - "name": "BatchInference" - }, { "name": "MemoryBanks" }, { - "name": "Datasets" + "name": "Telemetry" + }, + { + "name": "PostTraining" + }, + { + "name": "Models" + }, + { + "name": "Inspect" + }, + { + "name": "Safety" + }, + { + "name": "Scoring" + }, + { + "name": "BatchInference" + }, + { + "name": "Eval" }, { "name": "SyntheticDataGeneration" }, + { + "name": "EvalTasks" + }, + { + "name": "Shields" + }, + { + "name": "Memory" + }, { "name": "DatasetIO" }, @@ -8142,6 +8270,10 @@ "name": "DeleteAgentsSessionRequest", "description": "" }, + { + "name": "DeleteModelRequest", + "description": "" + }, { "name": "EmbeddingsRequest", "description": "" @@ -8453,6 +8585,10 @@ { "name": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" + }, + { + "name": "UpdateModelRequest", + "description": "" } ], "x-tagGroups": [ @@ -8521,6 +8657,7 @@ "Dataset", "DeleteAgentsRequest", "DeleteAgentsSessionRequest", + "DeleteModelRequest", "DoraFinetuningConfig", "EmbeddingsRequest", "EmbeddingsResponse", @@ -8618,6 +8755,7 @@ "Turn", "URL", "UnstructuredLogEvent", + "UpdateModelRequest", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 14f87cf54..06a4afa85 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -867,6 +867,14 @@ components: - agent_id - session_id type: object + DeleteModelRequest: + additionalProperties: false + properties: + model_id: + type: string + required: + - model_id + type: object DoraFinetuningConfig: additionalProperties: false properties: @@ -3272,6 +3280,28 @@ components: - message - severity type: object + UpdateModelRequest: + additionalProperties: false + properties: + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + model_id: + type: string + provider_id: + type: string + provider_model_id: + type: string + required: + - model_id + type: object UserMessage: additionalProperties: false properties: @@ -3384,7 +3414,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-13 11:02:50.081698" + \ draft and subject to change.\n Generated at 2024-11-13 15:29:27.077633" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4186,6 +4216,27 @@ paths: responses: {} tags: - MemoryBanks + /models/delete: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DeleteModelRequest' + required: true + responses: + '200': + description: OK + tags: + - Models /models/get: get: parameters: @@ -4256,6 +4307,31 @@ paths: description: OK tags: - Models + /models/update: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateModelRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Model' + description: OK + tags: + - Models /post_training/job/artifacts: get: parameters: @@ -4748,22 +4824,22 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inspect -- name: Models -- name: Eval -- name: EvalTasks -- name: Scoring -- name: Inference -- name: Memory -- name: Safety -- name: PostTraining -- name: ScoringFunctions -- name: Telemetry -- name: Shields -- name: BatchInference -- name: MemoryBanks - name: Datasets +- name: Inference +- name: ScoringFunctions +- name: MemoryBanks +- name: Telemetry +- name: PostTraining +- name: Models +- name: Inspect +- name: Safety +- name: Scoring +- name: BatchInference +- name: Eval - name: SyntheticDataGeneration +- name: EvalTasks +- name: Shields +- name: Memory - name: DatasetIO - name: Agents - description: @@ -4964,6 +5040,9 @@ tags: - description: name: DeleteAgentsSessionRequest +- description: + name: DeleteModelRequest - description: name: EmbeddingsRequest @@ -5194,6 +5273,9 @@ tags: ' name: SyntheticDataGenerationResponse +- description: + name: UpdateModelRequest x-tagGroups: - name: Operations tags: @@ -5256,6 +5338,7 @@ x-tagGroups: - Dataset - DeleteAgentsRequest - DeleteAgentsSessionRequest + - DeleteModelRequest - DoraFinetuningConfig - EmbeddingsRequest - EmbeddingsResponse @@ -5353,6 +5436,7 @@ x-tagGroups: - Turn - URL - UnstructuredLogEvent + - UpdateModelRequest - UserMessage - VectorMemoryBank - VectorMemoryBankParams diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index d986828ee..aa63ca541 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -7,7 +7,7 @@ import asyncio import json -from typing import List, Optional +from typing import Any, Dict, List, Optional import fire import httpx @@ -61,6 +61,36 @@ class ModelsClient(Models): return None return Model(**j) + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: + async with httpx.AsyncClient() as client: + response = await client.put( + f"{self.base_url}/models/update", + json={ + "model_id": model_id, + "provider_model_id": provider_model_id, + "provider_id": provider_id, + "metadata": metadata, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return Model(**response.json()) + + async def delete_model(self, model_id: str) -> None: + async with httpx.AsyncClient() as client: + response = await client.delete( + f"{self.base_url}/models/delete", + params={"model_id": model_id}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + async def run_main(host: str, port: int, stream: bool): client = ModelsClient(f"http://{host}:{port}") diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 2cd12b4bc..7eebe5b9f 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -54,3 +54,15 @@ class Models(Protocol): provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> Model: ... + + @webmethod(route="/models/update", method="PUT") + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: ... + + @webmethod(route="/models/delete", method="DELETE") + async def delete_model(self, model_id: str) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 8c1b0c1e7..32a341278 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -152,6 +152,10 @@ class CommonRoutingTableImpl(RoutingTable): assert len(objects) == 1 return objects[0] + async def delete_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + # TODO: delete from provider + async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: @@ -225,6 +229,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): registered_model = await self.register_object(model) return registered_model + async def update_model( + self, + model_id: str, + provider_model_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Model: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + + updated_model = Model( + identifier=model_id, + provider_resource_id=provider_model_id + or existing_model.provider_resource_id, + provider_id=provider_id or existing_model.provider_id, + metadata=metadata or existing_model.metadata, + ) + registered_model = await self.register_object(updated_model) + return registered_model + + async def delete_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.delete_object(existing_model) + class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[Shield]: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index bb87c81fa..35276b439 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,6 +36,8 @@ class DistributionRegistry(Protocol): # The current approach could lead to inconsistencies if the same logical object has different data across providers. async def register(self, obj: RoutableObjectWithProvider) -> bool: ... + async def delete(self, type: str, identifier: str) -> None: ... + REGISTER_PREFIX = "distributions:registry" KEY_VERSION = "v1" @@ -120,6 +122,9 @@ class DiskDistributionRegistry(DistributionRegistry): ) return True + async def delete(self, type: str, identifier: str) -> None: + await self.kvstore.delete(KEY_FORMAT.format(type=type, identifier=identifier)) + class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): @@ -206,6 +211,13 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): return success + async def delete(self, type: str, identifier: str) -> None: + await super().delete(type, identifier) + cache_key = (type, identifier) + async with self._locked_cache() as cache: + if cache_key in cache: + del cache[cache_key] + async def create_dist_registry( metadata_store: Optional[KVStoreConfig],