diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py index f69c81f8b..6e7385a57 100644 --- a/llama_stack/core/prompts/prompts.py +++ b/llama_stack/core/prompts/prompts.py @@ -10,17 +10,19 @@ from typing import Any from pydantic import BaseModel from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts +from llama_stack.core.datatypes import StackRunConfig +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class PromptServiceConfig(BaseModel): """Configuration for the built-in prompt service. - :param kvstore: Configuration for the key-value store backend + :param run_config: Stack run configuration containing distribution info """ - kvstore: KVStoreConfig + run_config: StackRunConfig async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]): @@ -39,7 +41,10 @@ class PromptServiceImpl(Prompts): self.kvstore: KVStore async def initialize(self) -> None: - self.kvstore = await kvstore_impl(self.config.kvstore) + kvstore_config = SqliteKVStoreConfig( + db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix() + ) + self.kvstore = await kvstore_impl(kvstore_config) def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str: if version: diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 69760543a..f7b723a93 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -43,6 +43,7 @@ from llama_stack.core.providers import ProviderImpl, ProviderImplConfig from llama_stack.core.resolver import ProviderRegistry, resolve_impls from llama_stack.core.routing_tables.common import CommonRoutingTableImpl from llama_stack.core.store.registry import create_dist_registry +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -310,7 +311,7 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf impls[Api.providers] = providers_impl prompts_impl = PromptServiceImpl( - PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=os.path.expanduser("~/.llama-stack/prompts.db"))), + PromptServiceConfig(run_config=run_config), deps=impls, ) impls[Api.prompts] = prompts_impl