feat: Adding OpenAI Compatible Prompts API

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-03 14:14:54 -04:00
parent 30117dea22
commit 8b00883abd
181 changed files with 21356 additions and 10332 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
class PromptServiceConfig(BaseModel):
"""Configuration for the built-in prompt service.
:param kvstore: Configuration for the key-value store backend
"""
kvstore: KVStoreConfig
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
"""Get the prompt service implementation."""
impl = PromptServiceImpl(config, deps)
await impl.initialize()
return impl
class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
if version:
return f"prompts:v1:{prompt_id}:{version}"
return f"prompts:v1:{prompt_id}:default"
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
return f"prompts:v1:{prompt_id}:{version}"
def _get_list_key_prefix(self) -> str:
"""Get the key prefix for listing prompts."""
return "prompts:v1:"
def _serialize_prompt(self, prompt: Prompt) -> str:
"""Serialize a prompt to JSON string for storage."""
return json.dumps(
{
"prompt_id": prompt.prompt_id,
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or {},
}
)
def _deserialize_prompt(self, data: str) -> Prompt:
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", {})
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only)."""
prefix = self._get_list_key_prefix()
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
prompts = []
for key in keys:
if key.endswith(":default"):
try:
default_version = await self.kvstore.get(key)
if default_version:
prompt_id = key.replace(prefix, "").replace(":default", "")
version_key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(version_key)
if data:
prompt = self._deserialize_prompt(data)
prompts.append(prompt)
except (json.JSONDecodeError, KeyError):
continue
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
if version:
key = self._get_version_key(prompt_id, version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
else:
default_key = self._get_prompt_key(prompt_id)
default_version = await self.kvstore.get(default_key)
if default_version is None:
raise ValueError(f"Prompt with ID '{prompt_id}' not found")
key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt with ID '{prompt_id}' not found")
return self._deserialize_prompt(data)
async def create_prompt(
self,
prompt: str,
variables: dict[str, str] | None = None,
) -> Prompt:
"""Create a new prompt."""
if variables is None:
variables = {}
prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables)
version_key = self._get_version_key(prompt_obj.prompt_id, "1")
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_prompt_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, "1")
return prompt_obj
async def update_prompt(
self,
prompt_id: str,
prompt: str,
variables: dict[str, str] | None = None,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if variables is None:
variables = {}
current_prompt = await self.get_prompt(prompt_id)
new_version = str(int(current_prompt.version) + 1)
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, new_version)
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
default_key = self._get_prompt_key(prompt_id)
await self.kvstore.set(default_key, new_version)
return updated_prompt
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete a prompt and all its versions."""
await self.get_prompt(prompt_id)
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
for key in keys:
await self.kvstore.delete(key)
async def set_default_version(self, prompt_id: str, version: str) -> Prompt:
"""Set which version of a prompt should be the default (latest)."""
version_key = self._get_version_key(prompt_id, version)
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_prompt_key(prompt_id)
await self.kvstore.set(default_key, version)
return self._deserialize_prompt(data)

View file

@ -19,6 +19,7 @@ from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
@ -93,6 +94,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
}
if external_apis:
@ -284,7 +286,15 @@ async def instantiate_providers(
if provider.provider_id is None:
continue
deps = {a: impls[a] for a in provider.spec.api_dependencies}
try:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
except KeyError as e:
missing_api = e.args[0]
raise RuntimeError(
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
f"required dependency '{missing_api.value}' is not available. "
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
) from e
for a in provider.spec.optional_api_dependencies:
if a in impls:
deps[a] = impls[a]

View file

@ -755,7 +755,7 @@ class InferenceRouter(Inference):
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": None,
"finish_reason": "stop",
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]

View file

@ -132,9 +132,9 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
},
)
elif isinstance(exc, ConflictError):
return HTTPException(status_code=409, detail=str(exc))
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=404, detail=str(exc))
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
@ -513,6 +513,7 @@ def main(args: argparse.Namespace | None = None):
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
@ -37,6 +38,7 @@ from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
@ -44,6 +46,7 @@ from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="core")
@ -72,6 +75,7 @@ class LlamaStack(
ToolRuntime,
RAGToolRuntime,
Files,
Prompts,
):
pass
@ -105,12 +109,12 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
if hasattr(obj, "provider_id"):
# Do not register models on disabled providers
if not obj.provider_id or obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
@ -305,6 +309,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
)
impls[Api.providers] = providers_impl
prompts_impl = PromptServiceImpl(
PromptServiceConfig(kvstore=SqliteKVStoreConfig(db_path=os.path.expanduser("~/.llama-stack/prompts.db"))),
deps=impls,
)
impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
@ -329,6 +339,9 @@ async def construct_stack(
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(run_config, impls)
await refresh_registry_once(impls)