mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: add Prompts API to Responses API
This commit is contained in:
parent
9f6c658f2a
commit
bdc16ea392
15 changed files with 526 additions and 4 deletions
|
@ -38,6 +38,7 @@ from .openai_responses import (
|
|||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponsePromptParam,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
|
||||
|
@ -796,6 +797,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
prompt: OpenAIResponsePromptParam | None = None,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -807,9 +809,9 @@ class Agents(Protocol):
|
|||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param prompt: Prompt object with ID, version, and variables.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:returns: An OpenAIResponseObject.
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.prompts.prompts import Prompt
|
||||
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
@ -336,6 +337,20 @@ class OpenAIResponseTextFormat(TypedDict, total=False):
|
|||
strict: bool | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponsePromptParam(BaseModel):
|
||||
"""Prompt object that is used for OpenAI responses.
|
||||
|
||||
:param id: Unique identifier of the prompt template
|
||||
:param variables: Dictionary of variable names to values for template substitution
|
||||
:param version: Version number of the prompt to use (defaults to latest if not specified)
|
||||
"""
|
||||
|
||||
id: str
|
||||
variables: dict[str, Any] | None = None
|
||||
version: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseText(BaseModel):
|
||||
"""Text response configuration for OpenAI responses.
|
||||
|
@ -357,6 +372,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param object: Object type identifier, always "response"
|
||||
:param output: List of generated output items (messages, tool calls, etc.)
|
||||
:param parallel_tool_calls: Whether tool calls can be executed in parallel
|
||||
:param prompt: (Optional) Prompt object with ID, version, and variables
|
||||
:param previous_response_id: (Optional) ID of the previous response in a conversation
|
||||
:param status: Current status of the response generation
|
||||
:param temperature: (Optional) Sampling temperature used for generation
|
||||
|
@ -373,6 +389,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
output: list[OpenAIResponseOutput]
|
||||
parallel_tool_calls: bool = False
|
||||
previous_response_id: str | None = None
|
||||
prompt: Prompt | None = None
|
||||
status: str
|
||||
temperature: float | None = None
|
||||
# Default to text format to avoid breaking the loading of old responses
|
||||
|
|
|
@ -321,6 +321,10 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
# Set prompts API on agents provider if it exists
|
||||
if Api.agents in impls and hasattr(impls[Api.agents], "set_prompts_api"):
|
||||
impls[Api.agents].set_prompts_api(prompts_impl)
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
|
|
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
None, # prompts_api will be set later when available
|
||||
policy,
|
||||
)
|
||||
await impl.initialize()
|
||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponsePromptParam, OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
|
@ -37,6 +37,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
@ -63,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
prompts_api: Prompts | None,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
self.config = config
|
||||
|
@ -71,6 +73,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.prompts_api = prompts_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
@ -86,8 +89,14 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
prompts_api=self.prompts_api,
|
||||
)
|
||||
|
||||
def set_prompts_api(self, prompts_api: Prompts) -> None:
|
||||
self.prompts_api = prompts_api
|
||||
if hasattr(self, "openai_responses_impl") and self.openai_responses_impl:
|
||||
self.openai_responses_impl.prompts_api = prompts_api
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -320,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
prompt: OpenAIResponsePromptParam | None = None,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -333,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
prompt,
|
||||
instructions,
|
||||
previous_response_id,
|
||||
store,
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponsePromptParam,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
|
@ -29,6 +30,8 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.prompts.prompts import Prompt
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -61,12 +64,14 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
prompts_api: Prompts,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.prompts_api = prompts_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
|
@ -123,6 +128,41 @@ class OpenAIResponsesImpl:
|
|||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def _prepend_prompt(
|
||||
self, messages: list[OpenAIMessageParam], prompt_params: OpenAIResponsePromptParam
|
||||
) -> Prompt:
|
||||
if not prompt_params or not prompt_params.id:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Check if prompt exists in Llama Stack and retrieve it
|
||||
prompt_version = int(prompt_params.version) if prompt_params.version else None
|
||||
cur_prompt = await self.prompts_api.get_prompt(prompt_params.id, prompt_version)
|
||||
if cur_prompt and cur_prompt.prompt:
|
||||
cur_prompt_text = cur_prompt.prompt
|
||||
cur_prompt_variables = cur_prompt.variables
|
||||
|
||||
final_prompt_text = cur_prompt_text
|
||||
if prompt_params.variables:
|
||||
# check if the variables are valid
|
||||
for name in prompt_params.variables.keys():
|
||||
if name not in cur_prompt_variables:
|
||||
raise ValueError(f"Variable {name} not found in prompt {prompt_params.id}")
|
||||
|
||||
# replace the variables in the prompt text
|
||||
for name, value in prompt_params.variables.items():
|
||||
final_prompt_text = final_prompt_text.replace(f"{{{{ {name} }}}}", str(value))
|
||||
|
||||
messages.insert(0, OpenAISystemMessageParam(content=final_prompt_text))
|
||||
logger.info(f"Prompt {prompt_params.id} found and applied\nFinal prompt text: {final_prompt_text}")
|
||||
return cur_prompt
|
||||
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Prompt {prompt_params.id} with version {prompt_params.version} not found, skipping prompt prepending"
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
|
@ -199,6 +239,7 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
prompt: OpenAIResponsePromptParam | None = None,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -215,6 +256,7 @@ class OpenAIResponsesImpl:
|
|||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
|
@ -243,6 +285,7 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
prompt: OpenAIResponsePromptParam | None = None,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -253,6 +296,10 @@ class OpenAIResponsesImpl:
|
|||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
|
||||
# Prepend reusable prompt (if provided)
|
||||
prompt_obj = await self._prepend_prompt(messages, prompt)
|
||||
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
@ -276,6 +323,7 @@ class OpenAIResponsesImpl:
|
|||
ctx=ctx,
|
||||
response_id=response_id,
|
||||
created_at=created_at,
|
||||
prompt=prompt_obj,
|
||||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
|
|
|
@ -45,6 +45,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.apis.prompts.prompts import Prompt
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
|
@ -81,6 +82,7 @@ class StreamingResponseOrchestrator:
|
|||
ctx: ChatCompletionContext,
|
||||
response_id: str,
|
||||
created_at: int,
|
||||
prompt: Prompt | None,
|
||||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
|
@ -89,6 +91,7 @@ class StreamingResponseOrchestrator:
|
|||
self.ctx = ctx
|
||||
self.response_id = response_id
|
||||
self.created_at = created_at
|
||||
self.prompt = prompt
|
||||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
|
@ -109,6 +112,7 @@ class StreamingResponseOrchestrator:
|
|||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
prompt=self.prompt,
|
||||
text=self.text,
|
||||
)
|
||||
|
||||
|
@ -195,6 +199,7 @@ class StreamingResponseOrchestrator:
|
|||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="completed",
|
||||
prompt=self.prompt,
|
||||
text=self.text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue