From bfdd15d1fa2abcd40b56cf6bb895a4fb3c4211b2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 28 May 2025 13:17:48 -0700 Subject: [PATCH] fix(responses): use input, not original_input when storing the Response (#2300) We must store the full (re-hydrated) input not just the original input in the Response object. Of course, this is not very space efficient and we should likely find a better storage scheme so that we can only store unique entries in the database and then re-hydrate them efficiently later. But that can be done safely later. Closes https://github.com/meta-llama/llama-stack/issues/2299 ## Test Plan Unit test --- .../agents/meta_reference/openai_responses.py | 21 +++--- .../meta_reference/test_openai_responses.py | 66 +++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 3a56d41ef..1fcb1c461 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -292,12 +292,12 @@ class OpenAIResponsesImpl: async def _store_response( self, response: OpenAIResponseObject, - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], ) -> None: new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(original_input, str): + if isinstance(input, str): # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=original_input) + input_content = OpenAIResponseInputMessageContentText(text=input) input_content_item = OpenAIResponseMessage( role="user", content=[input_content], @@ -307,7 +307,7 @@ class OpenAIResponsesImpl: else: # we already have a list of messages input_items_data = [] - for input_item in original_input: + for input_item in input: if isinstance(input_item, OpenAIResponseMessage): # These may or may not already have an id, so dump to dict, check for id, and add if missing input_item_dict = input_item.model_dump() @@ -334,7 +334,6 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, ): stream = False if stream is None else stream - original_input = input # Keep reference for storage output_messages: list[OpenAIResponseOutput] = [] @@ -372,7 +371,7 @@ class OpenAIResponsesImpl: inference_result=inference_result, ctx=ctx, output_messages=output_messages, - original_input=original_input, + input=input, model=model, store=store, tools=tools, @@ -382,7 +381,7 @@ class OpenAIResponsesImpl: inference_result=inference_result, ctx=ctx, output_messages=output_messages, - original_input=original_input, + input=input, model=model, store=store, tools=tools, @@ -393,7 +392,7 @@ class OpenAIResponsesImpl: inference_result: Any, ctx: ChatCompletionContext, output_messages: list[OpenAIResponseOutput], - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], model: str, store: bool | None, tools: list[OpenAIResponseInputTool] | None, @@ -423,7 +422,7 @@ class OpenAIResponsesImpl: if store: await self._store_response( response=response, - original_input=original_input, + input=input, ) return response @@ -433,7 +432,7 @@ class OpenAIResponsesImpl: inference_result: Any, ctx: ChatCompletionContext, output_messages: list[OpenAIResponseOutput], - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], model: str, store: bool | None, tools: list[OpenAIResponseInputTool] | None, @@ -544,7 +543,7 @@ class OpenAIResponsesImpl: if store: await self._store_response( response=final_response, - original_input=original_input, + input=input, ) # Emit response.completed diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 9c491accb..5b6cee0ec 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -628,3 +628,69 @@ async def test_responses_store_list_input_items_logic(): result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) assert result.object == "list" assert len(result.data) == 0 # Should return no items + + +@pytest.mark.asyncio +async def test_store_response_uses_rehydrated_input_with_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that _store_response uses the full re-hydrated input (including previous responses) + rather than just the original input when previous_response_id is provided.""" + + # Setup - Create a previous response that should be included in the stored input + previous_response = OpenAIResponseObjectWithInput( + id="resp-previous-123", + object="response", + created_at=1234567890, + model="meta-llama/Llama-3.1-8B-Instruct", + status="completed", + input=[ + OpenAIResponseMessage( + id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")] + ) + ], + output=[ + OpenAIResponseMessage( + id="msg-prev-assistant", + role="assistant", + content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")], + ) + ], + ) + + mock_responses_store.get_response_object.return_value = previous_response + + current_input = "Now what is 3+3?" + model = "meta-llama/Llama-3.1-8B-Instruct" + mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") + mock_inference_api.openai_chat_completion.return_value = mock_chat_completion + + # Execute - Create response with previous_response_id + result = await openai_responses_impl.create_openai_response( + input=current_input, + model=model, + previous_response_id="resp-previous-123", + store=True, + ) + + store_call_args = mock_responses_store.store_response_object.call_args + stored_input = store_call_args.kwargs["input"] + + # Verify that the stored input contains the full re-hydrated conversation: + # 1. Previous user message + # 2. Previous assistant response + # 3. Current user message + assert len(stored_input) == 3 + + assert stored_input[0].role == "user" + assert stored_input[0].content[0].text == "What is 2+2?" + + assert stored_input[1].role == "assistant" + assert stored_input[1].content[0].text == "2+2 equals 4." + + assert stored_input[2].role == "user" + assert stored_input[2].content == "Now what is 3+3?" + + # Verify the response itself is correct + assert result.model == model + assert result.status == "completed"