From beb2db487df5e375ba2226aeba72b5417cf1baa9 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sun, 7 Sep 2025 22:03:01 -0400 Subject: [PATCH] fixed update_prompt to properly handle latest and default version, made version a required parameter, and removed unused CreatePromptRequest Signed-off-by: Francisco Javier Arceo --- docs/_static/llama-stack-spec.html | 13 ++++----- docs/_static/llama-stack-spec.yaml | 11 ++++---- llama_stack/apis/prompts/__init__.py | 4 +-- llama_stack/apis/prompts/prompts.py | 13 +++------ llama_stack/core/prompts/prompts.py | 12 +++++---- tests/unit/prompts/prompts/test_prompts.py | 31 +++++++++++++--------- 6 files changed, 43 insertions(+), 41 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 31c637bbb..2996f4790 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -10018,7 +10018,7 @@ "items": { "type": "string" }, - "description": "Dictionary of prompt variable names and values" + "description": "List of prompt variable names that can be used in the prompt template" }, "is_default": { "type": "boolean", @@ -17824,21 +17824,22 @@ "type": "string", "description": "The updated prompt text content." }, + "version": { + "type": "string", + "description": "The current version of the prompt being updated (as a string)." + }, "variables": { "type": "array", "items": { "type": "string" }, "description": "Updated list of variable names that can be used in the prompt template." - }, - "version": { - "type": "string", - "description": "The current version of the prompt being updated (as a string)." } }, "additionalProperties": false, "required": [ - "prompt" + "prompt", + "version" ], "title": "UpdatePromptRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 03457ecf0..106464f79 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7403,7 +7403,7 @@ components: items: type: string description: >- - Dictionary of prompt variable names and values + List of prompt variable names that can be used in the prompt template is_default: type: boolean default: false @@ -13226,19 +13226,20 @@ components: prompt: type: string description: The updated prompt text content. + version: + type: string + description: >- + The current version of the prompt being updated (as a string). variables: type: array items: type: string description: >- Updated list of variable names that can be used in the prompt template. - version: - type: string - description: >- - The current version of the prompt being updated (as a string). additionalProperties: false required: - prompt + - version title: UpdatePromptRequest VersionInfo: type: object diff --git a/llama_stack/apis/prompts/__init__.py b/llama_stack/apis/prompts/__init__.py index c76b3b825..3534f5864 100644 --- a/llama_stack/apis/prompts/__init__.py +++ b/llama_stack/apis/prompts/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .prompts import CreatePromptRequest, ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest +from .prompts import ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest -__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "CreatePromptRequest", "UpdatePromptRequest"] +__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "UpdatePromptRequest"] diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py index 8d34aa5ca..31073b750 100644 --- a/llama_stack/apis/prompts/prompts.py +++ b/llama_stack/apis/prompts/prompts.py @@ -21,7 +21,7 @@ class Prompt(BaseModel): :param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API. :param version: Version string (integer start at 1 cast as string, incremented on save) :param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>' - :param variables: Dictionary of prompt variable names and values + :param variables: List of prompt variable names that can be used in the prompt template :param is_default: Boolean indicating whether this version is the default version for this prompt """ @@ -90,13 +90,6 @@ class Prompt(BaseModel): return f"pmpt_{hex_string}" -class CreatePromptRequest(BaseModel): - """Request model to create a prompt.""" - - prompt: str = Field(description="The prompt text content") - variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection") - - class UpdatePromptRequest(BaseModel): """Request model for updating a prompt.""" @@ -168,15 +161,15 @@ class Prompts(Protocol): self, prompt_id: str, prompt: str, + version: str, variables: list[str] | None = None, - version: str | None = None, ) -> Prompt: """Update an existing prompt (increments version). :param prompt_id: The identifier of the prompt to update. :param prompt: The updated prompt text content. - :param variables: Updated list of variable names that can be used in the prompt template. :param version: The current version of the prompt being updated (as a string). + :param variables: Updated list of variable names that can be used in the prompt template. :returns: The updated Prompt resource with incremented version. """ ... diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py index 624764020..ecd4ff2c2 100644 --- a/llama_stack/core/prompts/prompts.py +++ b/llama_stack/core/prompts/prompts.py @@ -142,20 +142,22 @@ class PromptServiceImpl(Prompts): self, prompt_id: str, prompt: str, + version: str, variables: list[str] | None = None, - version: str | None = None, ) -> Prompt: """Update an existing prompt (increments version).""" if variables is None: variables = [] - current_prompt = await self.get_prompt(prompt_id) - if version and current_prompt.version != version: + prompt_versions = await self.list_prompt_versions(prompt_id) + latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version)) + + if version and latest_prompt.version != version: raise ValueError( - f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{current_prompt.version}' in request." + f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request." ) - current_version = current_prompt.version if version is None else version + current_version = latest_prompt.version if version is None else version new_version = str(int(current_version) + 1) updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py index 93f443374..a4512ed02 100644 --- a/tests/unit/prompts/prompts/test_prompts.py +++ b/tests/unit/prompts/prompts/test_prompts.py @@ -42,7 +42,7 @@ class TestPrompts: async def test_update_prompt(self, store): prompt = await store.create_prompt("Original") - updated = await store.update_prompt(prompt.prompt_id, "Updated", ["v"]) + updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"]) assert updated.version == "2" assert updated.prompt == "Updated" @@ -51,16 +51,16 @@ class TestPrompts: prompt = await store.create_prompt("Original") assert prompt.version == "1" - prompt = await store.update_prompt(prompt.prompt_id, "Updated", ["v"], version_for_update) + prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"]) assert prompt.version == "2" with pytest.raises(ValueError): # now this is a stale version - await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], version_for_update) + await store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"]) with pytest.raises(ValueError): # this version does not exist - await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], "99") + await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"]) async def test_delete_prompt(self, store): prompt = await store.create_prompt("to be deleted") @@ -80,7 +80,7 @@ class TestPrompts: async def test_version(self, store): prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2") + await store.update_prompt(prompt.prompt_id, "V2", "1") v1 = await store.get_prompt(prompt.prompt_id, version="1") assert v1.version == "1" and v1.prompt == "V1" @@ -89,11 +89,16 @@ class TestPrompts: assert latest.version == "2" and latest.prompt == "V2" async def test_set_default_version(self, store): - prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2") + prompt0 = await store.create_prompt("V1") + prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1") - await store.set_default_version(prompt.prompt_id, "1") - assert (await store.get_prompt(prompt.prompt_id)).version == "1" + assert (await store.get_prompt(prompt0.prompt_id)).version == "2" + prompt_default = await store.set_default_version(prompt0.prompt_id, "1") + assert (await store.get_prompt(prompt0.prompt_id)).version == "1" + assert prompt_default.version == "1" + + prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version) + assert prompt2.version == "3" async def test_prompt_id_generation_and_validation(self, store): prompt = await store.create_prompt("Test") @@ -105,8 +110,8 @@ class TestPrompts: async def test_list_shows_default_versions(self, store): prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2") - await store.update_prompt(prompt.prompt_id, "V3") + await store.update_prompt(prompt.prompt_id, "V2", "1") + await store.update_prompt(prompt.prompt_id, "V3", "2") response = await store.list_prompts() listed_prompt = response.data[0] @@ -120,8 +125,8 @@ class TestPrompts: async def test_get_all_prompt_versions(self, store): prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2") - await store.update_prompt(prompt.prompt_id, "V3") + await store.update_prompt(prompt.prompt_id, "V2", "1") + await store.update_prompt(prompt.prompt_id, "V3", "2") versions = (await store.list_prompt_versions(prompt.prompt_id)).data assert len(versions) == 3