From 1390660dcf5e2dcbc136049598187ff6f7364e81 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Sun, 7 Sep 2025 14:10:19 -0400 Subject: [PATCH] update prompt variables to list for validation Signed-off-by: Francisco Javier Arceo --- docs/_static/llama-stack-spec.html | 17 ++++++----- docs/_static/llama-stack-spec.yaml | 17 ++++++----- llama_stack/apis/prompts/prompts.py | 34 ++++++++++++++++------ llama_stack/core/prompts/prompts.py | 12 ++++---- tests/unit/prompts/prompts/test_prompts.py | 22 ++++++++++---- 5 files changed, 65 insertions(+), 37 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 5575eb36e..31c637bbb 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9985,11 +9985,11 @@ "description": "The prompt text content with variable placeholders." }, "variables": { - "type": "object", - "additionalProperties": { + "type": "array", + "items": { "type": "string" }, - "description": "Dictionary of variable names to their default values." + "description": "List of variable names that can be used in the prompt template." } }, "additionalProperties": false, @@ -10014,8 +10014,8 @@ "description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'" }, "variables": { - "type": "object", - "additionalProperties": { + "type": "array", + "items": { "type": "string" }, "description": "Dictionary of prompt variable names and values" @@ -10030,6 +10030,7 @@ "required": [ "version", "prompt_id", + "variables", "is_default" ], "title": "Prompt", @@ -17824,11 +17825,11 @@ "description": "The updated prompt text content." }, "variables": { - "type": "object", - "additionalProperties": { + "type": "array", + "items": { "type": "string" }, - "description": "Updated dictionary of variable names to their default values." + "description": "Updated list of variable names that can be used in the prompt template." }, "version": { "type": "string", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 2c0993a9e..03457ecf0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7373,11 +7373,11 @@ components: description: >- The prompt text content with variable placeholders. variables: - type: object - additionalProperties: + type: array + items: type: string description: >- - Dictionary of variable names to their default values. + List of variable names that can be used in the prompt template. additionalProperties: false required: - prompt @@ -7399,8 +7399,8 @@ components: description: >- Unique identifier formatted as 'pmpt_<48-digit-hash>' variables: - type: object - additionalProperties: + type: array + items: type: string description: >- Dictionary of prompt variable names and values @@ -7414,6 +7414,7 @@ components: required: - version - prompt_id + - variables - is_default title: Prompt description: >- @@ -13226,11 +13227,11 @@ components: type: string description: The updated prompt text content. variables: - type: object - additionalProperties: + type: array + items: type: string description: >- - Updated dictionary of variable names to their default values. + Updated list of variable names that can be used in the prompt template. version: type: string description: >- diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py index ca102d9c6..8d34aa5ca 100644 --- a/llama_stack/apis/prompts/prompts.py +++ b/llama_stack/apis/prompts/prompts.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re import secrets from typing import Protocol, runtime_checkable -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -27,8 +28,8 @@ class Prompt(BaseModel): prompt: str | None = Field(default=None, description="The system prompt with variable placeholders") version: str = Field(description="Version string (integer start at 1 cast as string)") prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'") - variables: dict[str, str] | None = Field( - default_factory=dict, description="Variables for dynamic injection using {{variable}} syntax" + variables: list[str] = Field( + default_factory=list, description="List of variable names that can be used in the prompt template" ) is_default: bool = Field( default=False, description="Boolean indicating whether this version is the default version" @@ -66,6 +67,21 @@ class Prompt(BaseModel): raise return prompt_version + @model_validator(mode="after") + def validate_prompt_variables(self): + """Validate that all variables used in the prompt are declared in the variables list.""" + if not self.prompt: + return self + + prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt)) + declared_variables = set(self.variables) + + undeclared = prompt_variables - declared_variables + if undeclared: + raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}") + + return self + @classmethod def generate_prompt_id(cls) -> str: # Generate 48 hex characters (24 bytes) @@ -78,14 +94,14 @@ class CreatePromptRequest(BaseModel): """Request model to create a prompt.""" prompt: str = Field(description="The prompt text content") - variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection") + variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection") class UpdatePromptRequest(BaseModel): """Request model for updating a prompt.""" prompt: str = Field(description="The prompt text content") - variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection") + variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection") class ListPromptsResponse(BaseModel): @@ -137,12 +153,12 @@ class Prompts(Protocol): async def create_prompt( self, prompt: str, - variables: dict[str, str] | None = None, + variables: list[str] | None = None, ) -> Prompt: """Create a new prompt. :param prompt: The prompt text content with variable placeholders. - :param variables: Dictionary of variable names to their default values. + :param variables: List of variable names that can be used in the prompt template. :returns: The created Prompt resource. """ ... @@ -152,14 +168,14 @@ class Prompts(Protocol): self, prompt_id: str, prompt: str, - variables: dict[str, str] | None = None, + variables: list[str] | None = None, version: str | None = None, ) -> Prompt: """Update an existing prompt (increments version). :param prompt_id: The identifier of the prompt to update. :param prompt: The updated prompt text content. - :param variables: Updated dictionary of variable names to their default values. + :param variables: Updated list of variable names that can be used in the prompt template. :param version: The current version of the prompt being updated (as a string). :returns: The updated Prompt resource with incremented version. """ diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py index 730ce00e3..624764020 100644 --- a/llama_stack/core/prompts/prompts.py +++ b/llama_stack/core/prompts/prompts.py @@ -76,7 +76,7 @@ class PromptServiceImpl(Prompts): "prompt_id": prompt.prompt_id, "prompt": prompt.prompt, "version": prompt.version, - "variables": prompt.variables or {}, + "variables": prompt.variables or [], } ) @@ -84,7 +84,7 @@ class PromptServiceImpl(Prompts): """Deserialize a prompt from JSON string.""" obj = json.loads(data) return Prompt( - prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", {}) + prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", []) ) async def list_prompts(self) -> ListPromptsResponse: @@ -121,11 +121,11 @@ class PromptServiceImpl(Prompts): async def create_prompt( self, prompt: str, - variables: dict[str, str] | None = None, + variables: list[str] | None = None, ) -> Prompt: """Create a new prompt.""" if variables is None: - variables = {} + variables = [] prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables) @@ -142,12 +142,12 @@ class PromptServiceImpl(Prompts): self, prompt_id: str, prompt: str, - variables: dict[str, str] | None = None, + variables: list[str] | None = None, version: str | None = None, ) -> Prompt: """Update an existing prompt (increments version).""" if variables is None: - variables = {} + variables = [] current_prompt = await self.get_prompt(prompt_id) if version and current_prompt.version != version: diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py index 01395df84..93f443374 100644 --- a/tests/unit/prompts/prompts/test_prompts.py +++ b/tests/unit/prompts/prompts/test_prompts.py @@ -30,11 +30,11 @@ class TestPrompts: yield store async def test_create_and_get_prompt(self, store): - prompt = await store.create_prompt("Hello world!", {"name": "John"}) + prompt = await store.create_prompt("Hello world!", ["name"]) assert prompt.prompt == "Hello world!" assert prompt.version == "1" assert prompt.prompt_id.startswith("pmpt_") - assert prompt.variables == {"name": "John"} + assert prompt.variables == ["name"] retrieved = await store.get_prompt(prompt.prompt_id) assert retrieved.prompt_id == prompt.prompt_id @@ -42,7 +42,7 @@ class TestPrompts: async def test_update_prompt(self, store): prompt = await store.create_prompt("Original") - updated = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}) + updated = await store.update_prompt(prompt.prompt_id, "Updated", ["v"]) assert updated.version == "2" assert updated.prompt == "Updated" @@ -51,16 +51,16 @@ class TestPrompts: prompt = await store.create_prompt("Original") assert prompt.version == "1" - prompt = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}, version_for_update) + prompt = await store.update_prompt(prompt.prompt_id, "Updated", ["v"], version_for_update) assert prompt.version == "2" with pytest.raises(ValueError): # now this is a stale version - await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, version_for_update) + await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], version_for_update) with pytest.raises(ValueError): # this version does not exist - await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, "99") + await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], "99") async def test_delete_prompt(self, store): prompt = await store.create_prompt("to be deleted") @@ -134,3 +134,13 @@ class TestPrompts: with pytest.raises(ValueError): await store.list_prompt_versions("nonexistent") + + async def test_prompt_variable_validation(self, store): + prompt = await store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"]) + assert prompt.variables == ["name", "city"] + + prompt_no_vars = await store.create_prompt("Hello world!", []) + assert prompt_no_vars.variables == [] + + with pytest.raises(ValueError, match="undeclared variables"): + await store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])