diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a2015810a..a036e5dc0 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -17834,12 +17834,17 @@ "type": "string" }, "description": "Updated list of variable names that can be used in the prompt template." + }, + "set_as_default": { + "type": "boolean", + "description": "Set the new version as the default (default=True)." } }, "additionalProperties": false, "required": [ "prompt", - "version" + "version", + "set_as_default" ], "title": "UpdatePromptRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 74bc7ef41..8ed04c1f8 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -13236,10 +13236,15 @@ components: type: string description: >- Updated list of variable names that can be used in the prompt template. + set_as_default: + type: boolean + description: >- + Set the new version as the default (default=True). additionalProperties: false required: - prompt - version + - set_as_default title: UpdatePromptRequest VersionInfo: type: object diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py index 8018b7fff..e6a376c3f 100644 --- a/llama_stack/apis/prompts/prompts.py +++ b/llama_stack/apis/prompts/prompts.py @@ -150,6 +150,7 @@ class Prompts(Protocol): prompt: str, version: int, variables: list[str] | None = None, + set_as_default: bool = True, ) -> Prompt: """Update an existing prompt (increments version). @@ -157,6 +158,7 @@ class Prompts(Protocol): :param prompt: The updated prompt text content. :param version: The current version of the prompt being updated. :param variables: Updated list of variable names that can be used in the prompt template. + :param set_as_default: Set the new version as the default (default=True). :returns: The updated Prompt resource with incremented version. """ ... diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py index a5fe79ed3..26e8f5cef 100644 --- a/llama_stack/core/prompts/prompts.py +++ b/llama_stack/core/prompts/prompts.py @@ -154,6 +154,7 @@ class PromptServiceImpl(Prompts): prompt: str, version: int, variables: list[str] | None = None, + set_as_default: bool = True, ) -> Prompt: """Update an existing prompt (increments version).""" if version < 1: @@ -178,8 +179,8 @@ class PromptServiceImpl(Prompts): data = self._serialize_prompt(updated_prompt) await self.kvstore.set(version_key, data) - default_key = self._get_default_key(prompt_id) - await self.kvstore.set(default_key, str(new_version)) + if set_as_default: + await self.set_default_version(prompt_id, new_version) return updated_prompt diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index f778fbf16..b2c619e49 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -18,7 +18,13 @@ async def temp_prompt_store(tmp_path_factory): temp_dir = tmp_path_factory.getbasetemp() db_path = str(temp_dir / f"{unique_id}.db") - config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path)) + from llama_stack.core.datatypes import StackRunConfig + from llama_stack.providers.utils.kvstore import kvstore_impl + + mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={}) + config = PromptServiceConfig(run_config=mock_run_config) store = PromptServiceImpl(config, deps={}) - await store.initialize() + + store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + yield store diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py index 94bbbdbf0..792e55530 100644 --- a/tests/unit/prompts/prompts/test_prompts.py +++ b/tests/unit/prompts/prompts/test_prompts.py @@ -4,149 +4,141 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -import tempfile import pytest -from llama_stack.core.datatypes import StackRunConfig -from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl - class TestPrompts: - @pytest.fixture - async def store(self): - with tempfile.TemporaryDirectory() as temp_dir: - mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={}) - config = PromptServiceConfig(run_config=mock_run_config) - store = PromptServiceImpl(config, deps={}) - - from llama_stack.providers.utils.kvstore import kvstore_impl - from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - - test_db_path = os.path.join(temp_dir, "test_prompts.db") - store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=test_db_path)) - - yield store - - async def test_create_and_get_prompt(self, store): - prompt = await store.create_prompt("Hello world!", ["name"]) + async def test_create_and_get_prompt(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Hello world!", ["name"]) assert prompt.prompt == "Hello world!" assert prompt.version == 1 assert prompt.prompt_id.startswith("pmpt_") assert prompt.variables == ["name"] - retrieved = await store.get_prompt(prompt.prompt_id) + retrieved = await temp_prompt_store.get_prompt(prompt.prompt_id) assert retrieved.prompt_id == prompt.prompt_id assert retrieved.prompt == prompt.prompt - async def test_update_prompt(self, store): - prompt = await store.create_prompt("Original") - updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"]) + async def test_update_prompt(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Original") + updated = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"]) assert updated.version == 2 assert updated.prompt == "Updated" - async def test_update_prompt_with_version(self, store): + async def test_update_prompt_with_version(self, temp_prompt_store): version_for_update = 1 - prompt = await store.create_prompt("Original") + prompt = await temp_prompt_store.create_prompt("Original") assert prompt.version == 1 - prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"]) + prompt = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"]) assert prompt.version == 2 with pytest.raises(ValueError): # now this is a stale version - await store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"]) + await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"]) with pytest.raises(ValueError): # this version does not exist - await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"]) + await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"]) - async def test_delete_prompt(self, store): - prompt = await store.create_prompt("to be deleted") - await store.delete_prompt(prompt.prompt_id) + async def test_delete_prompt(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("to be deleted") + await temp_prompt_store.delete_prompt(prompt.prompt_id) with pytest.raises(ValueError): - await store.get_prompt(prompt.prompt_id) + await temp_prompt_store.get_prompt(prompt.prompt_id) - async def test_list_prompts(self, store): - response = await store.list_prompts() + async def test_list_prompts(self, temp_prompt_store): + response = await temp_prompt_store.list_prompts() assert response.data == [] - await store.create_prompt("first") - await store.create_prompt("second") + await temp_prompt_store.create_prompt("first") + await temp_prompt_store.create_prompt("second") - response = await store.list_prompts() + response = await temp_prompt_store.list_prompts() assert len(response.data) == 2 - async def test_version(self, store): - prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2", 1) + async def test_version(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1) - v1 = await store.get_prompt(prompt.prompt_id, version=1) + v1 = await temp_prompt_store.get_prompt(prompt.prompt_id, version=1) assert v1.version == 1 and v1.prompt == "V1" - latest = await store.get_prompt(prompt.prompt_id) + latest = await temp_prompt_store.get_prompt(prompt.prompt_id) assert latest.version == 2 and latest.prompt == "V2" - async def test_set_default_version(self, store): - prompt0 = await store.create_prompt("V1") - prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1) + async def test_set_default_version(self, temp_prompt_store): + prompt0 = await temp_prompt_store.create_prompt("V1") + prompt1 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V2", 1) - assert (await store.get_prompt(prompt0.prompt_id)).version == 2 - prompt_default = await store.set_default_version(prompt0.prompt_id, 1) - assert (await store.get_prompt(prompt0.prompt_id)).version == 1 + assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 2 + prompt_default = await temp_prompt_store.set_default_version(prompt0.prompt_id, 1) + assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 1 assert prompt_default.version == 1 - prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version) + prompt2 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V3", prompt1.version) assert prompt2.version == 3 - async def test_prompt_id_generation_and_validation(self, store): - prompt = await store.create_prompt("Test") + async def test_prompt_id_generation_and_validation(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Test") assert prompt.prompt_id.startswith("pmpt_") assert len(prompt.prompt_id) == 53 with pytest.raises(ValueError): - await store.get_prompt("invalid_id") + await temp_prompt_store.get_prompt("invalid_id") - async def test_list_shows_default_versions(self, store): - prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2", 1) - await store.update_prompt(prompt.prompt_id, "V3", 2) + async def test_list_shows_default_versions(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1) + await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2) - response = await store.list_prompts() + response = await temp_prompt_store.list_prompts() listed_prompt = response.data[0] assert listed_prompt.version == 3 and listed_prompt.prompt == "V3" - await store.set_default_version(prompt.prompt_id, 1) + await temp_prompt_store.set_default_version(prompt.prompt_id, 1) - response = await store.list_prompts() + response = await temp_prompt_store.list_prompts() listed_prompt = response.data[0] assert listed_prompt.version == 1 and listed_prompt.prompt == "V1" - assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default + assert not (await temp_prompt_store.get_prompt(prompt.prompt_id, 3)).is_default - async def test_get_all_prompt_versions(self, store): - prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2", 1) - await store.update_prompt(prompt.prompt_id, "V3", 2) + async def test_get_all_prompt_versions(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1) + await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2) - versions = (await store.list_prompt_versions(prompt.prompt_id)).data + versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data assert len(versions) == 3 assert [v.version for v in versions] == [1, 2, 3] assert [v.is_default for v in versions] == [False, False, True] - await store.set_default_version(prompt.prompt_id, 2) - versions = (await store.list_prompt_versions(prompt.prompt_id)).data + await temp_prompt_store.set_default_version(prompt.prompt_id, 2) + versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data assert [v.is_default for v in versions] == [False, True, False] with pytest.raises(ValueError): - await store.list_prompt_versions("nonexistent") + await temp_prompt_store.list_prompt_versions("nonexistent") - async def test_prompt_variable_validation(self, store): - prompt = await store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"]) + async def test_prompt_variable_validation(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"]) assert prompt.variables == ["name", "city"] - prompt_no_vars = await store.create_prompt("Hello world!", []) + prompt_no_vars = await temp_prompt_store.create_prompt("Hello world!", []) assert prompt_no_vars.variables == [] with pytest.raises(ValueError, match="undeclared variables"): - await store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"]) + await temp_prompt_store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"]) + + async def test_update_prompt_set_as_default_behavior(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 1 + + prompt_v2 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1, [], set_as_default=True) + assert prompt_v2.version == 2 + assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2 + + prompt_v3 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2, [], set_as_default=False) + assert prompt_v3.version == 3 + assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2