From d880c2df0ed0d1405a5458a25309ad3b66907219 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 10 Jul 2025 14:40:32 -0700 Subject: [PATCH] fix: auth sql store: user is owner policy (#2674) # What does this PR do? The current authorized sql store implementation does not respect user.principal (only checks attributes). This PR addresses that. ## Test Plan Added test cases to integration tests. --- .../access_control/access_control.py | 2 +- .../utils/sqlstore/authorized_sqlstore.py | 54 ++-- .../utils/sqlstore/sqlalchemy_sqlstore.py | 58 ++-- .../sqlstore/test_authorized_sqlstore.py | 304 +++++++++++------- tests/unit/utils/test_authorized_sqlstore.py | 4 +- 5 files changed, 247 insertions(+), 175 deletions(-) diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/distribution/access_control/access_control.py index 075152ce4..64c0122c1 100644 --- a/llama_stack/distribution/access_control/access_control.py +++ b/llama_stack/distribution/access_control/access_control.py @@ -81,7 +81,7 @@ def is_action_allowed( if not len(policy): policy = default_policy() - qualified_resource_id = resource.type + "::" + resource.identifier + qualified_resource_id = f"{resource.type}::{resource.identifier}" for rule in policy: if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): if rule.when: diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 5dff7f122..864a7dbb6 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [ class SqlRecord(ProtectedResource): - """Simple ProtectedResource implementation for SQL records.""" - - def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None): + def __init__(self, record_id: str, table_name: str, owner: User): self.type = f"sql_record::{table_name}" self.identifier = record_id - - if access_attributes: - self.owner = User( - principal="system", - attributes=access_attributes, - ) - else: - self.owner = User( - principal="system_public", - attributes=None, - ) + self.owner = owner class AuthorizedSqlStore: @@ -101,22 +89,27 @@ class AuthorizedSqlStore: async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: """Create a table with built-in access control support.""" - await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) enhanced_schema = dict(schema) if "access_attributes" not in enhanced_schema: enhanced_schema["access_attributes"] = ColumnType.JSON + if "owner_principal" not in enhanced_schema: + enhanced_schema["owner_principal"] = ColumnType.STRING await self.sql_store.create_table(table, enhanced_schema) + await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) + await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) async def insert(self, table: str, data: Mapping[str, Any]) -> None: """Insert a row with automatic access control attribute capture.""" enhanced_data = dict(data) current_user = get_authenticated_user() - if current_user and current_user.attributes: + if current_user: + enhanced_data["owner_principal"] = current_user.principal enhanced_data["access_attributes"] = current_user.attributes else: + enhanced_data["owner_principal"] = None enhanced_data["access_attributes"] = None await self.sql_store.insert(table, enhanced_data) @@ -146,9 +139,12 @@ class AuthorizedSqlStore: for row in rows.data: stored_access_attrs = row.get("access_attributes") + stored_owner_principal = row.get("owner_principal") or "" record_id = row.get("id", "unknown") - sql_record = SqlRecord(str(record_id), table, stored_access_attrs) + sql_record = SqlRecord( + str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) + ) if is_action_allowed(policy, Action.READ, sql_record, current_user): filtered_rows.append(row) @@ -186,8 +182,10 @@ class AuthorizedSqlStore: Only applies SQL filtering for the default policy to ensure correctness. For custom policies, uses conservative filtering to avoid blocking legitimate access. """ + current_user = get_authenticated_user() + if not policy or policy == SQL_OPTIMIZED_POLICY: - return self._build_default_policy_where_clause() + return self._build_default_policy_where_clause(current_user) else: return self._build_conservative_where_clause() @@ -227,29 +225,27 @@ class AuthorizedSqlStore: def _get_public_access_conditions(self) -> list[str]: """Get the SQL conditions for public access.""" + # Public records are records that have no owner_principal or access_attributes + conditions = ["owner_principal = ''"] if self.database_type == SqlStoreType.postgres: # Postgres stores JSON null as 'null' - return ["access_attributes::text = 'null'"] + conditions.append("access_attributes::text = 'null'") elif self.database_type == SqlStoreType.sqlite: - return ["access_attributes = 'null'"] + conditions.append("access_attributes = 'null'") else: raise ValueError(f"Unsupported database type: {self.database_type}") + return conditions - def _build_default_policy_where_clause(self) -> str: + def _build_default_policy_where_clause(self, current_user: User | None) -> str: """Build SQL WHERE clause for the default policy. Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] This means user must match ALL attribute categories that exist in the resource. """ - current_user = get_authenticated_user() - base_conditions = self._get_public_access_conditions() - if not current_user or not current_user.attributes: - # Only allow public records - return f"({' OR '.join(base_conditions)})" - else: - user_attr_conditions = [] + user_attr_conditions = [] + if current_user and current_user.attributes: for attr_key, user_values in current_user.attributes.items(): if user_values: value_conditions = [] @@ -269,7 +265,7 @@ class AuthorizedSqlStore: all_requirements_met = f"({' AND '.join(user_attr_conditions)})" base_conditions.append(all_requirements_met) - return f"({' OR '.join(base_conditions)})" + return f"({' OR '.join(base_conditions)})" def _build_conservative_where_clause(self) -> str: """Conservative SQL filtering for custom policies. diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 3aecb0d59..6414929db 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore): engine = create_async_engine(self.config.engine_str) try: - inspector = inspect(engine) - - table_names = inspector.get_table_names() - if table not in table_names: - return - - existing_columns = inspector.get_columns(table) - column_names = [col["name"] for col in existing_columns] - - if column_name in column_names: - return - - sqlalchemy_type = TYPE_MAPPING.get(column_type) - if not sqlalchemy_type: - raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") - - # Create the ALTER TABLE statement - # Note: We need to get the dialect-specific type name - dialect = engine.dialect - type_impl = sqlalchemy_type() - compiled_type = type_impl.compile(dialect=dialect) - - nullable_clause = "" if nullable else " NOT NULL" - add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") - async with engine.begin() as conn: + + def check_column_exists(sync_conn): + inspector = inspect(sync_conn) + + table_names = inspector.get_table_names() + if table not in table_names: + return False, False # table doesn't exist, column doesn't exist + + existing_columns = inspector.get_columns(table) + column_names = [col["name"] for col in existing_columns] + + return True, column_name in column_names # table exists, column exists or not + + table_exists, column_exists = await conn.run_sync(check_column_exists) + if not table_exists or column_exists: + return + + sqlalchemy_type = TYPE_MAPPING.get(column_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") + + # Create the ALTER TABLE statement + # Note: We need to get the dialect-specific type name + dialect = engine.dialect + type_impl = sqlalchemy_type() + compiled_type = type_impl.compile(dialect=dialect) + + nullable_clause = "" if nullable else " NOT NULL" + add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") + await conn.execute(add_column_sql) - except Exception: + except Exception as e: # If any error occurs during migration, log it but don't fail # The table creation will handle adding the column + logger.error(f"Error adding column {column_name} to table {table}: {e}") pass diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 93b4d8905..bf6077532 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic from llama_stack.distribution.datatypes import User from llama_stack.providers.utils.sqlstore.api import ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl def get_postgres_config(): @@ -30,144 +29,213 @@ def get_postgres_config(): def get_sqlite_config(): - """Get SQLite configuration with temporary database.""" - tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) - tmp_file.close() - return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name + """Get SQLite configuration with temporary file database.""" + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_file.close() + return SqliteSqlStoreConfig(db_path=temp_file.name) + + +# Backend configurations for parametrized tests +BACKEND_CONFIGS = [ + pytest.param( + get_postgres_config, + marks=pytest.mark.skipif( + not os.environ.get("ENABLE_POSTGRES_TESTS"), + reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable", + ), + id="postgres", + ), + pytest.param(get_sqlite_config, id="sqlite"), +] + + +@pytest.fixture +def authorized_store(backend_config): + """Set up authorized store with proper cleanup.""" + config_func = backend_config + + config = config_func() + + base_sqlstore = sqlstore_impl(config) + authorized_store = AuthorizedSqlStore(base_sqlstore) + + yield authorized_store + + if hasattr(config, "db_path"): + try: + os.unlink(config.db_path) + except (OSError, FileNotFoundError): + pass + + +async def create_test_table(authorized_store, table_name): + """Create a test table with standard schema.""" + await authorized_store.create_table( + table=table_name, + schema={ + "id": ColumnType.STRING, + "data": ColumnType.STRING, + }, + ) + + +async def cleanup_records(sql_store, table_name, record_ids): + """Clean up test records.""" + for record_id in record_ids: + try: + await sql_store.delete(table_name, {"id": record_id}) + except Exception: + pass @pytest.mark.asyncio -@pytest.mark.parametrize( - "backend_config", - [ - pytest.param( - ("postgres", get_postgres_config), - marks=pytest.mark.skipif( - not os.environ.get("ENABLE_POSTGRES_TESTS"), - reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable", - ), - id="postgres", - ), - pytest.param(("sqlite", get_sqlite_config), id="sqlite"), - ], -) +@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") -async def test_json_comparison(mock_get_authenticated_user, backend_config): +async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request): """Test that JSON column comparisons work correctly for both PostgreSQL and SQLite""" - backend_name, config_func = backend_config + backend_name = request.node.callspec.id - # Handle different config types - if backend_name == "postgres": - config = config_func() - cleanup_path = None - else: # sqlite - config, cleanup_path = config_func() + # Create test table + table_name = f"test_json_comparison_{backend_name}" + await create_test_table(authorized_store, table_name) try: - base_sqlstore = SqlAlchemySqlStoreImpl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + # Test with no authenticated user (should handle JSON null comparison) + mock_get_authenticated_user.return_value = None - # Create test table - table_name = f"test_json_comparison_{backend_name}" - await authorized_store.create_table( - table=table_name, - schema={ - "id": ColumnType.STRING, - "data": ColumnType.STRING, - }, + # Insert some test data + await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) + + # Test fetching with no user - should not error on JSON comparison + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 1 + assert result.data[0]["id"] == "1" + assert result.data[0]["access_attributes"] is None + + # Test with authenticated user + test_user = User("test-user", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = test_user + + # Insert data with user attributes + await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) + + # Fetch all - admin should see both + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 2 + + # Test with non-admin user + regular_user = User("regular-user", {"roles": ["user"]}) + mock_get_authenticated_user.return_value = regular_user + + # Should only see public record + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 1 + assert result.data[0]["id"] == "1" + + # Test the category missing branch: user with multiple attributes + multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]}) + mock_get_authenticated_user.return_value = multi_user + + # Insert record with multi-user (has both roles and teams) + await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"}) + + # Test different user types to create records with different attribute patterns + # Record with only roles (teams category will be missing) + roles_only_user = User("roles-user", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = roles_only_user + await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"}) + + # Record with only teams (roles category will be missing) + teams_only_user = User("teams-user", {"teams": ["dev"]}) + mock_get_authenticated_user.return_value = teams_only_user + await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"}) + + # Record with different roles/teams (shouldn't match our test user) + different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]}) + mock_get_authenticated_user.return_value = different_user + await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"}) + + # Now test with the multi-user who has both roles=admin and teams=dev + mock_get_authenticated_user.return_value = multi_user + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + + # Should see: + # - public record (1) - no access_attributes + # - admin record (2) - user matches roles=admin, teams missing (allowed) + # - multi_user record (3) - user matches both roles=admin and teams=dev + # - roles_only record (4) - user matches roles=admin, teams missing (allowed) + # - teams_only record (5) - user matches teams=dev, roles missing (allowed) + # Should NOT see: + # - different_user record (6) - user doesn't match roles=user or teams=qa + expected_ids = {"1", "2", "3", "4", "5"} + actual_ids = {record["id"] for record in result.data} + assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}" + + # Verify the category missing logic specifically + # Records 4 and 5 test the "category missing" branch where one attribute category is missing + category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]} + assert category_test_ids == {"4", "5"}, ( + f"Category missing logic failed: expected 4,5 but got {category_test_ids}" ) - try: - # Test with no authenticated user (should handle JSON null comparison) - mock_get_authenticated_user.return_value = None + finally: + # Clean up records + await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"]) - # Insert some test data - await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) - # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - assert len(result.data) == 1 - assert result.data[0]["id"] == "1" - assert result.data[0]["access_attributes"] is None +@pytest.mark.asyncio +@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request): + """Test that 'user is owner' policies work correctly with record ownership""" + from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope - # Test with authenticated user - test_user = User("test-user", {"roles": ["admin"]}) - mock_get_authenticated_user.return_value = test_user + backend_name = request.node.callspec.id - # Insert data with user attributes - await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) + # Create test table + table_name = f"test_ownership_{backend_name}" + await create_test_table(authorized_store, table_name) - # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - assert len(result.data) == 2 + try: + # Test with first user who creates records + user1 = User("user1", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = user1 - # Test with non-admin user - regular_user = User("regular-user", {"roles": ["user"]}) - mock_get_authenticated_user.return_value = regular_user + # Insert a record owned by user1 + await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"}) - # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - assert len(result.data) == 1 - assert result.data[0]["id"] == "1" + # Test with second user + user2 = User("user2", {"roles": ["user"]}) + mock_get_authenticated_user.return_value = user2 - # Test the category missing branch: user with multiple attributes - multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]}) - mock_get_authenticated_user.return_value = multi_user + # Insert a record owned by user2 + await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"}) - # Insert record with multi-user (has both roles and teams) - await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"}) + # Create a policy that only allows access when user is the owner + owner_only_policy = [ + AccessRule( + permit=Scope(actions=[Action.READ]), + when=["user is owner"], + ), + ] - # Test different user types to create records with different attribute patterns - # Record with only roles (teams category will be missing) - roles_only_user = User("roles-user", {"roles": ["admin"]}) - mock_get_authenticated_user.return_value = roles_only_user - await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"}) + # Test user1 access - should only see their own record + mock_get_authenticated_user.return_value = user1 + result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" + assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" - # Record with only teams (roles category will be missing) - teams_only_user = User("teams-user", {"teams": ["dev"]}) - mock_get_authenticated_user.return_value = teams_only_user - await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"}) + # Test user2 access - should only see their own record + mock_get_authenticated_user.return_value = user2 + result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" + assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" - # Record with different roles/teams (shouldn't match our test user) - different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]}) - mock_get_authenticated_user.return_value = different_user - await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"}) - - # Now test with the multi-user who has both roles=admin and teams=dev - mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - - # Should see: - # - public record (1) - no access_attributes - # - admin record (2) - user matches roles=admin, teams missing (allowed) - # - multi_user record (3) - user matches both roles=admin and teams=dev - # - roles_only record (4) - user matches roles=admin, teams missing (allowed) - # - teams_only record (5) - user matches teams=dev, roles missing (allowed) - # Should NOT see: - # - different_user record (6) - user doesn't match roles=user or teams=qa - expected_ids = {"1", "2", "3", "4", "5"} - actual_ids = {record["id"] for record in result.data} - assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}" - - # Verify the category missing logic specifically - # Records 4 and 5 test the "category missing" branch where one attribute category is missing - category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]} - assert category_test_ids == {"4", "5"}, ( - f"Category missing logic failed: expected 4,5 but got {category_test_ids}" - ) - - finally: - # Clean up records - for record_id in ["1", "2", "3", "4", "5", "6"]: - try: - await base_sqlstore.delete(table_name, {"id": record_id}) - except Exception: - pass + # Test with anonymous user - should see no records + mock_get_authenticated_user.return_value = None + result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: - # Clean up temporary SQLite database file if needed - if cleanup_path: - try: - os.unlink(cleanup_path) - except OSError: - pass + # Clean up records + await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"]) diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 1624c0ba7..61763719a 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -153,7 +153,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): policy_ids = set() for scenario in test_scenarios: sql_record = SqlRecord( - record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"] + record_id=scenario["id"], + table_name="resources", + owner=User(principal="test-user", attributes=scenario["access_attributes"]), ) if is_action_allowed(policy, Action.READ, sql_record, user):