feat: add Prompts API to Responses API

This commit is contained in:
r3v5 2025-09-21 13:52:55 +01:00
parent 9f6c658f2a
commit bdc16ea392
No known key found for this signature in database
GPG key ID: C7611ACB4FECAD54
15 changed files with 526 additions and 4 deletions

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponsePromptParam,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
@ -29,6 +30,8 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.prompts.prompts import Prompt
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
@ -61,12 +64,14 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
prompts_api: Prompts,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.prompts_api = prompts_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
@ -123,6 +128,41 @@ class OpenAIResponsesImpl:
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def _prepend_prompt(
self, messages: list[OpenAIMessageParam], prompt_params: OpenAIResponsePromptParam
) -> Prompt:
if not prompt_params or not prompt_params.id:
return None
try:
# Check if prompt exists in Llama Stack and retrieve it
prompt_version = int(prompt_params.version) if prompt_params.version else None
cur_prompt = await self.prompts_api.get_prompt(prompt_params.id, prompt_version)
if cur_prompt and cur_prompt.prompt:
cur_prompt_text = cur_prompt.prompt
cur_prompt_variables = cur_prompt.variables
final_prompt_text = cur_prompt_text
if prompt_params.variables:
# check if the variables are valid
for name in prompt_params.variables.keys():
if name not in cur_prompt_variables:
raise ValueError(f"Variable {name} not found in prompt {prompt_params.id}")
# replace the variables in the prompt text
for name, value in prompt_params.variables.items():
final_prompt_text = final_prompt_text.replace(f"{{{{ {name} }}}}", str(value))
messages.insert(0, OpenAISystemMessageParam(content=final_prompt_text))
logger.info(f"Prompt {prompt_params.id} found and applied\nFinal prompt text: {final_prompt_text}")
return cur_prompt
except ValueError:
logger.warning(
f"Prompt {prompt_params.id} with version {prompt_params.version} not found, skipping prompt prepending"
)
return None
async def get_openai_response(
self,
response_id: str,
@ -199,6 +239,7 @@ class OpenAIResponsesImpl:
self,
input: str | list[OpenAIResponseInput],
model: str,
prompt: OpenAIResponsePromptParam | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
@ -215,6 +256,7 @@ class OpenAIResponsesImpl:
stream_gen = self._create_streaming_response(
input=input,
model=model,
prompt=prompt,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
@ -243,6 +285,7 @@ class OpenAIResponsesImpl:
self,
input: str | list[OpenAIResponseInput],
model: str,
prompt: OpenAIResponsePromptParam | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
@ -253,6 +296,10 @@ class OpenAIResponsesImpl:
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
# Prepend reusable prompt (if provided)
prompt_obj = await self._prepend_prompt(messages, prompt)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -276,6 +323,7 @@ class OpenAIResponsesImpl:
ctx=ctx,
response_id=response_id,
created_at=created_at,
prompt=prompt_obj,
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,

View file

@ -45,6 +45,7 @@ from llama_stack.apis.inference import (
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.apis.prompts.prompts import Prompt
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
@ -81,6 +82,7 @@ class StreamingResponseOrchestrator:
ctx: ChatCompletionContext,
response_id: str,
created_at: int,
prompt: Prompt | None,
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
@ -89,6 +91,7 @@ class StreamingResponseOrchestrator:
self.ctx = ctx
self.response_id = response_id
self.created_at = created_at
self.prompt = prompt
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
@ -109,6 +112,7 @@ class StreamingResponseOrchestrator:
object="response",
status="in_progress",
output=output_messages.copy(),
prompt=self.prompt,
text=self.text,
)
@ -195,6 +199,7 @@ class StreamingResponseOrchestrator:
model=self.ctx.model,
object="response",
status="completed",
prompt=self.prompt,
text=self.text,
output=output_messages,
)