From b4974d411d3a2557ad679822c2372e355cd49532 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 14:59:30 -0700 Subject: [PATCH 01/20] chore: simplify authorized sqlstore # What does this PR do? ## Test Plan --- .../providers/inline/files/localfs/files.py | 5 ++--- .../providers/remote/files/s3/files.py | 5 ++--- .../utils/inference/inference_store.py | 4 +--- .../utils/responses/responses_store.py | 7 ++----- .../utils/sqlstore/authorized_sqlstore.py | 11 +++++------ .../sqlstore/test_authorized_sqlstore.py | 19 +++++++++++-------- tests/unit/utils/test_authorized_sqlstore.py | 18 +++++++++--------- 7 files changed, 32 insertions(+), 37 deletions(-) diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 9c610c1ba..65cf8d815 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files): storage_path.mkdir(parents=True, exist_ok=True) # Initialize SQL store for metadata - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy) await self.sql_store.create_table( "openai_files", { @@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "client.files.list()") @@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 54742d900..8ea96af9e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -137,7 +137,7 @@ class S3FilesImpl(Files): where: dict[str, str | dict] = {"id": file_id} if not return_expired: where["expires_at"] = {">": self._now()} - if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): raise ResourceNotFoundError(file_id, "File", "files.list()") return row @@ -164,7 +164,7 @@ class S3FilesImpl(Files): self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) await self._sql_store.create_table( "openai_files", { @@ -268,7 +268,6 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 17f4c6268..ffc9f3e11 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -54,7 +54,7 @@ class InferenceStore: async def initialize(self): """Create the necessary tables if they don't exist.""" - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "chat_completions", { @@ -202,7 +202,6 @@ class InferenceStore: order_by=[("created", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [ @@ -229,7 +228,6 @@ class InferenceStore: row = await self.sql_store.fetch_one( table="chat_completions", where={"id": completion_id}, - policy=self.policy, ) if not row: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..829cd8a62 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -28,8 +28,7 @@ class ResponsesStore: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) - self.policy = policy + self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) async def initialize(self): """Create the necessary tables if they don't exist.""" @@ -87,7 +86,6 @@ class ResponsesStore: order_by=[("created_at", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] @@ -105,7 +103,6 @@ class ResponsesStore: row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, - policy=self.policy, ) if not row: @@ -116,7 +113,7 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") await self.sql_store.delete("openai_responses", where={"id": response_id}) diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index acb688f96..ab67f7052 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -53,13 +53,15 @@ class AuthorizedSqlStore: access control policies, user attribute capture, and SQL filtering optimization. """ - def __init__(self, sql_store: SqlStore): + def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): """ Initialize the authorization layer. :param sql_store: Base SqlStore implementation to wrap + :param policy: Access control policy to use for authorization """ self.sql_store = sql_store + self.policy = policy self._detect_database_type() self._validate_sql_optimized_policy() @@ -117,14 +119,13 @@ class AuthorizedSqlStore: async def fetch_all( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, ) -> PaginatedResponse: """Fetch all rows with automatic access control filtering.""" - access_where = self._build_access_control_where_clause(policy) + access_where = self._build_access_control_where_clause(self.policy) rows = await self.sql_store.fetch_all( table=table, where=where, @@ -146,7 +147,7 @@ class AuthorizedSqlStore: str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) ) - if is_action_allowed(policy, Action.READ, sql_record, current_user): + if is_action_allowed(self.policy, Action.READ, sql_record, current_user): filtered_rows.append(row) return PaginatedResponse( @@ -157,14 +158,12 @@ class AuthorizedSqlStore: async def fetch_one( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: """Fetch one row with automatic access control checking.""" results = await self.fetch_all( table=table, - policy=policy, where=where, limit=1, order_by=order_by, diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 4002f2e1f..98bef0f2c 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -57,7 +57,7 @@ def authorized_store(backend_config): config = config_func() base_sqlstore = sqlstore_impl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) yield authorized_store @@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" assert result.data[0]["access_attributes"] is None @@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 2 # Test with non-admin user @@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz mock_get_authenticated_user.return_value = regular_user # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" @@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz # Now test with the multi-user who has both roles=admin and teams=dev mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) # Should see: # - public record (1) - no access_attributes @@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto ), ] + # Create a new authorized store with the owner-only policy + owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy) + # Test user1 access - should only see their own record mock_get_authenticated_user.return_value = user1 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" # Test user2 access - should only see their own record mock_get_authenticated_user.return_value = user2 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" # Test with anonymous user - should see no records mock_get_authenticated_user.return_value = None - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 90eb706e4..d85e784a9 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) # Create table with access control await sqlstore.create_table( @@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic mock_get_authenticated_user.return_value = admin_user # Admin should see both documents - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 1 assert result.data[0]["title"] == "Admin Document" # User should only see their document mock_get_authenticated_user.return_value = regular_user - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 0 - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) + result = await sqlstore.fetch_all("documents", where={"id": 2}) assert len(result.data) == 1 assert result.data[0]["title"] == "User Document" - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) + row = await sqlstore.fetch_one("documents", where={"id": 1}) assert row is None - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) + row = await sqlstore.fetch_one("documents", where={"id": 2}) assert row is not None assert row["title"] == "User Document" @@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) await sqlstore.create_table( table="resources", @@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): user = User(principal=user_data["principal"], attributes=user_data["attributes"]) mock_get_authenticated_user.return_value = user - sql_results = await sqlstore.fetch_all("resources", policy=policy) + sql_results = await sqlstore.fetch_all("resources") sql_ids = {row["id"] for row in sql_results.data} policy_ids = set() for scenario in test_scenarios: @@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us db_path=tmp_dir + "/" + db_name, ) ) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) await authorized_store.create_table( table="user_data", From f0da887e793dce0a084a2239c1cef4ab11cd0156 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 15:49:40 -0700 Subject: [PATCH 02/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../utils/responses/test_responses_store.py | 21 ++++ 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From b93b7798adb380db33a8f0eeb5f742b279339e26 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 15:53:26 -0700 Subject: [PATCH 03/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 2 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 124 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..fd128f585 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -677,7 +677,7 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore(None, policy=default_policy()) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From b0115674a4e0100c1449720e9e5e8fa4c0fc5f46 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 15:59:36 -0700 Subject: [PATCH 04/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 3 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 125 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..df89986af 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,6 +42,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) @@ -677,7 +678,7 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore(ResponsesStoreConfig(), policy=default_policy()) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From 7660ba844f64d01d40bbf9631f309acb05067cfd Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 16:02:02 -0700 Subject: [PATCH 05/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 128 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..e467e910d 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From 04fd837d2fc26a26e1655c4ba80cc3652eab3d2b Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 19 Sep 2025 16:13:43 -0700 Subject: [PATCH 06/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 102 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 7 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 129 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..367b8aa94 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,70 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding response id={getattr(response_object, 'id', '')}" + ) + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..67ab87504 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,10 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), + policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From ce9a62aa840933e83c47825a0d207da02c4fc153 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 21 Sep 2025 20:37:58 -0700 Subject: [PATCH 07/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 100 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 126 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..f952d0880 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "openai_responses", { @@ -43,9 +72,68 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '')}") + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..38ce365c1 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From c0b6c9d7179dc83da9efeef70a20260a30f64e30 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 21 Sep 2025 20:40:25 -0700 Subject: [PATCH 08/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 ++ .../utils/responses/responses_store.py | 101 ++++++++++++++++-- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 127 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 829cd8a62..8dec807a3 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,24 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) + self.sql_store = None + self.policy = policy + + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "openai_responses", { @@ -42,9 +72,68 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '')}") + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..38ce365c1 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From a772f0a42dc165ee9993fc1eaebfc422e4666b85 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 21 Sep 2025 20:46:34 -0700 Subject: [PATCH 09/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 + .../utils/responses/responses_store.py | 110 +++++++++++++++++- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 136 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 829cd8a62..b9fceb1ab 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,24 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) + self.sql_store = None + self.policy = policy + + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "openai_responses", { @@ -42,9 +72,68 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '')}") + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] @@ -73,6 +162,9 @@ class ResponsesStore: :param model: The model to filter by. :param order: The order to sort the responses by. """ + if not self.sql_store: + raise ValueError("Responses store is not initialized") + if not order: order = Order.desc @@ -100,6 +192,9 @@ class ResponsesStore: """ Get a response object with automatic access control checking. """ + if not self.sql_store: + raise ValueError("Responses store is not initialized") + row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, @@ -113,6 +208,9 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: + if not self.sql_store: + raise ValueError("Responses store is not initialized") + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..38ce365c1 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From f0211ffb7004e1aec58fd88c85741317d08b46fc Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 22 Sep 2025 21:25:09 -0700 Subject: [PATCH 10/20] chore: fix build # What does this PR do? ## Test Plan --- llama_stack/core/build_container.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/core/build_container.sh b/llama_stack/core/build_container.sh index 424b40a9d..29964f324 100755 --- a/llama_stack/core/build_container.sh +++ b/llama_stack/core/build_container.sh @@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \ procps psmisc lsof \ traceroute \ bubblewrap \ - gcc \ + gcc g++ \ && rm -rf /var/lib/apt/lists/* ENV UV_SYSTEM_PYTHON=1 From 7650d2c96a122479ebe629c613d6f11fc3e8c9f7 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 22 Sep 2025 21:33:14 -0700 Subject: [PATCH 11/20] chore: fix build # What does this PR do? ## Test Plan --- llama_stack/core/build_container.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_stack/core/build_container.sh b/llama_stack/core/build_container.sh index 424b40a9d..8e47fc592 100755 --- a/llama_stack/core/build_container.sh +++ b/llama_stack/core/build_container.sh @@ -147,7 +147,7 @@ WORKDIR /app RUN dnf -y update && dnf install -y iputils git net-tools wget \ vim-minimal python3.12 python3.12-pip python3.12-wheel \ - python3.12-setuptools python3.12-devel gcc make && \ + python3.12-setuptools python3.12-devel gcc gcc-c++ make && \ ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all ENV UV_SYSTEM_PYTHON=1 @@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \ procps psmisc lsof \ traceroute \ bubblewrap \ - gcc \ + gcc g++ \ && rm -rf /var/lib/apt/lists/* ENV UV_SYSTEM_PYTHON=1 From 88ad5d6d7319667eae1c0a04d8e99349cb98c13b Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 26 Sep 2025 10:35:11 -0700 Subject: [PATCH 12/20] chore: introduce write queue for response_store # What does this PR do? ## Test Plan --- llama_stack/core/datatypes.py | 6 + .../utils/responses/responses_store.py | 110 +++++++++++++++++- .../meta_reference/test_openai_responses.py | 6 +- .../utils/responses/test_responses_store.py | 21 ++++ 4 files changed, 136 insertions(+), 7 deletions(-) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index b5558c66f..6a297f012 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel): num_writers: int = Field(default=4, description="Number of concurrent background writers") +class ResponsesStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 829cd8a62..b9fceb1ab 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.agents import ( Order, ) @@ -14,24 +17,51 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.core.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="responses_store") class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( + def __init__( + self, + config: ResponsesStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, ResponsesStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = ResponsesStoreConfig( + sql_store_config=config, + ) + + self.config = config + self.sql_store_config = config.sql_store_config + if not self.sql_store_config: + self.sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) + self.sql_store = None + self.policy = policy + + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) async def initialize(self): """Create the necessary tables if they don't exist.""" + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "openai_responses", { @@ -42,9 +72,68 @@ class ResponsesStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_response_object( self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] ) -> None: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Responses store is not initialized") + try: + self._queue.put_nowait((response_object, input)) + except asyncio.QueueFull: + logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '')}") + await self._queue.put((response_object, input)) + else: + await self._write_response_object(response_object, input) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + response_object, input = item + try: + await self._write_response_object(response_object, input) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing response object: {e}") + finally: + self._queue.task_done() + + async def _write_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + if self.sql_store is None: + raise ValueError("Responses store is not initialized") + data = response_object.model_dump() data["input"] = [input_item.model_dump() for input_item in input] @@ -73,6 +162,9 @@ class ResponsesStore: :param model: The model to filter by. :param order: The order to sort the responses by. """ + if not self.sql_store: + raise ValueError("Responses store is not initialized") + if not order: order = Order.desc @@ -100,6 +192,9 @@ class ResponsesStore: """ Get a response object with automatic access control checking. """ + if not self.sql_store: + raise ValueError("Responses store is not initialized") + row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, @@ -113,6 +208,9 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: + if not self.sql_store: + raise ValueError("Responses store is not initialized") + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index a964bc219..38ce365c1 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -42,10 +42,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture @@ -677,7 +679,9 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) + responses_store = ResponsesStore( + ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy() + ) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 44d4b30da..4e5256c1b 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_responses(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_responses(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_responses(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit(): input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test without limit (should use default of 50) result = await store.list_responses(order=Order.desc) assert len(result.data) == 2 @@ -212,6 +224,9 @@ async def test_responses_store_get_response_object(): input_list = [create_test_response_input("Test input content", "input-test-resp")] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Retrieve the response retrieved = await store.get_response_object("test-resp") assert retrieved.id == "test-resp" @@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Verify all items are stored correctly with explicit IDs all_items = await store.list_response_input_items("test-resp", order=Order.desc) assert len(all_items.data) == 5 @@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination(): ] await store.store_response_object(response, input_list) + # Wait for all queued writes to complete + await store.flush() + # Test before pagination with descending order # In desc order: [Fifth, Fourth, Third, Second, First] # before="before-3" should return [Fifth, Fourth] From 7004ac27b51c7b59effe74fcfc1ec7a65db9573d Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Fri, 26 Sep 2025 14:44:17 -0700 Subject: [PATCH 13/20] chore: remove extra logging # What does this PR do? ## Test Plan --- .../providers/inline/telemetry/meta_reference/telemetry.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 9224c3792..2a4032543 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -224,10 +224,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: - # Always log to console if console sink is enabled (debug) - if TelemetrySink.CONSOLE in self.config.sinks: - logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}") - # Add metric as an event to the current span try: with self._lock: From b1cbfe99f96fae717593fe67d878804eea40d175 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 15:52:44 -0700 Subject: [PATCH 14/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 1 + .../meta_reference/responses/__init__.py | 5 + .../responses/test_streaming.py | 147 ++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 tests/unit/providers/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..b6ffb1471 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -568,6 +568,7 @@ class StreamingResponseOrchestrator: description=param.description, required=param.required, default=param.default, + items=param.items, ) for param in t.parameters }, diff --git a/tests/unit/providers/agents/meta_reference/responses/__init__.py b/tests/unit/providers/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..6f3c1df03 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. \ No newline at end of file diff --git a/tests/unit/providers/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..5807dd17e --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Unit tests for MCP tool parameter conversion in streaming responses. + +This tests the fix for handling array-type parameters with 'items' field +when converting MCP tool definitions to OpenAI format. +""" + +from llama_stack.apis.tools import ToolDef, ToolParameter +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition +from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + +def test_mcp_tool_conversion_with_array_items(): + """ + Test that MCP tool parameters with array type and items field are properly converted. + + This is a regression test for the bug where array parameters without 'items' + caused OpenAI API validation errors like: + "Invalid schema for function 'pods_exec': In context=('properties', 'command'), + array schema missing items." + """ + # Create a tool parameter with array type and items specification + # This mimics what kubernetes-mcp-server's pods_exec tool has + tool_param = ToolParameter( + name="command", + parameter_type="array", + description="Command to execute in the pod", + required=True, + items={"type": "string"}, # This is the crucial field + ) + + # Convert to ToolDefinition format (as done in streaming.py) + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with array parameter", + parameters={ + "command": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + default=tool_param.default, + items=tool_param.items, # The fix: ensure items is passed through + ) + }, + ) + + # Convert to OpenAI format + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify the conversion includes the items field + assert openai_tool["type"] == "function" + assert openai_tool["function"]["name"] == "test_tool" + assert "parameters" in openai_tool["function"] + + parameters = openai_tool["function"]["parameters"] + assert "properties" in parameters + assert "command" in parameters["properties"] + + command_param = parameters["properties"]["command"] + assert command_param["type"] == "array" + assert "items" in command_param, "Array parameter must have 'items' field for OpenAI API" + assert command_param["items"] == {"type": "string"} + + +def test_mcp_tool_conversion_without_array(): + """Test that non-array parameters work correctly without items field.""" + tool_param = ToolParameter( + name="name", + parameter_type="string", + description="Name parameter", + required=True, + ) + + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with string parameter", + parameters={ + "name": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + items=tool_param.items, # Will be None for non-array types + ) + }, + ) + + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify basic structure + assert openai_tool["type"] == "function" + parameters = openai_tool["function"]["parameters"] + assert "name" in parameters["properties"] + + name_param = parameters["properties"]["name"] + assert name_param["type"] == "string" + # items should not be present for non-array types + assert "items" not in name_param or name_param.get("items") is None + + +def test_mcp_tool_conversion_complex_array_items(): + """Test array parameter with complex items schema (object type).""" + tool_param = ToolParameter( + name="configs", + parameter_type="array", + description="Array of configuration objects", + required=False, + items={ + "type": "object", + "properties": { + "key": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["key"], + }, + ) + + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with complex array parameter", + parameters={ + "configs": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + items=tool_param.items, + ) + }, + ) + + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify complex items schema is preserved + parameters = openai_tool["function"]["parameters"] + configs_param = parameters["properties"]["configs"] + + assert configs_param["type"] == "array" + assert "items" in configs_param + assert configs_param["items"]["type"] == "object" + assert "properties" in configs_param["items"] + assert "key" in configs_param["items"]["properties"] + assert "value" in configs_param["items"]["properties"] \ No newline at end of file From cd1f6410ceb1f90e80f490fffab46b710b5d74b9 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 15:53:13 -0700 Subject: [PATCH 15/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 1 + .../meta_reference/responses/__init__.py | 5 + .../responses/test_streaming.py | 147 ++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 tests/unit/providers/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..b6ffb1471 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -568,6 +568,7 @@ class StreamingResponseOrchestrator: description=param.description, required=param.required, default=param.default, + items=param.items, ) for param in t.parameters }, diff --git a/tests/unit/providers/agents/meta_reference/responses/__init__.py b/tests/unit/providers/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..6f3c1df03 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. \ No newline at end of file diff --git a/tests/unit/providers/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..f4bba613e --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Unit tests for MCP tool parameter conversion in streaming responses. + +This tests the fix for handling array-type parameters with 'items' field +when converting MCP tool definitions to OpenAI format. +""" + +from llama_stack.apis.tools import ToolParameter +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition +from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + +def test_mcp_tool_conversion_with_array_items(): + """ + Test that MCP tool parameters with array type and items field are properly converted. + + This is a regression test for the bug where array parameters without 'items' + caused OpenAI API validation errors like: + "Invalid schema for function 'pods_exec': In context=('properties', 'command'), + array schema missing items." + """ + # Create a tool parameter with array type and items specification + # This mimics what kubernetes-mcp-server's pods_exec tool has + tool_param = ToolParameter( + name="command", + parameter_type="array", + description="Command to execute in the pod", + required=True, + items={"type": "string"}, # This is the crucial field + ) + + # Convert to ToolDefinition format (as done in streaming.py) + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with array parameter", + parameters={ + "command": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + default=tool_param.default, + items=tool_param.items, # The fix: ensure items is passed through + ) + }, + ) + + # Convert to OpenAI format + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify the conversion includes the items field + assert openai_tool["type"] == "function" + assert openai_tool["function"]["name"] == "test_tool" + assert "parameters" in openai_tool["function"] + + parameters = openai_tool["function"]["parameters"] + assert "properties" in parameters + assert "command" in parameters["properties"] + + command_param = parameters["properties"]["command"] + assert command_param["type"] == "array" + assert "items" in command_param, "Array parameter must have 'items' field for OpenAI API" + assert command_param["items"] == {"type": "string"} + + +def test_mcp_tool_conversion_without_array(): + """Test that non-array parameters work correctly without items field.""" + tool_param = ToolParameter( + name="name", + parameter_type="string", + description="Name parameter", + required=True, + ) + + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with string parameter", + parameters={ + "name": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + items=tool_param.items, # Will be None for non-array types + ) + }, + ) + + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify basic structure + assert openai_tool["type"] == "function" + parameters = openai_tool["function"]["parameters"] + assert "name" in parameters["properties"] + + name_param = parameters["properties"]["name"] + assert name_param["type"] == "string" + # items should not be present for non-array types + assert "items" not in name_param or name_param.get("items") is None + + +def test_mcp_tool_conversion_complex_array_items(): + """Test array parameter with complex items schema (object type).""" + tool_param = ToolParameter( + name="configs", + parameter_type="array", + description="Array of configuration objects", + required=False, + items={ + "type": "object", + "properties": { + "key": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["key"], + }, + ) + + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with complex array parameter", + parameters={ + "configs": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + items=tool_param.items, + ) + }, + ) + + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify complex items schema is preserved + parameters = openai_tool["function"]["parameters"] + configs_param = parameters["properties"]["configs"] + + assert configs_param["type"] == "array" + assert "items" in configs_param + assert configs_param["items"]["type"] == "object" + assert "properties" in configs_param["items"] + assert "key" in configs_param["items"]["properties"] + assert "value" in configs_param["items"]["properties"] From 1b308fd87237ee3bc45431c4f7d28d36551a0d95 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 15:53:13 -0700 Subject: [PATCH 16/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 1 + .../meta_reference/responses/__init__.py | 5 + .../responses/test_streaming.py | 147 ++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 tests/unit/providers/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..b6ffb1471 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -568,6 +568,7 @@ class StreamingResponseOrchestrator: description=param.description, required=param.required, default=param.default, + items=param.items, ) for param in t.parameters }, diff --git a/tests/unit/providers/agents/meta_reference/responses/__init__.py b/tests/unit/providers/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..f4bba613e --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Unit tests for MCP tool parameter conversion in streaming responses. + +This tests the fix for handling array-type parameters with 'items' field +when converting MCP tool definitions to OpenAI format. +""" + +from llama_stack.apis.tools import ToolParameter +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition +from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + +def test_mcp_tool_conversion_with_array_items(): + """ + Test that MCP tool parameters with array type and items field are properly converted. + + This is a regression test for the bug where array parameters without 'items' + caused OpenAI API validation errors like: + "Invalid schema for function 'pods_exec': In context=('properties', 'command'), + array schema missing items." + """ + # Create a tool parameter with array type and items specification + # This mimics what kubernetes-mcp-server's pods_exec tool has + tool_param = ToolParameter( + name="command", + parameter_type="array", + description="Command to execute in the pod", + required=True, + items={"type": "string"}, # This is the crucial field + ) + + # Convert to ToolDefinition format (as done in streaming.py) + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with array parameter", + parameters={ + "command": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + default=tool_param.default, + items=tool_param.items, # The fix: ensure items is passed through + ) + }, + ) + + # Convert to OpenAI format + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify the conversion includes the items field + assert openai_tool["type"] == "function" + assert openai_tool["function"]["name"] == "test_tool" + assert "parameters" in openai_tool["function"] + + parameters = openai_tool["function"]["parameters"] + assert "properties" in parameters + assert "command" in parameters["properties"] + + command_param = parameters["properties"]["command"] + assert command_param["type"] == "array" + assert "items" in command_param, "Array parameter must have 'items' field for OpenAI API" + assert command_param["items"] == {"type": "string"} + + +def test_mcp_tool_conversion_without_array(): + """Test that non-array parameters work correctly without items field.""" + tool_param = ToolParameter( + name="name", + parameter_type="string", + description="Name parameter", + required=True, + ) + + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with string parameter", + parameters={ + "name": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + items=tool_param.items, # Will be None for non-array types + ) + }, + ) + + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify basic structure + assert openai_tool["type"] == "function" + parameters = openai_tool["function"]["parameters"] + assert "name" in parameters["properties"] + + name_param = parameters["properties"]["name"] + assert name_param["type"] == "string" + # items should not be present for non-array types + assert "items" not in name_param or name_param.get("items") is None + + +def test_mcp_tool_conversion_complex_array_items(): + """Test array parameter with complex items schema (object type).""" + tool_param = ToolParameter( + name="configs", + parameter_type="array", + description="Array of configuration objects", + required=False, + items={ + "type": "object", + "properties": { + "key": {"type": "string"}, + "value": {"type": "string"}, + }, + "required": ["key"], + }, + ) + + tool_def = ToolDefinition( + tool_name="test_tool", + description="Test tool with complex array parameter", + parameters={ + "configs": ToolParamDefinition( + param_type=tool_param.parameter_type, + description=tool_param.description, + required=tool_param.required, + items=tool_param.items, + ) + }, + ) + + openai_tool = convert_tooldef_to_openai_tool(tool_def) + + # Verify complex items schema is preserved + parameters = openai_tool["function"]["parameters"] + configs_param = parameters["properties"]["configs"] + + assert configs_param["type"] == "array" + assert "items" in configs_param + assert configs_param["items"]["type"] == "object" + assert "properties" in configs_param["items"] + assert "key" in configs_param["items"]["properties"] + assert "value" in configs_param["items"]["properties"] From fad9f6c4c9f2487a867c1434624ac37838dc6d84 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 22:11:20 -0700 Subject: [PATCH 17/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 49 ++++++++++++------- .../meta_reference/responses/__init__.py | 5 ++ .../responses/test_streaming.py | 42 ++++++++++++++++ 3 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..059d240f1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -50,6 +50,37 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal logger = get_logger(name=__name__, category="agents::meta_reference") +def convert_tooldef_to_chat_tool(tool_def): + """Convert a ToolDef to OpenAI ChatCompletionToolParam format. + + Args: + tool_def: ToolDef from the tools API + + Returns: + ChatCompletionToolParam suitable for OpenAI chat completion + """ + from openai.types.chat import ChatCompletionToolParam + + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + internal_tool_def = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + items=param.items, + ) + for param in tool_def.parameters + }, + ) + return convert_tooldef_to_openai_tool(internal_tool_def) + + class StreamingResponseOrchestrator: def __init__( self, @@ -556,23 +587,7 @@ class StreamingResponseOrchestrator: continue if not always_allowed or t.name in always_allowed: # Add to chat tools for inference - from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition - from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool - - tool_def = ToolDefinition( - tool_name=t.name, - description=t.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in t.parameters - }, - ) - openai_tool = convert_tooldef_to_openai_tool(tool_def) + openai_tool = convert_tooldef_to_chat_tool(t) if self.ctx.chat_tools is None: self.ctx.chat_tools = [] self.ctx.chat_tools.append(openai_tool) diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..6f3c1df03 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. \ No newline at end of file diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..6fda2b508 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.tools import ToolDef, ToolParameter +from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( + convert_tooldef_to_chat_tool, +) + + +def test_convert_tooldef_to_chat_tool_preserves_items_field(): + """Test that array parameters preserve the items field during conversion. + + This test ensures that when converting ToolDef with array-type parameters + to OpenAI ChatCompletionToolParam format, the 'items' field is preserved. + Without this fix, array parameters would be missing schema information about their items. + """ + tool_def = ToolDef( + name="test_tool", + description="A test tool with array parameter", + parameters=[ + ToolParameter( + name="tags", + parameter_type="array", + description="List of tags", + required=True, + items={"type": "string"}, + ) + ], + ) + + result = convert_tooldef_to_chat_tool(tool_def) + + assert result["type"] == "function" + assert result["function"]["name"] == "test_tool" + + tags_param = result["function"]["parameters"]["properties"]["tags"] + assert tags_param["type"] == "array" + assert "items" in tags_param, "items field should be preserved for array parameters" + assert tags_param["items"] == {"type": "string"} From be97c9f9dfc184653cad974b3736da63af27ad24 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 22:11:20 -0700 Subject: [PATCH 18/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 48 ++++++++++++------- .../meta_reference/responses/__init__.py | 5 ++ .../responses/test_streaming.py | 42 ++++++++++++++++ 3 files changed, 78 insertions(+), 17 deletions(-) create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..2f45ad2a3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -50,6 +50,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal logger = get_logger(name=__name__, category="agents::meta_reference") +def convert_tooldef_to_chat_tool(tool_def): + """Convert a ToolDef to OpenAI ChatCompletionToolParam format. + + Args: + tool_def: ToolDef from the tools API + + Returns: + ChatCompletionToolParam suitable for OpenAI chat completion + """ + + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + internal_tool_def = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + items=param.items, + ) + for param in tool_def.parameters + }, + ) + return convert_tooldef_to_openai_tool(internal_tool_def) + + class StreamingResponseOrchestrator: def __init__( self, @@ -556,23 +586,7 @@ class StreamingResponseOrchestrator: continue if not always_allowed or t.name in always_allowed: # Add to chat tools for inference - from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition - from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool - - tool_def = ToolDefinition( - tool_name=t.name, - description=t.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in t.parameters - }, - ) - openai_tool = convert_tooldef_to_openai_tool(tool_def) + openai_tool = convert_tooldef_to_chat_tool(t) if self.ctx.chat_tools is None: self.ctx.chat_tools = [] self.ctx.chat_tools.append(openai_tool) diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..6fda2b508 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.tools import ToolDef, ToolParameter +from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( + convert_tooldef_to_chat_tool, +) + + +def test_convert_tooldef_to_chat_tool_preserves_items_field(): + """Test that array parameters preserve the items field during conversion. + + This test ensures that when converting ToolDef with array-type parameters + to OpenAI ChatCompletionToolParam format, the 'items' field is preserved. + Without this fix, array parameters would be missing schema information about their items. + """ + tool_def = ToolDef( + name="test_tool", + description="A test tool with array parameter", + parameters=[ + ToolParameter( + name="tags", + parameter_type="array", + description="List of tags", + required=True, + items={"type": "string"}, + ) + ], + ) + + result = convert_tooldef_to_chat_tool(tool_def) + + assert result["type"] == "function" + assert result["function"]["name"] == "test_tool" + + tags_param = result["function"]["parameters"]["properties"]["tags"] + assert tags_param["type"] == "array" + assert "items" in tags_param, "items field should be preserved for array parameters" + assert tags_param["items"] == {"type": "string"} From d87459790814f5cf87a4e2f33cb25e4dc73317d7 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 22:39:06 -0700 Subject: [PATCH 19/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 48 ++++++++++++------- tests/unit/providers/inline/__init__.py | 0 .../unit/providers/inline/agents/__init__.py | 0 .../inline/agents/meta_reference/__init__.py | 0 .../meta_reference/responses/__init__.py | 5 ++ .../responses/test_streaming.py | 42 ++++++++++++++++ 6 files changed, 78 insertions(+), 17 deletions(-) create mode 100644 tests/unit/providers/inline/__init__.py create mode 100644 tests/unit/providers/inline/agents/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..2f45ad2a3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -50,6 +50,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal logger = get_logger(name=__name__, category="agents::meta_reference") +def convert_tooldef_to_chat_tool(tool_def): + """Convert a ToolDef to OpenAI ChatCompletionToolParam format. + + Args: + tool_def: ToolDef from the tools API + + Returns: + ChatCompletionToolParam suitable for OpenAI chat completion + """ + + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + internal_tool_def = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + items=param.items, + ) + for param in tool_def.parameters + }, + ) + return convert_tooldef_to_openai_tool(internal_tool_def) + + class StreamingResponseOrchestrator: def __init__( self, @@ -556,23 +586,7 @@ class StreamingResponseOrchestrator: continue if not always_allowed or t.name in always_allowed: # Add to chat tools for inference - from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition - from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool - - tool_def = ToolDefinition( - tool_name=t.name, - description=t.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in t.parameters - }, - ) - openai_tool = convert_tooldef_to_openai_tool(tool_def) + openai_tool = convert_tooldef_to_chat_tool(t) if self.ctx.chat_tools is None: self.ctx.chat_tools = [] self.ctx.chat_tools.append(openai_tool) diff --git a/tests/unit/providers/inline/__init__.py b/tests/unit/providers/inline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/providers/inline/agents/__init__.py b/tests/unit/providers/inline/agents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/providers/inline/agents/meta_reference/__init__.py b/tests/unit/providers/inline/agents/meta_reference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..6fda2b508 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.tools import ToolDef, ToolParameter +from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( + convert_tooldef_to_chat_tool, +) + + +def test_convert_tooldef_to_chat_tool_preserves_items_field(): + """Test that array parameters preserve the items field during conversion. + + This test ensures that when converting ToolDef with array-type parameters + to OpenAI ChatCompletionToolParam format, the 'items' field is preserved. + Without this fix, array parameters would be missing schema information about their items. + """ + tool_def = ToolDef( + name="test_tool", + description="A test tool with array parameter", + parameters=[ + ToolParameter( + name="tags", + parameter_type="array", + description="List of tags", + required=True, + items={"type": "string"}, + ) + ], + ) + + result = convert_tooldef_to_chat_tool(tool_def) + + assert result["type"] == "function" + assert result["function"]["name"] == "test_tool" + + tags_param = result["function"]["parameters"]["properties"]["tags"] + assert tags_param["type"] == "array" + assert "items" in tags_param, "items field should be preserved for array parameters" + assert tags_param["items"] == {"type": "string"} From b2694a362055ea79afc9db35bb42208fd34b6e52 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 29 Sep 2025 22:43:20 -0700 Subject: [PATCH 20/20] fix: mcp tool with array type should include items # What does this PR do? ## Test Plan --- .../meta_reference/responses/streaming.py | 48 ++++++++++++------- tests/unit/providers/inline/__init__.py | 6 +++ .../unit/providers/inline/agents/__init__.py | 6 +++ .../inline/agents/meta_reference/__init__.py | 6 +++ .../meta_reference/responses/__init__.py | 5 ++ .../responses/test_streaming.py | 42 ++++++++++++++++ 6 files changed, 96 insertions(+), 17 deletions(-) create mode 100644 tests/unit/providers/inline/__init__.py create mode 100644 tests/unit/providers/inline/agents/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/__init__.py create mode 100644 tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 3e69fa5cd..2f45ad2a3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -50,6 +50,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal logger = get_logger(name=__name__, category="agents::meta_reference") +def convert_tooldef_to_chat_tool(tool_def): + """Convert a ToolDef to OpenAI ChatCompletionToolParam format. + + Args: + tool_def: ToolDef from the tools API + + Returns: + ChatCompletionToolParam suitable for OpenAI chat completion + """ + + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + internal_tool_def = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + items=param.items, + ) + for param in tool_def.parameters + }, + ) + return convert_tooldef_to_openai_tool(internal_tool_def) + + class StreamingResponseOrchestrator: def __init__( self, @@ -556,23 +586,7 @@ class StreamingResponseOrchestrator: continue if not always_allowed or t.name in always_allowed: # Add to chat tools for inference - from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition - from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool - - tool_def = ToolDefinition( - tool_name=t.name, - description=t.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in t.parameters - }, - ) - openai_tool = convert_tooldef_to_openai_tool(tool_def) + openai_tool = convert_tooldef_to_chat_tool(t) if self.ctx.chat_tools is None: self.ctx.chat_tools = [] self.ctx.chat_tools.append(openai_tool) diff --git a/tests/unit/providers/inline/__init__.py b/tests/unit/providers/inline/__init__.py new file mode 100644 index 000000000..d4a3e15c8 --- /dev/null +++ b/tests/unit/providers/inline/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + diff --git a/tests/unit/providers/inline/agents/__init__.py b/tests/unit/providers/inline/agents/__init__.py new file mode 100644 index 000000000..d4a3e15c8 --- /dev/null +++ b/tests/unit/providers/inline/agents/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + diff --git a/tests/unit/providers/inline/agents/meta_reference/__init__.py b/tests/unit/providers/inline/agents/meta_reference/__init__.py new file mode 100644 index 000000000..d4a3e15c8 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py new file mode 100644 index 000000000..6fda2b508 --- /dev/null +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.tools import ToolDef, ToolParameter +from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( + convert_tooldef_to_chat_tool, +) + + +def test_convert_tooldef_to_chat_tool_preserves_items_field(): + """Test that array parameters preserve the items field during conversion. + + This test ensures that when converting ToolDef with array-type parameters + to OpenAI ChatCompletionToolParam format, the 'items' field is preserved. + Without this fix, array parameters would be missing schema information about their items. + """ + tool_def = ToolDef( + name="test_tool", + description="A test tool with array parameter", + parameters=[ + ToolParameter( + name="tags", + parameter_type="array", + description="List of tags", + required=True, + items={"type": "string"}, + ) + ], + ) + + result = convert_tooldef_to_chat_tool(tool_def) + + assert result["type"] == "function" + assert result["function"]["name"] == "test_tool" + + tags_param = result["function"]["parameters"]["properties"]["tags"] + assert tags_param["type"] == "array" + assert "items" in tags_param, "items field should be preserved for array parameters" + assert tags_param["items"] == {"type": "string"}