mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
fix: responses <> chat completion input conversion (#3645)
# What does this PR do? closes #3268 closes #3498 When resuming from previous response ID, currently we attempt to convert from the stored responses input to chat completion messages, which is not always possible, e.g. for tool calls where some data is lost once converted from chat completion message to repsonses input format. This PR stores the chat completion messages that correspond to the _last_ call to chat completion, which is sufficient to be resumed from in the next responses API call, where we load these saved messages and skip conversion entirely. Separate issue to optimize storage: https://github.com/llamastack/llama-stack/issues/3646 ## Test Plan existing CI tests
This commit is contained in:
parent
2e544ecd8a
commit
cf422da825
7 changed files with 202 additions and 58 deletions
|
@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
|
||||
input: list[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIResponseObject(BaseModel):
|
||||
|
|
|
@ -8,7 +8,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
|
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
|
|||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
return new_input_items
|
||||
|
||||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
"""
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
)
|
||||
all_input = await self._prepend_previous_response(input, previous_response)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
if previous_response.messages:
|
||||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
return all_input, messages
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
|
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
|
|||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
return response_with_input.to_response_object()
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
|
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
|
|||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
|
@ -224,8 +252,7 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
@ -265,7 +292,8 @@ class OpenAIResponsesImpl:
|
|||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
|
|
|
@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -94,6 +95,8 @@ class StreamingResponseOrchestrator:
|
|||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
@ -183,6 +186,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
messages = next_turn_messages
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
|
|||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
||||
"""Internal class for storing responses with chat completion messages.
|
||||
|
||||
This extends the public OpenAIResponseObjectWithInput with messages field
|
||||
for internal storage. The messages field is not exposed in the public API.
|
||||
|
||||
The messages field is optional for backward compatibility with responses
|
||||
stored before this feature was added.
|
||||
"""
|
||||
|
||||
messages: list[OpenAIMessageParam] | None = None
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -54,7 +68,9 @@ class ResponsesStore:
|
|||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
||||
self._queue: (
|
||||
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||
) = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
@ -100,18 +116,21 @@ class ResponsesStore:
|
|||
await self._queue.join()
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input))
|
||||
self._queue.put_nowait((response_object, input, messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input))
|
||||
await self._queue.put((response_object, input, messages))
|
||||
else:
|
||||
await self._write_response_object(response_object, input)
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
|
@ -120,22 +139,26 @@ class ResponsesStore:
|
|||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input = item
|
||||
response_object, input, messages = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input)
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
data["messages"] = [msg.model_dump() for msg in messages]
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_responses",
|
||||
|
@ -188,7 +211,7 @@ class ResponsesStore:
|
|||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
||||
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
|
||||
"""
|
||||
Get a response object with automatic access control checking.
|
||||
"""
|
||||
|
@ -205,7 +228,7 @@ class ResponsesStore:
|
|||
# This provides security by not revealing whether the record exists
|
||||
raise ValueError(f"Response with id {response_id} not found") from None
|
||||
|
||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
if not self.sql_store:
|
||||
|
@ -241,8 +264,8 @@ class ResponsesStore:
|
|||
if before and after:
|
||||
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
||||
|
||||
response_with_input = await self.get_response_object(response_id)
|
||||
items = response_with_input.input
|
||||
response_with_input_and_messages = await self.get_response_object(response_id)
|
||||
items = response_with_input_and_messages.input
|
||||
|
||||
if order == Order.desc:
|
||||
items = list(reversed(items))
|
||||
|
|
|
@ -127,6 +127,70 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
|
|||
assert response.output_text
|
||||
|
||||
|
||||
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
|
||||
"""Test file search with sequential responses using previous_response_id."""
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||
|
||||
# Create a test file with content
|
||||
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
|
||||
file_name = "test_sequential_file_search.txt"
|
||||
file_path = tmp_path / file_name
|
||||
file_path.write_text(file_content)
|
||||
|
||||
file_response = upload_file(compat_client, file_name, file_path)
|
||||
|
||||
# Attach the file to the vector store
|
||||
compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
|
||||
# Wait for the file to be attached
|
||||
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
|
||||
|
||||
tools = [{"type": "file_search", "vector_store_ids": [vector_store.id]}]
|
||||
|
||||
# First response request with file search
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries
|
||||
assert response.output[0].results
|
||||
assert "128" in response.output_text or "experts" in response.output_text.lower()
|
||||
|
||||
# Second response request using previous_response_id
|
||||
response2 = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Can you tell me more about the architecture?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
previous_response_id=response.id,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the second response has output
|
||||
assert len(response2.output) >= 1
|
||||
assert response2.output_text
|
||||
|
||||
# The second response should maintain context from the first
|
||||
final_message = [output for output in response2.output if output.type == "message"]
|
||||
assert len(final_message) >= 1
|
||||
assert final_message[-1].role == "assistant"
|
||||
assert final_message[-1].status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
||||
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
|
||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
|
|
|
@ -22,7 +22,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObjectWithInput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
|
@ -45,7 +44,10 @@ from llama_stack.core.datatypes import ResponsesStoreConfig
|
|||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
@ -499,13 +501,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
||||
|
||||
|
||||
async def test_prepend_previous_response_none(openai_responses_impl):
|
||||
"""Test prepending no previous response to a new response."""
|
||||
|
||||
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
|
||||
assert input == "fake_input"
|
||||
|
||||
|
||||
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a basic previous response to a new response."""
|
||||
|
||||
|
@ -520,7 +515,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
|||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
previous_response = OpenAIResponseObjectWithInput(
|
||||
previous_response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
|
@ -528,10 +523,11 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
|||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[OpenAIUserMessageParam(content="fake_previous_input")],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = previous_response
|
||||
|
||||
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
|
||||
input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response)
|
||||
|
||||
assert len(input) == 3
|
||||
# Check for previous input
|
||||
|
@ -562,7 +558,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
|||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
|
@ -570,11 +566,12 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
|||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[OpenAIUserMessageParam(content="test input")],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
|
||||
|
||||
assert len(input) == 4
|
||||
# Check for previous input
|
||||
|
@ -609,7 +606,7 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
|
|||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
|
@ -617,11 +614,12 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
|
|||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[OpenAIUserMessageParam(content="test input")],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
|
||||
|
||||
assert len(input) == 4
|
||||
# Check for previous input
|
||||
|
@ -725,7 +723,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
|
@ -733,6 +731,10 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[
|
||||
OpenAIUserMessageParam(content="Name some towns in Ireland"),
|
||||
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
|
||||
],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
|
@ -818,7 +820,7 @@ async def test_responses_store_list_input_items_logic():
|
|||
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
|
||||
]
|
||||
|
||||
response_with_input = OpenAIResponseObjectWithInput(
|
||||
response_with_input = _OpenAIResponseObjectWithInputAndMessages(
|
||||
id="resp_123",
|
||||
model="test_model",
|
||||
created_at=1234567890,
|
||||
|
@ -827,6 +829,7 @@ async def test_responses_store_list_input_items_logic():
|
|||
output=[],
|
||||
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
|
||||
input=input_items,
|
||||
messages=[OpenAIUserMessageParam(content="First message")],
|
||||
)
|
||||
|
||||
# Mock the get_response_object method to return our test data
|
||||
|
@ -887,7 +890,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
rather than just the original input when previous_response_id is provided."""
|
||||
|
||||
# Setup - Create a previous response that should be included in the stored input
|
||||
previous_response = OpenAIResponseObjectWithInput(
|
||||
previous_response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
id="resp-previous-123",
|
||||
object="response",
|
||||
created_at=1234567890,
|
||||
|
@ -906,6 +909,10 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
|
||||
)
|
||||
],
|
||||
messages=[
|
||||
OpenAIUserMessageParam(content="What is 2+2?"),
|
||||
OpenAIAssistantMessageParam(content="2+2 equals 4."),
|
||||
],
|
||||
)
|
||||
|
||||
mock_responses_store.get_response_object.return_value = previous_response
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
|
|||
)
|
||||
|
||||
|
||||
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
|
||||
"""Helper to create test messages for chat completion."""
|
||||
return [OpenAIUserMessageParam(content=content)]
|
||||
|
||||
|
||||
async def test_responses_store_pagination_basic():
|
||||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
|
|||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
|
|||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
for response_id, timestamp, model in test_data:
|
||||
response = create_test_response_object(response_id, timestamp, model)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
|
|||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
|
|||
# Store a test response
|
||||
response = create_test_response_object("test-resp", int(time.time()))
|
||||
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages("Test input content")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
|
|||
create_test_response_input("Fourth input", "input-4"),
|
||||
create_test_response_input("Fifth input", "input-5"),
|
||||
]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages("First input")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
|
|||
create_test_response_input("Fourth input", "before-4"),
|
||||
create_test_response_input("Fifth input", "before-5"),
|
||||
]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages("First input")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue