# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import os import tempfile import pytest from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class TestPrompts: @pytest.fixture async def store(self): with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp_file: db_path = tmp_file.name try: config = PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=db_path)) store = PromptServiceImpl(config, deps={}) await store.initialize() yield store finally: if os.path.exists(db_path): os.unlink(db_path) async def test_create_and_get_prompt(self, store): prompt = await store.create_prompt("Hello world!", {"name": "John"}) assert prompt.prompt == "Hello world!" assert prompt.version == "1" assert prompt.prompt_id.startswith("pmpt_") assert prompt.variables == {"name": "John"} retrieved = await store.get_prompt(prompt.prompt_id) assert retrieved.prompt_id == prompt.prompt_id assert retrieved.prompt == prompt.prompt async def test_update_prompt(self, store): prompt = await store.create_prompt("Original") updated = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}) assert updated.version == "2" assert updated.prompt == "Updated" async def test_delete_prompt(self, store): prompt = await store.create_prompt("to be deleted") await store.delete_prompt(prompt.prompt_id) with pytest.raises(ValueError): await store.get_prompt(prompt.prompt_id) async def test_list_prompts(self, store): response = await store.list_prompts() assert response.data == [] await store.create_prompt("first") await store.create_prompt("second") response = await store.list_prompts() assert len(response.data) == 2 async def test_version(self, store): prompt = await store.create_prompt("V1") await store.update_prompt(prompt.prompt_id, "V2") v1 = await store.get_prompt(prompt.prompt_id, version="1") assert v1.version == "1" and v1.prompt == "V1" latest = await store.get_prompt(prompt.prompt_id) assert latest.version == "2" and latest.prompt == "V2" async def test_set_default_version(self, store): prompt = await store.create_prompt("V1") await store.update_prompt(prompt.prompt_id, "V2") await store.set_default_version(prompt.prompt_id, "1") assert (await store.get_prompt(prompt.prompt_id)).version == "1" async def test_prompt_id_generation_and_validation(self, store): prompt = await store.create_prompt("Test") assert prompt.prompt_id.startswith("pmpt_") assert len(prompt.prompt_id) == 53 with pytest.raises(ValueError): await store.get_prompt("invalid_id") async def test_list_shows_default_versions(self, store): prompt = await store.create_prompt("V1") await store.update_prompt(prompt.prompt_id, "V2") await store.update_prompt(prompt.prompt_id, "V3") response = await store.list_prompts() listed_prompt = response.data[0] assert listed_prompt.version == "3" and listed_prompt.prompt == "V3" await store.set_default_version(prompt.prompt_id, "1") response = await store.list_prompts() listed_prompt = response.data[0] assert listed_prompt.version == "1" and listed_prompt.prompt == "V1" async def test_get_all_prompt_versions(self, store): prompt = await store.create_prompt("V1") await store.update_prompt(prompt.prompt_id, "V2") await store.update_prompt(prompt.prompt_id, "V3") versions = (await store.list_prompt_versions(prompt.prompt_id)).data assert len(versions) == 3 assert [v.version for v in versions] == ["1", "2", "3"] assert [v.is_default for v in versions] == [False, False, True] await store.set_default_version(prompt.prompt_id, "2") versions = (await store.list_prompt_versions(prompt.prompt_id)).data assert [v.is_default for v in versions] == [False, True, False] with pytest.raises(ValueError): await store.list_prompt_versions("nonexistent")