From f44eb935c4ff110278758ff2972bd5a5fd544915 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 19 Sep 2025 16:13:56 -0700 Subject: [PATCH] chore: simplify authorized sqlstore (#3496) # What does this PR do? This PR is generated with AI and reviewed by me. Refactors the AuthorizedSqlStore class to store the access policy as an instance variable rather than passing it as a parameter to each method call. This simplifies the API. # Test Plan existing tests --- .../providers/inline/files/localfs/files.py | 5 ++--- .../providers/remote/files/s3/files.py | 5 ++--- .../utils/inference/inference_store.py | 4 +--- .../utils/responses/responses_store.py | 7 ++----- .../utils/sqlstore/authorized_sqlstore.py | 11 +++++------ .../sqlstore/test_authorized_sqlstore.py | 19 +++++++++++-------- tests/unit/utils/test_authorized_sqlstore.py | 18 +++++++++--------- 7 files changed, 32 insertions(+), 37 deletions(-) diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 9c610c1ba..65cf8d815 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files): storage_path.mkdir(parents=True, exist_ok=True) # Initialize SQL store for metadata - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy) await self.sql_store.create_table( "openai_files", { @@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "client.files.list()") @@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 54742d900..8ea96af9e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -137,7 +137,7 @@ class S3FilesImpl(Files): where: dict[str, str | dict] = {"id": file_id} if not return_expired: where["expires_at"] = {">": self._now()} - if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): raise ResourceNotFoundError(file_id, "File", "files.list()") return row @@ -164,7 +164,7 @@ class S3FilesImpl(Files): self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) await self._sql_store.create_table( "openai_files", { @@ -268,7 +268,6 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 17f4c6268..ffc9f3e11 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -54,7 +54,7 @@ class InferenceStore: async def initialize(self): """Create the necessary tables if they don't exist.""" - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "chat_completions", { @@ -202,7 +202,6 @@ class InferenceStore: order_by=[("created", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [ @@ -229,7 +228,6 @@ class InferenceStore: row = await self.sql_store.fetch_one( table="chat_completions", where={"id": completion_id}, - policy=self.policy, ) if not row: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..829cd8a62 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -28,8 +28,7 @@ class ResponsesStore: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) - self.policy = policy + self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) async def initialize(self): """Create the necessary tables if they don't exist.""" @@ -87,7 +86,6 @@ class ResponsesStore: order_by=[("created_at", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] @@ -105,7 +103,6 @@ class ResponsesStore: row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, - policy=self.policy, ) if not row: @@ -116,7 +113,7 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") await self.sql_store.delete("openai_responses", where={"id": response_id}) diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index acb688f96..ab67f7052 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -53,13 +53,15 @@ class AuthorizedSqlStore: access control policies, user attribute capture, and SQL filtering optimization. """ - def __init__(self, sql_store: SqlStore): + def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): """ Initialize the authorization layer. :param sql_store: Base SqlStore implementation to wrap + :param policy: Access control policy to use for authorization """ self.sql_store = sql_store + self.policy = policy self._detect_database_type() self._validate_sql_optimized_policy() @@ -117,14 +119,13 @@ class AuthorizedSqlStore: async def fetch_all( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, ) -> PaginatedResponse: """Fetch all rows with automatic access control filtering.""" - access_where = self._build_access_control_where_clause(policy) + access_where = self._build_access_control_where_clause(self.policy) rows = await self.sql_store.fetch_all( table=table, where=where, @@ -146,7 +147,7 @@ class AuthorizedSqlStore: str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) ) - if is_action_allowed(policy, Action.READ, sql_record, current_user): + if is_action_allowed(self.policy, Action.READ, sql_record, current_user): filtered_rows.append(row) return PaginatedResponse( @@ -157,14 +158,12 @@ class AuthorizedSqlStore: async def fetch_one( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: """Fetch one row with automatic access control checking.""" results = await self.fetch_all( table=table, - policy=policy, where=where, limit=1, order_by=order_by, diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 4002f2e1f..98bef0f2c 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -57,7 +57,7 @@ def authorized_store(backend_config): config = config_func() base_sqlstore = sqlstore_impl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) yield authorized_store @@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" assert result.data[0]["access_attributes"] is None @@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 2 # Test with non-admin user @@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz mock_get_authenticated_user.return_value = regular_user # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" @@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz # Now test with the multi-user who has both roles=admin and teams=dev mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) # Should see: # - public record (1) - no access_attributes @@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto ), ] + # Create a new authorized store with the owner-only policy + owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy) + # Test user1 access - should only see their own record mock_get_authenticated_user.return_value = user1 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" # Test user2 access - should only see their own record mock_get_authenticated_user.return_value = user2 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" # Test with anonymous user - should see no records mock_get_authenticated_user.return_value = None - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 90eb706e4..d85e784a9 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) # Create table with access control await sqlstore.create_table( @@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic mock_get_authenticated_user.return_value = admin_user # Admin should see both documents - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 1 assert result.data[0]["title"] == "Admin Document" # User should only see their document mock_get_authenticated_user.return_value = regular_user - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 0 - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) + result = await sqlstore.fetch_all("documents", where={"id": 2}) assert len(result.data) == 1 assert result.data[0]["title"] == "User Document" - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) + row = await sqlstore.fetch_one("documents", where={"id": 1}) assert row is None - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) + row = await sqlstore.fetch_one("documents", where={"id": 2}) assert row is not None assert row["title"] == "User Document" @@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) await sqlstore.create_table( table="resources", @@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): user = User(principal=user_data["principal"], attributes=user_data["attributes"]) mock_get_authenticated_user.return_value = user - sql_results = await sqlstore.fetch_all("resources", policy=policy) + sql_results = await sqlstore.fetch_all("resources") sql_ids = {row["id"] for row in sql_results.data} policy_ids = set() for scenario in test_scenarios: @@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us db_path=tmp_dir + "/" + db_name, ) ) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) await authorized_store.create_table( table="user_data",