# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json from typing import Any from pydantic import BaseModel from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import KVStoreConfig class PromptServiceConfig(BaseModel): """Configuration for the built-in prompt service. :param kvstore: Configuration for the key-value store backend """ kvstore: KVStoreConfig async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]): """Get the prompt service implementation.""" impl = PromptServiceImpl(config, deps) await impl.initialize() return impl class PromptServiceImpl(Prompts): """Built-in prompt service implementation using KVStore.""" def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]): self.config = config self.deps = deps self.kvstore: KVStore async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str: if version: return self._get_version_key(prompt_id, version) return f"prompts:v1:{prompt_id}:default" def _get_version_key(self, prompt_id: str, version: str) -> str: """Get the KVStore key for a specific prompt version.""" return f"prompts:v1:{prompt_id}:{version}" def _get_list_key_prefix(self) -> str: """Get the key prefix for listing prompts.""" return "prompts:v1:" def _serialize_prompt(self, prompt: Prompt) -> str: """Serialize a prompt to JSON string for storage.""" return json.dumps( { "prompt_id": prompt.prompt_id, "prompt": prompt.prompt, "version": prompt.version, "variables": prompt.variables or {}, } ) def _deserialize_prompt(self, data: str) -> Prompt: """Deserialize a prompt from JSON string.""" obj = json.loads(data) return Prompt( prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", {}) ) async def list_prompts(self) -> ListPromptsResponse: """List all prompts (default versions only).""" prefix = self._get_list_key_prefix() keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") prompts = [] for key in keys: if key.endswith(":default"): try: default_version = await self.kvstore.get(key) if default_version: prompt_id = key.replace(prefix, "").replace(":default", "") version_key = self._get_version_key(prompt_id, default_version) data = await self.kvstore.get(version_key) if data: prompt = self._deserialize_prompt(data) prompts.append(prompt) except (json.JSONDecodeError, KeyError): continue prompts.sort(key=lambda p: p.prompt_id or "", reverse=True) return ListPromptsResponse(data=prompts) async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt: """Get a prompt by its identifier and optional version.""" if version: key = self._get_version_key(prompt_id, version) data = await self.kvstore.get(key) if data is None: raise ValueError(f"Prompt {prompt_id} version {version} not found") else: default_key = self._get_prompt_key(prompt_id) default_version = await self.kvstore.get(default_key) if default_version is None: raise ValueError(f"Prompt with ID '{prompt_id}' not found") key = self._get_version_key(prompt_id, default_version) data = await self.kvstore.get(key) if data is None: raise ValueError(f"Prompt with ID '{prompt_id}' not found") return self._deserialize_prompt(data) async def create_prompt( self, prompt: str, variables: dict[str, str] | None = None, ) -> Prompt: """Create a new prompt.""" if variables is None: variables = {} prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables) version_key = self._get_version_key(prompt_obj.prompt_id, "1") data = self._serialize_prompt(prompt_obj) await self.kvstore.set(version_key, data) default_key = self._get_prompt_key(prompt_obj.prompt_id) await self.kvstore.set(default_key, "1") return prompt_obj async def update_prompt( self, prompt_id: str, prompt: str, variables: dict[str, str] | None = None, ) -> Prompt: """Update an existing prompt (increments version).""" if variables is None: variables = {} current_prompt = await self.get_prompt(prompt_id) new_version = str(int(current_prompt.version) + 1) updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) version_key = self._get_version_key(prompt_id, new_version) data = self._serialize_prompt(updated_prompt) await self.kvstore.set(version_key, data) default_key = self._get_prompt_key(prompt_id) await self.kvstore.set(default_key, new_version) return updated_prompt async def delete_prompt(self, prompt_id: str) -> None: """Delete a prompt and all its versions.""" await self.get_prompt(prompt_id) prefix = f"prompts:v1:{prompt_id}:" keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") for key in keys: await self.kvstore.delete(key) async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse: """List all versions of a specific prompt.""" prefix = f"prompts:v1:{prompt_id}:" keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") default_version = None prompts = [] for key in keys: data = await self.kvstore.get(key) if key.endswith(":default"): default_version = data else: if data: prompt_obj = self._deserialize_prompt(data) prompts.append(prompt_obj) if not prompts: raise ValueError(f"Prompt {prompt_id} not found") for prompt in prompts: prompt.is_default = prompt.version == default_version prompts.sort(key=lambda x: int(x.version)) return ListPromptsResponse(data=prompts) async def set_default_version(self, prompt_id: str, version: str) -> Prompt: """Set which version of a prompt should be the default (latest).""" version_key = self._get_version_key(prompt_id, version) data = await self.kvstore.get(version_key) if data is None: raise ValueError(f"Prompt {prompt_id} version {version} not found") default_key = self._get_prompt_key(prompt_id) await self.kvstore.set(default_key, version) return self._deserialize_prompt(data)