From 64f68401951f901838fee669d6ccf365d487d817 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Mon, 8 Sep 2025 10:26:29 -0400 Subject: [PATCH] removed UpdatePromptRequest, convert prompt version to integer, update tests, and remove TempPathFile in favor of tmp_path_factory Signed-off-by: Francisco Javier Arceo --- docs/_static/llama-stack-spec.html | 12 ++--- docs/_static/llama-stack-spec.yaml | 12 ++--- llama_stack/apis/prompts/__init__.py | 4 +- llama_stack/apis/prompts/prompts.py | 31 ++++--------- llama_stack/core/prompts/prompts.py | 46 ++++++++++++------- tests/unit/prompts/prompts/conftest.py | 39 ++++------------ tests/unit/prompts/prompts/test_prompts.py | 53 +++++++++++----------- 7 files changed, 88 insertions(+), 109 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 2996f4790..a2015810a 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1021,7 +1021,7 @@ "description": "The version of the prompt to get (defaults to latest).", "required": false, "schema": { - "type": "string" + "type": "integer" } } ] @@ -10006,8 +10006,8 @@ "description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API." }, "version": { - "type": "string", - "description": "Version string (integer start at 1 cast as string, incremented on save)" + "type": "integer", + "description": "Version (integer starting at 1, incremented on save)" }, "prompt_id": { "type": "string", @@ -17523,7 +17523,7 @@ "type": "object", "properties": { "version": { - "type": "string", + "type": "integer", "description": "The version to set as default." } }, @@ -17825,8 +17825,8 @@ "description": "The updated prompt text content." }, "version": { - "type": "string", - "description": "The current version of the prompt being updated (as a string)." + "type": "integer", + "description": "The current version of the prompt being updated." }, "variables": { "type": "array", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 106464f79..74bc7ef41 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -704,7 +704,7 @@ paths: The version of the prompt to get (defaults to latest). required: false schema: - type: string + type: integer post: responses: '200': @@ -7391,9 +7391,9 @@ components: The system prompt text with variable placeholders. Variables are only supported when using the Responses API. version: - type: string + type: integer description: >- - Version string (integer start at 1 cast as string, incremented on save) + Version (integer starting at 1, incremented on save) prompt_id: type: string description: >- @@ -13018,7 +13018,7 @@ components: type: object properties: version: - type: string + type: integer description: The version to set as default. additionalProperties: false required: @@ -13227,9 +13227,9 @@ components: type: string description: The updated prompt text content. version: - type: string + type: integer description: >- - The current version of the prompt being updated (as a string). + The current version of the prompt being updated. variables: type: array items: diff --git a/llama_stack/apis/prompts/__init__.py b/llama_stack/apis/prompts/__init__.py index 3534f5864..6070f3450 100644 --- a/llama_stack/apis/prompts/__init__.py +++ b/llama_stack/apis/prompts/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .prompts import ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest +from .prompts import ListPromptsResponse, Prompt, Prompts -__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "UpdatePromptRequest"] +__all__ = ["Prompt", "Prompts", "ListPromptsResponse"] diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py index 31073b750..8018b7fff 100644 --- a/llama_stack/apis/prompts/prompts.py +++ b/llama_stack/apis/prompts/prompts.py @@ -19,14 +19,14 @@ class Prompt(BaseModel): """A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack. :param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API. - :param version: Version string (integer start at 1 cast as string, incremented on save) + :param version: Version (integer starting at 1, incremented on save) :param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>' :param variables: List of prompt variable names that can be used in the prompt template :param is_default: Boolean indicating whether this version is the default version for this prompt """ prompt: str | None = Field(default=None, description="The system prompt with variable placeholders") - version: str = Field(description="Version string (integer start at 1 cast as string)") + version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1) prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'") variables: list[str] = Field( default_factory=list, description="List of variable names that can be used in the prompt template" @@ -56,15 +56,9 @@ class Prompt(BaseModel): @field_validator("version") @classmethod - def validate_version(cls, prompt_version: str) -> str: - try: - int_version = int(prompt_version) - if int_version < 1: - raise ValueError("version must be >= 1") - except ValueError as e: - if "invalid literal" in str(e): - raise ValueError("version must be a string representation of an integer") from e - raise + def validate_version(cls, prompt_version: int) -> int: + if prompt_version < 1: + raise ValueError("version must be >= 1") return prompt_version @model_validator(mode="after") @@ -90,13 +84,6 @@ class Prompt(BaseModel): return f"pmpt_{hex_string}" -class UpdatePromptRequest(BaseModel): - """Request model for updating a prompt.""" - - prompt: str = Field(description="The prompt text content") - variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection") - - class ListPromptsResponse(BaseModel): """Response model to list prompts.""" @@ -132,7 +119,7 @@ class Prompts(Protocol): async def get_prompt( self, prompt_id: str, - version: str | None = None, + version: int | None = None, ) -> Prompt: """Get a prompt by its identifier and optional version. @@ -161,14 +148,14 @@ class Prompts(Protocol): self, prompt_id: str, prompt: str, - version: str, + version: int, variables: list[str] | None = None, ) -> Prompt: """Update an existing prompt (increments version). :param prompt_id: The identifier of the prompt to update. :param prompt: The updated prompt text content. - :param version: The current version of the prompt being updated (as a string). + :param version: The current version of the prompt being updated. :param variables: Updated list of variable names that can be used in the prompt template. :returns: The updated Prompt resource with incremented version. """ @@ -189,7 +176,7 @@ class Prompts(Protocol): async def set_default_version( self, prompt_id: str, - version: str, + version: int, ) -> Prompt: """Set which version of a prompt should be the default in get_prompt (latest). diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py index ecd4ff2c2..a5fe79ed3 100644 --- a/llama_stack/core/prompts/prompts.py +++ b/llama_stack/core/prompts/prompts.py @@ -50,10 +50,10 @@ class PromptServiceImpl(Prompts): """Get the KVStore key that stores the default version number.""" return f"prompts:v1:{prompt_id}:default" - async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str: + async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str: """Get the KVStore key for prompt data, returning default version if applicable.""" if version: - return self._get_version_key(prompt_id, version) + return self._get_version_key(prompt_id, str(version)) default_key = self._get_default_key(prompt_id) resolved_version = await self.kvstore.get(default_key) @@ -77,6 +77,7 @@ class PromptServiceImpl(Prompts): "prompt": prompt.prompt, "version": prompt.version, "variables": prompt.variables or [], + "is_default": prompt.is_default, } ) @@ -84,7 +85,11 @@ class PromptServiceImpl(Prompts): """Deserialize a prompt from JSON string.""" obj = json.loads(data) return Prompt( - prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", []) + prompt_id=obj["prompt_id"], + prompt=obj["prompt"], + version=obj["version"], + variables=obj.get("variables", []), + is_default=obj.get("is_default", False), ) async def list_prompts(self) -> ListPromptsResponse: @@ -110,7 +115,7 @@ class PromptServiceImpl(Prompts): prompts.sort(key=lambda p: p.prompt_id or "", reverse=True) return ListPromptsResponse(data=prompts) - async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt: + async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt: """Get a prompt by its identifier and optional version.""" key = await self._get_prompt_key(prompt_id, version) data = await self.kvstore.get(key) @@ -127,14 +132,19 @@ class PromptServiceImpl(Prompts): if variables is None: variables = [] - prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables) + prompt_obj = Prompt( + prompt_id=Prompt.generate_prompt_id(), + prompt=prompt, + version=1, + variables=variables, + ) - version_key = self._get_version_key(prompt_obj.prompt_id, "1") + version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version)) data = self._serialize_prompt(prompt_obj) await self.kvstore.set(version_key, data) default_key = self._get_default_key(prompt_obj.prompt_id) - await self.kvstore.set(default_key, "1") + await self.kvstore.set(default_key, str(prompt_obj.version)) return prompt_obj @@ -142,10 +152,12 @@ class PromptServiceImpl(Prompts): self, prompt_id: str, prompt: str, - version: str, + version: int, variables: list[str] | None = None, ) -> Prompt: """Update an existing prompt (increments version).""" + if version < 1: + raise ValueError("Version must be >= 1") if variables is None: variables = [] @@ -158,16 +170,16 @@ class PromptServiceImpl(Prompts): ) current_version = latest_prompt.version if version is None else version - new_version = str(int(current_version) + 1) + new_version = current_version + 1 updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) - version_key = self._get_version_key(prompt_id, new_version) + version_key = self._get_version_key(prompt_id, str(new_version)) data = self._serialize_prompt(updated_prompt) await self.kvstore.set(version_key, data) default_key = self._get_default_key(prompt_id) - await self.kvstore.set(default_key, new_version) + await self.kvstore.set(default_key, str(new_version)) return updated_prompt @@ -202,19 +214,19 @@ class PromptServiceImpl(Prompts): raise ValueError(f"Prompt {prompt_id} not found") for prompt in prompts: - prompt.is_default = prompt.version == default_version + prompt.is_default = str(prompt.version) == default_version - prompts.sort(key=lambda x: int(x.version)) + prompts.sort(key=lambda x: x.version) return ListPromptsResponse(data=prompts) - async def set_default_version(self, prompt_id: str, version: str) -> Prompt: - """Set which version of a prompt should be the default (latest).""" - version_key = self._get_version_key(prompt_id, version) + async def set_default_version(self, prompt_id: str, version: int) -> Prompt: + """Set which version of a prompt should be the default, If not set. the default is the latest.""" + version_key = self._get_version_key(prompt_id, str(version)) data = await self.kvstore.get(version_key) if data is None: raise ValueError(f"Prompt {prompt_id} version {version} not found") default_key = self._get_default_key(prompt_id) - await self.kvstore.set(default_key, version) + await self.kvstore.set(default_key, str(version)) return self._deserialize_prompt(data) diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index 70b3bd12b..f778fbf16 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -4,8 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -import tempfile +import random import pytest @@ -14,32 +13,12 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @pytest.fixture -async def temp_prompt_store(): - with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file: - db_path = tmp_file.name +async def temp_prompt_store(tmp_path_factory): + unique_id = f"prompt_store_{random.randint(1, 1000000)}" + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / f"{unique_id}.db") - try: - config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path)) - store = PromptServiceImpl(config, deps={}) - await store.initialize() - yield store - finally: - if os.path.exists(db_path): - os.unlink(db_path) - - -@pytest.fixture -def sample_prompt_data(): - return { - "prompt": "Hello {{name}}, welcome to {{platform}}!", - "variables": {"name": "John", "platform": "LlamaStack"}, - } - - -@pytest.fixture -def sample_prompts_data(): - return [ - {"prompt": "Hello {{name}}!", "variables": {"name": "Alice"}}, - {"prompt": "Welcome to {{platform}}, {{user}}!", "variables": {"platform": "LlamaStack", "user": "Bob"}}, - {"prompt": "Your order {{order_id}} is ready for pickup.", "variables": {"order_id": "12345"}}, - ] + config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path)) + store = PromptServiceImpl(config, deps={}) + await store.initialize() + yield store diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py index a4512ed02..94bbbdbf0 100644 --- a/tests/unit/prompts/prompts/test_prompts.py +++ b/tests/unit/prompts/prompts/test_prompts.py @@ -32,7 +32,7 @@ class TestPrompts: async def test_create_and_get_prompt(self, store): prompt = await store.create_prompt("Hello world!", ["name"]) assert prompt.prompt == "Hello world!" - assert prompt.version == "1" + assert prompt.version == 1 assert prompt.prompt_id.startswith("pmpt_") assert prompt.variables == ["name"] @@ -42,17 +42,17 @@ class TestPrompts: async def test_update_prompt(self, store): prompt = await store.create_prompt("Original") - updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"]) - assert updated.version == "2" + updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"]) + assert updated.version == 2 assert updated.prompt == "Updated" async def test_update_prompt_with_version(self, store): - version_for_update = "1" + version_for_update = 1 prompt = await store.create_prompt("Original") - assert prompt.version == "1" + assert prompt.version == 1 prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"]) - assert prompt.version == "2" + assert prompt.version == 2 with pytest.raises(ValueError): # now this is a stale version @@ -60,7 +60,7 @@ class TestPrompts: with pytest.raises(ValueError): # this version does not exist - await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"]) + await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"]) async def test_delete_prompt(self, store): prompt = await store.create_prompt("to be deleted") @@ -80,25 +80,25 @@ class TestPrompts: async def test_version(self, store): prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2", "1") + await store.update_prompt(prompt.prompt_id, "V2", 1) - v1 = await store.get_prompt(prompt.prompt_id, version="1") - assert v1.version == "1" and v1.prompt == "V1" + v1 = await store.get_prompt(prompt.prompt_id, version=1) + assert v1.version == 1 and v1.prompt == "V1" latest = await store.get_prompt(prompt.prompt_id) - assert latest.version == "2" and latest.prompt == "V2" + assert latest.version == 2 and latest.prompt == "V2" async def test_set_default_version(self, store): prompt0 = await store.create_prompt("V1") - prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1") + prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1) - assert (await store.get_prompt(prompt0.prompt_id)).version == "2" - prompt_default = await store.set_default_version(prompt0.prompt_id, "1") - assert (await store.get_prompt(prompt0.prompt_id)).version == "1" - assert prompt_default.version == "1" + assert (await store.get_prompt(prompt0.prompt_id)).version == 2 + prompt_default = await store.set_default_version(prompt0.prompt_id, 1) + assert (await store.get_prompt(prompt0.prompt_id)).version == 1 + assert prompt_default.version == 1 prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version) - assert prompt2.version == "3" + assert prompt2.version == 3 async def test_prompt_id_generation_and_validation(self, store): prompt = await store.create_prompt("Test") @@ -110,30 +110,31 @@ class TestPrompts: async def test_list_shows_default_versions(self, store): prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2", "1") - await store.update_prompt(prompt.prompt_id, "V3", "2") + await store.update_prompt(prompt.prompt_id, "V2", 1) + await store.update_prompt(prompt.prompt_id, "V3", 2) response = await store.list_prompts() listed_prompt = response.data[0] - assert listed_prompt.version == "3" and listed_prompt.prompt == "V3" + assert listed_prompt.version == 3 and listed_prompt.prompt == "V3" - await store.set_default_version(prompt.prompt_id, "1") + await store.set_default_version(prompt.prompt_id, 1) response = await store.list_prompts() listed_prompt = response.data[0] - assert listed_prompt.version == "1" and listed_prompt.prompt == "V1" + assert listed_prompt.version == 1 and listed_prompt.prompt == "V1" + assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default async def test_get_all_prompt_versions(self, store): prompt = await store.create_prompt("V1") - await store.update_prompt(prompt.prompt_id, "V2", "1") - await store.update_prompt(prompt.prompt_id, "V3", "2") + await store.update_prompt(prompt.prompt_id, "V2", 1) + await store.update_prompt(prompt.prompt_id, "V3", 2) versions = (await store.list_prompt_versions(prompt.prompt_id)).data assert len(versions) == 3 - assert [v.version for v in versions] == ["1", "2", "3"] + assert [v.version for v in versions] == [1, 2, 3] assert [v.is_default for v in versions] == [False, False, True] - await store.set_default_version(prompt.prompt_id, "2") + await store.set_default_version(prompt.prompt_id, 2) versions = (await store.list_prompt_versions(prompt.prompt_id)).data assert [v.is_default for v in versions] == [False, True, False]