store prev messages

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-02 15:53:31 -07:00
parent 4819a2e0ee
commit 2ec9f8770e
7 changed files with 202 additions and 58 deletions

View file

@ -127,6 +127,70 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
assert response.output_text
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
"""Test file search with sequential responses using previous_response_id."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
# Create a test file with content
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
file_name = "test_sequential_file_search.txt"
file_path = tmp_path / file_name
file_path.write_text(file_content)
file_response = upload_file(compat_client, file_name, file_path)
# Attach the file to the vector store
compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
# Wait for the file to be attached
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
tools = [{"type": "file_search", "vector_store_ids": [vector_store.id]}]
# First response request with file search
response = compat_client.responses.create(
model=text_model_id,
input="How many experts does the Llama 4 Maverick model have?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries
assert response.output[0].results
assert "128" in response.output_text or "experts" in response.output_text.lower()
# Second response request using previous_response_id
response2 = compat_client.responses.create(
model=text_model_id,
input="Can you tell me more about the architecture?",
tools=tools,
stream=False,
previous_response_id=response.id,
include=["file_search_call.results"],
)
# Verify the second response has output
assert len(response2.output) >= 1
assert response2.output_text
# The second response should maintain context from the first
final_message = [output for output in response2.output if output.type == "message"]
assert len(final_message) >= 1
assert final_message[-1].role == "assistant"
assert final_message[-1].status == "completed"
@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
if not isinstance(compat_client, LlamaStackAsLibraryClient):

View file

@ -22,7 +22,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
@ -45,7 +44,10 @@ from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -498,13 +500,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response."""
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
assert input == "fake_input"
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response."""
@ -519,7 +514,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
role="assistant",
)
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -527,10 +522,11 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="fake_previous_input")],
)
mock_responses_store.get_response_object.return_value = previous_response
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response)
assert len(input) == 3
# Check for previous input
@ -561,7 +557,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -569,11 +565,12 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -608,7 +605,7 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -616,11 +613,12 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -724,7 +722,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -732,6 +730,10 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
)
mock_responses_store.get_response_object.return_value = response
@ -817,7 +819,7 @@ async def test_responses_store_list_input_items_logic():
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
]
response_with_input = OpenAIResponseObjectWithInput(
response_with_input = _OpenAIResponseObjectWithInputAndMessages(
id="resp_123",
model="test_model",
created_at=1234567890,
@ -826,6 +828,7 @@ async def test_responses_store_list_input_items_logic():
output=[],
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
input=input_items,
messages=[OpenAIUserMessageParam(content="First message")],
)
# Mock the get_response_object method to return our test data
@ -886,7 +889,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
id="resp-previous-123",
object="response",
created_at=1234567890,
@ -905,6 +908,10 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
messages=[
OpenAIUserMessageParam(content="What is 2+2?"),
OpenAIAssistantMessageParam(content="2+2 equals 4."),
],
)
mock_responses_store.get_response_object.return_value = previous_response

View file

@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
)
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
"""Helper to create test messages for chat completion."""
return [OpenAIUserMessageParam(content=content)]
async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
for response_id, timestamp, model in test_data:
response = create_test_response_object(response_id, timestamp, model)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
# Store a test response
response = create_test_response_object("test-resp", int(time.time()))
input_list = [create_test_response_input("Test input content", "input-test-resp")]
await store.store_response_object(response, input_list)
messages = create_test_messages("Test input content")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
create_test_response_input("Fourth input", "input-4"),
create_test_response_input("Fifth input", "input-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
create_test_response_input("Fourth input", "before-4"),
create_test_response_input("Fifth input", "before-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()