mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 04:45:44 +00:00
removed UpdatePromptRequest, convert prompt version to integer, update tests, and remove TempPathFile in favor of tmp_path_factory
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
2a8469f156
commit
64f6840195
7 changed files with 88 additions and 109 deletions
|
@ -4,6 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts, UpdatePromptRequest
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse", "UpdatePromptRequest"]
|
||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
|
||||
|
|
|
@ -19,14 +19,14 @@ class Prompt(BaseModel):
|
|||
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
|
||||
|
||||
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
|
||||
:param version: Version string (integer start at 1 cast as string, incremented on save)
|
||||
:param version: Version (integer starting at 1, incremented on save)
|
||||
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
|
||||
:param variables: List of prompt variable names that can be used in the prompt template
|
||||
:param is_default: Boolean indicating whether this version is the default version for this prompt
|
||||
"""
|
||||
|
||||
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
||||
version: str = Field(description="Version string (integer start at 1 cast as string)")
|
||||
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
|
||||
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
||||
variables: list[str] = Field(
|
||||
default_factory=list, description="List of variable names that can be used in the prompt template"
|
||||
|
@ -56,15 +56,9 @@ class Prompt(BaseModel):
|
|||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, prompt_version: str) -> str:
|
||||
try:
|
||||
int_version = int(prompt_version)
|
||||
if int_version < 1:
|
||||
raise ValueError("version must be >= 1")
|
||||
except ValueError as e:
|
||||
if "invalid literal" in str(e):
|
||||
raise ValueError("version must be a string representation of an integer") from e
|
||||
raise
|
||||
def validate_version(cls, prompt_version: int) -> int:
|
||||
if prompt_version < 1:
|
||||
raise ValueError("version must be >= 1")
|
||||
return prompt_version
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -90,13 +84,6 @@ class Prompt(BaseModel):
|
|||
return f"pmpt_{hex_string}"
|
||||
|
||||
|
||||
class UpdatePromptRequest(BaseModel):
|
||||
"""Request model for updating a prompt."""
|
||||
|
||||
prompt: str = Field(description="The prompt text content")
|
||||
variables: list[str] = Field(default_factory=list, description="List of variable names for dynamic injection")
|
||||
|
||||
|
||||
class ListPromptsResponse(BaseModel):
|
||||
"""Response model to list prompts."""
|
||||
|
||||
|
@ -132,7 +119,7 @@ class Prompts(Protocol):
|
|||
async def get_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: str | None = None,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
|
||||
|
@ -161,14 +148,14 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
:param version: The current version of the prompt being updated (as a string).
|
||||
:param version: The current version of the prompt being updated.
|
||||
:param variables: Updated list of variable names that can be used in the prompt template.
|
||||
:returns: The updated Prompt resource with incremented version.
|
||||
"""
|
||||
|
@ -189,7 +176,7 @@ class Prompts(Protocol):
|
|||
async def set_default_version(
|
||||
self,
|
||||
prompt_id: str,
|
||||
version: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
|
|
|
@ -50,10 +50,10 @@ class PromptServiceImpl(Prompts):
|
|||
"""Get the KVStore key that stores the default version number."""
|
||||
return f"prompts:v1:{prompt_id}:default"
|
||||
|
||||
async def _get_prompt_key(self, prompt_id: str, version: str | None = None) -> str:
|
||||
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
||||
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||
if version:
|
||||
return self._get_version_key(prompt_id, version)
|
||||
return self._get_version_key(prompt_id, str(version))
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
resolved_version = await self.kvstore.get(default_key)
|
||||
|
@ -77,6 +77,7 @@ class PromptServiceImpl(Prompts):
|
|||
"prompt": prompt.prompt,
|
||||
"version": prompt.version,
|
||||
"variables": prompt.variables or [],
|
||||
"is_default": prompt.is_default,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -84,7 +85,11 @@ class PromptServiceImpl(Prompts):
|
|||
"""Deserialize a prompt from JSON string."""
|
||||
obj = json.loads(data)
|
||||
return Prompt(
|
||||
prompt_id=obj["prompt_id"], prompt=obj["prompt"], version=obj["version"], variables=obj.get("variables", [])
|
||||
prompt_id=obj["prompt_id"],
|
||||
prompt=obj["prompt"],
|
||||
version=obj["version"],
|
||||
variables=obj.get("variables", []),
|
||||
is_default=obj.get("is_default", False),
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
|
@ -110,7 +115,7 @@ class PromptServiceImpl(Prompts):
|
|||
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def get_prompt(self, prompt_id: str, version: str | None = None) -> Prompt:
|
||||
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version."""
|
||||
key = await self._get_prompt_key(prompt_id, version)
|
||||
data = await self.kvstore.get(key)
|
||||
|
@ -127,14 +132,19 @@ class PromptServiceImpl(Prompts):
|
|||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_obj = Prompt(prompt_id=Prompt.generate_prompt_id(), prompt=prompt, version="1", variables=variables)
|
||||
prompt_obj = Prompt(
|
||||
prompt_id=Prompt.generate_prompt_id(),
|
||||
prompt=prompt,
|
||||
version=1,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
version_key = self._get_version_key(prompt_obj.prompt_id, "1")
|
||||
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
||||
data = self._serialize_prompt(prompt_obj)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||
await self.kvstore.set(default_key, "1")
|
||||
await self.kvstore.set(default_key, str(prompt_obj.version))
|
||||
|
||||
return prompt_obj
|
||||
|
||||
|
@ -142,10 +152,12 @@ class PromptServiceImpl(Prompts):
|
|||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version)."""
|
||||
if version < 1:
|
||||
raise ValueError("Version must be >= 1")
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
|
@ -158,16 +170,16 @@ class PromptServiceImpl(Prompts):
|
|||
)
|
||||
|
||||
current_version = latest_prompt.version if version is None else version
|
||||
new_version = str(int(current_version) + 1)
|
||||
new_version = current_version + 1
|
||||
|
||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||
|
||||
version_key = self._get_version_key(prompt_id, new_version)
|
||||
version_key = self._get_version_key(prompt_id, str(new_version))
|
||||
data = self._serialize_prompt(updated_prompt)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
await self.kvstore.set(default_key, new_version)
|
||||
await self.kvstore.set(default_key, str(new_version))
|
||||
|
||||
return updated_prompt
|
||||
|
||||
|
@ -202,19 +214,19 @@ class PromptServiceImpl(Prompts):
|
|||
raise ValueError(f"Prompt {prompt_id} not found")
|
||||
|
||||
for prompt in prompts:
|
||||
prompt.is_default = prompt.version == default_version
|
||||
prompt.is_default = str(prompt.version) == default_version
|
||||
|
||||
prompts.sort(key=lambda x: int(x.version))
|
||||
prompts.sort(key=lambda x: x.version)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def set_default_version(self, prompt_id: str, version: str) -> Prompt:
|
||||
"""Set which version of a prompt should be the default (latest)."""
|
||||
version_key = self._get_version_key(prompt_id, version)
|
||||
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
||||
version_key = self._get_version_key(prompt_id, str(version))
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
await self.kvstore.set(default_key, version)
|
||||
await self.kvstore.set(default_key, str(version))
|
||||
|
||||
return self._deserialize_prompt(data)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue