mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
fixed update_prompt to properly handle latest and default version, made version a required parameter, and removed unused CreatePromptRequest
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
1390660dcf
commit
beb2db487d
6 changed files with 43 additions and 41 deletions
13
docs/_static/llama-stack-spec.html
vendored
13
docs/_static/llama-stack-spec.html
vendored
|
@ -10018,7 +10018,7 @@
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Dictionary of prompt variable names and values"
|
"description": "List of prompt variable names that can be used in the prompt template"
|
||||||
},
|
},
|
||||||
"is_default": {
|
"is_default": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
|
@ -17824,21 +17824,22 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The updated prompt text content."
|
"description": "The updated prompt text content."
|
||||||
},
|
},
|
||||||
|
"version": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The current version of the prompt being updated (as a string)."
|
||||||
|
},
|
||||||
"variables": {
|
"variables": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Updated list of variable names that can be used in the prompt template."
|
"description": "Updated list of variable names that can be used in the prompt template."
|
||||||
},
|
|
||||||
"version": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The current version of the prompt being updated (as a string)."
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"prompt"
|
"prompt",
|
||||||
|
"version"
|
||||||
],
|
],
|
||||||
"title": "UpdatePromptRequest"
|
"title": "UpdatePromptRequest"
|
||||||
},
|
},
|
||||||
|
|
11
docs/_static/llama-stack-spec.yaml
vendored
11
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7403,7 +7403,7 @@ components:
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Dictionary of prompt variable names and values
|
List of prompt variable names that can be used in the prompt template
|
||||||
is_default:
|
is_default:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
@ -13226,19 +13226,20 @@ components:
|
||||||
prompt:
|
prompt:
|
||||||
type: string
|
type: string
|
||||||
description: The updated prompt text content.
|
description: The updated prompt text content.
|
||||||
|
version:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The current version of the prompt being updated (as a string).
|
||||||
variables:
|
variables:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Updated list of variable names that can be used in the prompt template.
|
Updated list of variable names that can be used in the prompt template.
|
||||||
version:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
The current version of the prompt being updated (as a string).
|
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- prompt
|
- prompt
|
||||||
|
- version
|
||||||
title: UpdatePromptRequest
|
title: UpdatePromptRequest
|
||||||
VersionInfo:
|
VersionInfo:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -4,6 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .prompts import CreatePromptRequest, ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest
|
from .prompts import ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest
|
||||||
|
|
||||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "CreatePromptRequest", "UpdatePromptRequest"]
|
__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "UpdatePromptRequest"]
|
||||||
|
|
|
@ -21,7 +21,7 @@ class Prompt(BaseModel):
|
||||||
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
||||||
:param version: Version string (integer start at 1 cast as string, incremented on save)
|
:param version: Version string (integer start at 1 cast as string, incremented on save)
|
||||||
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||||
:param variables: Dictionary of prompt variable names and values
|
:param variables: List of prompt variable names that can be used in the prompt template
|
||||||
:param is_default: Boolean indicating whether this version is the default version for this prompt
|
:param is_default: Boolean indicating whether this version is the default version for this prompt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -90,13 +90,6 @@ class Prompt(BaseModel):
|
||||||
return f"pmpt_{hex_string}"
|
return f"pmpt_{hex_string}"
|
||||||
|
|
||||||
|
|
||||||
class CreatePromptRequest(BaseModel):
|
|
||||||
"""Request model to create a prompt."""
|
|
||||||
|
|
||||||
prompt: str = Field(description="The prompt text content")
|
|
||||||
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
|
|
||||||
|
|
||||||
|
|
||||||
class UpdatePromptRequest(BaseModel):
|
class UpdatePromptRequest(BaseModel):
|
||||||
"""Request model for updating a prompt."""
|
"""Request model for updating a prompt."""
|
||||||
|
|
||||||
|
@ -168,15 +161,15 @@ class Prompts(Protocol):
|
||||||
self,
|
self,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
version: str,
|
||||||
variables: list[str] | None = None,
|
variables: list[str] | None = None,
|
||||||
version: str | None = None,
|
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version).
|
"""Update an existing prompt (increments version).
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to update.
|
:param prompt_id: The identifier of the prompt to update.
|
||||||
:param prompt: The updated prompt text content.
|
:param prompt: The updated prompt text content.
|
||||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
|
||||||
:param version: The current version of the prompt being updated (as a string).
|
:param version: The current version of the prompt being updated (as a string).
|
||||||
|
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||||
:returns: The updated Prompt resource with incremented version.
|
:returns: The updated Prompt resource with incremented version.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -142,20 +142,22 @@ class PromptServiceImpl(Prompts):
|
||||||
self,
|
self,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
version: str,
|
||||||
variables: list[str] | None = None,
|
variables: list[str] | None = None,
|
||||||
version: str | None = None,
|
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version)."""
|
"""Update an existing prompt (increments version)."""
|
||||||
if variables is None:
|
if variables is None:
|
||||||
variables = []
|
variables = []
|
||||||
|
|
||||||
current_prompt = await self.get_prompt(prompt_id)
|
prompt_versions = await self.list_prompt_versions(prompt_id)
|
||||||
if version and current_prompt.version != version:
|
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
||||||
|
|
||||||
|
if version and latest_prompt.version != version:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{current_prompt.version}' in request."
|
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
||||||
)
|
)
|
||||||
|
|
||||||
current_version = current_prompt.version if version is None else version
|
current_version = latest_prompt.version if version is None else version
|
||||||
new_version = str(int(current_version) + 1)
|
new_version = str(int(current_version) + 1)
|
||||||
|
|
||||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class TestPrompts:
|
||||||
|
|
||||||
async def test_update_prompt(self, store):
|
async def test_update_prompt(self, store):
|
||||||
prompt = await store.create_prompt("Original")
|
prompt = await store.create_prompt("Original")
|
||||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", ["v"])
|
updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"])
|
||||||
assert updated.version == "2"
|
assert updated.version == "2"
|
||||||
assert updated.prompt == "Updated"
|
assert updated.prompt == "Updated"
|
||||||
|
|
||||||
|
@ -51,16 +51,16 @@ class TestPrompts:
|
||||||
|
|
||||||
prompt = await store.create_prompt("Original")
|
prompt = await store.create_prompt("Original")
|
||||||
assert prompt.version == "1"
|
assert prompt.version == "1"
|
||||||
prompt = await store.update_prompt(prompt.prompt_id, "Updated", ["v"], version_for_update)
|
prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
|
||||||
assert prompt.version == "2"
|
assert prompt.version == "2"
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# now this is a stale version
|
# now this is a stale version
|
||||||
await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], version_for_update)
|
await store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# this version does not exist
|
# this version does not exist
|
||||||
await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], "99")
|
await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"])
|
||||||
|
|
||||||
async def test_delete_prompt(self, store):
|
async def test_delete_prompt(self, store):
|
||||||
prompt = await store.create_prompt("to be deleted")
|
prompt = await store.create_prompt("to be deleted")
|
||||||
|
@ -80,7 +80,7 @@ class TestPrompts:
|
||||||
|
|
||||||
async def test_version(self, store):
|
async def test_version(self, store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt = await store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2")
|
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||||
|
|
||||||
v1 = await store.get_prompt(prompt.prompt_id, version="1")
|
v1 = await store.get_prompt(prompt.prompt_id, version="1")
|
||||||
assert v1.version == "1" and v1.prompt == "V1"
|
assert v1.version == "1" and v1.prompt == "V1"
|
||||||
|
@ -89,11 +89,16 @@ class TestPrompts:
|
||||||
assert latest.version == "2" and latest.prompt == "V2"
|
assert latest.version == "2" and latest.prompt == "V2"
|
||||||
|
|
||||||
async def test_set_default_version(self, store):
|
async def test_set_default_version(self, store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt0 = await store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2")
|
prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1")
|
||||||
|
|
||||||
await store.set_default_version(prompt.prompt_id, "1")
|
assert (await store.get_prompt(prompt0.prompt_id)).version == "2"
|
||||||
assert (await store.get_prompt(prompt.prompt_id)).version == "1"
|
prompt_default = await store.set_default_version(prompt0.prompt_id, "1")
|
||||||
|
assert (await store.get_prompt(prompt0.prompt_id)).version == "1"
|
||||||
|
assert prompt_default.version == "1"
|
||||||
|
|
||||||
|
prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
|
||||||
|
assert prompt2.version == "3"
|
||||||
|
|
||||||
async def test_prompt_id_generation_and_validation(self, store):
|
async def test_prompt_id_generation_and_validation(self, store):
|
||||||
prompt = await store.create_prompt("Test")
|
prompt = await store.create_prompt("Test")
|
||||||
|
@ -105,8 +110,8 @@ class TestPrompts:
|
||||||
|
|
||||||
async def test_list_shows_default_versions(self, store):
|
async def test_list_shows_default_versions(self, store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt = await store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2")
|
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V3")
|
await store.update_prompt(prompt.prompt_id, "V3", "2")
|
||||||
|
|
||||||
response = await store.list_prompts()
|
response = await store.list_prompts()
|
||||||
listed_prompt = response.data[0]
|
listed_prompt = response.data[0]
|
||||||
|
@ -120,8 +125,8 @@ class TestPrompts:
|
||||||
|
|
||||||
async def test_get_all_prompt_versions(self, store):
|
async def test_get_all_prompt_versions(self, store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt = await store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2")
|
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V3")
|
await store.update_prompt(prompt.prompt_id, "V3", "2")
|
||||||
|
|
||||||
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
||||||
assert len(versions) == 3
|
assert len(versions) == 3
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue