store prev messages

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-02 15:53:31 -07:00
parent 4819a2e0ee
commit 2ec9f8770e
7 changed files with 202 additions and 58 deletions

View file

@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
input: list[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
@json_schema_type
class ListOpenAIResponseObject(BaseModel):

View file

@ -8,7 +8,7 @@ import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
)
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
new_input_items.extend(previous_response.output)
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
return new_input_items
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion)
"""
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
)
all_input = await self._prepend_previous_response(input, previous_response)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
if previous_response.messages:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
messages.extend(new_messages)
else:
new_input_items.extend(input)
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
input = new_input_items
return input
return all_input, messages
async def _prepend_instructions(self, messages, instructions):
if instructions:
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
return response_with_input.to_response_object()
async def list_openai_responses(
self,
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
messages=messages,
)
async def create_openai_response(
@ -224,8 +252,7 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -265,7 +292,8 @@ class OpenAIResponsesImpl:
if store and final_response:
await self._store_response(
response=final_response,
input=input,
input=all_input,
messages=orchestrator.final_messages,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:

View file

@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.log import get_logger
@ -103,6 +104,8 @@ class StreamingResponseOrchestrator:
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
@ -192,6 +195,8 @@ class StreamingResponseOrchestrator:
messages = next_turn_messages
self.final_messages = messages.copy() + [current_response.choices[0].message]
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,

View file

@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
OpenAIResponseObjectWithInput,
)
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.log import get_logger
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
logger = get_logger(name=__name__, category="responses_store")
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
"""Internal class for storing responses with chat completion messages.
This extends the public OpenAIResponseObjectWithInput with messages field
for internal storage. The messages field is not exposed in the public API.
The messages field is optional for backward compatibility with responses
stored before this feature was added.
"""
messages: list[OpenAIMessageParam] | None = None
class ResponsesStore:
def __init__(
self,
@ -54,7 +68,9 @@ class ResponsesStore:
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
self._queue: (
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
) = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
@ -100,18 +116,21 @@ class ResponsesStore:
await self._queue.join()
async def store_response_object(
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Responses store is not initialized")
try:
self._queue.put_nowait((response_object, input))
self._queue.put_nowait((response_object, input, messages))
except asyncio.QueueFull:
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
await self._queue.put((response_object, input))
await self._queue.put((response_object, input, messages))
else:
await self._write_response_object(response_object, input)
await self._write_response_object(response_object, input, messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
@ -120,22 +139,26 @@ class ResponsesStore:
item = await self._queue.get()
except asyncio.CancelledError:
break
response_object, input = item
response_object, input, messages = item
try:
await self._write_response_object(response_object, input)
await self._write_response_object(response_object, input, messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing response object: {e}")
finally:
self._queue.task_done()
async def _write_response_object(
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.sql_store is None:
raise ValueError("Responses store is not initialized")
data = response_object.model_dump()
data["input"] = [input_item.model_dump() for input_item in input]
data["messages"] = [msg.model_dump() for msg in messages]
await self.sql_store.insert(
"openai_responses",
@ -188,7 +211,7 @@ class ResponsesStore:
last_id=data[-1].id if data else "",
)
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
"""
Get a response object with automatic access control checking.
"""
@ -205,7 +228,7 @@ class ResponsesStore:
# This provides security by not revealing whether the record exists
raise ValueError(f"Response with id {response_id} not found") from None
return OpenAIResponseObjectWithInput(**row["response_object"])
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
if not self.sql_store:
@ -241,8 +264,8 @@ class ResponsesStore:
if before and after:
raise ValueError("Cannot specify both 'before' and 'after' parameters")
response_with_input = await self.get_response_object(response_id)
items = response_with_input.input
response_with_input_and_messages = await self.get_response_object(response_id)
items = response_with_input_and_messages.input
if order == Order.desc:
items = list(reversed(items))

View file

@ -127,6 +127,70 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
assert response.output_text
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
"""Test file search with sequential responses using previous_response_id."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
# Create a test file with content
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
file_name = "test_sequential_file_search.txt"
file_path = tmp_path / file_name
file_path.write_text(file_content)
file_response = upload_file(compat_client, file_name, file_path)
# Attach the file to the vector store
compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
# Wait for the file to be attached
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
tools = [{"type": "file_search", "vector_store_ids": [vector_store.id]}]
# First response request with file search
response = compat_client.responses.create(
model=text_model_id,
input="How many experts does the Llama 4 Maverick model have?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries
assert response.output[0].results
assert "128" in response.output_text or "experts" in response.output_text.lower()
# Second response request using previous_response_id
response2 = compat_client.responses.create(
model=text_model_id,
input="Can you tell me more about the architecture?",
tools=tools,
stream=False,
previous_response_id=response.id,
include=["file_search_call.results"],
)
# Verify the second response has output
assert len(response2.output) >= 1
assert response2.output_text
# The second response should maintain context from the first
final_message = [output for output in response2.output if output.type == "message"]
assert len(final_message) >= 1
assert final_message[-1].role == "assistant"
assert final_message[-1].status == "completed"
@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
if not isinstance(compat_client, LlamaStackAsLibraryClient):

View file

@ -22,7 +22,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
@ -45,7 +44,10 @@ from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -498,13 +500,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response."""
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
assert input == "fake_input"
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response."""
@ -519,7 +514,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
role="assistant",
)
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -527,10 +522,11 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="fake_previous_input")],
)
mock_responses_store.get_response_object.return_value = previous_response
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response)
assert len(input) == 3
# Check for previous input
@ -561,7 +557,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -569,11 +565,12 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -608,7 +605,7 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -616,11 +613,12 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -724,7 +722,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -732,6 +730,10 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
)
mock_responses_store.get_response_object.return_value = response
@ -817,7 +819,7 @@ async def test_responses_store_list_input_items_logic():
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
]
response_with_input = OpenAIResponseObjectWithInput(
response_with_input = _OpenAIResponseObjectWithInputAndMessages(
id="resp_123",
model="test_model",
created_at=1234567890,
@ -826,6 +828,7 @@ async def test_responses_store_list_input_items_logic():
output=[],
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
input=input_items,
messages=[OpenAIUserMessageParam(content="First message")],
)
# Mock the get_response_object method to return our test data
@ -886,7 +889,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
id="resp-previous-123",
object="response",
created_at=1234567890,
@ -905,6 +908,10 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
messages=[
OpenAIUserMessageParam(content="What is 2+2?"),
OpenAIAssistantMessageParam(content="2+2 equals 4."),
],
)
mock_responses_store.get_response_object.return_value = previous_response

View file

@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
)
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
"""Helper to create test messages for chat completion."""
return [OpenAIUserMessageParam(content=content)]
async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
for response_id, timestamp, model in test_data:
response = create_test_response_object(response_id, timestamp, model)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
# Store a test response
response = create_test_response_object("test-resp", int(time.time()))
input_list = [create_test_response_input("Test input content", "input-test-resp")]
await store.store_response_object(response, input_list)
messages = create_test_messages("Test input content")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
create_test_response_input("Fourth input", "input-4"),
create_test_response_input("Fifth input", "input-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
create_test_response_input("Fourth input", "before-4"),
create_test_response_input("Fifth input", "before-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()