update prompt variables to list for validation

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-07 14:10:19 -04:00
parent 60361b910c
commit 1390660dcf
5 changed files with 65 additions and 37 deletions

View file

@ -9985,11 +9985,11 @@
"description": "The prompt text content with variable placeholders." "description": "The prompt text content with variable placeholders."
}, },
"variables": { "variables": {
"type": "object", "type": "array",
"additionalProperties": { "items": {
"type": "string" "type": "string"
}, },
"description": "Dictionary of variable names to their default values." "description": "List of variable names that can be used in the prompt template."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -10014,8 +10014,8 @@
"description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'" "description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'"
}, },
"variables": { "variables": {
"type": "object", "type": "array",
"additionalProperties": { "items": {
"type": "string" "type": "string"
}, },
"description": "Dictionary of prompt variable names and values" "description": "Dictionary of prompt variable names and values"
@ -10030,6 +10030,7 @@
"required": [ "required": [
"version", "version",
"prompt_id", "prompt_id",
"variables",
"is_default" "is_default"
], ],
"title": "Prompt", "title": "Prompt",
@ -17824,11 +17825,11 @@
"description": "The updated prompt text content." "description": "The updated prompt text content."
}, },
"variables": { "variables": {
"type": "object", "type": "array",
"additionalProperties": { "items": {
"type": "string" "type": "string"
}, },
"description": "Updated dictionary of variable names to their default values." "description": "Updated list of variable names that can be used in the prompt template."
}, },
"version": { "version": {
"type": "string", "type": "string",

View file

@ -7373,11 +7373,11 @@ components:
description: >- description: >-
The prompt text content with variable placeholders. The prompt text content with variable placeholders.
variables: variables:
type: object type: array
additionalProperties: items:
type: string type: string
description: >- description: >-
Dictionary of variable names to their default values. List of variable names that can be used in the prompt template.
additionalProperties: false additionalProperties: false
required: required:
- prompt - prompt
@ -7399,8 +7399,8 @@ components:
description: >- description: >-
Unique identifier formatted as 'pmpt_<48-digit-hash>' Unique identifier formatted as 'pmpt_<48-digit-hash>'
variables: variables:
type: object type: array
additionalProperties: items:
type: string type: string
description: >- description: >-
Dictionary of prompt variable names and values Dictionary of prompt variable names and values
@ -7414,6 +7414,7 @@ components:
required: required:
- version - version
- prompt_id - prompt_id
- variables
- is_default - is_default
title: Prompt title: Prompt
description: >- description: >-
@ -13226,11 +13227,11 @@ components:
type: string type: string
description: The updated prompt text content. description: The updated prompt text content.
variables: variables:
type: object type: array
additionalProperties: items:
type: string type: string
description: >- description: >-
Updated dictionary of variable names to their default values. Updated list of variable names that can be used in the prompt template.
version: version:
type: string type: string
description: >- description: >-

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import re
import secrets import secrets
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
@ -27,8 +28,8 @@ class Prompt(BaseModel):
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders") prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
version: str = Field(description="Version string (integer start at 1 cast as string)") version: str = Field(description="Version string (integer start at 1 cast as string)")
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'") prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: dict[str, str] | None = Field( variables: list[str] = Field(
default_factory=dict, description="Variables for dynamic injection using {{variable}} syntax" default_factory=list, description="List of variable names that can be used in the prompt template"
) )
is_default: bool = Field( is_default: bool = Field(
default=False, description="Boolean indicating whether this version is the default version" default=False, description="Boolean indicating whether this version is the default version"
@ -66,6 +67,21 @@ class Prompt(BaseModel):
raise raise
return prompt_version return prompt_version
@model_validator(mode="after")
def validate_prompt_variables(self):
"""Validate that all variables used in the prompt are declared in the variables list."""
if not self.prompt:
return self
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
declared_variables = set(self.variables)
undeclared = prompt_variables - declared_variables
if undeclared:
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
return self
@classmethod @classmethod
def generate_prompt_id(cls) -> str: def generate_prompt_id(cls) -> str:
# Generate 48 hex characters (24 bytes) # Generate 48 hex characters (24 bytes)
@ -78,14 +94,14 @@ class CreatePromptRequest(BaseModel):
"""Request model to create a prompt.""" """Request model to create a prompt."""
prompt: str = Field(description="The prompt text content") prompt: str = Field(description="The prompt text content")
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection") variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
class UpdatePromptRequest(BaseModel): class UpdatePromptRequest(BaseModel):
"""Request model for updating a prompt.""" """Request model for updating a prompt."""
prompt: str = Field(description="The prompt text content") prompt: str = Field(description="The prompt text content")
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection") variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
class ListPromptsResponse(BaseModel): class ListPromptsResponse(BaseModel):
@ -137,12 +153,12 @@ class Prompts(Protocol):
async def create_prompt( async def create_prompt(
self, self,
prompt: str, prompt: str,
variables: dict[str, str] | None = None, variables: list[str] | None = None,
) -> Prompt: ) -> Prompt:
"""Create a new prompt. """Create a new prompt.
:param prompt: The prompt text content with variable placeholders. :param prompt: The prompt text content with variable placeholders.
:param variables: Dictionary of variable names to their default values. :param variables: List of variable names that can be used in the prompt template.
:returns: The created Prompt resource. :returns: The created Prompt resource.
""" """
... ...
@ -152,14 +168,14 @@ class Prompts(Protocol):
self, self,
prompt_id: str, prompt_id: str,
prompt: str, prompt: str,
variables: dict[str, str] | None = None, variables: list[str] | None = None,
version: str | None = None, version: str | None = None,
) -> Prompt: ) -> Prompt:
"""Update an existing prompt (increments version). """Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update. :param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content. :param prompt: The updated prompt text content.
:param variables: Updated dictionary of variable names to their default values. :param variables: Updated list of variable names that can be used in the prompt template.
:param version: The current version of the prompt being updated (as a string). :param version: The current version of the prompt being updated (as a string).
:returns: The updated Prompt resource with incremented version. :returns: The updated Prompt resource with incremented version.
""" """

