mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 04:17:32 +00:00
update prompt variables to list for validation
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
60361b910c
commit
1390660dcf
5 changed files with 65 additions and 37 deletions
17
docs/_static/llama-stack-spec.html
vendored
17
docs/_static/llama-stack-spec.html
vendored
|
@ -9985,11 +9985,11 @@
|
||||||
"description": "The prompt text content with variable placeholders."
|
"description": "The prompt text content with variable placeholders."
|
||||||
},
|
},
|
||||||
"variables": {
|
"variables": {
|
||||||
"type": "object",
|
"type": "array",
|
||||||
"additionalProperties": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Dictionary of variable names to their default values."
|
"description": "List of variable names that can be used in the prompt template."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -10014,8 +10014,8 @@
|
||||||
"description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'"
|
"description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'"
|
||||||
},
|
},
|
||||||
"variables": {
|
"variables": {
|
||||||
"type": "object",
|
"type": "array",
|
||||||
"additionalProperties": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Dictionary of prompt variable names and values"
|
"description": "Dictionary of prompt variable names and values"
|
||||||
|
@ -10030,6 +10030,7 @@
|
||||||
"required": [
|
"required": [
|
||||||
"version",
|
"version",
|
||||||
"prompt_id",
|
"prompt_id",
|
||||||
|
"variables",
|
||||||
"is_default"
|
"is_default"
|
||||||
],
|
],
|
||||||
"title": "Prompt",
|
"title": "Prompt",
|
||||||
|
@ -17824,11 +17825,11 @@
|
||||||
"description": "The updated prompt text content."
|
"description": "The updated prompt text content."
|
||||||
},
|
},
|
||||||
"variables": {
|
"variables": {
|
||||||
"type": "object",
|
"type": "array",
|
||||||
"additionalProperties": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"description": "Updated dictionary of variable names to their default values."
|
"description": "Updated list of variable names that can be used in the prompt template."
|
||||||
},
|
},
|
||||||
"version": {
|
"version": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|
17
docs/_static/llama-stack-spec.yaml
vendored
17
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7373,11 +7373,11 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
The prompt text content with variable placeholders.
|
The prompt text content with variable placeholders.
|
||||||
variables:
|
variables:
|
||||||
type: object
|
type: array
|
||||||
additionalProperties:
|
items:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Dictionary of variable names to their default values.
|
List of variable names that can be used in the prompt template.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- prompt
|
- prompt
|
||||||
|
@ -7399,8 +7399,8 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||||
variables:
|
variables:
|
||||||
type: object
|
type: array
|
||||||
additionalProperties:
|
items:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Dictionary of prompt variable names and values
|
Dictionary of prompt variable names and values
|
||||||
|
@ -7414,6 +7414,7 @@ components:
|
||||||
required:
|
required:
|
||||||
- version
|
- version
|
||||||
- prompt_id
|
- prompt_id
|
||||||
|
- variables
|
||||||
- is_default
|
- is_default
|
||||||
title: Prompt
|
title: Prompt
|
||||||
description: >-
|
description: >-
|
||||||
|
@ -13226,11 +13227,11 @@ components:
|
||||||
type: string
|
type: string
|
||||||
description: The updated prompt text content.
|
description: The updated prompt text content.
|
||||||
variables:
|
variables:
|
||||||
type: object
|
type: array
|
||||||
additionalProperties:
|
items:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Updated dictionary of variable names to their default values.
|
Updated list of variable names that can be used in the prompt template.
|
||||||
version:
|
version:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
|
|
|
@ -4,10 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -27,8 +28,8 @@ class Prompt(BaseModel):
|
||||||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
||||||
version: str = Field(description="Version string (integer start at 1 cast as string)")
|
version: str = Field(description="Version string (integer start at 1 cast as string)")
|
||||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
||||||
variables: dict[str, str] | None = Field(
|
variables: list[str] = Field(
|
||||||
default_factory=dict, description="Variables for dynamic injection using {{variable}} syntax"
|
default_factory=list, description="List of variable names that can be used in the prompt template"
|
||||||
)
|
)
|
||||||
is_default: bool = Field(
|
is_default: bool = Field(
|
||||||
default=False, description="Boolean indicating whether this version is the default version"
|
default=False, description="Boolean indicating whether this version is the default version"
|
||||||
|
@ -66,6 +67,21 @@ class Prompt(BaseModel):
|
||||||
raise
|
raise
|
||||||
return prompt_version
|
return prompt_version
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_prompt_variables(self):
|
||||||
|
"""Validate that all variables used in the prompt are declared in the variables list."""
|
||||||
|
if not self.prompt:
|
||||||
|
return self
|
||||||
|
|
||||||
|
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
||||||
|
declared_variables = set(self.variables)
|
||||||
|
|
||||||
|
undeclared = prompt_variables - declared_variables
|
||||||
|
if undeclared:
|
||||||
|
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_prompt_id(cls) -> str:
|
def generate_prompt_id(cls) -> str:
|
||||||
# Generate 48 hex characters (24 bytes)
|
# Generate 48 hex characters (24 bytes)
|
||||||
|
@ -78,14 +94,14 @@ class CreatePromptRequest(BaseModel):
|
||||||
"""Request model to create a prompt."""
|
"""Request model to create a prompt."""
|
||||||
|
|
||||||
prompt: str = Field(description="The prompt text content")
|
prompt: str = Field(description="The prompt text content")
|
||||||
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
|
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
|
||||||
|
|
||||||
|
|
||||||
class UpdatePromptRequest(BaseModel):
|
class UpdatePromptRequest(BaseModel):
|
||||||
"""Request model for updating a prompt."""
|
"""Request model for updating a prompt."""
|
||||||
|
|
||||||
prompt: str = Field(description="The prompt text content")
|
prompt: str = Field(description="The prompt text content")
|
||||||
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
|
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
|
||||||
|
|
||||||
|
|
||||||
class ListPromptsResponse(BaseModel):
|
class ListPromptsResponse(BaseModel):
|
||||||
|
@ -137,12 +153,12 @@ class Prompts(Protocol):
|
||||||
async def create_prompt(
|
async def create_prompt(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
variables: dict[str, str] | None = None,
|
variables: list[str] | None = None,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Create a new prompt.
|
"""Create a new prompt.
|
||||||
|
|
||||||
:param prompt: The prompt text content with variable placeholders.
|
:param prompt: The prompt text content with variable placeholders.
|
||||||
:param variables: Dictionary of variable names to their default values.
|
:param variables: List of variable names that can be used in the prompt template.
|
||||||
:returns: The created Prompt resource.
|
:returns: The created Prompt resource.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -152,14 +168,14 @@ class Prompts(Protocol):
|
||||||
self,
|
self,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
variables: dict[str, str] | None = None,
|
variables: list[str] | None = None,
|
||||||
version: str | None = None,
|
version: str | None = None,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version).
|
"""Update an existing prompt (increments version).
|
||||||
|
|
||||||
:param prompt_id: The identifier of the prompt to update.
|
:param prompt_id: The identifier of the prompt to update.
|
||||||
:param prompt: The updated prompt text content.
|
:param prompt: The updated prompt text content.
|
||||||
:param variables: Updated dictionary of variable names to their default values.
|
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||||
:param version: The current version of the prompt being updated (as a string).
|
:param version: The current version of the prompt being updated (as a string).
|
||||||
:returns: The updated Prompt resource with incremented version.
|
:returns: The updated Prompt resource with incremented version.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -76,7 +76,7 @@ class PromptServiceImpl(Prompts):
|
||||||
"prompt_id": prompt.prompt_id,
|
"prompt_id": prompt.prompt_id,
|
||||||
"prompt": prompt.prompt,
|
"prompt": prompt.prompt,
|
||||||
"version": prompt.version,
|
"version": prompt.version,
|
||||||
"variables": prompt.variables or {},
|
"variables": prompt.variables or [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class PromptServiceImpl(Prompts):
|
||||||
"""Deserialize a prompt from JSON string."""
|
"""Deserialize a prompt from JSON string."""
|
||||||
obj = json.loads(data)
|
obj = json.loads(data)
|
||||||
return Prompt(
|
return Prompt(
|
||||||
prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", {})
|
prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_prompts(self) -> ListPromptsResponse:
|
async def list_prompts(self) -> ListPromptsResponse:
|
||||||
|
@ -121,11 +121,11 @@ class PromptServiceImpl(Prompts):
|
||||||
async def create_prompt(
|
async def create_prompt(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
variables: dict[str, str] | None = None,
|
variables: list[str] | None = None,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Create a new prompt."""
|
"""Create a new prompt."""
|
||||||
if variables is None:
|
if variables is None:
|
||||||
variables = {}
|
variables = []
|
||||||
|
|
||||||
prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables)
|
prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables)
|
||||||
|
|
||||||
|
@ -142,12 +142,12 @@ class PromptServiceImpl(Prompts):
|
||||||
self,
|
self,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
variables: dict[str, str] | None = None,
|
variables: list[str] | None = None,
|
||||||
version: str | None = None,
|
version: str | None = None,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""Update an existing prompt (increments version)."""
|
"""Update an existing prompt (increments version)."""
|
||||||
if variables is None:
|
if variables is None:
|
||||||
variables = {}
|
variables = []
|
||||||
|
|
||||||
current_prompt = await self.get_prompt(prompt_id)
|
current_prompt = await self.get_prompt(prompt_id)
|
||||||
if version and current_prompt.version != version:
|
if version and current_prompt.version != version:
|
||||||
|
|
|
@ -30,11 +30,11 @@ class TestPrompts:
|
||||||
yield store
|
yield store
|
||||||
|
|
||||||
async def test_create_and_get_prompt(self, store):
|
async def test_create_and_get_prompt(self, store):
|
||||||
prompt = await store.create_prompt("Hello world!", {"name": "John"})
|
prompt = await store.create_prompt("Hello world!", ["name"])
|
||||||
assert prompt.prompt == "Hello world!"
|
assert prompt.prompt == "Hello world!"
|
||||||
assert prompt.version == "1"
|
assert prompt.version == "1"
|
||||||
assert prompt.prompt_id.startswith("pmpt_")
|
assert prompt.prompt_id.startswith("pmpt_")
|
||||||
assert prompt.variables == {"name": "John"}
|
assert prompt.variables == ["name"]
|
||||||
|
|
||||||
retrieved = await store.get_prompt(prompt.prompt_id)
|
retrieved = await store.get_prompt(prompt.prompt_id)
|
||||||
assert retrieved.prompt_id == prompt.prompt_id
|
assert retrieved.prompt_id == prompt.prompt_id
|
||||||
|
@ -42,7 +42,7 @@ class TestPrompts:
|
||||||
|
|
||||||
async def test_update_prompt(self, store):
|
async def test_update_prompt(self, store):
|
||||||
prompt = await store.create_prompt("Original")
|
prompt = await store.create_prompt("Original")
|
||||||
updated = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"})
|
updated = await store.update_prompt(prompt.prompt_id, "Updated", ["v"])
|
||||||
assert updated.version == "2"
|
assert updated.version == "2"
|
||||||
assert updated.prompt == "Updated"
|
assert updated.prompt == "Updated"
|
||||||
|
|
||||||
|
@ -51,16 +51,16 @@ class TestPrompts:
|
||||||
|
|
||||||
prompt = await store.create_prompt("Original")
|
prompt = await store.create_prompt("Original")
|
||||||
assert prompt.version == "1"
|
assert prompt.version == "1"
|
||||||
prompt = await store.update_prompt(prompt.prompt_id, "Updated", {"v": "2"}, version_for_update)
|
prompt = await store.update_prompt(prompt.prompt_id, "Updated", ["v"], version_for_update)
|
||||||
assert prompt.version == "2"
|
assert prompt.version == "2"
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# now this is a stale version
|
# now this is a stale version
|
||||||
await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, version_for_update)
|
await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], version_for_update)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# this version does not exist
|
# this version does not exist
|
||||||
await store.update_prompt(prompt.prompt_id, "Another Update", {"v": "2"}, "99")
|
await store.update_prompt(prompt.prompt_id, "Another Update", ["v"], "99")
|
||||||
|
|
||||||
async def test_delete_prompt(self, store):
|
async def test_delete_prompt(self, store):
|
||||||
prompt = await store.create_prompt("to be deleted")
|
prompt = await store.create_prompt("to be deleted")
|
||||||
|
@ -134,3 +134,13 @@ class TestPrompts:
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await store.list_prompt_versions("nonexistent")
|
await store.list_prompt_versions("nonexistent")
|
||||||
|
|
||||||
|
async def test_prompt_variable_validation(self, store):
|
||||||
|
prompt = await store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
|
||||||
|
assert prompt.variables == ["name", "city"]
|
||||||
|
|
||||||
|
prompt_no_vars = await store.create_prompt("Hello world!", [])
|
||||||
|
assert prompt_no_vars.variables == []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="undeclared variables"):
|
||||||
|
await store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue