From b4974d411d3a2557ad679822c2372e355cd49532 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 14:59:30 -0700 Subject: [PATCH 1/5] chore: simplify authorized sqlstore # What does this PR do? ## Test Plan --- .../providers/inline/files/localfs/files.py | 5 ++--- .../providers/remote/files/s3/files.py | 5 ++--- .../utils/inference/inference_store.py | 4 +--- .../utils/responses/responses_store.py | 7 ++----- .../utils/sqlstore/authorized_sqlstore.py | 11 +++++------ .../sqlstore/test_authorized_sqlstore.py | 19 +++++++++++-------- tests/unit/utils/test_authorized_sqlstore.py | 18 +++++++++--------- 7 files changed, 32 insertions(+), 37 deletions(-) diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 9c610c1ba..65cf8d815 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files): storage_path.mkdir(parents=True, exist_ok=True) # Initialize SQL store for metadata - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy) await self.sql_store.create_table( "openai_files", { @@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "client.files.list()") @@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 54742d900..8ea96af9e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -137,7 +137,7 @@ class S3FilesImpl(Files): where: dict[str, str | dict] = {"id": file_id} if not return_expired: where["expires_at"] = {">": self._now()} - if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): raise ResourceNotFoundError(file_id, "File", "files.list()") return row @@ -164,7 +164,7 @@ class S3FilesImpl(Files): self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) await self._sql_store.create_table( "openai_files", { @@ -268,7 +268,6 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 17f4c6268..ffc9f3e11 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -54,7 +54,7 @@ class InferenceStore: async def initialize(self): """Create the necessary tables if they don't exist.""" - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "chat_completions", { @@ -202,7 +202,6 @@ class InferenceStore: order_by=[("created", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [ @@ -229,7 +228,6 @@ class InferenceStore: row = await self.sql_store.fetch_one( table="chat_completions", where={"id": completion_id}, - policy=self.policy, ) if not row: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..829cd8a62 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -28,8 +28,7 @@ class ResponsesStore: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) - self.policy = policy + self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) async def initialize(self): """Create the necessary tables if they don't exist.""" @@ -87,7 +86,6 @@ class ResponsesStore: order_by=[("created_at", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] @@ -105,7 +103,6 @@ class ResponsesStore: row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, - policy=self.policy, ) if not row: @@ -116,7 +113,7 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") await self.sql_store.delete("openai_responses", where={"id": response_id}) diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index acb688f96..ab67f7052 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -53,13 +53,15 @@ class AuthorizedSqlStore: access control policies, user attribute capture, and SQL filtering optimization. """ - def __init__(self, sql_store: SqlStore): + def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): """ Initialize the authorization layer. :param sql_store: Base SqlStore implementation to wrap + :param policy: Access control policy to use for authorization """ self.sql_store = sql_store + self.policy = policy self._detect_database_type() self._validate_sql_optimized_policy() @@ -117,14 +119,13 @@ class AuthorizedSqlStore: async def fetch_all( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, ) -> PaginatedResponse: """Fetch all rows with automatic access control filtering.""" - access_where = self._build_access_control_where_clause(policy) + access_where = self._build_access_control_where_clause(self.policy) rows = await self.sql_store.fetch_all( table=table, where=where, @@ -146,7 +147,7 @@ class AuthorizedSqlStore: str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) ) - if is_action_allowed(policy, Action.READ, sql_record, current_user): + if is_action_allowed(self.policy, Action.READ, sql_record, current_user): filtered_rows.append(row) return PaginatedResponse( @@ -157,14 +158,12 @@ class AuthorizedSqlStore: async def fetch_one( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: """Fetch one row with automatic access control checking.""" results = await self.fetch_all( table=table, - policy=policy, where=where, limit=1, order_by=order_by, diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 4002f2e1f..98bef0f2c 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -57,7 +57,7 @@ def authorized_store(backend_config): config = config_func() base_sqlstore = sqlstore_impl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) yield authorized_store @@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" assert result.data[0]["access_attributes"] is None @@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 2 # Test with non-admin user @@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz mock_get_authenticated_user.return_value = regular_user # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" @@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz # Now test with the multi-user who has both roles=admin and teams=dev mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) # Should see: # - public record (1) - no access_attributes @@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto ), ] + # Create a new authorized store with the owner-only policy + owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy) + # Test user1 access - should only see their own record mock_get_authenticated_user.return_value = user1 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" # Test user2 access - should only see their own record mock_get_authenticated_user.return_value = user2 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" # Test with anonymous user - should see no records mock_get_authenticated_user.return_value = None - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 90eb706e4..d85e784a9 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) # Create table with access control await sqlstore.create_table( @@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic mock_get_authenticated_user.return_value = admin_user # Admin should see both documents - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 1 assert result.data[0]["title"] == "Admin Document" # User should only see their document mock_get_authenticated_user.return_value = regular_user - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 0 - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) + result = await sqlstore.fetch_all("documents", where={"id": 2}) assert len(result.data) == 1 assert result.data[0]["title"] == "User Document" - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) + row = await sqlstore.fetch_one("documents", where={"id": 1}) assert row is None - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) + row = await sqlstore.fetch_one("documents", where={"id": 2}) assert row is not None assert row["title"] == "User Document" @@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) await sqlstore.create_table( table="resources", @@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): user = User(principal=user_data["principal"], attributes=user_data["attributes"]) mock_get_authenticated_user.return_value = user - sql_results = await sqlstore.fetch_all("resources", policy=policy) + sql_results = await sqlstore.fetch_all("resources") sql_ids = {row["id"] for row in sql_results.data} policy_ids = set() for scenario in test_scenarios: @@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us db_path=tmp_dir + "/" + db_name, ) ) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) await authorized_store.create_table( table="user_data", From f0da887e793dce0a084a2239c1cef4ab11cd0156 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 15:49:40 -0700 Subject: [PATCH 2/5] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../utils/responses/test_responses_store.py | 21 ++++ 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From b93b7798adb380db33a8f0eeb5f742b279339e26 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 15:53:26 -0700 Subject: [PATCH 3/5] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 2 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 124 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..fd128f585 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -677,7 +677,7 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore(None, policy=default_policy()) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From b0115674a4e0100c1449720e9e5e8fa4c0fc5f46 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 15:59:36 -0700 Subject: [PATCH 4/5] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 3 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 125 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..df89986af 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,6 +42,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) @@ -677,7 +678,7 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore(ResponsesStoreConfig(), policy=default_policy()) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From 7660ba844f64d01d40bbf9631f309acb05067cfd Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 16:02:02 -0700 Subject: [PATCH 5/5] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 128 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..e467e910d 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth]