From c0b6c9d7179dc83da9efeef70a20260a30f64e30 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 21 Sep 2025 20:40:25 -0700 Subject: [PATCH] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 101 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 127 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 829cd8a62..8dec807a3 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,24 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) + self.sql_store = None + self.policy = policy + + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "openai_responses", { @@ -42,9 +72,68 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '')}") + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..38ce365c1 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth]