diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 7cb2a73f3..a036e5dc0 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -633,6 +633,80 @@ } } }, + "/v1/prompts": { + "get": { + "responses": { + "200": { + "description": "A ListPromptsResponse containing all prompts.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListPromptsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "List all prompts.", + "parameters": [] + }, + "post": { + "responses": { + "200": { + "description": "The created Prompt resource.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Prompt" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "Create a new prompt.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreatePromptRequest" + } + } + }, + "required": true + } + } + }, "/v1/agents/{agent_id}": { "get": { "responses": { @@ -901,6 +975,143 @@ ] } }, + "/v1/prompts/{prompt_id}": { + "get": { + "responses": { + "200": { + "description": "A Prompt resource.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Prompt" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "Get a prompt by its identifier and optional version.", + "parameters": [ + { + "name": "prompt_id", + "in": "path", + "description": "The identifier of the prompt to get.", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "version", + "in": "query", + "description": "The version of the prompt to get (defaults to latest).", + "required": false, + "schema": { + "type": "integer" + } + } + ] + }, + "post": { + "responses": { + "200": { + "description": "The updated Prompt resource with incremented version.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Prompt" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "Update an existing prompt (increments version).", + "parameters": [ + { + "name": "prompt_id", + "in": "path", + "description": "The identifier of the prompt to update.", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdatePromptRequest" + } + } + }, + "required": true + } + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "Delete a prompt.", + "parameters": [ + { + "name": "prompt_id", + "in": "path", + "description": "The identifier of the prompt to delete.", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/inference/embeddings": { "post": { "responses": { @@ -2836,6 +3047,49 @@ ] } }, + "/v1/prompts/{prompt_id}/versions": { + "get": { + "responses": { + "200": { + "description": "A ListPromptsResponse containing all versions of the prompt.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListPromptsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "List all versions of a specific prompt.", + "parameters": [ + { + "name": "prompt_id", + "in": "path", + "description": "The identifier of the prompt to list versions for.", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/providers": { "get": { "responses": { @@ -5007,6 +5261,59 @@ } } }, + "/v1/prompts/{prompt_id}/set-default-version": { + "post": { + "responses": { + "200": { + "description": "The prompt with the specified version now set as default.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Prompt" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Prompts" + ], + "description": "Set which version of a prompt should be the default in get_prompt (latest).", + "parameters": [ + { + "name": "prompt_id", + "in": "path", + "description": "The identifier of the prompt.", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SetDefaultVersionRequest" + } + } + }, + "required": true + } + } + }, "/v1/post-training/supervised-fine-tune": { "post": { "responses": { @@ -9670,6 +9977,65 @@ ], "title": "OpenAIResponseObjectStreamResponseWebSearchCallSearching" }, + "CreatePromptRequest": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt text content with variable placeholders." + }, + "variables": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of variable names that can be used in the prompt template." + } + }, + "additionalProperties": false, + "required": [ + "prompt" + ], + "title": "CreatePromptRequest" + }, + "Prompt": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API." + }, + "version": { + "type": "integer", + "description": "Version (integer starting at 1, incremented on save)" + }, + "prompt_id": { + "type": "string", + "description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'" + }, + "variables": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of prompt variable names that can be used in the prompt template" + }, + "is_default": { + "type": "boolean", + "default": false, + "description": "Boolean indicating whether this version is the default version for this prompt" + } + }, + "additionalProperties": false, + "required": [ + "version", + "prompt_id", + "variables", + "is_default" + ], + "title": "Prompt", + "description": "A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack." + }, "OpenAIDeleteResponseObject": { "type": "object", "properties": { @@ -10296,7 +10662,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "benchmark", "default": "benchmark", @@ -10923,7 +11290,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "dataset", "default": "dataset", @@ -11073,7 +11441,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "model", "default": "model", @@ -11338,7 +11707,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "scoring_function", "default": "scoring_function", @@ -11446,7 +11816,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "shield", "default": "shield", @@ -11691,7 +12062,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "tool", "default": "tool", @@ -11773,7 +12145,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "tool_group", "default": "tool_group", @@ -12067,7 +12440,8 @@ "scoring_function", "benchmark", "tool", - "tool_group" + "tool_group", + "prompt" ], "const": "vector_db", "default": "vector_db", @@ -12882,6 +13256,23 @@ "title": "OpenAIResponseObjectWithInput", "description": "OpenAI response object extended with input context information." }, + "ListPromptsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Prompt" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "ListPromptsResponse", + "description": "Response model to list prompts." + }, "ListProvidersResponse": { "type": "object", "properties": { @@ -17128,6 +17519,20 @@ "title": "ScoreBatchResponse", "description": "Response from batch scoring operations on datasets." }, + "SetDefaultVersionRequest": { + "type": "object", + "properties": { + "version": { + "type": "integer", + "description": "The version to set as default." + } + }, + "additionalProperties": false, + "required": [ + "version" + ], + "title": "SetDefaultVersionRequest" + }, "AlgorithmConfig": { "oneOf": [ { @@ -17412,6 +17817,37 @@ "title": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." }, + "UpdatePromptRequest": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The updated prompt text content." + }, + "version": { + "type": "integer", + "description": "The current version of the prompt being updated." + }, + "variables": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Updated list of variable names that can be used in the prompt template." + }, + "set_as_default": { + "type": "boolean", + "description": "Set the new version as the default (default=True)." + } + }, + "additionalProperties": false, + "required": [ + "prompt", + "version", + "set_as_default" + ], + "title": "UpdatePromptRequest" + }, "VersionInfo": { "type": "object", "properties": { @@ -17537,6 +17973,10 @@ { "name": "PostTraining (Coming Soon)" }, + { + "name": "Prompts", + "x-displayName": "Protocol for prompt management operations." + }, { "name": "Providers", "x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations." @@ -17587,6 +18027,7 @@ "Inspect", "Models", "PostTraining (Coming Soon)", + "Prompts", "Providers", "Safety", "Scoring", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 25089868c..8ed04c1f8 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -427,6 +427,58 @@ paths: schema: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true + /v1/prompts: + get: + responses: + '200': + description: >- + A ListPromptsResponse containing all prompts. + content: + application/json: + schema: + $ref: '#/components/schemas/ListPromptsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: List all prompts. + parameters: [] + post: + responses: + '200': + description: The created Prompt resource. + content: + application/json: + schema: + $ref: '#/components/schemas/Prompt' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: Create a new prompt. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreatePromptRequest' + required: true /v1/agents/{agent_id}: get: responses: @@ -616,6 +668,103 @@ paths: required: true schema: type: string + /v1/prompts/{prompt_id}: + get: + responses: + '200': + description: A Prompt resource. + content: + application/json: + schema: + $ref: '#/components/schemas/Prompt' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: >- + Get a prompt by its identifier and optional version. + parameters: + - name: prompt_id + in: path + description: The identifier of the prompt to get. + required: true + schema: + type: string + - name: version + in: query + description: >- + The version of the prompt to get (defaults to latest). + required: false + schema: + type: integer + post: + responses: + '200': + description: >- + The updated Prompt resource with incremented version. + content: + application/json: + schema: + $ref: '#/components/schemas/Prompt' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: >- + Update an existing prompt (increments version). + parameters: + - name: prompt_id + in: path + description: The identifier of the prompt to update. + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdatePromptRequest' + required: true + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: Delete a prompt. + parameters: + - name: prompt_id + in: path + description: The identifier of the prompt to delete. + required: true + schema: + type: string /v1/inference/embeddings: post: responses: @@ -1983,6 +2132,37 @@ paths: required: false schema: $ref: '#/components/schemas/Order' + /v1/prompts/{prompt_id}/versions: + get: + responses: + '200': + description: >- + A ListPromptsResponse containing all versions of the prompt. + content: + application/json: + schema: + $ref: '#/components/schemas/ListPromptsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: List all versions of a specific prompt. + parameters: + - name: prompt_id + in: path + description: >- + The identifier of the prompt to list versions for. + required: true + schema: + type: string /v1/providers: get: responses: @@ -3546,6 +3726,43 @@ paths: schema: $ref: '#/components/schemas/ScoreBatchRequest' required: true + /v1/prompts/{prompt_id}/set-default-version: + post: + responses: + '200': + description: >- + The prompt with the specified version now set as default. + content: + application/json: + schema: + $ref: '#/components/schemas/Prompt' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Prompts + description: >- + Set which version of a prompt should be the default in get_prompt (latest). + parameters: + - name: prompt_id + in: path + description: The identifier of the prompt. + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SetDefaultVersionRequest' + required: true /v1/post-training/supervised-fine-tune: post: responses: @@ -7148,6 +7365,61 @@ components: - type title: >- OpenAIResponseObjectStreamResponseWebSearchCallSearching + CreatePromptRequest: + type: object + properties: + prompt: + type: string + description: >- + The prompt text content with variable placeholders. + variables: + type: array + items: + type: string + description: >- + List of variable names that can be used in the prompt template. + additionalProperties: false + required: + - prompt + title: CreatePromptRequest + Prompt: + type: object + properties: + prompt: + type: string + description: >- + The system prompt text with variable placeholders. Variables are only + supported when using the Responses API. + version: + type: integer + description: >- + Version (integer starting at 1, incremented on save) + prompt_id: + type: string + description: >- + Unique identifier formatted as 'pmpt_<48-digit-hash>' + variables: + type: array + items: + type: string + description: >- + List of prompt variable names that can be used in the prompt template + is_default: + type: boolean + default: false + description: >- + Boolean indicating whether this version is the default version for this + prompt + additionalProperties: false + required: + - version + - prompt_id + - variables + - is_default + title: Prompt + description: >- + A prompt resource representing a stored OpenAI Compatible prompt template + in Llama Stack. OpenAIDeleteResponseObject: type: object properties: @@ -7621,6 +7893,7 @@ components: - benchmark - tool - tool_group + - prompt const: benchmark default: benchmark description: The resource type, always benchmark @@ -8107,6 +8380,7 @@ components: - benchmark - tool - tool_group + - prompt const: dataset default: dataset description: >- @@ -8219,6 +8493,7 @@ components: - benchmark - tool - tool_group + - prompt const: model default: model description: >- @@ -8410,6 +8685,7 @@ components: - benchmark - tool - tool_group + - prompt const: scoring_function default: scoring_function description: >- @@ -8486,6 +8762,7 @@ components: - benchmark - tool - tool_group + - prompt const: shield default: shield description: The resource type, always shield @@ -8665,6 +8942,7 @@ components: - benchmark - tool - tool_group + - prompt const: tool default: tool description: Type of resource, always 'tool' @@ -8723,6 +9001,7 @@ components: - benchmark - tool - tool_group + - prompt const: tool_group default: tool_group description: Type of resource, always 'tool_group' @@ -8951,6 +9230,7 @@ components: - benchmark - tool - tool_group + - prompt const: vector_db default: vector_db description: >- @@ -9577,6 +9857,18 @@ components: title: OpenAIResponseObjectWithInput description: >- OpenAI response object extended with input context information. + ListPromptsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/Prompt' + additionalProperties: false + required: + - data + title: ListPromptsResponse + description: Response model to list prompts. ListProvidersResponse: type: object properties: @@ -12722,6 +13014,16 @@ components: title: ScoreBatchResponse description: >- Response from batch scoring operations on datasets. + SetDefaultVersionRequest: + type: object + properties: + version: + type: integer + description: The version to set as default. + additionalProperties: false + required: + - version + title: SetDefaultVersionRequest AlgorithmConfig: oneOf: - $ref: '#/components/schemas/LoraFinetuningConfig' @@ -12918,6 +13220,32 @@ components: description: >- Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. + UpdatePromptRequest: + type: object + properties: + prompt: + type: string + description: The updated prompt text content. + version: + type: integer + description: >- + The current version of the prompt being updated. + variables: + type: array + items: + type: string + description: >- + Updated list of variable names that can be used in the prompt template. + set_as_default: + type: boolean + description: >- + Set the new version as the default (default=True). + additionalProperties: false + required: + - prompt + - version + - set_as_default + title: UpdatePromptRequest VersionInfo: type: object properties: @@ -13029,6 +13357,9 @@ tags: - name: Inspect - name: Models - name: PostTraining (Coming Soon) + - name: Prompts + x-displayName: >- + Protocol for prompt management operations. - name: Providers x-displayName: >- Providers API for inspecting, listing, and modifying providers and their configurations. @@ -13056,6 +13387,7 @@ x-tagGroups: - Inspect - Models - PostTraining (Coming Soon) + - Prompts - Providers - Safety - Scoring diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 87fc95917..8d0f2e26d 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -102,6 +102,7 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar benchmarks: Benchmark suite management :cvar tool_groups: Tool group organization :cvar files: File storage and management + :cvar prompts: Prompt versions and management :cvar inspect: Built-in system inspection and introspection """ @@ -127,6 +128,7 @@ class Api(Enum, metaclass=DynamicApiMeta): benchmarks = "benchmarks" tool_groups = "tool_groups" files = "files" + prompts = "prompts" # built-in API inspect = "inspect" diff --git a/llama_stack/apis/prompts/__init__.py b/llama_stack/apis/prompts/__init__.py new file mode 100644 index 000000000..6070f3450 --- /dev/null +++ b/llama_stack/apis/prompts/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .prompts import ListPromptsResponse, Prompt, Prompts + +__all__ = ["Prompt", "Prompts", "ListPromptsResponse"] diff --git a/llama_stack/apis/prompts/prompts.py b/llama_stack/apis/prompts/prompts.py new file mode 100644 index 000000000..e6a376c3f --- /dev/null +++ b/llama_stack/apis/prompts/prompts.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import re +import secrets +from typing import Protocol, runtime_checkable + +from pydantic import BaseModel, Field, field_validator, model_validator + +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod + + +@json_schema_type +class Prompt(BaseModel): + """A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack. + + :param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API. + :param version: Version (integer starting at 1, incremented on save) + :param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>' + :param variables: List of prompt variable names that can be used in the prompt template + :param is_default: Boolean indicating whether this version is the default version for this prompt + """ + + prompt: str | None = Field(default=None, description="The system prompt with variable placeholders") + version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1) + prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'") + variables: list[str] = Field( + default_factory=list, description="List of variable names that can be used in the prompt template" + ) + is_default: bool = Field( + default=False, description="Boolean indicating whether this version is the default version" + ) + + @field_validator("prompt_id") + @classmethod + def validate_prompt_id(cls, prompt_id: str) -> str: + if not isinstance(prompt_id, str): + raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'") + + if not prompt_id.startswith("pmpt_"): + raise ValueError("prompt_id must start with 'pmpt_' prefix") + + hex_part = prompt_id[5:] + if len(hex_part) != 48: + raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)") + + for char in hex_part: + if char not in "0123456789abcdef": + raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]") + + return prompt_id + + @field_validator("version") + @classmethod + def validate_version(cls, prompt_version: int) -> int: + if prompt_version < 1: + raise ValueError("version must be >= 1") + return prompt_version + + @model_validator(mode="after") + def validate_prompt_variables(self): + """Validate that all variables used in the prompt are declared in the variables list.""" + if not self.prompt: + return self + + prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt)) + declared_variables = set(self.variables) + + undeclared = prompt_variables - declared_variables + if undeclared: + raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}") + + return self + + @classmethod + def generate_prompt_id(cls) -> str: + # Generate 48 hex characters (24 bytes) + random_bytes = secrets.token_bytes(24) + hex_string = random_bytes.hex() + return f"pmpt_{hex_string}" + + +class ListPromptsResponse(BaseModel): + """Response model to list prompts.""" + + data: list[Prompt] + + +@runtime_checkable +@trace_protocol +class Prompts(Protocol): + """Protocol for prompt management operations.""" + + @webmethod(route="/prompts", method="GET") + async def list_prompts(self) -> ListPromptsResponse: + """List all prompts. + + :returns: A ListPromptsResponse containing all prompts. + """ + ... + + @webmethod(route="/prompts/{prompt_id}/versions", method="GET") + async def list_prompt_versions( + self, + prompt_id: str, + ) -> ListPromptsResponse: + """List all versions of a specific prompt. + + :param prompt_id: The identifier of the prompt to list versions for. + :returns: A ListPromptsResponse containing all versions of the prompt. + """ + ... + + @webmethod(route="/prompts/{prompt_id}", method="GET") + async def get_prompt( + self, + prompt_id: str, + version: int | None = None, + ) -> Prompt: + """Get a prompt by its identifier and optional version. + + :param prompt_id: The identifier of the prompt to get. + :param version: The version of the prompt to get (defaults to latest). + :returns: A Prompt resource. + """ + ... + + @webmethod(route="/prompts", method="POST") + async def create_prompt( + self, + prompt: str, + variables: list[str] | None = None, + ) -> Prompt: + """Create a new prompt. + + :param prompt: The prompt text content with variable placeholders. + :param variables: List of variable names that can be used in the prompt template. + :returns: The created Prompt resource. + """ + ... + + @webmethod(route="/prompts/{prompt_id}", method="PUT") + async def update_prompt( + self, + prompt_id: str, + prompt: str, + version: int, + variables: list[str] | None = None, + set_as_default: bool = True, + ) -> Prompt: + """Update an existing prompt (increments version). + + :param prompt_id: The identifier of the prompt to update. + :param prompt: The updated prompt text content. + :param version: The current version of the prompt being updated. + :param variables: Updated list of variable names that can be used in the prompt template. + :param set_as_default: Set the new version as the default (default=True). + :returns: The updated Prompt resource with incremented version. + """ + ... + + @webmethod(route="/prompts/{prompt_id}", method="DELETE") + async def delete_prompt( + self, + prompt_id: str, + ) -> None: + """Delete a prompt. + + :param prompt_id: The identifier of the prompt to delete. + """ + ... + + @webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT") + async def set_default_version( + self, + prompt_id: str, + version: int, + ) -> Prompt: + """Set which version of a prompt should be the default in get_prompt (latest). + + :param prompt_id: The identifier of the prompt. + :param version: The version to set as default. + :returns: The prompt with the specified version now set as default. + """ + ... diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 3731fbf1d..7c4130f7d 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -19,6 +19,7 @@ class ResourceType(StrEnum): benchmark = "benchmark" tool = "tool" tool_group = "tool_group" + prompt = "prompt" class Resource(BaseModel): diff --git a/llama_stack/core/prompts/__init__.py b/llama_stack/core/prompts/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/core/prompts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/core/prompts/prompts.py b/llama_stack/core/prompts/prompts.py new file mode 100644 index 000000000..26e8f5cef --- /dev/null +++ b/llama_stack/core/prompts/prompts.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any + +from pydantic import BaseModel + +from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts +from llama_stack.core.datatypes import StackRunConfig +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class PromptServiceConfig(BaseModel): + """Configuration for the built-in prompt service. + + :param run_config: Stack run configuration containing distribution info + """ + + run_config: StackRunConfig + + +async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]): + """Get the prompt service implementation.""" + impl = PromptServiceImpl(config, deps) + await impl.initialize() + return impl + + +class PromptServiceImpl(Prompts): + """Built-in prompt service implementation using KVStore.""" + + def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]): + self.config = config + self.deps = deps + self.kvstore: KVStore + + async def initialize(self) -> None: + kvstore_config = SqliteKVStoreConfig( + db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix() + ) + self.kvstore = await kvstore_impl(kvstore_config) + + def _get_default_key(self, prompt_id: str) -> str: + """Get the KVStore key that stores the default version number.""" + return f"prompts:v1:{prompt_id}:default" + + async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str: + """Get the KVStore key for prompt data, returning default version if applicable.""" + if version: + return self._get_version_key(prompt_id, str(version)) + + default_key = self._get_default_key(prompt_id) + resolved_version = await self.kvstore.get(default_key) + if resolved_version is None: + raise ValueError(f"Prompt {prompt_id}:default not found") + return self._get_version_key(prompt_id, resolved_version) + + def _get_version_key(self, prompt_id: str, version: str) -> str: + """Get the KVStore key for a specific prompt version.""" + return f"prompts:v1:{prompt_id}:{version}" + + def _get_list_key_prefix(self) -> str: + """Get the key prefix for listing prompts.""" + return "prompts:v1:" + + def _serialize_prompt(self, prompt: Prompt) -> str: + """Serialize a prompt to JSON string for storage.""" + return json.dumps( + { + "prompt_id": prompt.prompt_id, + "prompt": prompt.prompt, + "version": prompt.version, + "variables": prompt.variables or [], + "is_default": prompt.is_default, + } + ) + + def _deserialize_prompt(self, data: str) -> Prompt: + """Deserialize a prompt from JSON string.""" + obj = json.loads(data) + return Prompt( + prompt_id=obj["prompt_id"], + prompt=obj["prompt"], + version=obj["version"], + variables=obj.get("variables", []), + is_default=obj.get("is_default", False), + ) + + async def list_prompts(self) -> ListPromptsResponse: + """List all prompts (default versions only).""" + prefix = self._get_list_key_prefix() + keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") + + prompts = [] + for key in keys: + if key.endswith(":default"): + try: + default_version = await self.kvstore.get(key) + if default_version: + prompt_id = key.replace(prefix, "").replace(":default", "") + version_key = self._get_version_key(prompt_id, default_version) + data = await self.kvstore.get(version_key) + if data: + prompt = self._deserialize_prompt(data) + prompts.append(prompt) + except (json.JSONDecodeError, KeyError): + continue + + prompts.sort(key=lambda p: p.prompt_id or "", reverse=True) + return ListPromptsResponse(data=prompts) + + async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt: + """Get a prompt by its identifier and optional version.""" + key = await self._get_prompt_key(prompt_id, version) + data = await self.kvstore.get(key) + if data is None: + raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found") + return self._deserialize_prompt(data) + + async def create_prompt( + self, + prompt: str, + variables: list[str] | None = None, + ) -> Prompt: + """Create a new prompt.""" + if variables is None: + variables = [] + + prompt_obj = Prompt( + prompt_id=Prompt.generate_prompt_id(), + prompt=prompt, + version=1, + variables=variables, + ) + + version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version)) + data = self._serialize_prompt(prompt_obj) + await self.kvstore.set(version_key, data) + + default_key = self._get_default_key(prompt_obj.prompt_id) + await self.kvstore.set(default_key, str(prompt_obj.version)) + + return prompt_obj + + async def update_prompt( + self, + prompt_id: str, + prompt: str, + version: int, + variables: list[str] | None = None, + set_as_default: bool = True, + ) -> Prompt: + """Update an existing prompt (increments version).""" + if version < 1: + raise ValueError("Version must be >= 1") + if variables is None: + variables = [] + + prompt_versions = await self.list_prompt_versions(prompt_id) + latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version)) + + if version and latest_prompt.version != version: + raise ValueError( + f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request." + ) + + current_version = latest_prompt.version if version is None else version + new_version = current_version + 1 + + updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) + + version_key = self._get_version_key(prompt_id, str(new_version)) + data = self._serialize_prompt(updated_prompt) + await self.kvstore.set(version_key, data) + + if set_as_default: + await self.set_default_version(prompt_id, new_version) + + return updated_prompt + + async def delete_prompt(self, prompt_id: str) -> None: + """Delete a prompt and all its versions.""" + await self.get_prompt(prompt_id) + + prefix = f"prompts:v1:{prompt_id}:" + keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") + + for key in keys: + await self.kvstore.delete(key) + + async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse: + """List all versions of a specific prompt.""" + prefix = f"prompts:v1:{prompt_id}:" + keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") + + default_version = None + prompts = [] + + for key in keys: + data = await self.kvstore.get(key) + if key.endswith(":default"): + default_version = data + else: + if data: + prompt_obj = self._deserialize_prompt(data) + prompts.append(prompt_obj) + + if not prompts: + raise ValueError(f"Prompt {prompt_id} not found") + + for prompt in prompts: + prompt.is_default = str(prompt.version) == default_version + + prompts.sort(key=lambda x: x.version) + return ListPromptsResponse(data=prompts) + + async def set_default_version(self, prompt_id: str, version: int) -> Prompt: + """Set which version of a prompt should be the default, If not set. the default is the latest.""" + version_key = self._get_version_key(prompt_id, str(version)) + data = await self.kvstore.get(version_key) + if data is None: + raise ValueError(f"Prompt {prompt_id} version {version} not found") + + default_key = self._get_default_key(prompt_id) + await self.kvstore.set(default_key, str(version)) + + return self._deserialize_prompt(data) diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index a8ad03e1a..373446de6 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -19,6 +19,7 @@ from llama_stack.apis.inference import Inference, InferenceProvider from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining +from llama_stack.apis.prompts import Prompts from llama_stack.apis.providers import Providers as ProvidersAPI from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring @@ -93,6 +94,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.tool_groups: ToolGroups, Api.tool_runtime: ToolRuntime, Api.files: Files, + Api.prompts: Prompts, } if external_apis: diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 288bf46e1..d3e875fec 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -515,6 +515,7 @@ def main(args: argparse.Namespace | None = None): apis_to_serve.add("inspect") apis_to_serve.add("providers") + apis_to_serve.add("prompts") for api_str in apis_to_serve: api = Api(api_str) diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index bccea48d3..7ab8d2c64 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining +from llama_stack.apis.prompts import Prompts from llama_stack.apis.providers import Providers from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring @@ -37,6 +38,7 @@ from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import Provider, StackRunConfig from llama_stack.core.distribution import get_provider_registry from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl +from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl from llama_stack.core.providers import ProviderImpl, ProviderImplConfig from llama_stack.core.resolver import ProviderRegistry, resolve_impls from llama_stack.core.routing_tables.common import CommonRoutingTableImpl @@ -72,6 +74,7 @@ class LlamaStack( ToolRuntime, RAGToolRuntime, Files, + Prompts, ): pass @@ -305,6 +308,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf ) impls[Api.providers] = providers_impl + prompts_impl = PromptServiceImpl( + PromptServiceConfig(run_config=run_config), + deps=impls, + ) + impls[Api.prompts] = prompts_impl + # Produces a stack of providers for the given run config. Not all APIs may be # asked for in the run config. @@ -329,6 +338,9 @@ async def construct_stack( # Add internal implementations after all other providers are resolved add_internal_implementations(impls, run_config) + if Api.prompts in impls: + await impls[Api.prompts].initialize() + await register_resources(run_config, impls) await refresh_registry_once(impls) diff --git a/tests/unit/prompts/prompts/__init__.py b/tests/unit/prompts/prompts/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/prompts/prompts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py new file mode 100644 index 000000000..b2c619e49 --- /dev/null +++ b/tests/unit/prompts/prompts/conftest.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import random + +import pytest + +from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +@pytest.fixture +async def temp_prompt_store(tmp_path_factory): + unique_id = f"prompt_store_{random.randint(1, 1000000)}" + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / f"{unique_id}.db") + + from llama_stack.core.datatypes import StackRunConfig + from llama_stack.providers.utils.kvstore import kvstore_impl + + mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={}) + config = PromptServiceConfig(run_config=mock_run_config) + store = PromptServiceImpl(config, deps={}) + + store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + + yield store diff --git a/tests/unit/prompts/prompts/test_prompts.py b/tests/unit/prompts/prompts/test_prompts.py new file mode 100644 index 000000000..792e55530 --- /dev/null +++ b/tests/unit/prompts/prompts/test_prompts.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + + +class TestPrompts: + async def test_create_and_get_prompt(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Hello world!", ["name"]) + assert prompt.prompt == "Hello world!" + assert prompt.version == 1 + assert prompt.prompt_id.startswith("pmpt_") + assert prompt.variables == ["name"] + + retrieved = await temp_prompt_store.get_prompt(prompt.prompt_id) + assert retrieved.prompt_id == prompt.prompt_id + assert retrieved.prompt == prompt.prompt + + async def test_update_prompt(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Original") + updated = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"]) + assert updated.version == 2 + assert updated.prompt == "Updated" + + async def test_update_prompt_with_version(self, temp_prompt_store): + version_for_update = 1 + + prompt = await temp_prompt_store.create_prompt("Original") + assert prompt.version == 1 + prompt = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"]) + assert prompt.version == 2 + + with pytest.raises(ValueError): + # now this is a stale version + await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"]) + + with pytest.raises(ValueError): + # this version does not exist + await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"]) + + async def test_delete_prompt(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("to be deleted") + await temp_prompt_store.delete_prompt(prompt.prompt_id) + with pytest.raises(ValueError): + await temp_prompt_store.get_prompt(prompt.prompt_id) + + async def test_list_prompts(self, temp_prompt_store): + response = await temp_prompt_store.list_prompts() + assert response.data == [] + + await temp_prompt_store.create_prompt("first") + await temp_prompt_store.create_prompt("second") + + response = await temp_prompt_store.list_prompts() + assert len(response.data) == 2 + + async def test_version(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1) + + v1 = await temp_prompt_store.get_prompt(prompt.prompt_id, version=1) + assert v1.version == 1 and v1.prompt == "V1" + + latest = await temp_prompt_store.get_prompt(prompt.prompt_id) + assert latest.version == 2 and latest.prompt == "V2" + + async def test_set_default_version(self, temp_prompt_store): + prompt0 = await temp_prompt_store.create_prompt("V1") + prompt1 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V2", 1) + + assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 2 + prompt_default = await temp_prompt_store.set_default_version(prompt0.prompt_id, 1) + assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 1 + assert prompt_default.version == 1 + + prompt2 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V3", prompt1.version) + assert prompt2.version == 3 + + async def test_prompt_id_generation_and_validation(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Test") + assert prompt.prompt_id.startswith("pmpt_") + assert len(prompt.prompt_id) == 53 + + with pytest.raises(ValueError): + await temp_prompt_store.get_prompt("invalid_id") + + async def test_list_shows_default_versions(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1) + await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2) + + response = await temp_prompt_store.list_prompts() + listed_prompt = response.data[0] + assert listed_prompt.version == 3 and listed_prompt.prompt == "V3" + + await temp_prompt_store.set_default_version(prompt.prompt_id, 1) + + response = await temp_prompt_store.list_prompts() + listed_prompt = response.data[0] + assert listed_prompt.version == 1 and listed_prompt.prompt == "V1" + assert not (await temp_prompt_store.get_prompt(prompt.prompt_id, 3)).is_default + + async def test_get_all_prompt_versions(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1) + await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2) + + versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data + assert len(versions) == 3 + assert [v.version for v in versions] == [1, 2, 3] + assert [v.is_default for v in versions] == [False, False, True] + + await temp_prompt_store.set_default_version(prompt.prompt_id, 2) + versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data + assert [v.is_default for v in versions] == [False, True, False] + + with pytest.raises(ValueError): + await temp_prompt_store.list_prompt_versions("nonexistent") + + async def test_prompt_variable_validation(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"]) + assert prompt.variables == ["name", "city"] + + prompt_no_vars = await temp_prompt_store.create_prompt("Hello world!", []) + assert prompt_no_vars.variables == [] + + with pytest.raises(ValueError, match="undeclared variables"): + await temp_prompt_store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"]) + + async def test_update_prompt_set_as_default_behavior(self, temp_prompt_store): + prompt = await temp_prompt_store.create_prompt("V1") + assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 1 + + prompt_v2 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1, [], set_as_default=True) + assert prompt_v2.version == 2 + assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2 + + prompt_v3 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2, [], set_as_default=False) + assert prompt_v3.version == 3 + assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2