mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
removed UpdatePromptRequest, convert prompt version to integer, update tests, and remove TempPathFile in favor of tmp_path_factory
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
2a8469f156
commit
64f6840195
7 changed files with 88 additions and 109 deletions
|
@ -4,8 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -14,32 +13,12 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_prompt_store():
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file:
|
||||
db_path = tmp_file.name
|
||||
async def temp_prompt_store(tmp_path_factory):
|
||||
unique_id = f"prompt_store_{random.randint(1, 1000000)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
try:
|
||||
config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
await store.initialize()
|
||||
yield store
|
||||
finally:
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompt_data():
|
||||
return {
|
||||
"prompt": "Hello {{name}}, welcome to {{platform}}!",
|
||||
"variables": {"name": "John", "platform": "LlamaStack"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompts_data():
|
||||
return [
|
||||
{"prompt": "Hello {{name}}!", "variables": {"name": "Alice"}},
|
||||
{"prompt": "Welcome to {{platform}}, {{user}}!", "variables": {"platform": "LlamaStack", "user": "Bob"}},
|
||||
{"prompt": "Your order {{order_id}} is ready for pickup.", "variables": {"order_id": "12345"}},
|
||||
]
|
||||
config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
await store.initialize()
|
||||
yield store
|
||||
|
|
|
@ -32,7 +32,7 @@ class TestPrompts:
|
|||
async def test_create_and_get_prompt(self, store):
|
||||
prompt = await store.create_prompt("Hello world!", ["name"])
|
||||
assert prompt.prompt == "Hello world!"
|
||||
assert prompt.version == "1"
|
||||
assert prompt.version == 1
|
||||
assert prompt.prompt_id.startswith("pmpt_")
|
||||
assert prompt.variables == ["name"]
|
||||
|
||||
|
@ -42,17 +42,17 @@ class TestPrompts:
|
|||
|
||||
async def test_update_prompt(self, store):
|
||||
prompt = await store.create_prompt("Original")
|
||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", "1", ["v"])
|
||||
assert updated.version == "2"
|
||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
|
||||
assert updated.version == 2
|
||||
assert updated.prompt == "Updated"
|
||||
|
||||
async def test_update_prompt_with_version(self, store):
|
||||
version_for_update = "1"
|
||||
version_for_update = 1
|
||||
|
||||
prompt = await store.create_prompt("Original")
|
||||
assert prompt.version == "1"
|
||||
assert prompt.version == 1
|
||||
prompt = await store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
|
||||
assert prompt.version == "2"
|
||||
assert prompt.version == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# now this is a stale version
|
||||
|
@ -60,7 +60,7 @@ class TestPrompts:
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
# this version does not exist
|
||||
await store.update_prompt(prompt.prompt_id, "Another Update", "99", ["v"])
|
||||
await store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
|
||||
|
||||
async def test_delete_prompt(self, store):
|
||||
prompt = await store.create_prompt("to be deleted")
|
||||
|
@ -80,25 +80,25 @@ class TestPrompts:
|
|||
|
||||
async def test_version(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||
|
||||
v1 = await store.get_prompt(prompt.prompt_id, version="1")
|
||||
assert v1.version == "1" and v1.prompt == "V1"
|
||||
v1 = await store.get_prompt(prompt.prompt_id, version=1)
|
||||
assert v1.version == 1 and v1.prompt == "V1"
|
||||
|
||||
latest = await store.get_prompt(prompt.prompt_id)
|
||||
assert latest.version == "2" and latest.prompt == "V2"
|
||||
assert latest.version == 2 and latest.prompt == "V2"
|
||||
|
||||
async def test_set_default_version(self, store):
|
||||
prompt0 = await store.create_prompt("V1")
|
||||
prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", "1")
|
||||
prompt1 = await store.update_prompt(prompt0.prompt_id, "V2", 1)
|
||||
|
||||
assert (await store.get_prompt(prompt0.prompt_id)).version == "2"
|
||||
prompt_default = await store.set_default_version(prompt0.prompt_id, "1")
|
||||
assert (await store.get_prompt(prompt0.prompt_id)).version == "1"
|
||||
assert prompt_default.version == "1"
|
||||
assert (await store.get_prompt(prompt0.prompt_id)).version == 2
|
||||
prompt_default = await store.set_default_version(prompt0.prompt_id, 1)
|
||||
assert (await store.get_prompt(prompt0.prompt_id)).version == 1
|
||||
assert prompt_default.version == 1
|
||||
|
||||
prompt2 = await store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
|
||||
assert prompt2.version == "3"
|
||||
assert prompt2.version == 3
|
||||
|
||||
async def test_prompt_id_generation_and_validation(self, store):
|
||||
prompt = await store.create_prompt("Test")
|
||||
|
@ -110,30 +110,31 @@ class TestPrompts:
|
|||
|
||||
async def test_list_shows_default_versions(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||
await store.update_prompt(prompt.prompt_id, "V3", "2")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||
await store.update_prompt(prompt.prompt_id, "V3", 2)
|
||||
|
||||
response = await store.list_prompts()
|
||||
listed_prompt = response.data[0]
|
||||
assert listed_prompt.version == "3" and listed_prompt.prompt == "V3"
|
||||
assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
|
||||
|
||||
await store.set_default_version(prompt.prompt_id, "1")
|
||||
await store.set_default_version(prompt.prompt_id, 1)
|
||||
|
||||
response = await store.list_prompts()
|
||||
listed_prompt = response.data[0]
|
||||
assert listed_prompt.version == "1" and listed_prompt.prompt == "V1"
|
||||
assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
|
||||
assert not (await store.get_prompt(prompt.prompt_id, 3)).is_default
|
||||
|
||||
async def test_get_all_prompt_versions(self, store):
|
||||
prompt = await store.create_prompt("V1")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", "1")
|
||||
await store.update_prompt(prompt.prompt_id, "V3", "2")
|
||||
await store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||
await store.update_prompt(prompt.prompt_id, "V3", 2)
|
||||
|
||||
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
||||
assert len(versions) == 3
|
||||
assert [v.version for v in versions] == ["1", "2", "3"]
|
||||
assert [v.version for v in versions] == [1, 2, 3]
|
||||
assert [v.is_default for v in versions] == [False, False, True]
|
||||
|
||||
await store.set_default_version(prompt.prompt_id, "2")
|
||||
await store.set_default_version(prompt.prompt_id, 2)
|
||||
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
|
||||
assert [v.is_default for v in versions] == [False, True, False]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue