diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index afa8a2049..5575eb36e 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -5261,7 +5261,7 @@
}
}
},
- "/v1/prompts/{prompt_id}/default-version": {
+ "/v1/prompts/{prompt_id}/set-default-version": {
"post": {
"responses": {
"200": {
@@ -17829,6 +17829,10 @@
"type": "string"
},
"description": "Updated dictionary of variable names to their default values."
+ },
+ "version": {
+ "type": "string",
+ "description": "The current version of the prompt being updated (as a string)."
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 23b76c05b..2c0993a9e 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -3726,7 +3726,7 @@ paths:
schema:
$ref: '#/components/schemas/ScoreBatchRequest'
required: true
- /v1/prompts/{prompt_id}/default-version:
+ /v1/prompts/{prompt_id}/set-default-version:
post:
responses:
'200':
@@ -13231,6 +13231,10 @@ components:
type: string
description: >-
Updated dictionary of variable names to their default values.
+ version:
+ type: string
+ description: >-
+ The current version of the prompt being updated (as a string).
additionalProperties: false
required:
- prompt
diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py
index 6b1dddee6..ca102d9c6 100644
--- a/llama_stack/apis/prompts/prompts.py
+++ b/llama_stack/apis/prompts/prompts.py
@@ -153,12 +153,14 @@ class Prompts(Protocol):
prompt_id: str,
prompt: str,
variables: dict[str, str] | None = None,
+ version: str | None = None,
) -> Prompt:
"""Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
:param variables: Updated dictionary of variable names to their default values.
+ :param version: The current version of the prompt being updated (as a string).
:returns: The updated Prompt resource with incremented version.
"""
...
@@ -174,7 +176,7 @@ class Prompts(Protocol):
"""
...
- @webmethod(route="/prompts/{prompt_id}/default-version", method="PUT")
+ @webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
async def set_default_version(
self,
prompt_id: str,
diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py
index 6e7385a57..730ce00e3 100644
--- a/llama_stack/core/prompts/prompts.py
+++ b/llama_stack/core/prompts/prompts.py
@@ -46,10 +46,20 @@ class PromptServiceImpl(Prompts):
)
self.kvstore = await kvstore_impl(kvstore_config)
- def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
+ def _get_default_key(self, prompt_id: str) -> str:
+ """Get the KVStore key that stores the default version number."""
+ return f"prompts:v1:{prompt_id}:default"
+
+ async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
+ """Get the KVStore key for prompt data, returning default version if applicable."""
if version:
return self._get_version_key(prompt_id, version)
- return f"prompts:v1:{prompt_id}:default"
+
+ default_key = self._get_default_key(prompt_id)
+ resolved_version = await self.kvstore.get(default_key)
+ if resolved_version is None:
+ raise ValueError(f"Prompt {prompt_id}:default not found")
+ return self._get_version_key(prompt_id, resolved_version)
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
@@ -102,22 +112,10 @@ class PromptServiceImpl(Prompts):
async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
- if version:
- key = self._get_version_key(prompt_id, version)
- data = await self.kvstore.get(key)
- if data is None:
- raise ValueError(f"Prompt {prompt_id} version {version} not found")
- else:
- default_key = self._get_prompt_key(prompt_id)
- default_version = await self.kvstore.get(default_key)
- if default_version is None:
- raise ValueError(f"Prompt with ID '{prompt_id}' not found")
-
- key = self._get_version_key(prompt_id, default_version)
- data = await self.kvstore.get(key)
- if data is None:
- raise ValueError(f"Prompt with ID '{prompt_id}' not found")
-
+ key = await self._get_prompt_key(prompt_id, version)
+ data = await self.kvstore.get(key)
+ if data is None:
+ raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
return self._deserialize_prompt(data)
async def create_prompt(
@@ -135,7 +133,7 @@ class PromptServiceImpl(Prompts):
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
- default_key = self._get_prompt_key(prompt_obj.prompt_id)
+ default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, "1")
return prompt_obj
@@ -145,13 +143,20 @@ class PromptServiceImpl(Prompts):
prompt_id: str,
prompt: str,
variables: dict[str, str] | None = None,
+ version: str | None = None,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if variables is None:
variables = {}
current_prompt = await self.get_prompt(prompt_id)
- new_version = str(int(current_prompt.version) + 1)
+ if version and current_prompt.version != version:
+ raise ValueError(
+ f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{current_prompt.version}' in request."
+ )
+
+ current_version = current_prompt.version if version is None else version
+ new_version = str(int(current_version) + 1)
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
@@ -159,7 +164,7 @@ class PromptServiceImpl(Prompts):
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
- default_key = self._get_prompt_key(prompt_id)
+ default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, new_version)
return updated_prompt
@@ -207,7 +212,7 @@ class PromptServiceImpl(Prompts):
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
- default_key = self._get_prompt_key(prompt_id)
+ default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, version)
return self._deserialize_prompt(data)
diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py
index 0527b6b89..01395df84 100644
--- a/tests/unit/prompts/prompts/test_prompts.py
+++ b/tests/unit/prompts/prompts/test_prompts.py
@@ -46,6 +46,22 @@ class TestPrompts:
assert updated.version == "2"
assert updated.prompt == "Updated"
+ async def test_update_prompt_with_version(self, store):
+ version_for_update = "1"
+
+ prompt = await store.create_prompt("Original")
+ assert prompt.version == "1"
+ prompt = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}, version_for_update)
+ assert prompt.version == "2"
+
+ with pytest.raises(ValueError):
+ # now this is a stale version
+ await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, version_for_update)
+
+ with pytest.raises(ValueError):
+ # this version does not exist
+ await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, "99")
+
async def test_delete_prompt(self, store):
prompt = await store.create_prompt("to be deleted")
await store.delete_prompt(prompt.prompt_id)