update prompt variables to list for validation

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-07 14:10:19 -04:00
parent 60361b910c
commit 1390660dcf
5 changed files with 65 additions and 37 deletions

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import secrets
from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -27,8 +28,8 @@ class Prompt(BaseModel):
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
version: str = Field(description="Version string (integer start at 1 cast as string)")
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: dict[str, str] | None = Field(
default_factory=dict, description="Variables for dynamic injection using {{variable}} syntax"
variables: list[str] = Field(
default_factory=list, description="List of variable names that can be used in the prompt template"
)
is_default: bool = Field(
default=False, description="Boolean indicating whether this version is the default version"
@ -66,6 +67,21 @@ class Prompt(BaseModel):
raise
return prompt_version
@model_validator(mode="after")
def validate_prompt_variables(self):
"""Validate that all variables used in the prompt are declared in the variables list."""
if not self.prompt:
return self
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
declared_variables = set(self.variables)
undeclared = prompt_variables - declared_variables
if undeclared:
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
return self
@classmethod
def generate_prompt_id(cls) -> str:
# Generate 48 hex characters (24 bytes)
@ -78,14 +94,14 @@ class CreatePromptRequest(BaseModel):
"""Request model to create a prompt."""
prompt: str = Field(description="The prompt text content")
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
class UpdatePromptRequest(BaseModel):
"""Request model for updating a prompt."""
prompt: str = Field(description="The prompt text content")
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
class ListPromptsResponse(BaseModel):
@ -137,12 +153,12 @@ class Prompts(Protocol):
async def create_prompt(
self,
prompt: str,
variables: dict[str, str] | None = None,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt.
:param prompt: The prompt text content with variable placeholders.
:param variables: Dictionary of variable names to their default values.
:param variables: List of variable names that can be used in the prompt template.
:returns: The created Prompt resource.
"""
...
@ -152,14 +168,14 @@ class Prompts(Protocol):
self,
prompt_id: str,
prompt: str,
variables: dict[str, str] | None = None,
variables: list[str] | None = None,
version: str | None = None,
) -> Prompt:
"""Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
:param variables: Updated dictionary of variable names to their default values.
:param variables: Updated list of variable names that can be used in the prompt template.
:param version: The current version of the prompt being updated (as a string).
:returns: The updated Prompt resource with incremented version.
"""