diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index afa8a2049..5575eb36e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -5261,7 +5261,7 @@ } } }, - "/v1/prompts/{prompt_id}/default-version": { + "/v1/prompts/{prompt_id}/set-default-version": { "post": { "responses": { "200": { @@ -17829,6 +17829,10 @@ "type": "string" }, "description": "Updated dictionary of variable names to their default values." + }, + "version": { + "type": "string", + "description": "The current version of the prompt being updated (as a string)." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 23b76c05b..2c0993a9e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3726,7 +3726,7 @@ paths: schema: $ref: '#/components/schemas/ScoreBatchRequest' required: true - /v1/prompts/{prompt_id}/default-version: + /v1/prompts/{prompt_id}/set-default-version: post: responses: '200': @@ -13231,6 +13231,10 @@ components: type: string description: >- Updated dictionary of variable names to their default values. + version: + type: string + description: >- + The current version of the prompt being updated (as a string). additionalProperties: false required: - prompt diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py index 6b1dddee6..ca102d9c6 100644 --- a/llama_stack/apis/prompts/prompts.py +++ b/llama_stack/apis/prompts/prompts.py @@ -153,12 +153,14 @@ class Prompts(Protocol): prompt_id: str, prompt: str, variables: dict[str, str] | None = None, + version: str | None = None, ) -> Prompt: """Update an existing prompt (increments version). :param prompt_id: The identifier of the prompt to update. :param prompt: The updated prompt text content. :param variables: Updated dictionary of variable names to their default values. + :param version: The current version of the prompt being updated (as a string). :returns: The updated Prompt resource with incremented version. """ ... @@ -174,7 +176,7 @@ class Prompts(Protocol): """ ... - @webmethod(route="/prompts/{prompt_id}/default-version", method="PUT") + @webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT") async def set_default_version( self, prompt_id: str, diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py index 6e7385a57..730ce00e3 100644 --- a/llama_stack/core/prompts/prompts.py +++ b/llama_stack/core/prompts/prompts.py @@ -46,10 +46,20 @@ class PromptServiceImpl(Prompts): ) self.kvstore = await kvstore_impl(kvstore_config) - def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str: + def _get_default_key(self, prompt_id: str) -> str: + """Get the KVStore key that stores the default version number.""" + return f"prompts:v1:{prompt_id}:default" + + async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str: + """Get the KVStore key for prompt data, returning default version if applicable.""" if version: return self._get_version_key(prompt_id, version) - return f"prompts:v1:{prompt_id}:default" + + default_key = self._get_default_key(prompt_id) + resolved_version = await self.kvstore.get(default_key) + if resolved_version is None: + raise ValueError(f"Prompt {prompt_id}:default not found") + return self._get_version_key(prompt_id, resolved_version) def _get_version_key(self, prompt_id: str, version: str) -> str: """Get the KVStore key for a specific prompt version.""" @@ -102,22 +112,10 @@ class PromptServiceImpl(Prompts): async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt: """Get a prompt by its identifier and optional version.""" - if version: - key = self._get_version_key(prompt_id, version) - data = await self.kvstore.get(key) - if data is None: - raise ValueError(f"Prompt {prompt_id} version {version} not found") - else: - default_key = self._get_prompt_key(prompt_id) - default_version = await self.kvstore.get(default_key) - if default_version is None: - raise ValueError(f"Prompt with ID '{prompt_id}' not found") - - key = self._get_version_key(prompt_id, default_version) - data = await self.kvstore.get(key) - if data is None: - raise ValueError(f"Prompt with ID '{prompt_id}' not found") - + key = await self._get_prompt_key(prompt_id, version) + data = await self.kvstore.get(key) + if data is None: + raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found") return self._deserialize_prompt(data) async def create_prompt( @@ -135,7 +133,7 @@ class PromptServiceImpl(Prompts): data = self._serialize_prompt(prompt_obj) await self.kvstore.set(version_key, data) - default_key = self._get_prompt_key(prompt_obj.prompt_id) + default_key = self._get_default_key(prompt_obj.prompt_id) await self.kvstore.set(default_key, "1") return prompt_obj @@ -145,13 +143,20 @@ class PromptServiceImpl(Prompts): prompt_id: str, prompt: str, variables: dict[str, str] | None = None, + version: str | None = None, ) -> Prompt: """Update an existing prompt (increments version).""" if variables is None: variables = {} current_prompt = await self.get_prompt(prompt_id) - new_version = str(int(current_prompt.version) + 1) + if version and current_prompt.version != version: + raise ValueError( + f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{current_prompt.version}' in request." + ) + + current_version = current_prompt.version if version is None else version + new_version = str(int(current_version) + 1) updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) @@ -159,7 +164,7 @@ class PromptServiceImpl(Prompts): data = self._serialize_prompt(updated_prompt) await self.kvstore.set(version_key, data) - default_key = self._get_prompt_key(prompt_id) + default_key = self._get_default_key(prompt_id) await self.kvstore.set(default_key, new_version) return updated_prompt @@ -207,7 +212,7 @@ class PromptServiceImpl(Prompts): if data is None: raise ValueError(f"Prompt {prompt_id} version {version} not found") - default_key = self._get_prompt_key(prompt_id) + default_key = self._get_default_key(prompt_id) await self.kvstore.set(default_key, version) return self._deserialize_prompt(data) diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py index 0527b6b89..01395df84 100644 --- a/tests/unit/prompts/prompts/test_prompts.py +++ b/tests/unit/prompts/prompts/test_prompts.py @@ -46,6 +46,22 @@ class TestPrompts: assert updated.version == "2" assert updated.prompt == "Updated" + async def test_update_prompt_with_version(self, store): + version_for_update = "1" + + prompt = await store.create_prompt("Original") + assert prompt.version == "1" + prompt = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}, version_for_update) + assert prompt.version == "2" + + with pytest.raises(ValueError): + # now this is a stale version + await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, version_for_update) + + with pytest.raises(ValueError): + # this version does not exist + await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, "99") + async def test_delete_prompt(self, store): prompt = await store.create_prompt("to be deleted") await store.delete_prompt(prompt.prompt_id)