mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fix: Responses API previous_response input items
This adds storing of input items with previous responses and then restores those input items to prepend to the user's messages list when using conversation state. I missed this in the initial implementation, but it makes sense that we have to store the input items from previous responses so that we can reconstruct the proper messages stack for multi-turn conversations - just the output from previous responses isn't enough context for the models to follow the turns and the original instructions. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
150b9a0834
commit
5b2e850754
2 changed files with 104 additions and 10 deletions
|
@ -131,3 +131,20 @@ OpenAIResponseInputTool = Annotated[
|
|||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputItemMessage(OpenAIResponseInputMessage):
|
||||
id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputItemList(BaseModel):
|
||||
data: list[OpenAIResponseInputItemMessage]
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: OpenAIResponseInputItemList
|
||||
response: OpenAIResponseObject
|
||||
|
|
|
@ -12,7 +12,10 @@ from typing import cast
|
|||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputItemMessage,
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
|
@ -24,6 +27,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponsePreviousResponseWithInputItems,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
Inference,
|
||||
|
@ -52,9 +56,56 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
async def _convert_response_input_content_to_chat_content_parts(
|
||||
input_content: list[OpenAIResponseInputMessageContent],
|
||||
) -> list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert a list of input content items to a list of chat completion content parts
|
||||
"""
|
||||
content_parts = []
|
||||
for input_content_part in input_content:
|
||||
if isinstance(input_content_part, OpenAIResponseInputMessageContentText):
|
||||
content_parts.append(OpenAIChatCompletionContentPartTextParam(text=input_content_part.text))
|
||||
elif isinstance(input_content_part, OpenAIResponseInputMessageContentImage):
|
||||
if input_content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=input_content_part.image_url, detail=input_content_part.detail)
|
||||
content_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
return content_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_user_content(
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
if isinstance(user_input.content, list):
|
||||
user_content.extend(await _convert_response_input_content_to_chat_content_parts(user_input.content))
|
||||
else:
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
||||
else:
|
||||
user_content = input
|
||||
return user_content
|
||||
|
||||
|
||||
async def _previous_response_to_messages(
|
||||
previous_response: OpenAIResponsePreviousResponseWithInputItems,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
for previous_message in previous_response.input_items.data:
|
||||
previous_content = await _convert_response_input_content_to_chat_content_parts(previous_message.content)
|
||||
if previous_message.role == "user":
|
||||
converted_message = OpenAIUserMessageParam(content=previous_content)
|
||||
elif previous_message.role == "assistant":
|
||||
converted_message = OpenAIAssistantMessageParam(content=previous_content)
|
||||
else:
|
||||
# TODO: handle other message roles? unclear if system/developer roles are
|
||||
# used in previous responses
|
||||
continue
|
||||
messages.append(converted_message)
|
||||
|
||||
for output_message in previous_response.response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
return messages
|
||||
|
@ -102,15 +153,19 @@ class OpenAIResponsesImpl:
|
|||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponseObject.model_validate_json(response_json)
|
||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self._get_previous_response_with_input(id)
|
||||
return response_with_input.response
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
|
@ -126,8 +181,8 @@ class OpenAIResponsesImpl:
|
|||
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response_with_input))
|
||||
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
|
@ -216,10 +271,32 @@ class OpenAIResponsesImpl:
|
|||
|
||||
if store:
|
||||
# Store in kvstore
|
||||
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseInputItemMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
input_items_data.append(
|
||||
OpenAIResponseInputItemMessage(id=f"msg_{uuid.uuid4()}", **input_item.model_dump())
|
||||
)
|
||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=response.model_dump_json(),
|
||||
value=prev_response.model_dump_json(),
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue