mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
feat: Adding OpenAI Compatible Prompts API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
30117dea22
commit
8b00883abd
181 changed files with 21356 additions and 10332 deletions
173
llama_stack/apis/prompts/prompts.py
Normal file
173
llama_stack/apis/prompts/prompts.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import secrets
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
|
||||
|
||||
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
||||
:param version: Version string (integer start at 1 cast as string, incremented on save)
|
||||
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||
:param variables: Dictionary of prompt variable names and values
|
||||
"""
|
||||
|
||||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
||||
version: str = Field(description="Version string (integer start at 1 cast as string)")
|
||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
||||
variables: dict[str, str] | None = Field(
|
||||
default_factory=dict, description="Variables for dynamic injection using {{variable}} syntax"
|
||||
)
|
||||
|
||||
@field_validator("prompt_id")
|
||||
@classmethod
|
||||
def validate_prompt_id(cls, prompt_id: str) -> str:
|
||||
if not isinstance(prompt_id, str):
|
||||
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
|
||||
|
||||
if not prompt_id.startswith("pmpt_"):
|
||||
raise ValueError("prompt_id must start with 'pmpt_' prefix")
|
||||
|
||||
hex_part = prompt_id[5:]
|
||||
if len(hex_part) != 48:
|
||||
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
|
||||
|
||||
for char in hex_part:
|
||||
if char not in "0123456789abcdef":
|
||||
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
|
||||
|
||||
return prompt_id
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, prompt_version: str) -> str:
|
||||
try:
|
||||
int_version = int(prompt_version)
|
||||
if int_version < 1:
|
||||
raise ValueError("version must be >= 1")
|
||||
except ValueError as e:
|
||||
if "invalid literal" in str(e):
|
||||
raise ValueError("version must be a string representation of an integer") from e
|
||||
raise
|
||||
return prompt_version
|
||||
|
||||
@classmethod
|
||||
def generate_prompt_id(cls) -> str:
|
||||
# Generate 48 hex characters (24 bytes)
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
hex_string = random_bytes.hex()
|
||||
return f"pmpt_{hex_string}"
|
||||
|
||||
|
||||
class CreatePromptRequest(BaseModel):
|
||||
"""Request model to create a prompt."""
|
||||
|
||||
prompt: str = Field(description="The prompt text content")
|
||||
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
|
||||
|
||||
|
||||
class UpdatePromptRequest(BaseModel):
|
||||
"""Request model for updating a prompt."""
|
||||
|
||||
prompt: str = Field(description="The prompt text content")
|
||||
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
|
||||
|
||||
|
||||
class ListPromptsResponse(BaseModel):
|
||||
"""Response model to list prompts."""
|
||||
|
||||
data: list[Prompt]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET")
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts.
|
||||
|
||||
:returns: A ListPromptsResponse containing all prompts.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id:path}", method="GET")
|
||||
async def get_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: str | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
:returns: A Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts", method="POST")
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: dict[str, str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: Dictionary of variable names to their default values.
|
||||
:returns: The created Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id:path}", method="PUT")
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
variables: dict[str, str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
:param variables: Updated dictionary of variable names to their default values.
|
||||
:returns: The updated Prompt resource with incremented version.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id:path}", method="DELETE")
|
||||
async def delete_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id:path}/default-version", method="PUT")
|
||||
async def set_default_version(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: str,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
:returns: The prompt with the specified version now set as default.
|
||||
"""
|
||||
...
|
Loading…
Add table
Add a link
Reference in a new issue