diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 693f224c3..17832246d 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -108,3 +108,8 @@ class QuotaMiddleware: ) body = json.dumps({"error": {"message": message}}).encode() await send({"type": "http.response.body", "body": body}) + + async def close(self): + """Close the KV store connection.""" + if self.kv is not None: + await self.kv.close() diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0c37b05bc..a67d8ade9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -315,6 +315,7 @@ class MetaReferenceAgentsImpl(Agents): async def shutdown(self) -> None: await self.persistence_store.close() + await self.responses_store.shutdown() # OpenAI responses async def get_openai_response( diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index a76b982ce..b48975702 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -62,7 +62,8 @@ class LocalfsFilesImpl(Files): ) async def shutdown(self) -> None: - pass + if self.sql_store: + await self.sql_store.close() def _generate_file_id(self) -> str: """Generate a unique file ID for OpenAI API.""" diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index c0e9f81d6..938f6142d 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -181,7 +181,8 @@ class S3FilesImpl(Files): ) async def shutdown(self) -> None: - pass + if self._sql_store: + await self._sql_store.close() @property def client(self) -> boto3.client: diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 901f77c67..44ab8c0ce 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -74,19 +74,21 @@ class InferenceStore: logger.info("Write queue disabled for SQLite to avoid concurrency issues") async def shutdown(self) -> None: - if not self._worker_tasks: - return - if self._queue is not None: - await self._queue.join() - for t in self._worker_tasks: - if not t.done(): - t.cancel() - for t in self._worker_tasks: - try: - await t - except asyncio.CancelledError: - pass - self._worker_tasks.clear() + if self._worker_tasks: + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + if self.sql_store: + await self.sql_store.close() async def flush(self) -> None: """Wait for all queued writes to complete. Useful for testing.""" diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index e610a1ba2..80b17d116 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -96,19 +96,21 @@ class ResponsesStore: logger.info("Write queue disabled for SQLite to avoid concurrency issues") async def shutdown(self) -> None: - if not self._worker_tasks: - return - if self._queue is not None: - await self._queue.join() - for t in self._worker_tasks: - if not t.done(): - t.cancel() - for t in self._worker_tasks: - try: - await t - except asyncio.CancelledError: - pass - self._worker_tasks.clear() + if self._worker_tasks: + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + if self.sql_store: + await self.sql_store.close() async def flush(self) -> None: """Wait for all queued writes to complete. Useful for testing.""" diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py index a61fd1090..9061a2ead 100644 --- a/llama_stack/providers/utils/sqlstore/api.py +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -126,3 +126,9 @@ class SqlStore(Protocol): :param nullable: Whether the column should be nullable (default: True) """ pass + + async def close(self) -> None: + """ + Close any persistent database connections. + """ + pass diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index e1da4db6e..373deeb23 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -197,6 +197,10 @@ class AuthorizedSqlStore: """Delete rows with automatic access control filtering.""" await self.sql_store.delete(table, where) + async def close(self) -> None: + """Close the underlying SQL store connection.""" + await self.sql_store.close() + def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str: """Build SQL WHERE clause for access control filtering. diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 23cd6444e..088b1c554 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -311,3 +311,11 @@ class SqlAlchemySqlStoreImpl(SqlStore): # The table creation will handle adding the column logger.error(f"Error adding column {column_name} to table {table}: {e}") pass + + async def close(self) -> None: + """Close the database engine and all connections.""" + if hasattr(self, "async_session"): + # Get the engine from the session maker + engine = self.async_session.kw.get("bind") + if engine: + await engine.dispose() diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index e14e033b9..b227d69de 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -43,6 +43,7 @@ async def files_provider(tmp_path): provider = LocalfsFilesImpl(config, default_policy()) await provider.initialize() yield provider + await provider.shutdown() @pytest.fixture diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index b2c619e49..94c10f6bc 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -28,3 +28,5 @@ async def temp_prompt_store(tmp_path_factory): store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) yield store + + await store.kvstore.close() diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py index 85acbc66a..d3c569049 100644 --- a/tests/unit/server/test_quota.py +++ b/tests/unit/server/test_quota.py @@ -52,17 +52,21 @@ def auth_app(tmp_path, request): db_path = tmp_path / f"quota_{request.node.name}.db" quota = build_quota_config(db_path) - app = InjectClientIDMiddleware( - QuotaMiddleware( - inner_app, - kv_config=quota.kvstore, - anonymous_max_requests=quota.anonymous_max_requests, - authenticated_max_requests=quota.authenticated_max_requests, - window_seconds=86400, - ), - client_id=f"client_{request.node.name}", + quota_middleware = QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, ) - return app + app = InjectClientIDMiddleware(quota_middleware, client_id=f"client_{request.node.name}") + + yield app + + # Cleanup + import asyncio + + asyncio.run(quota_middleware.close()) def test_authenticated_quota_allows_up_to_limit(auth_app): @@ -81,6 +85,8 @@ def test_authenticated_quota_blocks_after_limit(auth_app): def test_anonymous_quota_allows_up_to_limit(tmp_path, request): + import asyncio + inner_app = FastAPI() @inner_app.get("/test") @@ -101,8 +107,12 @@ def test_anonymous_quota_allows_up_to_limit(tmp_path, request): client = TestClient(app) assert client.get("/test").status_code == 200 + asyncio.run(app.close()) + def test_anonymous_quota_blocks_after_limit(tmp_path, request): + import asyncio + inner_app = FastAPI() @inner_app.get("/test") @@ -125,3 +135,5 @@ def test_anonymous_quota_blocks_after_limit(tmp_path, request): resp = client.get("/test") assert resp.status_code == 429 assert resp.json()["error"]["message"] == "Quota exceeded" + + asyncio.run(app.close()) diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index f6d63490a..4bea03b88 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -89,6 +89,8 @@ async def test_inference_store_pagination_basic(): assert result3.data[0].id == "zebra-task" assert result3.has_more is False + await store.sql_store.close() + async def test_inference_store_pagination_ascending(): """Test pagination with ascending order.""" @@ -126,6 +128,8 @@ async def test_inference_store_pagination_ascending(): assert result2.data[0].id == "charlie-task" assert result2.has_more is True + await store.sql_store.close() + async def test_inference_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" @@ -166,6 +170,8 @@ async def test_inference_store_pagination_with_model_filter(): assert result2.data[0].model == "model-a" assert result2.has_more is False + await store.sql_store.close() + async def test_inference_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" @@ -178,6 +184,8 @@ async def test_inference_store_pagination_invalid_after(): with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"): await store.list_chat_completions(after="non-existent", limit=2) + await store.sql_store.close() + async def test_inference_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" @@ -208,3 +216,5 @@ async def test_inference_store_pagination_no_limit(): assert result.data[0].id == "beta-second" # Most recent first assert result.data[1].id == "omega-first" assert result.has_more is False + + await store.sql_store.close() diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index c27b5a8e5..aa5c1a7e8 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -98,6 +98,8 @@ async def test_responses_store_pagination_basic(): assert result3.data[0].id == "zebra-resp" assert result3.has_more is False + await store.sql_store.close() + async def test_responses_store_pagination_ascending(): """Test pagination with ascending order.""" @@ -136,6 +138,8 @@ async def test_responses_store_pagination_ascending(): assert result2.data[0].id == "charlie-resp" assert result2.has_more is True + await store.sql_store.close() + async def test_responses_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" @@ -177,6 +181,8 @@ async def test_responses_store_pagination_with_model_filter(): assert result2.data[0].model == "model-a" assert result2.has_more is False + await store.sql_store.close() + async def test_responses_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" @@ -189,6 +195,8 @@ async def test_responses_store_pagination_invalid_after(): with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"): await store.list_responses(after="non-existent", limit=2) + await store.sql_store.close() + async def test_responses_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" @@ -221,6 +229,8 @@ async def test_responses_store_pagination_no_limit(): assert result.data[1].id == "omega-resp" assert result.has_more is False + await store.sql_store.close() + async def test_responses_store_get_response_object(): """Test retrieving a single response object.""" @@ -249,6 +259,8 @@ async def test_responses_store_get_response_object(): with pytest.raises(ValueError, match="Response with id non-existent not found"): await store.get_response_object("non-existent") + await store.sql_store.close() + async def test_responses_store_input_items_pagination(): """Test pagination functionality for input items.""" @@ -330,6 +342,8 @@ async def test_responses_store_input_items_pagination(): with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"): await store.list_response_input_items("test-resp", before="some-id", after="other-id") + await store.sql_store.close() + async def test_responses_store_input_items_before_pagination(): """Test before pagination functionality for input items.""" @@ -390,3 +404,5 @@ async def test_responses_store_input_items_before_pagination(): ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'" ): await store.list_response_input_items("test-resp-before", before="non-existent") + + await store.sql_store.close() diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index 00669b698..a68a5b681 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -64,6 +64,9 @@ async def test_sqlite_sqlstore(): assert result.data == [{"id": 12, "name": "test12"}] assert result.has_more is False + # cleanup + await sqlstore.close() + async def test_sqlstore_pagination_basic(): """Test basic pagination functionality at the SQL store level.""" @@ -128,6 +131,8 @@ async def test_sqlstore_pagination_basic(): assert result3.data[0]["id"] == "zebra" assert result3.has_more is False + await store.close() + async def test_sqlstore_pagination_with_filter(): """Test pagination with WHERE conditions.""" @@ -180,6 +185,8 @@ async def test_sqlstore_pagination_with_filter(): assert result2.data[0]["id"] == "xyz" assert result2.has_more is False + await store.close() + async def test_sqlstore_pagination_ascending_order(): """Test pagination with ascending order.""" @@ -228,6 +235,8 @@ async def test_sqlstore_pagination_ascending_order(): assert result2.data[0]["id"] == "alpha" assert result2.has_more is True + await store.close() + async def test_sqlstore_pagination_multi_column_ordering_error(): """Test that multi-column ordering raises an error when using cursor pagination.""" @@ -265,6 +274,8 @@ async def test_sqlstore_pagination_multi_column_ordering_error(): assert len(result.data) == 1 assert result.data[0]["id"] == "task1" + await store.close() + async def test_sqlstore_pagination_cursor_requires_order_by(): """Test that cursor pagination requires order_by parameter.""" @@ -282,6 +293,8 @@ async def test_sqlstore_pagination_cursor_requires_order_by(): cursor=("id", "task1"), ) + await store.close() + async def test_sqlstore_pagination_error_handling(): """Test error handling for invalid columns and cursor IDs.""" @@ -414,6 +427,8 @@ async def test_where_operator_edge_cases(): with pytest.raises(ValueError, match="Unsupported operator"): await store.fetch_all("events", where={"ts": {"!=": base}}) + await store.close() + async def test_sqlstore_pagination_custom_key_column(): """Test pagination with custom primary key column (not 'id').""" @@ -463,3 +478,5 @@ async def test_sqlstore_pagination_custom_key_column(): assert len(result2.data) == 1 assert result2.data[0]["uuid"] == "uuid-alpha" assert result2.has_more is False + + await store.close() diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index d85e784a9..8f395e616 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -77,6 +77,8 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic assert row is not None assert row["title"] == "User Document" + await base_sqlstore.close() + @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_sql_policy_consistency(mock_get_authenticated_user): @@ -163,6 +165,8 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}" ) + await base_sqlstore.close() + @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): @@ -211,3 +215,5 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us # Third item should have null attributes (no authenticated user) assert result.data[2]["id"] == "item3" assert result.data[2]["access_attributes"] is None + + await base_sqlstore.close()