diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index a2015810a..a036e5dc0 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -17834,12 +17834,17 @@
"type": "string"
},
"description": "Updated list of variable names that can be used in the prompt template."
+ },
+ "set_as_default": {
+ "type": "boolean",
+ "description": "Set the new version as the default (default=True)."
}
},
"additionalProperties": false,
"required": [
"prompt",
- "version"
+ "version",
+ "set_as_default"
],
"title": "UpdatePromptRequest"
},
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 74bc7ef41..8ed04c1f8 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -13236,10 +13236,15 @@ components:
type: string
description: >-
Updated list of variable names that can be used in the prompt template.
+ set_as_default:
+ type: boolean
+ description: >-
+ Set the new version as the default (default=True).
additionalProperties: false
required:
- prompt
- version
+ - set_as_default
title: UpdatePromptRequest
VersionInfo:
type: object
diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py
index 8018b7fff..e6a376c3f 100644
--- a/llama_stack/apis/prompts/prompts.py
+++ b/llama_stack/apis/prompts/prompts.py
@@ -150,6 +150,7 @@ class Prompts(Protocol):
prompt: str,
version: int,
variables: list[str] | None = None,
+ set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version).
@@ -157,6 +158,7 @@ class Prompts(Protocol):
:param prompt: The updated prompt text content.
:param version: The current version of the prompt being updated.
:param variables: Updated list of variable names that can be used in the prompt template.
+ :param set_as_default: Set the new version as the default (default=True).
:returns: The updated Prompt resource with incremented version.
"""
...
diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py
index a5fe79ed3..26e8f5cef 100644
--- a/llama_stack/core/prompts/prompts.py
+++ b/llama_stack/core/prompts/prompts.py
@@ -154,6 +154,7 @@ class PromptServiceImpl(Prompts):
prompt: str,
version: int,
variables: list[str] | None = None,
+ set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if version < 1:
@@ -178,8 +179,8 @@ class PromptServiceImpl(Prompts):
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
- default_key = self._get_default_key(prompt_id)
- await self.kvstore.set(default_key, str(new_version))
+ if set_as_default:
+ await self.set_default_version(prompt_id, new_version)
return updated_prompt
diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py
index f778fbf16..b2c619e49 100644
--- a/tests/unit/prompts/prompts/conftest.py
+++ b/tests/unit/prompts/prompts/conftest.py
@@ -18,7 +18,13 @@ async def temp_prompt_store(tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
- config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
+ from llama_stack.core.datatypes import StackRunConfig
+ from llama_stack.providers.utils.kvstore import kvstore_impl
+
+ mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
+ config = PromptServiceConfig(run_config=mock_run_config)
store = PromptServiceImpl(config, deps={})
- await store.initialize()
+
+ store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
+
yield store
diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py
index 94bbbdbf0..792e55530 100644
--- a/tests/unit/prompts/prompts/test_prompts.py
+++ b/tests/unit/prompts/prompts/test_prompts.py
@@ -4,149 +4,141 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import os
-import tempfile
import pytest
-from llama_stack.core.datatypes import StackRunConfig
-from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
-
class TestPrompts:
- @pytest.fixture
- async def store(self):
- with tempfile.TemporaryDirectory() as temp_dir:
- mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
- config = PromptServiceConfig(run_config=mock_run_config)
- store = PromptServiceImpl(config, deps={})
-
- from llama_stack.providers.utils.kvstore import kvstore_impl
- from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
-
- test_db_path = os.path.join(temp_dir, "test_prompts.db")
- store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=test_db_path))
-
- yield store
-
- async def test_create_and_get_prompt(self, store):
- prompt = await store.create_prompt("Hello world!", ["name"])
+ async def test_create_and_get_prompt(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("Hello world!", ["name"])
assert prompt.prompt == "Hello world!"
assert prompt.version == 1
assert prompt.prompt_id.startswith("pmpt_")
assert prompt.variables == ["name"]
- retrieved = await store.get_prompt(prompt.prompt_id)
+ retrieved = await temp_prompt_store.get_prompt(prompt.prompt_id)
assert retrieved.prompt_id == prompt.prompt_id
assert retrieved.prompt == prompt.prompt
- async def test_update_prompt(self, store):
- prompt = await store.create_prompt("Original")
- updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
+ async def test_update_prompt(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("Original")
+ updated = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
assert updated.version == 2
assert updated.prompt == "Updated"
- async def test_update_prompt_with_version(self, store):
+ async def test_update_prompt_with_version(self, temp_prompt_store):
version_for_update = 1
- prompt = await store.create_prompt("Original")
+ prompt = await temp_prompt_store.create_prompt("Original")
assert prompt.version == 1
- prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
+ prompt = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
assert prompt.version == 2
with pytest.raises(ValueError):
# now this is a stale version
- await store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
with pytest.raises(ValueError):
# this version does not exist
- await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
- async def test_delete_prompt(self, store):
- prompt = await store.create_prompt("to be deleted")
- await store.delete_prompt(prompt.prompt_id)
+ async def test_delete_prompt(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("to be deleted")
+ await temp_prompt_store.delete_prompt(prompt.prompt_id)
with pytest.raises(ValueError):
- await store.get_prompt(prompt.prompt_id)
+ await temp_prompt_store.get_prompt(prompt.prompt_id)
- async def test_list_prompts(self, store):
- response = await store.list_prompts()
+ async def test_list_prompts(self, temp_prompt_store):
+ response = await temp_prompt_store.list_prompts()
assert response.data == []
- await store.create_prompt("first")
- await store.create_prompt("second")
+ await temp_prompt_store.create_prompt("first")
+ await temp_prompt_store.create_prompt("second")
- response = await store.list_prompts()
+ response = await temp_prompt_store.list_prompts()
assert len(response.data) == 2
- async def test_version(self, store):
- prompt = await store.create_prompt("V1")
- await store.update_prompt(prompt.prompt_id, "V2", 1)
+ async def test_version(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("V1")
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
- v1 = await store.get_prompt(prompt.prompt_id, version=1)
+ v1 = await temp_prompt_store.get_prompt(prompt.prompt_id, version=1)
assert v1.version == 1 and v1.prompt == "V1"
- latest = await store.get_prompt(prompt.prompt_id)
+ latest = await temp_prompt_store.get_prompt(prompt.prompt_id)
assert latest.version == 2 and latest.prompt == "V2"
- async def test_set_default_version(self, store):
- prompt0 = await store.create_prompt("V1")
- prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1)
+ async def test_set_default_version(self, temp_prompt_store):
+ prompt0 = await temp_prompt_store.create_prompt("V1")
+ prompt1 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V2", 1)
- assert (await store.get_prompt(prompt0.prompt_id)).version == 2
- prompt_default = await store.set_default_version(prompt0.prompt_id, 1)
- assert (await store.get_prompt(prompt0.prompt_id)).version == 1
+ assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 2
+ prompt_default = await temp_prompt_store.set_default_version(prompt0.prompt_id, 1)
+ assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 1
assert prompt_default.version == 1
- prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
+ prompt2 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
assert prompt2.version == 3
- async def test_prompt_id_generation_and_validation(self, store):
- prompt = await store.create_prompt("Test")
+ async def test_prompt_id_generation_and_validation(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("Test")
assert prompt.prompt_id.startswith("pmpt_")
assert len(prompt.prompt_id) == 53
with pytest.raises(ValueError):
- await store.get_prompt("invalid_id")
+ await temp_prompt_store.get_prompt("invalid_id")
- async def test_list_shows_default_versions(self, store):
- prompt = await store.create_prompt("V1")
- await store.update_prompt(prompt.prompt_id, "V2", 1)
- await store.update_prompt(prompt.prompt_id, "V3", 2)
+ async def test_list_shows_default_versions(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("V1")
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
- response = await store.list_prompts()
+ response = await temp_prompt_store.list_prompts()
listed_prompt = response.data[0]
assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
- await store.set_default_version(prompt.prompt_id, 1)
+ await temp_prompt_store.set_default_version(prompt.prompt_id, 1)
- response = await store.list_prompts()
+ response = await temp_prompt_store.list_prompts()
listed_prompt = response.data[0]
assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
- assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default
+ assert not (await temp_prompt_store.get_prompt(prompt.prompt_id, 3)).is_default
- async def test_get_all_prompt_versions(self, store):
- prompt = await store.create_prompt("V1")
- await store.update_prompt(prompt.prompt_id, "V2", 1)
- await store.update_prompt(prompt.prompt_id, "V3", 2)
+ async def test_get_all_prompt_versions(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("V1")
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
+ await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
- versions = (await store.list_prompt_versions(prompt.prompt_id)).data
+ versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
assert len(versions) == 3
assert [v.version for v in versions] == [1, 2, 3]
assert [v.is_default for v in versions] == [False, False, True]
- await store.set_default_version(prompt.prompt_id, 2)
- versions = (await store.list_prompt_versions(prompt.prompt_id)).data
+ await temp_prompt_store.set_default_version(prompt.prompt_id, 2)
+ versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
assert [v.is_default for v in versions] == [False, True, False]
with pytest.raises(ValueError):
- await store.list_prompt_versions("nonexistent")
+ await temp_prompt_store.list_prompt_versions("nonexistent")
- async def test_prompt_variable_validation(self, store):
- prompt = await store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
+ async def test_prompt_variable_validation(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
assert prompt.variables == ["name", "city"]
- prompt_no_vars = await store.create_prompt("Hello world!", [])
+ prompt_no_vars = await temp_prompt_store.create_prompt("Hello world!", [])
assert prompt_no_vars.variables == []
with pytest.raises(ValueError, match="undeclared variables"):
- await store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
+ await temp_prompt_store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
+
+ async def test_update_prompt_set_as_default_behavior(self, temp_prompt_store):
+ prompt = await temp_prompt_store.create_prompt("V1")
+ assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 1
+
+ prompt_v2 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1, [], set_as_default=True)
+ assert prompt_v2.version == 2
+ assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2
+
+ prompt_v3 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2, [], set_as_default=False)
+ assert prompt_v3.version == 3
+ assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2