mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 04:17:32 +00:00
update prompt variables to list for validation
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
60361b910c
commit
1390660dcf
5 changed files with 65 additions and 37 deletions
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import secrets
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
@ -27,8 +28,8 @@ class Prompt(BaseModel):
|
|||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
||||
version: str = Field(description="Version string (integer start at 1 cast as string)")
|
||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
||||
variables: dict[str, str] | None = Field(
|
||||
default_factory=dict, description="Variables for dynamic injection using {{variable}} syntax"
|
||||
variables: list[str] = Field(
|
||||
default_factory=list, description="List of variable names that can be used in the prompt template"
|
||||
)
|
||||
is_default: bool = Field(
|
||||
default=False, description="Boolean indicating whether this version is the default version"
|
||||
|
@ -66,6 +67,21 @@ class Prompt(BaseModel):
|
|||
raise
|
||||
return prompt_version
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_prompt_variables(self):
|
||||
"""Validate that all variables used in the prompt are declared in the variables list."""
|
||||
if not self.prompt:
|
||||
return self
|
||||
|
||||
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
||||
declared_variables = set(self.variables)
|
||||
|
||||
undeclared = prompt_variables - declared_variables
|
||||
if undeclared:
|
||||
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def generate_prompt_id(cls) -> str:
|
||||
# Generate 48 hex characters (24 bytes)
|
||||
|
@ -78,14 +94,14 @@ class CreatePromptRequest(BaseModel):
|
|||
"""Request model to create a prompt."""
|
||||
|
||||
prompt: str = Field(description="The prompt text content")
|
||||
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
|
||||
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
|
||||
|
||||
|
||||
class UpdatePromptRequest(BaseModel):
|
||||
"""Request model for updating a prompt."""
|
||||
|
||||
prompt: str = Field(description="The prompt text content")
|
||||
variables: dict[str, str] = Field(default_factory=dict, description="Variables for dynamic injection")
|
||||
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
|
||||
|
||||
|
||||
class ListPromptsResponse(BaseModel):
|
||||
|
@ -137,12 +153,12 @@ class Prompts(Protocol):
|
|||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: dict[str, str] | None = None,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: Dictionary of variable names to their default values.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
:returns: The created Prompt resource.
|
||||
"""
|
||||
...
|
||||
|
@ -152,14 +168,14 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
variables: dict[str, str] | None = None,
|
||||
variables: list[str] | None = None,
|
||||
version: str | None = None,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
:param variables: Updated dictionary of variable names to their default values.
|
||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||
:param version: The current version of the prompt being updated (as a string).
|
||||
:returns: The updated Prompt resource with incremented version.
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue