llama-stack-mirror/tests/unit/providers/prompts/test_prompts.py
Francisco Javier Arceo 5c02661b79 adding GET /prompts/{prompt_id}/versions
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
2025-09-04 21:45:16 -04:00

119 lines
4.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestPrompts:
@pytest.fixture
async def store(self):
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file:
db_path = tmp_file.name
try:
config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path))
store = PromptServiceImpl(config, deps={})
await store.initialize()
yield store
finally:
if os.path.exists(db_path):
os.unlink(db_path)
async def test_create_and_get_prompt(self, store):
prompt = await store.create_prompt("Hello world!", {"name": "John"})
assert prompt.prompt == "Hello world!"
assert prompt.version == "1"
assert prompt.prompt_id.startswith("pmpt_")
assert prompt.variables == {"name": "John"}
retrieved = await store.get_prompt(prompt.prompt_id)
assert retrieved.prompt_id == prompt.prompt_id
assert retrieved.prompt == prompt.prompt
async def test_update_prompt(self, store):
prompt = await store.create_prompt("Original")
updated = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"})
assert updated.version == "2"
assert updated.prompt == "Updated"
async def test_delete_prompt(self, store):
prompt = await store.create_prompt("to be deleted")
await store.delete_prompt(prompt.prompt_id)
with pytest.raises(ValueError):
await store.get_prompt(prompt.prompt_id)
async def test_list_prompts(self, store):
response = await store.list_prompts()
assert response.data == []
await store.create_prompt("first")
await store.create_prompt("second")
response = await store.list_prompts()
assert len(response.data) == 2
async def test_version(self, store):
prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2")
v1 = await store.get_prompt(prompt.prompt_id, version="1")
assert v1.version == "1" and v1.prompt == "V1"
latest = await store.get_prompt(prompt.prompt_id)
assert latest.version == "2" and latest.prompt == "V2"
async def test_set_default_version(self, store):
prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2")
await store.set_default_version(prompt.prompt_id, "1")
assert (await store.get_prompt(prompt.prompt_id)).version == "1"
async def test_prompt_id_generation_and_validation(self, store):
prompt = await store.create_prompt("Test")
assert prompt.prompt_id.startswith("pmpt_")
assert len(prompt.prompt_id) == 53
with pytest.raises(ValueError):
await store.get_prompt("invalid_id")
async def test_list_shows_default_versions(self, store):
prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2")
await store.update_prompt(prompt.prompt_id, "V3")
response = await store.list_prompts()
listed_prompt = response.data[0]
assert listed_prompt.version == "3" and listed_prompt.prompt == "V3"
await store.set_default_version(prompt.prompt_id, "1")
response = await store.list_prompts()
listed_prompt = response.data[0]
assert listed_prompt.version == "1" and listed_prompt.prompt == "V1"
async def test_get_all_prompt_versions(self, store):
prompt = await store.create_prompt("V1")
await store.update_prompt(prompt.prompt_id, "V2")
await store.update_prompt(prompt.prompt_id, "V3")
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
assert len(versions) == 3
assert [v.version for v in versions] == ["1", "2", "3"]
assert [v.is_default for v in versions] == [False, False, True]
await store.set_default_version(prompt.prompt_id, "2")
versions = (await store.list_prompt_versions(prompt.prompt_id)).data
assert [v.is_default for v in versions] == [False, True, False]
with pytest.raises(ValueError):
await store.list_prompt_versions("nonexistent")