diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 2996f4790..a2015810a 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -1021,7 +1021,7 @@
"description": "The version of the prompt to get (defaults to latest).",
"required": false,
"schema": {
- "type": "string"
+ "type": "integer"
}
}
]
@@ -10006,8 +10006,8 @@
"description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API."
},
"version": {
- "type": "string",
- "description": "Version string (integer start at 1 cast as string, incremented on save)"
+ "type": "integer",
+ "description": "Version (integer starting at 1, incremented on save)"
},
"prompt_id": {
"type": "string",
@@ -17523,7 +17523,7 @@
"type": "object",
"properties": {
"version": {
- "type": "string",
+ "type": "integer",
"description": "The version to set as default."
}
},
@@ -17825,8 +17825,8 @@
"description": "The updated prompt text content."
},
"version": {
- "type": "string",
- "description": "The current version of the prompt being updated (as a string)."
+ "type": "integer",
+ "description": "The current version of the prompt being updated."
},
"variables": {
"type": "array",
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 106464f79..74bc7ef41 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -704,7 +704,7 @@ paths:
The version of the prompt to get (defaults to latest).
required: false
schema:
- type: string
+ type: integer
post:
responses:
'200':
@@ -7391,9 +7391,9 @@ components:
The system prompt text with variable placeholders. Variables are only
supported when using the Responses API.
version:
- type: string
+ type: integer
description: >-
- Version string (integer start at 1 cast as string, incremented on save)
+ Version (integer starting at 1, incremented on save)
prompt_id:
type: string
description: >-
@@ -13018,7 +13018,7 @@ components:
type: object
properties:
version:
- type: string
+ type: integer
description: The version to set as default.
additionalProperties: false
required:
@@ -13227,9 +13227,9 @@ components:
type: string
description: The updated prompt text content.
version:
- type: string
+ type: integer
description: >-
- The current version of the prompt being updated (as a string).
+ The current version of the prompt being updated.
variables:
type: array
items:
diff --git a/llama_stack/apis/prompts/__init__.py b/llama_stack/apis/prompts/__init__.py
index 3534f5864..6070f3450 100644
--- a/llama_stack/apis/prompts/__init__.py
+++ b/llama_stack/apis/prompts/__init__.py
@@ -4,6 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from .prompts import ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest
+from .prompts import ListPromptsResponse, Prompt, Prompts
-__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "UpdatePromptRequest"]
+__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py
index 31073b750..8018b7fff 100644
--- a/llama_stack/apis/prompts/prompts.py
+++ b/llama_stack/apis/prompts/prompts.py
@@ -19,14 +19,14 @@ class Prompt(BaseModel):
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
- :param version: Version string (integer start at 1 cast as string, incremented on save)
+ :param version: Version (integer starting at 1, incremented on save)
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
:param variables: List of prompt variable names that can be used in the prompt template
:param is_default: Boolean indicating whether this version is the default version for this prompt
"""
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
- version: str = Field(description="Version string (integer start at 1 cast as string)")
+ version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: list[str] = Field(
default_factory=list, description="List of variable names that can be used in the prompt template"
@@ -56,15 +56,9 @@ class Prompt(BaseModel):
@field_validator("version")
@classmethod
- def validate_version(cls, prompt_version: str) -> str:
- try:
- int_version = int(prompt_version)
- if int_version < 1:
- raise ValueError("version must be >= 1")
- except ValueError as e:
- if "invalid literal" in str(e):
- raise ValueError("version must be a string representation of an integer") from e
- raise
+ def validate_version(cls, prompt_version: int) -> int:
+ if prompt_version < 1:
+ raise ValueError("version must be >= 1")
return prompt_version
@model_validator(mode="after")
@@ -90,13 +84,6 @@ class Prompt(BaseModel):
return f"pmpt_{hex_string}"
-class UpdatePromptRequest(BaseModel):
- """Request model for updating a prompt."""
-
- prompt: str = Field(description="The prompt text content")
- variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
-
-
class ListPromptsResponse(BaseModel):
"""Response model to list prompts."""
@@ -132,7 +119,7 @@ class Prompts(Protocol):
async def get_prompt(
self,
prompt_id: str,
- version: str | None = None,
+ version: int | None = None,
) -> Prompt:
"""Get a prompt by its identifier and optional version.
@@ -161,14 +148,14 @@ class Prompts(Protocol):
self,
prompt_id: str,
prompt: str,
- version: str,
+ version: int,
variables: list[str] | None = None,
) -> Prompt:
"""Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
- :param version: The current version of the prompt being updated (as a string).
+ :param version: The current version of the prompt being updated.
:param variables: Updated list of variable names that can be used in the prompt template.
:returns: The updated Prompt resource with incremented version.
"""
@@ -189,7 +176,7 @@ class Prompts(Protocol):
async def set_default_version(
self,
prompt_id: str,
- version: str,
+ version: int,
) -> Prompt:
"""Set which version of a prompt should be the default in get_prompt (latest).
diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py
index ecd4ff2c2..a5fe79ed3 100644
--- a/llama_stack/core/prompts/prompts.py
+++ b/llama_stack/core/prompts/prompts.py
@@ -50,10 +50,10 @@ class PromptServiceImpl(Prompts):
"""Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default"
- async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
+ async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable."""
if version:
- return self._get_version_key(prompt_id, version)
+ return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key)
@@ -77,6 +77,7 @@ class PromptServiceImpl(Prompts):
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or [],
+ "is_default": prompt.is_default,
}
)
@@ -84,7 +85,11 @@ class PromptServiceImpl(Prompts):
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
- prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", [])
+ prompt_id=obj["prompt_id"],
+ prompt=obj["prompt"],
+ version=obj["version"],
+ variables=obj.get("variables", []),
+ is_default=obj.get("is_default", False),
)
async def list_prompts(self) -> ListPromptsResponse:
@@ -110,7 +115,7 @@ class PromptServiceImpl(Prompts):
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
- async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt:
+ async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key)
@@ -127,14 +132,19 @@ class PromptServiceImpl(Prompts):
if variables is None:
variables = []
- prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables)
+ prompt_obj = Prompt(
+ prompt_id=Prompt.generate_prompt_id(),
+ prompt=prompt,
+ version=1,
+ variables=variables,
+ )
- version_key = self._get_version_key(prompt_obj.prompt_id, "1")
+ version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id)
- await self.kvstore.set(default_key, "1")
+ await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj
@@ -142,10 +152,12 @@ class PromptServiceImpl(Prompts):
self,
prompt_id: str,
prompt: str,
- version: str,
+ version: int,
variables: list[str] | None = None,
) -> Prompt:
"""Update an existing prompt (increments version)."""
+ if version < 1:
+ raise ValueError("Version must be >= 1")
if variables is None:
variables = []
@@ -158,16 +170,16 @@ class PromptServiceImpl(Prompts):
)
current_version = latest_prompt.version if version is None else version
- new_version = str(int(current_version) + 1)
+ new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
- version_key = self._get_version_key(prompt_id, new_version)
+ version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_id)
- await self.kvstore.set(default_key, new_version)
+ await self.kvstore.set(default_key, str(new_version))
return updated_prompt
@@ -202,19 +214,19 @@ class PromptServiceImpl(Prompts):
raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts:
- prompt.is_default = prompt.version == default_version
+ prompt.is_default = str(prompt.version) == default_version
- prompts.sort(key=lambda x: int(x.version))
+ prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts)
- async def set_default_version(self, prompt_id: str, version: str) -> Prompt:
- """Set which version of a prompt should be the default (latest)."""
- version_key = self._get_version_key(prompt_id, version)
+ async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
+ """Set which version of a prompt should be the default, If not set. the default is the latest."""
+ version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id)
- await self.kvstore.set(default_key, version)
+ await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data)
diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py
index 70b3bd12b..f778fbf16 100644
--- a/tests/unit/prompts/prompts/conftest.py
+++ b/tests/unit/prompts/prompts/conftest.py
@@ -4,8 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import os
-import tempfile
+import random
import pytest
@@ -14,32 +13,12 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
-async def temp_prompt_store():
- with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file:
- db_path = tmp_file.name
+async def temp_prompt_store(tmp_path_factory):
+ unique_id = f"prompt_store_{random.randint(1, 1000000)}"
+ temp_dir = tmp_path_factory.getbasetemp()
+ db_path = str(temp_dir / f"{unique_id}.db")
- try:
- config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
- store = PromptServiceImpl(config, deps={})
- await store.initialize()
- yield store
- finally:
- if os.path.exists(db_path):
- os.unlink(db_path)
-
-
-@pytest.fixture
-def sample_prompt_data():
- return {
- "prompt": "Hello {{name}}, welcome to {{platform}}!",
- "variables": {"name": "John", "platform": "LlamaStack"},
- }
-
-
-@pytest.fixture
-def sample_prompts_data():
- return [
- {"prompt": "Hello {{name}}!", "variables": {"name": "Alice"}},
- {"prompt": "Welcome to {{platform}}, {{user}}!", "variables": {"platform": "LlamaStack", "user": "Bob"}},
- {"prompt": "Your order {{order_id}} is ready for pickup.", "variables": {"order_id": "12345"}},
- ]
+ config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
+ store = PromptServiceImpl(config, deps={})
+ await store.initialize()
+ yield store
diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py
index a4512ed02..94bbbdbf0 100644
--- a/tests/unit/prompts/prompts/test_prompts.py
+++ b/tests/unit/prompts/prompts/test_prompts.py
@@ -32,7 +32,7 @@ class TestPrompts:
async def test_create_and_get_prompt(self, store):
prompt = await store.create_prompt("Hello world!", ["name"])
assert prompt.prompt == "Hello world!"
- assert prompt.version == "1"
+ assert prompt.version == 1
assert prompt.prompt_id.startswith("pmpt_")
assert prompt.variables == ["name"]
@@ -42,17 +42,17 @@ class TestPrompts:
async def test_update_prompt(self, store):
prompt = await store.create_prompt("Original")
- updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"])
- assert updated.version == "2"
+ updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
+ assert updated.version == 2
assert updated.prompt == "Updated"
async def test_update_prompt_with_version(self, store):
- version_for_update = "1"
+ version_for_update = 1
prompt = await store.create_prompt("Original")
- assert prompt.version == "1"
+ assert prompt.version == 1
prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
- assert prompt.version == "2"
+ assert prompt.version == 2
with pytest.raises(ValueError):
# now this is a stale version
@@ -60,7 +60,7 @@ class TestPrompts:
with pytest.raises(ValueError):
# this version does not exist
- await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"])
+ await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
async def test_delete_prompt(self, store):
prompt = await store.create_prompt("to be deleted")
@@ -80,25 +80,25 @@ class TestPrompts:
async def test_version(self, store):
prompt = await store.create_prompt("V1")
- await store.update_prompt(prompt.prompt_id, "V2", "1")
+ await store.update_prompt(prompt.prompt_id, "V2", 1)
- v1 = await store.get_prompt(prompt.prompt_id, version="1")
- assert v1.version == "1" and v1.prompt == "V1"
+ v1 = await store.get_prompt(prompt.prompt_id, version=1)
+ assert v1.version == 1 and v1.prompt == "V1"
latest = await store.get_prompt(prompt.prompt_id)
- assert latest.version == "2" and latest.prompt == "V2"
+ assert latest.version == 2 and latest.prompt == "V2"
async def test_set_default_version(self, store):
prompt0 = await store.create_prompt("V1")
- prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1")
+ prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1)
- assert (await store.get_prompt(prompt0.prompt_id)).version == "2"
- prompt_default = await store.set_default_version(prompt0.prompt_id, "1")
- assert (await store.get_prompt(prompt0.prompt_id)).version == "1"
- assert prompt_default.version == "1"
+ assert (await store.get_prompt(prompt0.prompt_id)).version == 2
+ prompt_default = await store.set_default_version(prompt0.prompt_id, 1)
+ assert (await store.get_prompt(prompt0.prompt_id)).version == 1
+ assert prompt_default.version == 1
prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
- assert prompt2.version == "3"
+ assert prompt2.version == 3
async def test_prompt_id_generation_and_validation(self, store):
prompt = await store.create_prompt("Test")
@@ -110,30 +110,31 @@ class TestPrompts:
async def test_list_shows_default_versions(self, store):
prompt = await store.create_prompt("V1")
- await store.update_prompt(prompt.prompt_id, "V2", "1")
- await store.update_prompt(prompt.prompt_id, "V3", "2")
+ await store.update_prompt(prompt.prompt_id, "V2", 1)
+ await store.update_prompt(prompt.prompt_id, "V3", 2)
response = await store.list_prompts()
listed_prompt = response.data[0]
- assert listed_prompt.version == "3" and listed_prompt.prompt == "V3"
+ assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
- await store.set_default_version(prompt.prompt_id, "1")
+ await store.set_default_version(prompt.prompt_id, 1)
response = await store.list_prompts()
listed_prompt = response.data[0]
- assert listed_prompt.version == "1" and listed_prompt.prompt == "V1"
+ assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
+ assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default
async def test_get_all_prompt_versions(self, store):
prompt = await store.create_prompt("V1")
- await store.update_prompt(prompt.prompt_id, "V2", "1")
- await store.update_prompt(prompt.prompt_id, "V3", "2")
+ await store.update_prompt(prompt.prompt_id, "V2", 1)
+ await store.update_prompt(prompt.prompt_id, "V3", 2)
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
assert len(versions) == 3
- assert [v.version for v in versions] == ["1", "2", "3"]
+ assert [v.version for v in versions] == [1, 2, 3]
assert [v.is_default for v in versions] == [False, False, True]
- await store.set_default_version(prompt.prompt_id, "2")
+ await store.set_default_version(prompt.prompt_id, 2)
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
assert [v.is_default for v in versions] == [False, True, False]