fix: auth sql store: user is owner policy, fix test

# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-07-09 22:32:28 -07:00
parent 81109a0f72
commit ed8f9c03b5
5 changed files with 247 additions and 175 deletions

View file

@ -81,7 +81,7 @@ def is_action_allowed(
if not len(policy): if not len(policy):
policy = default_policy() policy = default_policy()
qualified_resource_id = resource.type + "::" + resource.identifier qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy: for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when: if rule.when:

View file

@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
class SqlRecord(ProtectedResource): class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records.""" def __init__(self, record_id: str, table_name: str, owner: User):
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
self.type = f"sql_record::{table_name}" self.type = f"sql_record::{table_name}"
self.identifier = record_id self.identifier = record_id
self.owner = owner
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
class AuthorizedSqlStore: class AuthorizedSqlStore:
@ -101,22 +89,27 @@ class AuthorizedSqlStore:
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support.""" """Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema) enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema: if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema) await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture.""" """Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data) enhanced_data = dict(data)
current_user = get_authenticated_user() current_user = get_authenticated_user()
if current_user and current_user.attributes: if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes enhanced_data["access_attributes"] = current_user.attributes
else: else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data) await self.sql_store.insert(table, enhanced_data)
@ -146,9 +139,12 @@ class AuthorizedSqlStore:
for row in rows.data: for row in rows.data:
stored_access_attrs = row.get("access_attributes") stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown") record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs) sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
@ -186,8 +182,10 @@ class AuthorizedSqlStore:
Only applies SQL filtering for the default policy to ensure correctness. Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access. For custom policies, uses conservative filtering to avoid blocking legitimate access.
""" """
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY: if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause() return self._build_default_policy_where_clause(current_user)
else: else:
return self._build_conservative_where_clause() return self._build_conservative_where_clause()
@ -227,29 +225,27 @@ class AuthorizedSqlStore:
def _get_public_access_conditions(self) -> list[str]: def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access.""" """Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == SqlStoreType.postgres: if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null' # Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"] conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite: elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"] conditions.append("access_attributes = 'null'")
else: else:
raise ValueError(f"Unsupported database type: {self.database_type}") raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self) -> str: def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy. """Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource. This means user must match ALL attribute categories that exist in the resource.
""" """
current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions() base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes:
# Only allow public records
return f"({' OR '.join(base_conditions)})"
else:
user_attr_conditions = [] user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items(): for attr_key, user_values in current_user.attributes.items():
if user_values: if user_values:
value_conditions = [] value_conditions = []

View file

@ -244,16 +244,22 @@ class SqlAlchemySqlStoreImpl(SqlStore):
engine = create_async_engine(self.config.engine_str) engine = create_async_engine(self.config.engine_str)
try: try:
inspector = inspect(engine) async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names() table_names = inspector.get_table_names()
if table not in table_names: if table not in table_names:
return return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table) existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns] column_names = [col["name"] for col in existing_columns]
if column_name in column_names: return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return return
sqlalchemy_type = TYPE_MAPPING.get(column_type) sqlalchemy_type = TYPE_MAPPING.get(column_type)
@ -269,10 +275,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
nullable_clause = "" if nullable else " NOT NULL" nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn:
await conn.execute(add_column_sql) await conn.execute(add_column_sql)
except Exception: except Exception as e:
# If any error occurs during migration, log it but don't fail # If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column # The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass pass

View file

@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic
from llama_stack.distribution.datatypes import User from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
def get_postgres_config(): def get_postgres_config():
@ -30,45 +29,47 @@ def get_postgres_config():
def get_sqlite_config(): def get_sqlite_config():
"""Get SQLite configuration with temporary database.""" """Get SQLite configuration with temporary file database."""
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
tmp_file.close() temp_file.close()
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name return SqliteSqlStoreConfig(db_path=temp_file.name)
@pytest.mark.asyncio # Backend configurations for parametrized tests
@pytest.mark.parametrize( BACKEND_CONFIGS = [
"backend_config",
[
pytest.param( pytest.param(
("postgres", get_postgres_config), get_postgres_config,
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not os.environ.get("ENABLE_POSTGRES_TESTS"), not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable", reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
), ),
id="postgres", id="postgres",
), ),
pytest.param(("sqlite", get_sqlite_config), id="sqlite"), pytest.param(get_sqlite_config, id="sqlite"),
], ]
)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_json_comparison(mock_get_authenticated_user, backend_config): @pytest.fixture
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite""" def authorized_store(backend_config):
backend_name, config_func = backend_config """Set up authorized store with proper cleanup."""
config_func = backend_config
# Handle different config types
if backend_name == "postgres":
config = config_func() config = config_func()
cleanup_path = None
else: # sqlite
config, cleanup_path = config_func()
try: base_sqlstore = sqlstore_impl(config)
base_sqlstore = SqlAlchemySqlStoreImpl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore)
# Create test table yield authorized_store
table_name = f"test_json_comparison_{backend_name}"
if hasattr(config, "db_path"):
try:
os.unlink(config.db_path)
except (OSError, FileNotFoundError):
pass
async def create_test_table(authorized_store, table_name):
"""Create a test table with standard schema."""
await authorized_store.create_table( await authorized_store.create_table(
table=table_name, table=table_name,
schema={ schema={
@ -77,6 +78,27 @@ async def test_json_comparison(mock_get_authenticated_user, backend_config):
}, },
) )
async def cleanup_records(sql_store, table_name, record_ids):
"""Clean up test records."""
for record_id in record_ids:
try:
await sql_store.delete(table_name, {"id": record_id})
except Exception:
pass
@pytest.mark.asyncio
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name = request.node.callspec.id
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await create_test_table(authorized_store, table_name)
try: try:
# Test with no authenticated user (should handle JSON null comparison) # Test with no authenticated user (should handle JSON null comparison)
mock_get_authenticated_user.return_value = None mock_get_authenticated_user.return_value = None
@ -158,16 +180,62 @@ async def test_json_comparison(mock_get_authenticated_user, backend_config):
finally: finally:
# Clean up records # Clean up records
for record_id in ["1", "2", "3", "4", "5", "6"]: await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
@pytest.mark.asyncio
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
"""Test that 'user is owner' policies work correctly with record ownership"""
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
backend_name = request.node.callspec.id
# Create test table
table_name = f"test_ownership_{backend_name}"
await create_test_table(authorized_store, table_name)
try: try:
await base_sqlstore.delete(table_name, {"id": record_id}) # Test with first user who creates records
except Exception: user1 = User("user1", {"roles": ["admin"]})
pass mock_get_authenticated_user.return_value = user1
# Insert a record owned by user1
await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"})
# Test with second user
user2 = User("user2", {"roles": ["user"]})
mock_get_authenticated_user.return_value = user2
# Insert a record owned by user2
await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"})
# Create a policy that only allows access when user is the owner
owner_only_policy = [
AccessRule(
permit=Scope(actions=[Action.READ]),
when=["user is owner"],
),
]
# Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally: finally:
# Clean up temporary SQLite database file if needed # Clean up records
if cleanup_path: await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"])
try:
os.unlink(cleanup_path)
except OSError:
pass

View file

@ -153,7 +153,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
policy_ids = set() policy_ids = set()
for scenario in test_scenarios: for scenario in test_scenarios:
sql_record = SqlRecord( sql_record = SqlRecord(
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"] record_id=scenario["id"],
table_name="resources",
owner=User(principal="test-user", attributes=scenario["access_attributes"]),
) )
if is_action_allowed(policy, Action.READ, sql_record, user): if is_action_allowed(policy, Action.READ, sql_record, user):