View file

@ -76,7 +76,7 @@ class PromptServiceImpl(Prompts):
"prompt_id": prompt.prompt_id, "prompt_id": prompt.prompt_id,
"prompt": prompt.prompt, "prompt": prompt.prompt,
"version": prompt.version, "version": prompt.version,
"variables": prompt.variables or {}, "variables": prompt.variables or [],
} }
) )
@ -84,7 +84,7 @@ class PromptServiceImpl(Prompts):
"""Deserialize a prompt from JSON string.""" """Deserialize a prompt from JSON string."""
obj = json.loads(data) obj = json.loads(data)
return Prompt( return Prompt(
prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", {}) prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", [])
) )
async def list_prompts(self) -> ListPromptsResponse: async def list_prompts(self) -> ListPromptsResponse:
@ -121,11 +121,11 @@ class PromptServiceImpl(Prompts):
async def create_prompt( async def create_prompt(
self, self,
prompt: str, prompt: str,
variables: dict[str, str] | None = None, variables: list[str] | None = None,
) -> Prompt: ) -> Prompt:
"""Create a new prompt.""" """Create a new prompt."""
if variables is None: if variables is None:
variables = {} variables = []
prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables) prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables)
@ -142,12 +142,12 @@ class PromptServiceImpl(Prompts):
self, self,
prompt_id: str, prompt_id: str,
prompt: str, prompt: str,
variables: dict[str, str] | None = None, variables: list[str] | None = None,
version: str | None = None, version: str | None = None,
) -> Prompt: ) -> Prompt:
"""Update an existing prompt (increments version).""" """Update an existing prompt (increments version)."""
if variables is None: if variables is None:
variables = {} variables = []
current_prompt = await self.get_prompt(prompt_id) current_prompt = await self.get_prompt(prompt_id)
if version and current_prompt.version != version: if version and current_prompt.version != version:

View file

@ -30,11 +30,11 @@ class TestPrompts:
yield store yield store
async def test_create_and_get_prompt(self, store): async def test_create_and_get_prompt(self, store):
prompt = await store.create_prompt("Hello world!", {"name": "John"}) prompt = await store.create_prompt("Hello world!", ["name"])
assert prompt.prompt == "Hello world!" assert prompt.prompt == "Hello world!"
assert prompt.version == "1" assert prompt.version == "1"
assert prompt.prompt_id.startswith("pmpt_") assert prompt.prompt_id.startswith("pmpt_")
assert prompt.variables == {"name": "John"} assert prompt.variables == ["name"]
retrieved = await store.get_prompt(prompt.prompt_id) retrieved = await store.get_prompt(prompt.prompt_id)
assert retrieved.prompt_id == prompt.prompt_id assert retrieved.prompt_id == prompt.prompt_id
@ -42,7 +42,7 @@ class TestPrompts:
async def test_update_prompt(self, store): async def test_update_prompt(self, store):
prompt = await store.create_prompt("Original") prompt = await store.create_prompt("Original")
updated = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}) updated = await store.update_prompt(prompt.prompt_id, "Updated", ["v"])
assert updated.version == "2" assert updated.version == "2"
assert updated.prompt == "Updated" assert updated.prompt == "Updated"
@ -51,16 +51,16 @@ class TestPrompts:
prompt = await store.create_prompt("Original") prompt = await store.create_prompt("Original")
assert prompt.version == "1" assert prompt.version == "1"
prompt = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}, version_for_update) prompt = await store.update_prompt(prompt.prompt_id, "Updated", ["v"], version_for_update)
assert prompt.version == "2" assert prompt.version == "2"
with pytest.raises(ValueError): with pytest.raises(ValueError):
# now this is a stale version # now this is a stale version
await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, version_for_update) await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], version_for_update)
with pytest.raises(ValueError): with pytest.raises(ValueError):
# this version does not exist # this version does not exist
await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, "99") await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], "99")
async def test_delete_prompt(self, store): async def test_delete_prompt(self, store):
prompt = await store.create_prompt("to be deleted") prompt = await store.create_prompt("to be deleted")
@ -134,3 +134,13 @@ class TestPrompts:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await store.list_prompt_versions("nonexistent") await store.list_prompt_versions("nonexistent")
async def test_prompt_variable_validation(self, store):
prompt = await store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
assert prompt.variables == ["name", "city"]
prompt_no_vars = await store.create_prompt("Hello world!", [])
assert prompt_no_vars.variables == []
with pytest.raises(ValueError, match="undeclared variables"):
await store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])