mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
add set_as_default in update_prompt and in API with default to True, also fix tests to actualy use temp_prompt_store (with some modifications)
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
64f6840195
commit
54fcdf1d3d
6 changed files with 91 additions and 80 deletions
7
docs/_static/llama-stack-spec.html
vendored
7
docs/_static/llama-stack-spec.html
vendored
|
@ -17834,12 +17834,17 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Updated list of variable names that can be used in the prompt template."
|
"description": "Updated list of variable names that can be used in the prompt template."
|
||||||
|
},
|
||||||
|
"set_as_default": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Set the new version as the default (default=True)."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"prompt",
|
"prompt",
|
||||||
"version"
|
"version",
|
||||||
|
"set_as_default"
|
||||||
],
|
],
|
||||||
"title": "UpdatePromptRequest"
|
"title": "UpdatePromptRequest"
|
||||||
},
|
},
|
||||||
|
|
5
docs/_static/llama-stack-spec.yaml
vendored
5
docs/_static/llama-stack-spec.yaml
vendored
|
@ -13236,10 +13236,15 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Updated list of variable names that can be used in the prompt template.
|
Updated list of variable names that can be used in the prompt template.
|
||||||
|
set_as_default:
|
||||||
|
type: boolean
|
||||||
|
description: >-
|
||||||
|
Set the new version as the default (default=True).
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- prompt
|
- prompt
|
||||||
- version
|
- version
|
||||||
|
- set_as_default
|
||||||
title: UpdatePromptRequest
|
title: UpdatePromptRequest
|
||||||
VersionInfo:
|
VersionInfo:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -150,6 +150,7 @@ class Prompts(Protocol):
|
||||||
prompt: str,
|
prompt: str,
|
||||||
version: int,
|
version: int,
|
||||||
variables: list[str] | None = None,
|
variables: list[str] | None = None,
|
||||||
|
set_as_default: bool = True,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version).
|
"""Update an existing prompt (increments version).
|
||||||
|
|
||||||
|
@ -157,6 +158,7 @@ class Prompts(Protocol):
|
||||||
:param prompt: The updated prompt text content.
|
:param prompt: The updated prompt text content.
|
||||||
:param version: The current version of the prompt being updated.
|
:param version: The current version of the prompt being updated.
|
||||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||||
|
:param set_as_default: Set the new version as the default (default=True).
|
||||||
:returns: The updated Prompt resource with incremented version.
|
:returns: The updated Prompt resource with incremented version.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -154,6 +154,7 @@ class PromptServiceImpl(Prompts):
|
||||||
prompt: str,
|
prompt: str,
|
||||||
version: int,
|
version: int,
|
||||||
variables: list[str] | None = None,
|
variables: list[str] | None = None,
|
||||||
|
set_as_default: bool = True,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version)."""
|
"""Update an existing prompt (increments version)."""
|
||||||
if version < 1:
|
if version < 1:
|
||||||
|
@ -178,8 +179,8 @@ class PromptServiceImpl(Prompts):
|
||||||
data = self._serialize_prompt(updated_prompt)
|
data = self._serialize_prompt(updated_prompt)
|
||||||
await self.kvstore.set(version_key, data)
|
await self.kvstore.set(version_key, data)
|
||||||
|
|
||||||
default_key = self._get_default_key(prompt_id)
|
if set_as_default:
|
||||||
await self.kvstore.set(default_key, str(new_version))
|
await self.set_default_version(prompt_id, new_version)
|
||||||
|
|
||||||
return updated_prompt
|
return updated_prompt
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,13 @@ async def temp_prompt_store(tmp_path_factory):
|
||||||
temp_dir = tmp_path_factory.getbasetemp()
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
db_path = str(temp_dir / f"{unique_id}.db")
|
db_path = str(temp_dir / f"{unique_id}.db")
|
||||||
|
|
||||||
config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
||||||
|
config = PromptServiceConfig(run_config=mock_run_config)
|
||||||
store = PromptServiceImpl(config, deps={})
|
store = PromptServiceImpl(config, deps={})
|
||||||
await store.initialize()
|
|
||||||
|
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||||
|
|
||||||
yield store
|
yield store
|
||||||
|
|
|
@ -4,149 +4,141 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrompts:
|
class TestPrompts:
|
||||||
@pytest.fixture
|
async def test_create_and_get_prompt(self, temp_prompt_store):
|
||||||
async def store(self):
|
prompt = await temp_prompt_store.create_prompt("Hello world!", ["name"])
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
|
||||||
config = PromptServiceConfig(run_config=mock_run_config)
|
|
||||||
store = PromptServiceImpl(config, deps={})
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
|
|
||||||
test_db_path = os.path.join(temp_dir, "test_prompts.db")
|
|
||||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=test_db_path))
|
|
||||||
|
|
||||||
yield store
|
|
||||||
|
|
||||||
async def test_create_and_get_prompt(self, store):
|
|
||||||
prompt = await store.create_prompt("Hello world!", ["name"])
|
|
||||||
assert prompt.prompt == "Hello world!"
|
assert prompt.prompt == "Hello world!"
|
||||||
assert prompt.version == 1
|
assert prompt.version == 1
|
||||||
assert prompt.prompt_id.startswith("pmpt_")
|
assert prompt.prompt_id.startswith("pmpt_")
|
||||||
assert prompt.variables == ["name"]
|
assert prompt.variables == ["name"]
|
||||||
|
|
||||||
retrieved = await store.get_prompt(prompt.prompt_id)
|
retrieved = await temp_prompt_store.get_prompt(prompt.prompt_id)
|
||||||
assert retrieved.prompt_id == prompt.prompt_id
|
assert retrieved.prompt_id == prompt.prompt_id
|
||||||
assert retrieved.prompt == prompt.prompt
|
assert retrieved.prompt == prompt.prompt
|
||||||
|
|
||||||
async def test_update_prompt(self, store):
|
async def test_update_prompt(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("Original")
|
prompt = await temp_prompt_store.create_prompt("Original")
|
||||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
|
updated = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
|
||||||
assert updated.version == 2
|
assert updated.version == 2
|
||||||
assert updated.prompt == "Updated"
|
assert updated.prompt == "Updated"
|
||||||
|
|
||||||
async def test_update_prompt_with_version(self, store):
|
async def test_update_prompt_with_version(self, temp_prompt_store):
|
||||||
version_for_update = 1
|
version_for_update = 1
|
||||||
|
|
||||||
prompt = await store.create_prompt("Original")
|
prompt = await temp_prompt_store.create_prompt("Original")
|
||||||
assert prompt.version == 1
|
assert prompt.version == 1
|
||||||
prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
|
prompt = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
|
||||||
assert prompt.version == 2
|
assert prompt.version == 2
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# now this is a stale version
|
# now this is a stale version
|
||||||
await store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# this version does not exist
|
# this version does not exist
|
||||||
await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
|
||||||
|
|
||||||
async def test_delete_prompt(self, store):
|
async def test_delete_prompt(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("to be deleted")
|
prompt = await temp_prompt_store.create_prompt("to be deleted")
|
||||||
await store.delete_prompt(prompt.prompt_id)
|
await temp_prompt_store.delete_prompt(prompt.prompt_id)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await store.get_prompt(prompt.prompt_id)
|
await temp_prompt_store.get_prompt(prompt.prompt_id)
|
||||||
|
|
||||||
async def test_list_prompts(self, store):
|
async def test_list_prompts(self, temp_prompt_store):
|
||||||
response = await store.list_prompts()
|
response = await temp_prompt_store.list_prompts()
|
||||||
assert response.data == []
|
assert response.data == []
|
||||||
|
|
||||||
await store.create_prompt("first")
|
await temp_prompt_store.create_prompt("first")
|
||||||
await store.create_prompt("second")
|
await temp_prompt_store.create_prompt("second")
|
||||||
|
|
||||||
response = await store.list_prompts()
|
response = await temp_prompt_store.list_prompts()
|
||||||
assert len(response.data) == 2
|
assert len(response.data) == 2
|
||||||
|
|
||||||
async def test_version(self, store):
|
async def test_version(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt = await temp_prompt_store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2", 1)
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||||
|
|
||||||
v1 = await store.get_prompt(prompt.prompt_id, version=1)
|
v1 = await temp_prompt_store.get_prompt(prompt.prompt_id, version=1)
|
||||||
assert v1.version == 1 and v1.prompt == "V1"
|
assert v1.version == 1 and v1.prompt == "V1"
|
||||||
|
|
||||||
latest = await store.get_prompt(prompt.prompt_id)
|
latest = await temp_prompt_store.get_prompt(prompt.prompt_id)
|
||||||
assert latest.version == 2 and latest.prompt == "V2"
|
assert latest.version == 2 and latest.prompt == "V2"
|
||||||
|
|
||||||
async def test_set_default_version(self, store):
|
async def test_set_default_version(self, temp_prompt_store):
|
||||||
prompt0 = await store.create_prompt("V1")
|
prompt0 = await temp_prompt_store.create_prompt("V1")
|
||||||
prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1)
|
prompt1 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V2", 1)
|
||||||
|
|
||||||
assert (await store.get_prompt(prompt0.prompt_id)).version == 2
|
assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 2
|
||||||
prompt_default = await store.set_default_version(prompt0.prompt_id, 1)
|
prompt_default = await temp_prompt_store.set_default_version(prompt0.prompt_id, 1)
|
||||||
assert (await store.get_prompt(prompt0.prompt_id)).version == 1
|
assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 1
|
||||||
assert prompt_default.version == 1
|
assert prompt_default.version == 1
|
||||||
|
|
||||||
prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
|
prompt2 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
|
||||||
assert prompt2.version == 3
|
assert prompt2.version == 3
|
||||||
|
|
||||||
async def test_prompt_id_generation_and_validation(self, store):
|
async def test_prompt_id_generation_and_validation(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("Test")
|
prompt = await temp_prompt_store.create_prompt("Test")
|
||||||
assert prompt.prompt_id.startswith("pmpt_")
|
assert prompt.prompt_id.startswith("pmpt_")
|
||||||
assert len(prompt.prompt_id) == 53
|
assert len(prompt.prompt_id) == 53
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await store.get_prompt("invalid_id")
|
await temp_prompt_store.get_prompt("invalid_id")
|
||||||
|
|
||||||
async def test_list_shows_default_versions(self, store):
|
async def test_list_shows_default_versions(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt = await temp_prompt_store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2", 1)
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||||
await store.update_prompt(prompt.prompt_id, "V3", 2)
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
|
||||||
|
|
||||||
response = await store.list_prompts()
|
response = await temp_prompt_store.list_prompts()
|
||||||
listed_prompt = response.data[0]
|
listed_prompt = response.data[0]
|
||||||
assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
|
assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
|
||||||
|
|
||||||
await store.set_default_version(prompt.prompt_id, 1)
|
await temp_prompt_store.set_default_version(prompt.prompt_id, 1)
|
||||||
|
|
||||||
response = await store.list_prompts()
|
response = await temp_prompt_store.list_prompts()
|
||||||
listed_prompt = response.data[0]
|
listed_prompt = response.data[0]
|
||||||
assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
|
assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
|
||||||
assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default
|
assert not (await temp_prompt_store.get_prompt(prompt.prompt_id, 3)).is_default
|
||||||
|
|
||||||
async def test_get_all_prompt_versions(self, store):
|
async def test_get_all_prompt_versions(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("V1")
|
prompt = await temp_prompt_store.create_prompt("V1")
|
||||||
await store.update_prompt(prompt.prompt_id, "V2", 1)
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||||
await store.update_prompt(prompt.prompt_id, "V3", 2)
|
await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
|
||||||
|
|
||||||
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
|
||||||
assert len(versions) == 3
|
assert len(versions) == 3
|
||||||
assert [v.version for v in versions] == [1, 2, 3]
|
assert [v.version for v in versions] == [1, 2, 3]
|
||||||
assert [v.is_default for v in versions] == [False, False, True]
|
assert [v.is_default for v in versions] == [False, False, True]
|
||||||
|
|
||||||
await store.set_default_version(prompt.prompt_id, 2)
|
await temp_prompt_store.set_default_version(prompt.prompt_id, 2)
|
||||||
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
|
||||||
assert [v.is_default for v in versions] == [False, True, False]
|
assert [v.is_default for v in versions] == [False, True, False]
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await store.list_prompt_versions("nonexistent")
|
await temp_prompt_store.list_prompt_versions("nonexistent")
|
||||||
|
|
||||||
async def test_prompt_variable_validation(self, store):
|
async def test_prompt_variable_validation(self, temp_prompt_store):
|
||||||
prompt = await store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
|
prompt = await temp_prompt_store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
|
||||||
assert prompt.variables == ["name", "city"]
|
assert prompt.variables == ["name", "city"]
|
||||||
|
|
||||||
prompt_no_vars = await store.create_prompt("Hello world!", [])
|
prompt_no_vars = await temp_prompt_store.create_prompt("Hello world!", [])
|
||||||
assert prompt_no_vars.variables == []
|
assert prompt_no_vars.variables == []
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="undeclared variables"):
|
with pytest.raises(ValueError, match="undeclared variables"):
|
||||||
await store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
|
await temp_prompt_store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
|
||||||
|
|
||||||
|
async def test_update_prompt_set_as_default_behavior(self, temp_prompt_store):
|
||||||
|
prompt = await temp_prompt_store.create_prompt("V1")
|
||||||
|
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 1
|
||||||
|
|
||||||
|
prompt_v2 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1, [], set_as_default=True)
|
||||||
|
assert prompt_v2.version == 2
|
||||||
|
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2
|
||||||
|
|
||||||
|
prompt_v3 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2, [], set_as_default=False)
|
||||||
|
assert prompt_v3.version == 3
|
||||||
|
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue