removed UpdatePromptRequest, convert prompt version to integer, update tests, and remove TempPathFile in favor of tmp_path_factory

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-08 10:26:29 -04:00
parent 2a8469f156
commit 64f6840195
7 changed files with 88 additions and 109 deletions

View file

@ -1021,7 +1021,7 @@
"description": "The version of the prompt to get (defaults to latest).", "description": "The version of the prompt to get (defaults to latest).",
"required": false, "required": false,
"schema": { "schema": {
"type": "string" "type": "integer"
} }
} }
] ]
@ -10006,8 +10006,8 @@
"description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API." "description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API."
}, },
"version": { "version": {
"type": "string", "type": "integer",
"description": "Version string (integer start at 1 cast as string, incremented on save)" "description": "Version (integer starting at 1, incremented on save)"
}, },
"prompt_id": { "prompt_id": {
"type": "string", "type": "string",
@ -17523,7 +17523,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"version": { "version": {
"type": "string", "type": "integer",
"description": "The version to set as default." "description": "The version to set as default."
} }
}, },
@ -17825,8 +17825,8 @@
"description": "The updated prompt text content." "description": "The updated prompt text content."
}, },
"version": { "version": {
"type": "string", "type": "integer",
"description": "The current version of the prompt being updated (as a string)." "description": "The current version of the prompt being updated."
}, },
"variables": { "variables": {
"type": "array", "type": "array",

View file

@ -704,7 +704,7 @@ paths:
The version of the prompt to get (defaults to latest). The version of the prompt to get (defaults to latest).
required: false required: false
schema: schema:
type: string type: integer
post: post:
responses: responses:
'200': '200':
@ -7391,9 +7391,9 @@ components:
The system prompt text with variable placeholders. Variables are only The system prompt text with variable placeholders. Variables are only
supported when using the Responses API. supported when using the Responses API.
version: version:
type: string type: integer
description: >- description: >-
Version string (integer start at 1 cast as string, incremented on save) Version (integer starting at 1, incremented on save)
prompt_id: prompt_id:
type: string type: string
description: >- description: >-
@ -13018,7 +13018,7 @@ components:
type: object type: object
properties: properties:
version: version:
type: string type: integer
description: The version to set as default. description: The version to set as default.
additionalProperties: false additionalProperties: false
required: required:
@ -13227,9 +13227,9 @@ components:
type: string type: string
description: The updated prompt text content. description: The updated prompt text content.
version: version:
type: string type: integer
description: >- description: >-
The current version of the prompt being updated (as a string). The current version of the prompt being updated.
variables: variables:
type: array type: array
items: items:

View file

@ -4,6 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .prompts import ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest from .prompts import ListPromptsResponse, Prompt, Prompts
__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "UpdatePromptRequest"] __all__ = ["Prompt", "Prompts", "ListPromptsResponse"]

View file

@ -19,14 +19,14 @@ class Prompt(BaseModel):
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack. """A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API. :param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
:param version: Version string (integer start at 1 cast as string, incremented on save) :param version: Version (integer starting at 1, incremented on save)
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>' :param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
:param variables: List of prompt variable names that can be used in the prompt template :param variables: List of prompt variable names that can be used in the prompt template
:param is_default: Boolean indicating whether this version is the default version for this prompt :param is_default: Boolean indicating whether this version is the default version for this prompt
""" """
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders") prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
version: str = Field(description="Version string (integer start at 1 cast as string)") version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'") prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: list[str] = Field( variables: list[str] = Field(
default_factory=list, description="List of variable names that can be used in the prompt template" default_factory=list, description="List of variable names that can be used in the prompt template"
@ -56,15 +56,9 @@ class Prompt(BaseModel):
@field_validator("version") @field_validator("version")
@classmethod @classmethod
def validate_version(cls, prompt_version: str) -> str: def validate_version(cls, prompt_version: int) -> int:
try: if prompt_version < 1:
int_version = int(prompt_version)
if int_version < 1:
raise ValueError("version must be >= 1") raise ValueError("version must be >= 1")
except ValueError as e:
if "invalid literal" in str(e):
raise ValueError("version must be a string representation of an integer") from e
raise
return prompt_version return prompt_version
@model_validator(mode="after") @model_validator(mode="after")
@ -90,13 +84,6 @@ class Prompt(BaseModel):
return f"pmpt_{hex_string}" return f"pmpt_{hex_string}"
class UpdatePromptRequest(BaseModel):
"""Request model for updating a prompt."""
prompt: str = Field(description="The prompt text content")
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
class ListPromptsResponse(BaseModel): class ListPromptsResponse(BaseModel):
"""Response model to list prompts.""" """Response model to list prompts."""
@ -132,7 +119,7 @@ class Prompts(Protocol):
async def get_prompt( async def get_prompt(
self, self,
prompt_id: str, prompt_id: str,
version: str | None = None, version: int | None = None,
) -> Prompt: ) -> Prompt:
"""Get a prompt by its identifier and optional version. """Get a prompt by its identifier and optional version.
@ -161,14 +148,14 @@ class Prompts(Protocol):
self, self,
prompt_id: str, prompt_id: str,
prompt: str, prompt: str,
version: str, version: int,
variables: list[str] | None = None, variables: list[str] | None = None,
) -> Prompt: ) -> Prompt:
"""Update an existing prompt (increments version). """Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update. :param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content. :param prompt: The updated prompt text content.
:param version: The current version of the prompt being updated (as a string). :param version: The current version of the prompt being updated.
:param variables: Updated list of variable names that can be used in the prompt template. :param variables: Updated list of variable names that can be used in the prompt template.
:returns: The updated Prompt resource with incremented version. :returns: The updated Prompt resource with incremented version.
""" """
@ -189,7 +176,7 @@ class Prompts(Protocol):
async def set_default_version( async def set_default_version(
self, self,
prompt_id: str, prompt_id: str,
version: str, version: int,
) -> Prompt: ) -> Prompt:
"""Set which version of a prompt should be the default in get_prompt (latest). """Set which version of a prompt should be the default in get_prompt (latest).

View file

@ -50,10 +50,10 @@ class PromptServiceImpl(Prompts):
"""Get the KVStore key that stores the default version number.""" """Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default" return f"prompts:v1:{prompt_id}:default"
async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str: async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable.""" """Get the KVStore key for prompt data, returning default version if applicable."""
if version: if version:
return self._get_version_key(prompt_id, version) return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id) default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key) resolved_version = await self.kvstore.get(default_key)
@ -77,6 +77,7 @@ class PromptServiceImpl(Prompts):
"prompt": prompt.prompt, "prompt": prompt.prompt,
"version": prompt.version, "version": prompt.version,
"variables": prompt.variables or [], "variables": prompt.variables or [],
"is_default": prompt.is_default,
} }
) )
@ -84,7 +85,11 @@ class PromptServiceImpl(Prompts):
"""Deserialize a prompt from JSON string.""" """Deserialize a prompt from JSON string."""
obj = json.loads(data) obj = json.loads(data)
return Prompt( return Prompt(
prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", []) prompt_id=obj["prompt_id"],
prompt=obj["prompt"],
version=obj["version"],
variables=obj.get("variables", []),
is_default=obj.get("is_default", False),
) )
async def list_prompts(self) -> ListPromptsResponse: async def list_prompts(self) -> ListPromptsResponse:
@ -110,7 +115,7 @@ class PromptServiceImpl(Prompts):
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True) prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts) return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt: async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version.""" """Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version) key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key) data = await self.kvstore.get(key)
@ -127,14 +132,19 @@ class PromptServiceImpl(Prompts):
if variables is None: if variables is None:
variables = [] variables = []
prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables) prompt_obj = Prompt(
prompt_id=Prompt.generate_prompt_id(),
prompt=prompt,
version=1,
variables=variables,
)
version_key = self._get_version_key(prompt_obj.prompt_id, "1") version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj) data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data) await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id) default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, "1") await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj return prompt_obj
@ -142,10 +152,12 @@ class PromptServiceImpl(Prompts):
self, self,
prompt_id: str, prompt_id: str,
prompt: str, prompt: str,
version: str, version: int,
variables: list[str] | None = None, variables: list[str] | None = None,
) -> Prompt: ) -> Prompt:
"""Update an existing prompt (increments version).""" """Update an existing prompt (increments version)."""
if version < 1:
raise ValueError("Version must be >= 1")
if variables is None: if variables is None:
variables = [] variables = []
@ -158,16 +170,16 @@ class PromptServiceImpl(Prompts):
) )
current_version = latest_prompt.version if version is None else version current_version = latest_prompt.version if version is None else version
new_version = str(int(current_version) + 1) new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, new_version) version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt) data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data) await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_id) default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, new_version) await self.kvstore.set(default_key, str(new_version))
return updated_prompt return updated_prompt
@ -202,19 +214,19 @@ class PromptServiceImpl(Prompts):
raise ValueError(f"Prompt {prompt_id} not found") raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts: for prompt in prompts:
prompt.is_default = prompt.version == default_version prompt.is_default = str(prompt.version) == default_version
prompts.sort(key=lambda x: int(x.version)) prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts) return ListPromptsResponse(data=prompts)
async def set_default_version(self, prompt_id: str, version: str) -> Prompt: async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set which version of a prompt should be the default (latest).""" """Set which version of a prompt should be the default, If not set. the default is the latest."""
version_key = self._get_version_key(prompt_id, version) version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key) data = await self.kvstore.get(version_key)
if data is None: if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found") raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id) default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, version) await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data) return self._deserialize_prompt(data)

View file

@ -4,8 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os import random
import tempfile
import pytest import pytest
@ -14,32 +13,12 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture @pytest.fixture
async def temp_prompt_store(): async def temp_prompt_store(tmp_path_factory):
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file: unique_id = f"prompt_store_{random.randint(1, 1000000)}"
db_path = tmp_file.name temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
try:
config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path)) config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
store = PromptServiceImpl(config, deps={}) store = PromptServiceImpl(config, deps={})
await store.initialize() await store.initialize()
yield store yield store
finally:
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.fixture
def sample_prompt_data():
return {
"prompt": "Hello {{name}}, welcome to {{platform}}!",
"variables": {"name": "John", "platform": "LlamaStack"},
}
@pytest.fixture
def sample_prompts_data():
return [
{"prompt": "Hello {{name}}!", "variables": {"name": "Alice"}},
{"prompt": "Welcome to {{platform}}, {{user}}!", "variables": {"platform": "LlamaStack", "user": "Bob"}},
{"prompt": "Your order {{order_id}} is ready for pickup.", "variables": {"order_id": "12345"}},
]

View file

@ -32,7 +32,7 @@ class TestPrompts:
async def test_create_and_get_prompt(self, store): async def test_create_and_get_prompt(self, store):
prompt = await store.create_prompt("Hello world!", ["name"]) prompt = await store.create_prompt("Hello world!", ["name"])
assert prompt.prompt == "Hello world!" assert prompt.prompt == "Hello world!"
assert prompt.version == "1" assert prompt.version == 1
assert prompt.prompt_id.startswith("pmpt_") assert prompt.prompt_id.startswith("pmpt_")
assert prompt.variables == ["name"] assert prompt.variables == ["name"]
@ -42,17 +42,17 @@ class TestPrompts:
async def test_update_prompt(self, store): async def test_update_prompt(self, store):
prompt = await store.create_prompt("Original") prompt = await store.create_prompt("Original")
updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"]) updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
assert updated.version == "2" assert updated.version == 2
assert updated.prompt == "Updated" assert updated.prompt == "Updated"
async def test_update_prompt_with_version(self, store): async def test_update_prompt_with_version(self, store):
version_for_update = "1" version_for_update = 1
prompt = await store.create_prompt("Original") prompt = await store.create_prompt("Original")
assert prompt.version == "1" assert prompt.version == 1
prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"]) prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
assert prompt.version == "2" assert prompt.version == 2
with pytest.raises(ValueError): with pytest.raises(ValueError):
# now this is a stale version # now this is a stale version
@ -60,7 +60,7 @@ class TestPrompts:
with pytest.raises(ValueError): with pytest.raises(ValueError):
# this version does not exist # this version does not exist
await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"]) await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
async def test_delete_prompt(self, store): async def test_delete_prompt(self, store):
prompt = await store.create_prompt("to be deleted") prompt = await store.create_prompt("to be deleted")
@ -80,25 +80,25 @@ class TestPrompts:
async def test_version(self, store): async def test_version(self, store):
prompt = await store.create_prompt("V1") prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2", "1") await store.update_prompt(prompt.prompt_id, "V2", 1)
v1 = await store.get_prompt(prompt.prompt_id, version="1") v1 = await store.get_prompt(prompt.prompt_id, version=1)
assert v1.version == "1" and v1.prompt == "V1" assert v1.version == 1 and v1.prompt == "V1"
latest = await store.get_prompt(prompt.prompt_id) latest = await store.get_prompt(prompt.prompt_id)
assert latest.version == "2" and latest.prompt == "V2" assert latest.version == 2 and latest.prompt == "V2"
async def test_set_default_version(self, store): async def test_set_default_version(self, store):
prompt0 = await store.create_prompt("V1") prompt0 = await store.create_prompt("V1")
prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1") prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1)
assert (await store.get_prompt(prompt0.prompt_id)).version == "2" assert (await store.get_prompt(prompt0.prompt_id)).version == 2
prompt_default = await store.set_default_version(prompt0.prompt_id, "1") prompt_default = await store.set_default_version(prompt0.prompt_id, 1)
assert (await store.get_prompt(prompt0.prompt_id)).version == "1" assert (await store.get_prompt(prompt0.prompt_id)).version == 1
assert prompt_default.version == "1" assert prompt_default.version == 1
prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version) prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
assert prompt2.version == "3" assert prompt2.version == 3
async def test_prompt_id_generation_and_validation(self, store): async def test_prompt_id_generation_and_validation(self, store):
prompt = await store.create_prompt("Test") prompt = await store.create_prompt("Test")
@ -110,30 +110,31 @@ class TestPrompts:
async def test_list_shows_default_versions(self, store): async def test_list_shows_default_versions(self, store):
prompt = await store.create_prompt("V1") prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2", "1") await store.update_prompt(prompt.prompt_id, "V2", 1)
await store.update_prompt(prompt.prompt_id, "V3", "2") await store.update_prompt(prompt.prompt_id, "V3", 2)
response = await store.list_prompts() response = await store.list_prompts()
listed_prompt = response.data[0] listed_prompt = response.data[0]
assert listed_prompt.version == "3" and listed_prompt.prompt == "V3" assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
await store.set_default_version(prompt.prompt_id, "1") await store.set_default_version(prompt.prompt_id, 1)
response = await store.list_prompts() response = await store.list_prompts()
listed_prompt = response.data[0] listed_prompt = response.data[0]
assert listed_prompt.version == "1" and listed_prompt.prompt == "V1" assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default
async def test_get_all_prompt_versions(self, store): async def test_get_all_prompt_versions(self, store):
prompt = await store.create_prompt("V1") prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2", "1") await store.update_prompt(prompt.prompt_id, "V2", 1)
await store.update_prompt(prompt.prompt_id, "V3", "2") await store.update_prompt(prompt.prompt_id, "V3", 2)
versions = (await store.list_prompt_versions(prompt.prompt_id)).data versions = (await store.list_prompt_versions(prompt.prompt_id)).data
assert len(versions) == 3 assert len(versions) == 3
assert [v.version for v in versions] == ["1", "2", "3"] assert [v.version for v in versions] == [1, 2, 3]
assert [v.is_default for v in versions] == [False, False, True] assert [v.is_default for v in versions] == [False, False, True]
await store.set_default_version(prompt.prompt_id, "2") await store.set_default_version(prompt.prompt_id, 2)
versions = (await store.list_prompt_versions(prompt.prompt_id)).data versions = (await store.list_prompt_versions(prompt.prompt_id)).data
assert [v.is_default for v in versions] == [False, True, False] assert [v.is_default for v in versions] == [False, True, False]